aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--conn.go4
-rw-r--r--server.go22
-rw-r--r--server_modify.go3
-rw-r--r--server_test.go217
4 files changed, 242 insertions, 4 deletions
diff --git a/conn.go b/conn.go
index cd154f7..257f5d0 100644
--- a/conn.go
+++ b/conn.go
@@ -22,6 +22,8 @@ const (
MessageFinish = 3
)
+const oidStartTLS = "1.3.6.1.4.1.1466.20037"
+
type messagePacket struct {
Op int
MessageID uint64
@@ -150,7 +152,7 @@ func (l *Conn) StartTLS(config *tls.Config) error {
packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request")
packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, messageID, "MessageID"))
request := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationExtendedRequest, nil, "Start TLS")
- request.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, 0, "1.3.6.1.4.1.1466.20037", "TLS Extended Command"))
+ request.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, 0, oidStartTLS, "TLS Extended Command"))
packet.AppendChild(request)
l.Debug.PrintPacket(packet)
diff --git a/server.go b/server.go
index 071286c..cafeec5 100644
--- a/server.go
+++ b/server.go
@@ -44,7 +44,7 @@ type Closer interface {
//
type Server struct {
- Bind BindFunc
+ Bind BindFunc
Search SearchFunc
AddFns map[string]Adder
@@ -59,6 +59,9 @@ type Server struct {
EnforceLDAP bool
Stats *Stats
+ // If set, server will accept StartTLS.
+ TLSConfig *tls.Config
+
closing chan struct{}
}
@@ -307,12 +310,27 @@ handler:
server.Stats.countUnbinds(1)
break handler // simply disconnect
case ApplicationExtendedRequest:
- ldapResultCode := HandleExtendedRequest(req, boundDN, server.ExtendedFns, conn)
+ var tlsConn *tls.Conn
+ if n := len(req.Children); n == 1 || n == 2 {
+ if name := ber.DecodeString(req.Children[0].Data.Bytes()); name == oidStartTLS {
+ tlsConn = tls.Server(conn, server.TLSConfig)
+ }
+ }
+ var ldapResultCode LDAPResultCode
+ if tlsConn == nil {
+ // Wasn't an upgrade. Pass through.
+ ldapResultCode = HandleExtendedRequest(req, boundDN, server.ExtendedFns, conn)
+ } else {
+ ldapResultCode = LDAPResultSuccess
+ }
responsePacket := encodeLDAPResponse(messageID, ApplicationExtendedResponse, ldapResultCode, LDAPResultCodeMap[ldapResultCode])
if err = sendPacket(conn, responsePacket); err != nil {
log.Printf("sendPacket error %s", err.Error())
break handler
}
+ if tlsConn != nil {
+ conn = tlsConn
+ }
case ApplicationAbandonRequest:
HandleAbandonRequest(req, boundDN, server.AbandonFns, conn)
break handler
diff --git a/server_modify.go b/server_modify.go
index 56f45df..009fefd 100644
--- a/server_modify.go
+++ b/server_modify.go
@@ -1,9 +1,10 @@
package ldapserver
import (
- "github.com/mark-rushakoff/ldapserver/internal/asn1-ber"
"log"
"net"
+
+ "github.com/mark-rushakoff/ldapserver/internal/asn1-ber"
)
func HandleAddRequest(req *ber.Packet, boundDN string, fns map[string]Adder, conn net.Conn) (resultCode LDAPResultCode) {
diff --git a/server_test.go b/server_test.go
index 29df06b..991d927 100644
--- a/server_test.go
+++ b/server_test.go
@@ -2,18 +2,235 @@ package ldapserver
import (
"bytes"
+ "crypto/rsa"
"crypto/tls"
+ "crypto/x509"
+ "crypto/x509/pkix"
+ "encoding/pem"
+ "io/ioutil"
"log"
+ "math/big"
"net"
+ "os"
"os/exec"
+ "runtime"
"strings"
"testing"
"time"
+
+ "github.com/golang/go/src/crypto/rand"
)
var timeout = 400 * time.Millisecond
var serverBaseDN = "o=testers,c=test"
+type selfSignedCert struct {
+ // Path to the SSL certificates.
+ CACertPath, CertPath string
+
+ // Path to the private keys for the SSL certificates.
+ CAKeyPath, KeyPath string
+}
+
+func newSelfSignedCert() *selfSignedCert {
+ capk, err := rsa.GenerateKey(rand.Reader, 1024)
+ if err != nil {
+ panic(err)
+ }
+
+ caSerial, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128))
+ if err != nil {
+ panic(err)
+ }
+
+ caTemplate := x509.Certificate{
+ SerialNumber: caSerial,
+ NotBefore: time.Now(),
+ NotAfter: time.Now().Add(7 * 24 * time.Hour),
+
+ KeyUsage: x509.KeyUsageCertSign,
+ ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
+
+ BasicConstraintsValid: true,
+
+ Subject: pkix.Name{
+ Organization: []string{"my_test_ca"},
+ CommonName: "My Test CA",
+ },
+
+ IsCA: true,
+ }
+
+ caCert, err := x509.CreateCertificate(rand.Reader, &caTemplate, &caTemplate, capk.Public(), capk)
+ if err != nil {
+ panic(err)
+ }
+ // fmt.Printf("CA CERT\n%#v\n", caCert)
+ caCertPEM := &pem.Block{Type: "CERTIFICATE", Bytes: caCert}
+ caCertFile, err := ioutil.TempFile("", "cacert-*.pem")
+ if err != nil {
+ panic(err)
+ }
+ if err := pem.Encode(caCertFile, caCertPEM); err != nil {
+ panic(err)
+ }
+ caCertFile.Close()
+
+ caKeyFile, err := ioutil.TempFile("", "cakey-*.pem")
+ if err != nil {
+ panic(err)
+ }
+ if err := pem.Encode(caKeyFile, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(capk)}); err != nil {
+ panic(err)
+ }
+ caKeyFile.Close()
+
+ serial, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128))
+ if err != nil {
+ panic(err)
+ }
+ // Basically the same as the CA template, but its own serial, and with ip addresses and dns names.
+ template := x509.Certificate{
+ SerialNumber: serial,
+ NotBefore: time.Now(),
+ NotAfter: time.Now().Add(7 * 24 * time.Hour),
+
+ KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
+ ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
+
+ BasicConstraintsValid: true,
+
+ Subject: pkix.Name{
+ CommonName: "localhost",
+ },
+
+ IPAddresses: []net.IP{net.IPv4(127, 0, 0, 1)},
+ DNSNames: []string{"localhost"},
+ }
+
+ pk, err := rsa.GenerateKey(rand.Reader, 1024)
+ if err != nil {
+ panic(err)
+ }
+ cert, err := x509.CreateCertificate(rand.Reader, &template, &caTemplate, pk.Public(), capk)
+ if err != nil {
+ panic(err)
+ }
+ certPEM := &pem.Block{Type: "CERTIFICATE", Bytes: cert}
+ certFile, err := ioutil.TempFile("", "sslcert-*.pem")
+ if err != nil {
+ panic(err)
+ }
+ if err := pem.Encode(certFile, certPEM); err != nil {
+ panic(err)
+ }
+ certFile.Close()
+
+ keyFile, err := ioutil.TempFile("", "key-*.pem")
+ if err != nil {
+ panic(err)
+ }
+ if err := pem.Encode(keyFile, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(pk)}); err != nil {
+ panic(err)
+ }
+ keyFile.Close()
+
+ return &selfSignedCert{
+ CACertPath: caCertFile.Name(),
+ CAKeyPath: caKeyFile.Name(),
+ CertPath: certFile.Name(),
+ KeyPath: keyFile.Name(),
+ }
+}
+
+func (c *selfSignedCert) cleanup() {
+ os.RemoveAll(c.CertPath)
+ os.RemoveAll(c.CACertPath)
+ os.RemoveAll(c.KeyPath)
+ os.RemoveAll(c.CAKeyPath)
+}
+
+func (c *selfSignedCert) ClientTLSConfig() *tls.Config {
+ cert, err := ioutil.ReadFile(c.CACertPath)
+ if err != nil {
+ panic(err)
+ }
+
+ // Return a TLS config that trusts our self-generated CA.
+ pool := x509.NewCertPool()
+ if !pool.AppendCertsFromPEM(cert) {
+ panic("failed to append certificate")
+ }
+ return &tls.Config{
+ RootCAs: pool,
+ }
+}
+
+func (c *selfSignedCert) ServerTLSConfig() *tls.Config {
+ cert, err := tls.LoadX509KeyPair(c.CertPath, c.KeyPath)
+ if err != nil {
+ panic(err)
+ }
+ return &tls.Config{
+ ServerName: "localhost",
+ Certificates: []tls.Certificate{cert},
+ }
+}
+
+func TestStartTLS(t *testing.T) {
+ if runtime.GOOS == "darwin" {
+ defer func() {
+ if t.Failed() {
+ t.Logf(`NOTE: this test won't pass with the built-in Mac ldap utilities.
+Work around this by using brew install openldap, and running the test as PATH=/usr/local/opt/openldap/bin:$PATH go test.
+
+This test uses environment variables that are respected by OpenLDAP, but the Mac utilities don't let you override
+security settings through environment variables; they expect certificates to be added to the system keychain,
+which is very heavy-handed for a test like this.
+`)
+ }
+ }()
+ }
+ cert := newSelfSignedCert()
+ defer cert.cleanup()
+
+ s := NewServer()
+ defer s.Close()
+ s.Bind = BindAnonOK
+ s.Search = SearchSimple
+ s.TLSConfig = cert.ServerTLSConfig()
+
+ ln, addr := mustListen()
+ go func() {
+ if err := s.Serve(ln); err != nil {
+ t.Errorf("s.Serve failed: %s", err.Error())
+ }
+ }()
+
+ done := make(chan struct{})
+ go func() {
+ cmd := exec.Command("env",
+ "LDAPTLS_CACERT="+cert.CACertPath,
+ "ldapsearch", "-H", "ldap://"+addr, "-ZZ", "-d", "-1", "-x", "-b", "o=testers,c=test")
+ out, err := cmd.CombinedOutput()
+ if err != nil {
+ t.Error(err)
+ }
+
+ if !strings.Contains(string(out), "# numEntries: 3") || !strings.Contains(string(out), "result: 0 Success") {
+ t.Errorf("search did not succeed:\n%s", out)
+ }
+
+ close(done)
+ }()
+
+ select {
+ case <-done:
+ case <-time.After(timeout):
+ t.Error("ldapsearch command timed out")
+ }
+}
+
/////////////////////////
func TestBindAnonOK(t *testing.T) {
done := make(chan bool)