diff options
| -rw-r--r-- | conn.go | 4 | ||||
| -rw-r--r-- | server.go | 22 | ||||
| -rw-r--r-- | server_modify.go | 3 | ||||
| -rw-r--r-- | server_test.go | 217 |
4 files changed, 242 insertions, 4 deletions
@@ -22,6 +22,8 @@ const ( MessageFinish = 3 ) +const oidStartTLS = "1.3.6.1.4.1.1466.20037" + type messagePacket struct { Op int MessageID uint64 @@ -150,7 +152,7 @@ func (l *Conn) StartTLS(config *tls.Config) error { packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request") packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, messageID, "MessageID")) request := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationExtendedRequest, nil, "Start TLS") - request.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, 0, "1.3.6.1.4.1.1466.20037", "TLS Extended Command")) + request.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, 0, oidStartTLS, "TLS Extended Command")) packet.AppendChild(request) l.Debug.PrintPacket(packet) @@ -44,7 +44,7 @@ type Closer interface { // type Server struct { - Bind BindFunc + Bind BindFunc Search SearchFunc AddFns map[string]Adder @@ -59,6 +59,9 @@ type Server struct { EnforceLDAP bool Stats *Stats + // If set, server will accept StartTLS. + TLSConfig *tls.Config + closing chan struct{} } @@ -307,12 +310,27 @@ handler: server.Stats.countUnbinds(1) break handler // simply disconnect case ApplicationExtendedRequest: - ldapResultCode := HandleExtendedRequest(req, boundDN, server.ExtendedFns, conn) + var tlsConn *tls.Conn + if n := len(req.Children); n == 1 || n == 2 { + if name := ber.DecodeString(req.Children[0].Data.Bytes()); name == oidStartTLS { + tlsConn = tls.Server(conn, server.TLSConfig) + } + } + var ldapResultCode LDAPResultCode + if tlsConn == nil { + // Wasn't an upgrade. Pass through. + ldapResultCode = HandleExtendedRequest(req, boundDN, server.ExtendedFns, conn) + } else { + ldapResultCode = LDAPResultSuccess + } responsePacket := encodeLDAPResponse(messageID, ApplicationExtendedResponse, ldapResultCode, LDAPResultCodeMap[ldapResultCode]) if err = sendPacket(conn, responsePacket); err != nil { log.Printf("sendPacket error %s", err.Error()) break handler } + if tlsConn != nil { + conn = tlsConn + } case ApplicationAbandonRequest: HandleAbandonRequest(req, boundDN, server.AbandonFns, conn) break handler diff --git a/server_modify.go b/server_modify.go index 56f45df..009fefd 100644 --- a/server_modify.go +++ b/server_modify.go @@ -1,9 +1,10 @@ package ldapserver import ( - "github.com/mark-rushakoff/ldapserver/internal/asn1-ber" "log" "net" + + "github.com/mark-rushakoff/ldapserver/internal/asn1-ber" ) func HandleAddRequest(req *ber.Packet, boundDN string, fns map[string]Adder, conn net.Conn) (resultCode LDAPResultCode) { diff --git a/server_test.go b/server_test.go index 29df06b..991d927 100644 --- a/server_test.go +++ b/server_test.go @@ -2,18 +2,235 @@ package ldapserver import ( "bytes" + "crypto/rsa" "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "io/ioutil" "log" + "math/big" "net" + "os" "os/exec" + "runtime" "strings" "testing" "time" + + "github.com/golang/go/src/crypto/rand" ) var timeout = 400 * time.Millisecond var serverBaseDN = "o=testers,c=test" +type selfSignedCert struct { + // Path to the SSL certificates. + CACertPath, CertPath string + + // Path to the private keys for the SSL certificates. + CAKeyPath, KeyPath string +} + +func newSelfSignedCert() *selfSignedCert { + capk, err := rsa.GenerateKey(rand.Reader, 1024) + if err != nil { + panic(err) + } + + caSerial, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128)) + if err != nil { + panic(err) + } + + caTemplate := x509.Certificate{ + SerialNumber: caSerial, + NotBefore: time.Now(), + NotAfter: time.Now().Add(7 * 24 * time.Hour), + + KeyUsage: x509.KeyUsageCertSign, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + + BasicConstraintsValid: true, + + Subject: pkix.Name{ + Organization: []string{"my_test_ca"}, + CommonName: "My Test CA", + }, + + IsCA: true, + } + + caCert, err := x509.CreateCertificate(rand.Reader, &caTemplate, &caTemplate, capk.Public(), capk) + if err != nil { + panic(err) + } + // fmt.Printf("CA CERT\n%#v\n", caCert) + caCertPEM := &pem.Block{Type: "CERTIFICATE", Bytes: caCert} + caCertFile, err := ioutil.TempFile("", "cacert-*.pem") + if err != nil { + panic(err) + } + if err := pem.Encode(caCertFile, caCertPEM); err != nil { + panic(err) + } + caCertFile.Close() + + caKeyFile, err := ioutil.TempFile("", "cakey-*.pem") + if err != nil { + panic(err) + } + if err := pem.Encode(caKeyFile, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(capk)}); err != nil { + panic(err) + } + caKeyFile.Close() + + serial, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128)) + if err != nil { + panic(err) + } + // Basically the same as the CA template, but its own serial, and with ip addresses and dns names. + template := x509.Certificate{ + SerialNumber: serial, + NotBefore: time.Now(), + NotAfter: time.Now().Add(7 * 24 * time.Hour), + + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + + BasicConstraintsValid: true, + + Subject: pkix.Name{ + CommonName: "localhost", + }, + + IPAddresses: []net.IP{net.IPv4(127, 0, 0, 1)}, + DNSNames: []string{"localhost"}, + } + + pk, err := rsa.GenerateKey(rand.Reader, 1024) + if err != nil { + panic(err) + } + cert, err := x509.CreateCertificate(rand.Reader, &template, &caTemplate, pk.Public(), capk) + if err != nil { + panic(err) + } + certPEM := &pem.Block{Type: "CERTIFICATE", Bytes: cert} + certFile, err := ioutil.TempFile("", "sslcert-*.pem") + if err != nil { + panic(err) + } + if err := pem.Encode(certFile, certPEM); err != nil { + panic(err) + } + certFile.Close() + + keyFile, err := ioutil.TempFile("", "key-*.pem") + if err != nil { + panic(err) + } + if err := pem.Encode(keyFile, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(pk)}); err != nil { + panic(err) + } + keyFile.Close() + + return &selfSignedCert{ + CACertPath: caCertFile.Name(), + CAKeyPath: caKeyFile.Name(), + CertPath: certFile.Name(), + KeyPath: keyFile.Name(), + } +} + +func (c *selfSignedCert) cleanup() { + os.RemoveAll(c.CertPath) + os.RemoveAll(c.CACertPath) + os.RemoveAll(c.KeyPath) + os.RemoveAll(c.CAKeyPath) +} + +func (c *selfSignedCert) ClientTLSConfig() *tls.Config { + cert, err := ioutil.ReadFile(c.CACertPath) + if err != nil { + panic(err) + } + + // Return a TLS config that trusts our self-generated CA. + pool := x509.NewCertPool() + if !pool.AppendCertsFromPEM(cert) { + panic("failed to append certificate") + } + return &tls.Config{ + RootCAs: pool, + } +} + +func (c *selfSignedCert) ServerTLSConfig() *tls.Config { + cert, err := tls.LoadX509KeyPair(c.CertPath, c.KeyPath) + if err != nil { + panic(err) + } + return &tls.Config{ + ServerName: "localhost", + Certificates: []tls.Certificate{cert}, + } +} + +func TestStartTLS(t *testing.T) { + if runtime.GOOS == "darwin" { + defer func() { + if t.Failed() { + t.Logf(`NOTE: this test won't pass with the built-in Mac ldap utilities. +Work around this by using brew install openldap, and running the test as PATH=/usr/local/opt/openldap/bin:$PATH go test. + +This test uses environment variables that are respected by OpenLDAP, but the Mac utilities don't let you override +security settings through environment variables; they expect certificates to be added to the system keychain, +which is very heavy-handed for a test like this. +`) + } + }() + } + cert := newSelfSignedCert() + defer cert.cleanup() + + s := NewServer() + defer s.Close() + s.Bind = BindAnonOK + s.Search = SearchSimple + s.TLSConfig = cert.ServerTLSConfig() + + ln, addr := mustListen() + go func() { + if err := s.Serve(ln); err != nil { + t.Errorf("s.Serve failed: %s", err.Error()) + } + }() + + done := make(chan struct{}) + go func() { + cmd := exec.Command("env", + "LDAPTLS_CACERT="+cert.CACertPath, + "ldapsearch", "-H", "ldap://"+addr, "-ZZ", "-d", "-1", "-x", "-b", "o=testers,c=test") + out, err := cmd.CombinedOutput() + if err != nil { + t.Error(err) + } + + if !strings.Contains(string(out), "# numEntries: 3") || !strings.Contains(string(out), "result: 0 Success") { + t.Errorf("search did not succeed:\n%s", out) + } + + close(done) + }() + + select { + case <-done: + case <-time.After(timeout): + t.Error("ldapsearch command timed out") + } +} + ///////////////////////// func TestBindAnonOK(t *testing.T) { done := make(chan bool) |
