aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.gitignore4
-rw-r--r--LICENSE27
-rw-r--r--README.md105
-rw-r--r--bind.go99
-rw-r--r--conn.go340
-rw-r--r--control.go160
-rw-r--r--debug.go24
-rw-r--r--examples/enterprise.ldif63
-rw-r--r--examples/modify.go91
-rw-r--r--examples/proxy.go106
-rw-r--r--examples/search.go54
-rw-r--r--examples/searchSSL.go47
-rw-r--r--examples/searchTLS.go47
-rw-r--r--examples/server.go66
-rw-r--r--examples/slapd.conf67
-rw-r--r--filter.go402
-rw-r--r--filter_test.go137
-rw-r--r--ldap.go340
-rw-r--r--ldap_test.go123
-rw-r--r--modify.go162
-rw-r--r--search.go350
-rw-r--r--server.go473
-rw-r--r--server_bind.go73
-rw-r--r--server_modify.go231
-rw-r--r--server_modify_test.go191
-rw-r--r--server_search.go217
-rw-r--r--server_search_test.go453
-rw-r--r--server_test.go410
-rw-r--r--tests/add.ldif6
-rw-r--r--tests/add2.ldif6
-rw-r--r--tests/cert_DONOTUSE.pem18
-rw-r--r--tests/key_DONOTUSE.pem27
-rw-r--r--tests/modify.ldif16
-rw-r--r--tests/modify2.ldif10
34 files changed, 4945 insertions, 0 deletions
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..87275bf
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,4 @@
+examples/modify
+examples/search
+examples/searchSSL
+examples/searchTLS
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000..7448756
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,27 @@
+Copyright (c) 2012 The Go Authors. All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are
+met:
+
+ * Redistributions of source code must retain the above copyright
+notice, this list of conditions and the following disclaimer.
+ * Redistributions in binary form must reproduce the above
+copyright notice, this list of conditions and the following disclaimer
+in the documentation and/or other materials provided with the
+distribution.
+ * Neither the name of Google Inc. nor the names of its
+contributors may be used to endorse or promote products derived from
+this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
diff --git a/README.md b/README.md
new file mode 100644
index 0000000..2418eab
--- /dev/null
+++ b/README.md
@@ -0,0 +1,105 @@
+# LDAP for Golang
+
+This library provides basic LDAP v3 functionality for the GO programming language.
+
+The **client** portion is limited, but sufficient to perform LDAP authentication and directory lookups (binds and searches) against any modern LDAP server (tested with OpenLDAP and AD).
+
+The **server** portion implements Bind and Search from [RFC4510](http://tools.ietf.org/html/rfc4510), has good testing coverage, and is compatible with any LDAPv3 client. It provides the building blocks for a custom LDAP server, but you must implement the backend datastore of your choice.
+
+
+## LDAP client notes:
+
+### A simple LDAP bind operation:
+```go
+l, err := ldap.Dial("tcp", fmt.Sprintf("%s:%d", ldapServer, ldapPort))
+// be sure to add error checking!
+defer l.Close()
+err = l.Bind(user, passwd)
+if err==nil {
+ // authenticated
+} else {
+ // invalid authentication
+}
+```
+
+### A simple LDAP search operation:
+```go
+search := &SearchRequest{
+ BaseDN: "dc=example,dc=com",
+ Filter: "(objectclass=*)",
+}
+searchResults, err := l.Search(search)
+// be sure to add error checking!
+```
+
+### Implemented:
+* Connecting, binding to LDAP server
+* Searching for entries with filtering and paging controls
+* Compiling string filters to LDAP filters
+* Modify Requests / Responses
+
+### Not implemented:
+* Add, Delete, Modify DN, Compare operations
+* Most tests / benchmarks
+
+### LDAP client examples:
+* examples/search.go: **Basic client bind and search**
+* examples/searchSSL.go: **Client bind and search over SSL**
+* examples/searchTLS.go: **Client bind and search over TLS**
+* examples/modify.go: **Client modify operation**
+
+*Client library by: [mmitton](https://github.com/mmitton), with contributions from: [uavila](https://github.com/uavila), [vanackere](https://github.com/vanackere), [juju2013](https://github.com/juju2013), [johnweldon](https://github.com/johnweldon), [marcsauter](https://github.com/marcsauter), and [nmcclain](https://github.com/nmcclain)*
+
+## LDAP server notes:
+The server library is modeled after net/http - you designate handlers for the LDAP operations you want to support (Bind/Search/etc.), then start the server with ListenAndServe(). You can specify different handlers for different baseDNs - they must implement the interfaces of the operations you want to support:
+```go
+type Binder interface {
+ Bind(bindDN, bindSimplePw string, conn net.Conn) (LDAPResultCode, error)
+}
+type Searcher interface {
+ Search(boundDN string, searchReq SearchRequest, conn net.Conn) (ServerSearchResult, error)
+}
+type Closer interface {
+ Close(conn net.Conn) error
+}
+```
+
+### A basic bind-only LDAP server
+```go
+func main() {
+ s := ldap.NewServer()
+ handler := ldapHandler{}
+ s.BindFunc("", handler)
+ if err := s.ListenAndServe("localhost:389"); err != nil {
+ log.Fatal("LDAP Server Failed: %s", err.Error())
+ }
+}
+type ldapHandler struct {
+}
+func (h ldapHandler) Bind(bindDN, bindSimplePw string, conn net.Conn) (ldap.LDAPResultCode, error) {
+ if bindDN == "" && bindSimplePw == "" {
+ return ldap.LDAPResultSuccess, nil
+ }
+ return ldap.LDAPResultInvalidCredentials, nil
+}
+```
+
+* Server.EnforceLDAP: Normally, the LDAP server will return whatever results your handler provides. Set the **Server.EnforceLDAP** flag to **true** and the server will apply the LDAP **search filter**, **attributes limits**, **size/time limits**, **search scope**, and **base DN matching** to your handler's dataset. This makes it a lot simpler to write a custom LDAP server without worrying about LDAP internals.
+
+### LDAP server examples:
+* examples/server.go: **Basic LDAP authentication (bind and search only)**
+* examples/proxy.go: **Simple LDAP proxy server.**
+* server_test.go: **The _test.go files have examples of all server functions.**
+
+### Known limitations:
+
+* Golang's TLS implementation does not support SSLv2. Some old OSs require SSLv2, and are not able to connect to an LDAP server created with this library's ListenAndServeTLS() function. If you *must* support legacy (read: *insecure*) SSLv2 clients, run your LDAP server behind HAProxy.
+
+### Not implemented:
+From the server perspective, all of [RFC4510](http://tools.ietf.org/html/rfc4510) is implemented **except**:
+* 4.5.1.3. SearchRequest.derefAliases
+* 4.5.1.5. SearchRequest.timeLimit
+* 4.5.1.6. SearchRequest.typesOnly
+* 4.14. StartTLS Operation
+
+*Server library by: [nmcclain](https://github.com/nmcclain)*
diff --git a/bind.go b/bind.go
new file mode 100644
index 0000000..171a2e9
--- /dev/null
+++ b/bind.go
@@ -0,0 +1,99 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package ldap
+
+import (
+ "errors"
+
+ "github.com/nmcclain/asn1-ber"
+)
+
+func (l *Conn) Bind(username, password string) error {
+ messageID := l.nextMessageID()
+
+ packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request")
+ packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, messageID, "MessageID"))
+ bindRequest := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationBindRequest, nil, "Bind Request")
+ bindRequest.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, 3, "Version"))
+ bindRequest.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, username, "User Name"))
+ bindRequest.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, 0, password, "Password"))
+ packet.AppendChild(bindRequest)
+
+ if l.Debug {
+ ber.PrintPacket(packet)
+ }
+
+ channel, err := l.sendMessage(packet)
+ if err != nil {
+ return err
+ }
+ if channel == nil {
+ return NewError(ErrorNetwork, errors.New("ldap: could not send message"))
+ }
+ defer l.finishMessage(messageID)
+
+ packet = <-channel
+ if packet == nil {
+ return NewError(ErrorNetwork, errors.New("ldap: could not retrieve response"))
+ }
+
+ if l.Debug {
+ if err := addLDAPDescriptions(packet); err != nil {
+ return err
+ }
+ ber.PrintPacket(packet)
+ }
+
+ resultCode, resultDescription := getLDAPResultCode(packet)
+ if resultCode != 0 {
+ return NewError(resultCode, errors.New(resultDescription))
+ }
+
+ return nil
+}
+
+func (l *Conn) Unbind() error {
+ defer l.Close()
+
+ messageID := l.nextMessageID()
+
+ packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request")
+ packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, messageID, "MessageID"))
+ unbindRequest := ber.Encode(ber.ClassApplication, ber.TypePrimitive, ApplicationUnbindRequest, nil, "Unbind Request")
+ packet.AppendChild(unbindRequest)
+
+ if l.Debug {
+ ber.PrintPacket(packet)
+ }
+
+ channel, err := l.sendMessage(packet)
+ if err != nil {
+ return err
+ }
+ if channel == nil {
+ return NewError(ErrorNetwork, errors.New("ldap: could not send message"))
+ }
+ defer l.finishMessage(messageID)
+
+ packet = <-channel
+ if packet == nil {
+ return NewError(ErrorNetwork, errors.New("ldap: could not retrieve response"))
+ }
+
+ if l.Debug {
+ if err := addLDAPDescriptions(packet); err != nil {
+ return err
+ }
+ ber.PrintPacket(packet)
+ }
+
+ resultCode, resultDescription := getLDAPResultCode(packet)
+ if resultCode != 0 {
+ return NewError(resultCode, errors.New(resultDescription))
+ }
+
+ return nil
+}
+
diff --git a/conn.go b/conn.go
new file mode 100644
index 0000000..253e58e
--- /dev/null
+++ b/conn.go
@@ -0,0 +1,340 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package ldap
+
+import (
+ "crypto/tls"
+ "errors"
+ "log"
+ "net"
+ "sync"
+ "time"
+
+ "github.com/nmcclain/asn1-ber"
+)
+
+const (
+ MessageQuit = 0
+ MessageRequest = 1
+ MessageResponse = 2
+ MessageFinish = 3
+)
+
+type messagePacket struct {
+ Op int
+ MessageID uint64
+ Packet *ber.Packet
+ Channel chan *ber.Packet
+}
+
+// Conn represents an LDAP Connection
+type Conn struct {
+ conn net.Conn
+ isTLS bool
+ Debug debugging
+ chanConfirm chan bool
+ chanResults map[uint64]chan *ber.Packet
+ chanMessage chan *messagePacket
+ chanMessageID chan uint64
+ wgSender sync.WaitGroup
+ chanDone chan struct{}
+ once sync.Once
+}
+
+// Dial connects to the given address on the given network using net.Dial
+// and then returns a new Conn for the connection.
+func Dial(network, addr string) (*Conn, error) {
+ c, err := net.Dial(network, addr)
+ if err != nil {
+ return nil, NewError(ErrorNetwork, err)
+ }
+ conn := NewConn(c)
+ conn.start()
+ return conn, nil
+}
+
+// DialTimeout connects to the given address on the given network using net.DialTimeout
+// and then returns a new Conn for the connection. Acts like Dial but takes a timeout.
+func DialTimeout(network, addr string, timeout time.Duration) (*Conn, error) {
+ c, err := net.DialTimeout(network, addr, timeout)
+ if err != nil {
+ return nil, NewError(ErrorNetwork, err)
+ }
+ conn := NewConn(c)
+ conn.start()
+ return conn, nil
+}
+
+// DialTLS connects to the given address on the given network using tls.Dial
+// and then returns a new Conn for the connection.
+func DialTLS(network, addr string, config *tls.Config) (*Conn, error) {
+ c, err := tls.Dial(network, addr, config)
+ if err != nil {
+ return nil, NewError(ErrorNetwork, err)
+ }
+ conn := NewConn(c)
+ conn.isTLS = true
+ conn.start()
+ return conn, nil
+}
+
+// DialTLSDialer connects to the given address on the given network using tls.DialWithDialer
+// and then returns a new Conn for the connection.
+func DialTLSDialer(network, addr string, config *tls.Config, dialer *net.Dialer) (*Conn, error) {
+ c, err := tls.DialWithDialer(dialer, network, addr, config)
+ if err != nil {
+ return nil, NewError(ErrorNetwork, err)
+ }
+ conn := NewConn(c)
+ conn.isTLS = true
+ conn.start()
+ return conn, nil
+}
+
+// NewConn returns a new Conn using conn for network I/O.
+func NewConn(conn net.Conn) *Conn {
+ return &Conn{
+ conn: conn,
+ chanConfirm: make(chan bool),
+ chanMessageID: make(chan uint64),
+ chanMessage: make(chan *messagePacket, 10),
+ chanResults: map[uint64]chan *ber.Packet{},
+ chanDone: make(chan struct{}),
+ }
+}
+
+func (l *Conn) start() {
+ go l.reader()
+ go l.processMessages()
+}
+
+// Close closes the connection.
+func (l *Conn) Close() {
+ l.once.Do(func() {
+ close(l.chanDone)
+ l.wgSender.Wait()
+
+ l.Debug.Printf("Sending quit message and waiting for confirmation")
+ l.chanMessage <- &messagePacket{Op: MessageQuit}
+ <-l.chanConfirm
+ close(l.chanMessage)
+
+ l.Debug.Printf("Closing network connection")
+ if err := l.conn.Close(); err != nil {
+ log.Print(err)
+ }
+ })
+ <-l.chanDone
+}
+
+// Returns the next available messageID
+func (l *Conn) nextMessageID() uint64 {
+ if l.chanMessageID != nil {
+ if messageID, ok := <-l.chanMessageID; ok {
+ return messageID
+ }
+ }
+ return 0
+}
+
+// StartTLS sends the command to start a TLS session and then creates a new TLS Client
+func (l *Conn) StartTLS(config *tls.Config) error {
+ messageID := l.nextMessageID()
+
+ if l.isTLS {
+ return NewError(ErrorNetwork, errors.New("ldap: already encrypted"))
+ }
+
+ packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request")
+ packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, messageID, "MessageID"))
+ request := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationExtendedRequest, nil, "Start TLS")
+ request.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, 0, "1.3.6.1.4.1.1466.20037", "TLS Extended Command"))
+ packet.AppendChild(request)
+ l.Debug.PrintPacket(packet)
+
+ _, err := l.conn.Write(packet.Bytes())
+ if err != nil {
+ return NewError(ErrorNetwork, err)
+ }
+
+ packet, err = ber.ReadPacket(l.conn)
+ if err != nil {
+ return NewError(ErrorNetwork, err)
+ }
+
+ if l.Debug {
+ if err := addLDAPDescriptions(packet); err != nil {
+ return err
+ }
+ ber.PrintPacket(packet)
+ }
+
+ if packet.Children[1].Children[0].Value.(uint64) == 0 {
+ conn := tls.Client(l.conn, config)
+ l.isTLS = true
+ l.conn = conn
+ }
+
+ return nil
+}
+
+func (l *Conn) closing() bool {
+ select {
+ case <-l.chanDone:
+ return true
+ default:
+ return false
+ }
+}
+
+func (l *Conn) sendMessage(packet *ber.Packet) (chan *ber.Packet, error) {
+ if l.closing() {
+ return nil, NewError(ErrorNetwork, errors.New("ldap: connection closed"))
+ }
+ out := make(chan *ber.Packet)
+ message := &messagePacket{
+ Op: MessageRequest,
+ MessageID: packet.Children[0].Value.(uint64),
+ Packet: packet,
+ Channel: out,
+ }
+ l.sendProcessMessage(message)
+ return out, nil
+}
+
+func (l *Conn) finishMessage(messageID uint64) {
+ if l.closing() {
+ return
+ }
+ message := &messagePacket{
+ Op: MessageFinish,
+ MessageID: messageID,
+ }
+ l.sendProcessMessage(message)
+}
+
+func (l *Conn) sendProcessMessage(message *messagePacket) bool {
+ l.wgSender.Add(1)
+ defer l.wgSender.Done()
+
+ if l.closing() {
+ return false
+ }
+ l.chanMessage <- message
+ return true
+}
+
+func (l *Conn) processMessages() {
+ defer func() {
+ for messageID, channel := range l.chanResults {
+ l.Debug.Printf("Closing channel for MessageID %d", messageID)
+ close(channel)
+ delete(l.chanResults, messageID)
+ }
+ close(l.chanMessageID)
+ l.chanConfirm <- true
+ close(l.chanConfirm)
+ }()
+
+ var messageID uint64 = 1
+ for {
+ select {
+ case l.chanMessageID <- messageID:
+ messageID++
+ case messagePacket, ok := <-l.chanMessage:
+ if !ok {
+ l.Debug.Printf("Shutting down - message channel is closed")
+ return
+ }
+ switch messagePacket.Op {
+ case MessageQuit:
+ l.Debug.Printf("Shutting down - quit message received")
+ return
+ case MessageRequest:
+ // Add to message list and write to network
+ l.Debug.Printf("Sending message %d", messagePacket.MessageID)
+ l.chanResults[messagePacket.MessageID] = messagePacket.Channel
+ // go routine
+ buf := messagePacket.Packet.Bytes()
+
+ _, err := l.conn.Write(buf)
+ if err != nil {
+ l.Debug.Printf("Error Sending Message: %s", err.Error())
+ break
+ }
+ case MessageResponse:
+ l.Debug.Printf("Receiving message %d", messagePacket.MessageID)
+ if chanResult, ok := l.chanResults[messagePacket.MessageID]; ok {
+ chanResult <- messagePacket.Packet
+ } else {
+ log.Printf("Received unexpected message %d", messagePacket.MessageID)
+ ber.PrintPacket(messagePacket.Packet)
+ }
+ case MessageFinish:
+ // Remove from message list
+ l.Debug.Printf("Finished message %d", messagePacket.MessageID)
+ close(l.chanResults[messagePacket.MessageID])
+ delete(l.chanResults, messagePacket.MessageID)
+ }
+ }
+ }
+}
+
+func (l *Conn) reader() {
+ defer func() {
+ l.Close()
+ }()
+
+ for {
+ packet, err := ber.ReadPacket(l.conn)
+ if err != nil {
+ l.Debug.Printf("reader: %s", err.Error())
+ return
+ }
+ addLDAPDescriptions(packet)
+ message := &messagePacket{
+ Op: MessageResponse,
+ MessageID: packet.Children[0].Value.(uint64),
+ Packet: packet,
+ }
+ if !l.sendProcessMessage(message) {
+ return
+ }
+
+ }
+}
+
+// Use Abandon operation to perform connection keepalives
+func (l *Conn) Ping() error {
+
+ messageID := l.nextMessageID()
+
+ packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request")
+ packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, messageID, "MessageID"))
+ abandonRequest := ber.Encode(ber.ClassApplication, ber.TypePrimitive, ApplicationAbandonRequest, nil, "Abandon Request")
+ packet.AppendChild(abandonRequest)
+
+ if l.Debug {
+ ber.PrintPacket(packet)
+ }
+
+ channel, err := l.sendMessage(packet)
+ if err != nil {
+ return err
+ }
+ if channel == nil {
+ return NewError(ErrorNetwork, errors.New("ldap: could not send message"))
+ }
+ defer l.finishMessage(messageID)
+
+ if l.Debug {
+ if err := addLDAPDescriptions(packet); err != nil {
+ return err
+ }
+ ber.PrintPacket(packet)
+ }
+
+ return nil
+}
diff --git a/control.go b/control.go
new file mode 100644
index 0000000..60fde91
--- /dev/null
+++ b/control.go
@@ -0,0 +1,160 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package ldap
+
+import (
+ "strings"
+ "fmt"
+ "github.com/nmcclain/asn1-ber"
+)
+
+const (
+ ControlTypePaging = "1.2.840.113556.1.4.319"
+)
+
+var ControlTypeMap = map[string]string{
+ ControlTypePaging: "Paging",
+}
+
+type Control interface {
+ GetControlType() string
+ Encode() *ber.Packet
+ String() string
+}
+
+type ControlString struct {
+ ControlType string
+ Criticality bool
+ ControlValue string
+}
+
+func (c *ControlString) GetControlType() string {
+ return c.ControlType
+}
+
+func (c *ControlString) Encode() *ber.Packet {
+ packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Control")
+ packet.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, c.ControlType, "Control Type ("+ControlTypeMap[c.ControlType]+")"))
+ if c.Criticality {
+ packet.AppendChild(ber.NewBoolean(ber.ClassUniversal, ber.TypePrimitive, ber.TagBoolean, c.Criticality, "Criticality"))
+ }
+ if strings.TrimSpace(c.ControlValue) != "" {
+ packet.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, c.ControlValue, "Control Value"))
+ }
+ return packet
+}
+
+func (c *ControlString) String() string {
+ return fmt.Sprintf("Control Type: %s (%q) Criticality: %t Control Value: %s", ControlTypeMap[c.ControlType], c.ControlType, c.Criticality, c.ControlValue)
+}
+
+type ControlPaging struct {
+ PagingSize uint32
+ Cookie []byte
+}
+
+func (c *ControlPaging) GetControlType() string {
+ return ControlTypePaging
+}
+
+func (c *ControlPaging) Encode() *ber.Packet {
+ packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Control")
+ packet.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, ControlTypePaging, "Control Type ("+ControlTypeMap[ControlTypePaging]+")"))
+
+ p2 := ber.Encode(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, nil, "Control Value (Paging)")
+ seq := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Search Control Value")
+ seq.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, uint64(c.PagingSize), "Paging Size"))
+ cookie := ber.Encode(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, nil, "Cookie")
+ cookie.Value = c.Cookie
+ cookie.Data.Write(c.Cookie)
+ seq.AppendChild(cookie)
+ p2.AppendChild(seq)
+
+ packet.AppendChild(p2)
+ return packet
+}
+
+func (c *ControlPaging) String() string {
+ return fmt.Sprintf(
+ "Control Type: %s (%q) Criticality: %t PagingSize: %d Cookie: %q",
+ ControlTypeMap[ControlTypePaging],
+ ControlTypePaging,
+ false,
+ c.PagingSize,
+ c.Cookie)
+}
+
+func (c *ControlPaging) SetCookie(cookie []byte) {
+ c.Cookie = cookie
+}
+
+func FindControl(controls []Control, controlType string) Control {
+ for _, c := range controls {
+ if c.GetControlType() == controlType {
+ return c
+ }
+ }
+ return nil
+}
+
+func DecodeControl(packet *ber.Packet) Control {
+ ControlType := packet.Children[0].Value.(string)
+ packet.Children[0].Description = "Control Type (" + ControlTypeMap[ControlType] + ")"
+ c := new(ControlString)
+ c.ControlType = ControlType
+ c.Criticality = false
+
+ if len(packet.Children) > 1 {
+ value := packet.Children[1]
+ if len(packet.Children) == 3 {
+ value = packet.Children[2]
+ packet.Children[1].Description = "Criticality"
+ c.Criticality = packet.Children[1].Value.(bool)
+ }
+
+ value.Description = "Control Value"
+ switch ControlType {
+ case ControlTypePaging:
+ value.Description += " (Paging)"
+ c := new(ControlPaging)
+ if value.Value != nil {
+ valueChildren := ber.DecodePacket(value.Data.Bytes())
+ value.Data.Truncate(0)
+ value.Value = nil
+ value.AppendChild(valueChildren)
+ }
+ value = value.Children[0]
+ value.Description = "Search Control Value"
+ value.Children[0].Description = "Paging Size"
+ value.Children[1].Description = "Cookie"
+ c.PagingSize = uint32(value.Children[0].Value.(uint64))
+ c.Cookie = value.Children[1].Data.Bytes()
+ value.Children[1].Value = c.Cookie
+ return c
+ }
+ c.ControlValue = value.Value.(string)
+ }
+ return c
+}
+
+func NewControlString(controlType string, criticality bool, controlValue string) *ControlString {
+ return &ControlString{
+ ControlType: controlType,
+ Criticality: criticality,
+ ControlValue: controlValue,
+ }
+}
+
+func NewControlPaging(pagingSize uint32) *ControlPaging {
+ return &ControlPaging{PagingSize: pagingSize}
+}
+
+func encodeControls(controls []Control) *ber.Packet {
+ packet := ber.Encode(ber.ClassContext, ber.TypeConstructed, 0, nil, "Controls")
+ for _, control := range controls {
+ packet.AppendChild(control.Encode())
+ }
+ return packet
+}
diff --git a/debug.go b/debug.go
new file mode 100644
index 0000000..de9bc5a
--- /dev/null
+++ b/debug.go
@@ -0,0 +1,24 @@
+package ldap
+
+import (
+ "log"
+
+ "github.com/nmcclain/asn1-ber"
+)
+
+// debbuging type
+// - has a Printf method to write the debug output
+type debugging bool
+
+// write debug output
+func (debug debugging) Printf(format string, args ...interface{}) {
+ if debug {
+ log.Printf(format, args...)
+ }
+}
+
+func (debug debugging) PrintPacket(packet *ber.Packet) {
+ if debug {
+ ber.PrintPacket(packet)
+ }
+}
diff --git a/examples/enterprise.ldif b/examples/enterprise.ldif
new file mode 100644
index 0000000..f0ec28f
--- /dev/null
+++ b/examples/enterprise.ldif
@@ -0,0 +1,63 @@
+dn: dc=enterprise,dc=org
+objectClass: dcObject
+objectClass: organization
+o: acme
+
+dn: cn=admin,dc=enterprise,dc=org
+objectClass: person
+cn: admin
+sn: admin
+description: "LDAP Admin"
+
+dn: ou=crew,dc=enterprise,dc=org
+ou: crew
+objectClass: organizationalUnit
+
+
+dn: cn=kirkj,ou=crew,dc=enterprise,dc=org
+cn: kirkj
+sn: Kirk
+gn: James Tiberius
+objectClass: inetOrgPerson
+
+dn: cn=spock,ou=crew,dc=enterprise,dc=org
+cn: spock
+sn: Spock
+objectClass: inetOrgPerson
+
+dn: cn=mccoyl,ou=crew,dc=enterprise,dc=org
+cn: mccoyl
+sn: McCoy
+gn: Leonard
+objectClass: inetOrgPerson
+
+dn: cn=scottm,ou=crew,dc=enterprise,dc=org
+cn: scottm
+sn: Scott
+gn: Montgomery
+objectClass: inetOrgPerson
+
+dn: cn=uhuran,ou=crew,dc=enterprise,dc=org
+cn: uhuran
+sn: Uhura
+gn: Nyota
+objectClass: inetOrgPerson
+
+dn: cn=suluh,ou=crew,dc=enterprise,dc=org
+cn: suluh
+sn: Sulu
+gn: Hikaru
+objectClass: inetOrgPerson
+
+dn: cn=chekovp,ou=crew,dc=enterprise,dc=org
+cn: chekovp
+sn: Chekov
+gn: pavel
+objectClass: inetOrgPerson
diff --git a/examples/modify.go b/examples/modify.go
new file mode 100644
index 0000000..87d1119
--- /dev/null
+++ b/examples/modify.go
@@ -0,0 +1,91 @@
+// +build ignore
+
+// Copyright 2014 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package main
+
+import (
+ "errors"
+ "fmt"
+ "log"
+
+ "github.com/nmcclain/ldap"
+)
+
+var (
+ LdapServer string = "localhost"
+ LdapPort uint16 = 389
+ BaseDN string = "dc=enterprise,dc=org"
+ BindDN string = "cn=admin,dc=enterprise,dc=org"
+ BindPW string = "enterprise"
+ Filter string = "(cn=kirkj)"
+)
+
+func search(l *ldap.Conn, filter string, attributes []string) (*ldap.Entry, *ldap.Error) {
+ search := ldap.NewSearchRequest(
+ BaseDN,
+ ldap.ScopeWholeSubtree, ldap.NeverDerefAliases, 0, 0, false,
+ filter,
+ attributes,
+ nil)
+
+ sr, err := l.Search(search)
+ if err != nil {
+ log.Fatalf("ERROR: %s\n", err)
+ return nil, err
+ }
+
+ log.Printf("Search: %s -> num of entries = %d\n", search.Filter, len(sr.Entries))
+ if len(sr.Entries) == 0 {
+ return nil, ldap.NewError(ldap.ErrorDebugging, errors.New(fmt.Sprintf("no entries found for: %s", filter)))
+ }
+ return sr.Entries[0], nil
+}
+
+func main() {
+ l, err := ldap.Dial("tcp", fmt.Sprintf("%s:%d", LdapServer, LdapPort))
+ if err != nil {
+ log.Fatalf("ERROR: %s\n", err.Error())
+ }
+ defer l.Close()
+ // l.Debug = true
+
+ l.Bind(BindDN, BindPW)
+
+ log.Printf("The Search for Kirk ... %s\n", Filter)
+ entry, err := search(l, Filter, []string{})
+ if err != nil {
+ log.Fatal("could not get entry")
+ }
+ entry.PrettyPrint(0)
+
+ log.Printf("modify the mail address and add a description ... \n")
+ modify := ldap.NewModifyRequest(entry.DN)
+ modify.Add("description", []string{"Captain of the USS Enterprise"})
+ modify.Replace("mail", []string{"[email protected]"})
+ if err := l.Modify(modify); err != nil {
+ log.Fatalf("ERROR: %s\n", err.Error())
+ }
+
+ entry, err = search(l, Filter, []string{})
+ if err != nil {
+ log.Fatal("could not get entry")
+ }
+ entry.PrettyPrint(0)
+
+ log.Printf("reset the entry ... \n")
+ modify = ldap.NewModifyRequest(entry.DN)
+ modify.Delete("description", []string{})
+ modify.Replace("mail", []string{"[email protected]"})
+ if err := l.Modify(modify); err != nil {
+ log.Fatalf("ERROR: %s\n", err.Error())
+ }
+
+ entry, err = search(l, Filter, []string{})
+ if err != nil {
+ log.Fatal("could not get entry")
+ }
+ entry.PrettyPrint(0)
+}
diff --git a/examples/proxy.go b/examples/proxy.go
new file mode 100644
index 0000000..d6b01d0
--- /dev/null
+++ b/examples/proxy.go
@@ -0,0 +1,106 @@
+package main
+
+import (
+ "crypto/sha256"
+ "fmt"
+ "github.com/nmcclain/ldap"
+ "log"
+ "net"
+ "sync"
+)
+
+type ldapHandler struct {
+ sessions map[string]session
+ lock sync.Mutex
+ ldapServer string
+ ldapPort int
+}
+
+///////////// Run a simple LDAP proxy
+func main() {
+ s := ldap.NewServer()
+
+ handler := ldapHandler{
+ sessions: make(map[string]session),
+ ldapServer: "localhost",
+ ldapPort: 3389,
+ }
+ s.BindFunc("", handler)
+ s.SearchFunc("", handler)
+ s.CloseFunc("", handler)
+
+ // start the server
+ if err := s.ListenAndServe("localhost:3388"); err != nil {
+ log.Fatal("LDAP Server Failed: %s", err.Error())
+ }
+}
+
+/////////////
+type session struct {
+ id string
+ c net.Conn
+ ldap *ldap.Conn
+}
+
+func (h ldapHandler) getSession(conn net.Conn) (session, error) {
+ id := connID(conn)
+ h.lock.Lock()
+ s, ok := h.sessions[id] // use server connection if it exists
+ h.lock.Unlock()
+ if !ok { // open a new server connection if not
+ l, err := ldap.Dial("tcp", fmt.Sprintf("%s:%d", h.ldapServer, h.ldapPort))
+ if err != nil {
+ return session{}, err
+ }
+ s = session{id: id, c: conn, ldap: l}
+ h.lock.Lock()
+ h.sessions[s.id] = s
+ h.lock.Unlock()
+ }
+ return s, nil
+}
+
+/////////////
+func (h ldapHandler) Bind(bindDN, bindSimplePw string, conn net.Conn) (uint64, error) {
+ s, err := h.getSession(conn)
+ if err != nil {
+ return ldap.LDAPResultOperationsError, err
+ }
+ if err := s.ldap.Bind(bindDN, bindSimplePw); err != nil {
+ return ldap.LDAPResultOperationsError, err
+ }
+ return ldap.LDAPResultSuccess, nil
+}
+
+/////////////
+func (h ldapHandler) Search(boundDN string, searchReq ldap.SearchRequest, conn net.Conn) (ldap.ServerSearchResult, error) {
+ s, err := h.getSession(conn)
+ if err != nil {
+ return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultOperationsError}, nil
+ }
+ search := ldap.NewSearchRequest(
+ searchReq.BaseDN,
+ ldap.ScopeWholeSubtree, ldap.NeverDerefAliases, 0, 0, false,
+ searchReq.Filter,
+ searchReq.Attributes,
+ nil)
+ sr, err := s.ldap.Search(search)
+ if err != nil {
+ return ldap.ServerSearchResult{}, err
+ }
+ //log.Printf("P: Search OK: %s -> num of entries = %d\n", search.Filter, len(sr.Entries))
+ return ldap.ServerSearchResult{sr.Entries, []string{}, []ldap.Control{}, ldap.LDAPResultSuccess}, nil
+}
+func (h ldapHandler) Close(conn net.Conn) error {
+ conn.Close() // close connection to the server when then client is closed
+ h.lock.Lock()
+ defer h.lock.Unlock()
+ delete(h.sessions, connID(conn))
+ return nil
+}
+func connID(conn net.Conn) string {
+ h := sha256.New()
+ h.Write([]byte(conn.LocalAddr().String() + conn.RemoteAddr().String()))
+ sha := fmt.Sprintf("% x", h.Sum(nil))
+ return string(sha)
+}
diff --git a/examples/search.go b/examples/search.go
new file mode 100644
index 0000000..08b364a
--- /dev/null
+++ b/examples/search.go
@@ -0,0 +1,54 @@
+// +build ignore
+
+// Copyright 2014 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package main
+
+import (
+ "fmt"
+ "log"
+
+ "github.com/nmcclain/ldap"
+)
+
+var (
+ ldapServer string = "adserver"
+ ldapPort uint16 = 3268
+ baseDN string = "dc=*,dc=*"
+ filter string = "(&(objectClass=user)(sAMAccountName=*)(memberOf=CN=*,OU=*,DC=*,DC=*))"
+ Attributes []string = []string{"memberof"}
+ user string = "*"
+ passwd string = "*"
+)
+
+func main() {
+ l, err := ldap.Dial("tcp", fmt.Sprintf("%s:%d", ldapServer, ldapPort))
+ if err != nil {
+ log.Fatalf("ERROR: %s\n", err.Error())
+ }
+ defer l.Close()
+ // l.Debug = true
+
+ err = l.Bind(user, passwd)
+ if err != nil {
+ log.Printf("ERROR: Cannot bind: %s\n", err.Error())
+ return
+ }
+ search := ldap.NewSearchRequest(
+ baseDN,
+ ldap.ScopeWholeSubtree, ldap.NeverDerefAliases, 0, 0, false,
+ filter,
+ Attributes,
+ nil)
+
+ sr, err := l.Search(search)
+ if err != nil {
+ log.Fatalf("ERROR: %s\n", err.Error())
+ return
+ }
+
+ log.Printf("Search: %s -> num of entries = %d\n", search.Filter, len(sr.Entries))
+ sr.PrettyPrint(0)
+}
diff --git a/examples/searchSSL.go b/examples/searchSSL.go
new file mode 100644
index 0000000..75c8395
--- /dev/null
+++ b/examples/searchSSL.go
@@ -0,0 +1,47 @@
+// +build ignore
+
+// Copyright 2014 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package main
+
+import (
+ "fmt"
+ "log"
+
+ "github.com/nmcclain/ldap"
+)
+
+var (
+ LdapServer string = "localhost"
+ LdapPort uint16 = 636
+ BaseDN string = "dc=enterprise,dc=org"
+ Filter string = "(cn=kirkj)"
+ Attributes []string = []string{"mail"}
+)
+
+func main() {
+ l, err := ldap.DialSSL("tcp", fmt.Sprintf("%s:%d", LdapServer, LdapPort), nil)
+ if err != nil {
+ log.Fatalf("ERROR: %s\n", err.String())
+ }
+ defer l.Close()
+ // l.Debug = true
+
+ search := ldap.NewSearchRequest(
+ BaseDN,
+ ldap.ScopeWholeSubtree, ldap.NeverDerefAliases, 0, 0, false,
+ Filter,
+ Attributes,
+ nil)
+
+ sr, err := l.Search(search)
+ if err != nil {
+ log.Fatalf("ERROR: %s\n", err.String())
+ return
+ }
+
+ log.Printf("Search: %s -> num of entries = %d\n", search.Filter, len(sr.Entries))
+ sr.PrettyPrint(0)
+}
diff --git a/examples/searchTLS.go b/examples/searchTLS.go
new file mode 100644
index 0000000..56b3d27
--- /dev/null
+++ b/examples/searchTLS.go
@@ -0,0 +1,47 @@
+// +build ignore
+
+// Copyright 2014 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package main
+
+import (
+ "fmt"
+ "log"
+
+ "github.com/nmcclain/ldap"
+)
+
+var (
+ LdapServer string = "localhost"
+ LdapPort uint16 = 389
+ BaseDN string = "dc=enterprise,dc=org"
+ Filter string = "(cn=kirkj)"
+ Attributes []string = []string{"mail"}
+)
+
+func main() {
+ l, err := ldap.DialTLS("tcp", fmt.Sprintf("%s:%d", LdapServer, LdapPort), nil)
+ if err != nil {
+ log.Fatalf("ERROR: %s\n", err.Error())
+ }
+ defer l.Close()
+ // l.Debug = true
+
+ search := ldap.NewSearchRequest(
+ BaseDN,
+ ldap.ScopeWholeSubtree, ldap.NeverDerefAliases, 0, 0, false,
+ Filter,
+ Attributes,
+ nil)
+
+ sr, err := l.Search(search)
+ if err != nil {
+ log.Fatalf("ERROR: %s\n", err.Error())
+ return
+ }
+
+ log.Printf("Search: %s -> num of entries = %d\n", search.Filter, len(sr.Entries))
+ sr.PrettyPrint(0)
+}
diff --git a/examples/server.go b/examples/server.go
new file mode 100644
index 0000000..3341991
--- /dev/null
+++ b/examples/server.go
@@ -0,0 +1,66 @@
+package main
+
+import (
+ "github.com/nmcclain/ldap"
+ "log"
+ "net"
+)
+
+/////////////
+// Sample searches you can try against this simple LDAP server:
+//
+// ldapsearch -H ldap://localhost:3389 -x -b 'dn=test,dn=com'
+// ldapsearch -H ldap://localhost:3389 -x -b 'dn=test,dn=com' 'cn=ned'
+// ldapsearch -H ldap://localhost:3389 -x -b 'dn=test,dn=com' 'uidnumber=5000'
+/////////////
+
+///////////// Run a simple LDAP server
+func main() {
+ s := ldap.NewServer()
+
+ // register Bind and Search function handlers
+ handler := ldapHandler{}
+ s.BindFunc("", handler)
+ s.SearchFunc("", handler)
+
+ // start the server
+ listen := "localhost:3389"
+ log.Printf("Starting example LDAP server on %s", listen)
+ if err := s.ListenAndServe(listen); err != nil {
+ log.Fatal("LDAP Server Failed: %s", err.Error())
+ }
+}
+
+type ldapHandler struct {
+}
+
+///////////// Allow anonymous binds only
+func (h ldapHandler) Bind(bindDN, bindSimplePw string, conn net.Conn) (ldap.LDAPResultCode, error) {
+ if bindDN == "" && bindSimplePw == "" {
+ return ldap.LDAPResultSuccess, nil
+ }
+ return ldap.LDAPResultInvalidCredentials, nil
+}
+
+///////////// Return some hardcoded search results - we'll respond to any baseDN for testing
+func (h ldapHandler) Search(boundDN string, searchReq ldap.SearchRequest, conn net.Conn) (ldap.ServerSearchResult, error) {
+ entries := []*ldap.Entry{
+ &ldap.Entry{"cn=ned," + searchReq.BaseDN, []*ldap.EntryAttribute{
+ &ldap.EntryAttribute{"cn", []string{"ned"}},
+ &ldap.EntryAttribute{"uidNumber", []string{"5000"}},
+ &ldap.EntryAttribute{"accountStatus", []string{"active"}},
+ &ldap.EntryAttribute{"uid", []string{"ned"}},
+ &ldap.EntryAttribute{"description", []string{"ned"}},
+ &ldap.EntryAttribute{"objectClass", []string{"posixAccount"}},
+ }},
+ &ldap.Entry{"cn=trent," + searchReq.BaseDN, []*ldap.EntryAttribute{
+ &ldap.EntryAttribute{"cn", []string{"trent"}},
+ &ldap.EntryAttribute{"uidNumber", []string{"5005"}},
+ &ldap.EntryAttribute{"accountStatus", []string{"active"}},
+ &ldap.EntryAttribute{"uid", []string{"trent"}},
+ &ldap.EntryAttribute{"description", []string{"trent"}},
+ &ldap.EntryAttribute{"objectClass", []string{"posixAccount"}},
+ }},
+ }
+ return ldap.ServerSearchResult{entries, []string{}, []ldap.Control{}, ldap.LDAPResultSuccess}, nil
+}
diff --git a/examples/slapd.conf b/examples/slapd.conf
new file mode 100644
index 0000000..5a66be0
--- /dev/null
+++ b/examples/slapd.conf
@@ -0,0 +1,67 @@
+#
+# See slapd.conf(5) for details on configuration options.
+# This file should NOT be world readable.
+#
+include /private/etc/openldap/schema/core.schema
+include /private/etc/openldap/schema/cosine.schema
+include /private/etc/openldap/schema/inetorgperson.schema
+
+# Define global ACLs to disable default read access.
+
+# Do not enable referrals until AFTER you have a working directory
+# service AND an understanding of referrals.
+#referral ldap://root.openldap.org
+
+pidfile /private/var/db/openldap/run/slapd.pid
+argsfile /private/var/db/openldap/run/slapd.args
+
+# Load dynamic backend modules:
+# modulepath /usr/libexec/openldap
+# moduleload back_bdb.la
+# moduleload back_hdb.la
+# moduleload back_ldap.la
+
+# Sample security restrictions
+# Require integrity protection (prevent hijacking)
+# Require 112-bit (3DES or better) encryption for updates
+# Require 63-bit encryption for simple bind
+# security ssf=1 update_ssf=112 simple_bind=64
+
+# Sample access control policy:
+# Root DSE: allow anyone to read it
+# Subschema (sub)entry DSE: allow anyone to read it
+# Other DSEs:
+# Allow self write access
+# Allow authenticated users read access
+# Allow anonymous users to authenticate
+# Directives needed to implement policy:
+# access to dn.base="" by * read
+# access to dn.base="cn=Subschema" by * read
+# access to *
+# by self write
+# by users read
+# by anonymous auth
+#
+# if no access controls are present, the default policy
+# allows anyone and everyone to read anything but restricts
+# updates to rootdn. (e.g., "access to * by * read")
+#
+# rootdn can always read and write EVERYTHING!
+
+#######################################################################
+# BDB database definitions
+#######################################################################
+
+database bdb
+suffix "dc=enterprise,dc=org"
+rootdn "cn=admin,dc=enterprise,dc=org"
+# Cleartext passwords, especially for the rootdn, should
+# be avoid. See slappasswd(8) and slapd.conf(5) for details.
+# Use of strong authentication encouraged.
+rootpw {SSHA}laO00HsgszhK1O0Z5qR0/i/US69Osfeu
+# The database directory MUST exist prior to running slapd AND
+# should only be accessible by the slapd and slap tools.
+# Mode 700 recommended.
+directory /private/var/db/openldap/openldap-data
+# Indices to maintain
+index objectClass eq
diff --git a/filter.go b/filter.go
new file mode 100644
index 0000000..df3c86a
--- /dev/null
+++ b/filter.go
@@ -0,0 +1,402 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package ldap
+
+import (
+ "errors"
+ "fmt"
+ "github.com/nmcclain/asn1-ber"
+ "strings"
+)
+
+const (
+ FilterAnd = 0
+ FilterOr = 1
+ FilterNot = 2
+ FilterEqualityMatch = 3
+ FilterSubstrings = 4
+ FilterGreaterOrEqual = 5
+ FilterLessOrEqual = 6
+ FilterPresent = 7
+ FilterApproxMatch = 8
+ FilterExtensibleMatch = 9
+)
+
+var FilterMap = map[uint8]string{
+ FilterAnd: "And",
+ FilterOr: "Or",
+ FilterNot: "Not",
+ FilterEqualityMatch: "Equality Match",
+ FilterSubstrings: "Substrings",
+ FilterGreaterOrEqual: "Greater Or Equal",
+ FilterLessOrEqual: "Less Or Equal",
+ FilterPresent: "Present",
+ FilterApproxMatch: "Approx Match",
+ FilterExtensibleMatch: "Extensible Match",
+}
+
+const (
+ FilterSubstringsInitial = 0
+ FilterSubstringsAny = 1
+ FilterSubstringsFinal = 2
+)
+
+func CompileFilter(filter string) (*ber.Packet, error) {
+ if len(filter) == 0 || filter[0] != '(' {
+ return nil, NewError(ErrorFilterCompile, errors.New("ldap: filter does not start with an '('"))
+ }
+ packet, pos, err := compileFilter(filter, 1)
+ if err != nil {
+ return nil, err
+ }
+ if pos != len(filter) {
+ return nil, NewError(ErrorFilterCompile, errors.New("ldap: finished compiling filter with extra at end: "+fmt.Sprint(filter[pos:])))
+ }
+ return packet, nil
+}
+
+func DecompileFilter(packet *ber.Packet) (ret string, err error) {
+ defer func() {
+ if r := recover(); r != nil {
+ err = NewError(ErrorFilterDecompile, errors.New("ldap: error decompiling filter"))
+ }
+ }()
+ ret = "("
+ err = nil
+ childStr := ""
+
+ switch packet.Tag {
+ case FilterAnd:
+ ret += "&"
+ for _, child := range packet.Children {
+ childStr, err = DecompileFilter(child)
+ if err != nil {
+ return
+ }
+ ret += childStr
+ }
+ case FilterOr:
+ ret += "|"
+ for _, child := range packet.Children {
+ childStr, err = DecompileFilter(child)
+ if err != nil {
+ return
+ }
+ ret += childStr
+ }
+ case FilterNot:
+ ret += "!"
+ childStr, err = DecompileFilter(packet.Children[0])
+ if err != nil {
+ return
+ }
+ ret += childStr
+
+ case FilterSubstrings:
+ ret += ber.DecodeString(packet.Children[0].Data.Bytes())
+ ret += "="
+ switch packet.Children[1].Children[0].Tag {
+ case FilterSubstringsInitial:
+ ret += ber.DecodeString(packet.Children[1].Children[0].Data.Bytes()) + "*"
+ case FilterSubstringsAny:
+ ret += "*" + ber.DecodeString(packet.Children[1].Children[0].Data.Bytes()) + "*"
+ case FilterSubstringsFinal:
+ ret += "*" + ber.DecodeString(packet.Children[1].Children[0].Data.Bytes())
+ }
+ case FilterEqualityMatch:
+ ret += ber.DecodeString(packet.Children[0].Data.Bytes())
+ ret += "="
+ ret += ber.DecodeString(packet.Children[1].Data.Bytes())
+ case FilterGreaterOrEqual:
+ ret += ber.DecodeString(packet.Children[0].Data.Bytes())
+ ret += ">="
+ ret += ber.DecodeString(packet.Children[1].Data.Bytes())
+ case FilterLessOrEqual:
+ ret += ber.DecodeString(packet.Children[0].Data.Bytes())
+ ret += "<="
+ ret += ber.DecodeString(packet.Children[1].Data.Bytes())
+ case FilterPresent:
+ ret += ber.DecodeString(packet.Data.Bytes())
+ ret += "=*"
+ case FilterApproxMatch:
+ ret += ber.DecodeString(packet.Children[0].Data.Bytes())
+ ret += "~="
+ ret += ber.DecodeString(packet.Children[1].Data.Bytes())
+ }
+
+ ret += ")"
+ return
+}
+
+func compileFilterSet(filter string, pos int, parent *ber.Packet) (int, error) {
+ for pos < len(filter) && filter[pos] == '(' {
+ child, newPos, err := compileFilter(filter, pos+1)
+ if err != nil {
+ return pos, err
+ }
+ pos = newPos
+ parent.AppendChild(child)
+ }
+ if pos == len(filter) {
+ return pos, NewError(ErrorFilterCompile, errors.New("ldap: unexpected end of filter"))
+ }
+
+ return pos + 1, nil
+}
+
+func compileFilter(filter string, pos int) (*ber.Packet, int, error) {
+ var packet *ber.Packet
+ var err error
+
+ defer func() {
+ if r := recover(); r != nil {
+ err = NewError(ErrorFilterCompile, errors.New("ldap: error compiling filter"))
+ }
+ }()
+
+ newPos := pos
+ switch filter[pos] {
+ case '(':
+ packet, newPos, err = compileFilter(filter, pos+1)
+ newPos++
+ return packet, newPos, err
+ case '&':
+ packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterAnd, nil, FilterMap[FilterAnd])
+ newPos, err = compileFilterSet(filter, pos+1, packet)
+ return packet, newPos, err
+ case '|':
+ packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterOr, nil, FilterMap[FilterOr])
+ newPos, err = compileFilterSet(filter, pos+1, packet)
+ return packet, newPos, err
+ case '!':
+ packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterNot, nil, FilterMap[FilterNot])
+ var child *ber.Packet
+ child, newPos, err = compileFilter(filter, pos+1)
+ packet.AppendChild(child)
+ return packet, newPos, err
+ default:
+ attribute := ""
+ condition := ""
+ for newPos < len(filter) && filter[newPos] != ')' {
+ switch {
+ case packet != nil:
+ condition += fmt.Sprintf("%c", filter[newPos])
+ case filter[newPos] == '=':
+ packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterEqualityMatch, nil, FilterMap[FilterEqualityMatch])
+ case filter[newPos] == '>' && filter[newPos+1] == '=':
+ packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterGreaterOrEqual, nil, FilterMap[FilterGreaterOrEqual])
+ newPos++
+ case filter[newPos] == '<' && filter[newPos+1] == '=':
+ packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterLessOrEqual, nil, FilterMap[FilterLessOrEqual])
+ newPos++
+ case filter[newPos] == '~' && filter[newPos+1] == '=':
+ packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterApproxMatch, nil, FilterMap[FilterLessOrEqual])
+ newPos++
+ case packet == nil:
+ attribute += fmt.Sprintf("%c", filter[newPos])
+ }
+ newPos++
+ }
+ if newPos == len(filter) {
+ err = NewError(ErrorFilterCompile, errors.New("ldap: unexpected end of filter"))
+ return packet, newPos, err
+ }
+ if packet == nil {
+ err = NewError(ErrorFilterCompile, errors.New("ldap: error parsing filter"))
+ return packet, newPos, err
+ }
+ // Handle FilterEqualityMatch as a separate case (is primitive, not constructed like the other filters)
+ if packet.Tag == FilterEqualityMatch && condition == "*" {
+ packet.TagType = ber.TypePrimitive
+ packet.Tag = FilterPresent
+ packet.Description = FilterMap[packet.Tag]
+ packet.Data.WriteString(attribute)
+ return packet, newPos + 1, nil
+ }
+ packet.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, attribute, "Attribute"))
+ switch {
+ case packet.Tag == FilterEqualityMatch && condition[0] == '*' && condition[len(condition)-1] == '*':
+ // Any
+ packet.Tag = FilterSubstrings
+ packet.Description = FilterMap[packet.Tag]
+ seq := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Substrings")
+ seq.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, FilterSubstringsAny, condition[1:len(condition)-1], "Any Substring"))
+ packet.AppendChild(seq)
+ case packet.Tag == FilterEqualityMatch && condition[0] == '*':
+ // Final
+ packet.Tag = FilterSubstrings
+ packet.Description = FilterMap[packet.Tag]
+ seq := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Substrings")
+ seq.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, FilterSubstringsFinal, condition[1:], "Final Substring"))
+ packet.AppendChild(seq)
+ case packet.Tag == FilterEqualityMatch && condition[len(condition)-1] == '*':
+ // Initial
+ packet.Tag = FilterSubstrings
+ packet.Description = FilterMap[packet.Tag]
+ seq := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Substrings")
+ seq.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, FilterSubstringsInitial, condition[:len(condition)-1], "Initial Substring"))
+ packet.AppendChild(seq)
+ default:
+ packet.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, condition, "Condition"))
+ }
+ newPos++
+ return packet, newPos, err
+ }
+}
+
+func ServerApplyFilter(f *ber.Packet, entry *Entry) (bool, LDAPResultCode) {
+ switch FilterMap[f.Tag] {
+ default:
+ //log.Fatalf("Unknown LDAP filter code: %d", f.Tag)
+ return false, LDAPResultOperationsError
+ case "Equality Match":
+ if len(f.Children) != 2 {
+ return false, LDAPResultOperationsError
+ }
+ attribute := f.Children[0].Value.(string)
+ value := f.Children[1].Value.(string)
+ for _, a := range entry.Attributes {
+ if strings.ToLower(a.Name) == strings.ToLower(attribute) {
+ for _, v := range a.Values {
+ if strings.ToLower(v) == strings.ToLower(value) {
+ return true, LDAPResultSuccess
+ }
+ }
+ }
+ }
+ case "Present":
+ for _, a := range entry.Attributes {
+ if strings.ToLower(a.Name) == strings.ToLower(f.Data.String()) {
+ return true, LDAPResultSuccess
+ }
+ }
+ case "And":
+ for _, child := range f.Children {
+ ok, exitCode := ServerApplyFilter(child, entry)
+ if exitCode != LDAPResultSuccess {
+ return false, exitCode
+ }
+ if !ok {
+ return false, LDAPResultSuccess
+ }
+ }
+ return true, LDAPResultSuccess
+ case "Or":
+ anyOk := false
+ for _, child := range f.Children {
+ ok, exitCode := ServerApplyFilter(child, entry)
+ if exitCode != LDAPResultSuccess {
+ return false, exitCode
+ } else if ok {
+ anyOk = true
+ }
+ }
+ if anyOk {
+ return true, LDAPResultSuccess
+ }
+ case "Not":
+ if len(f.Children) != 1 {
+ return false, LDAPResultOperationsError
+ }
+ ok, exitCode := ServerApplyFilter(f.Children[0], entry)
+ if exitCode != LDAPResultSuccess {
+ return false, exitCode
+ } else if !ok {
+ return true, LDAPResultSuccess
+ }
+ case "Substrings":
+ if len(f.Children) != 2 {
+ return false, LDAPResultOperationsError
+ }
+ attribute := f.Children[0].Value.(string)
+ bytes := f.Children[1].Children[0].Data.Bytes()
+ value := string(bytes[:])
+ for _, a := range entry.Attributes {
+ if strings.ToLower(a.Name) == strings.ToLower(attribute) {
+ for _, v := range a.Values {
+ switch f.Children[1].Children[0].Tag {
+ case FilterSubstringsInitial:
+ if strings.HasPrefix(v, value) {
+ return true, LDAPResultSuccess
+ }
+ case FilterSubstringsAny:
+ if strings.Contains(v, value) {
+ return true, LDAPResultSuccess
+ }
+ case FilterSubstringsFinal:
+ if strings.HasSuffix(v, value) {
+ return true, LDAPResultSuccess
+ }
+ }
+ }
+ }
+ }
+ case "FilterGreaterOrEqual": // TODO
+ return false, LDAPResultOperationsError
+ case "FilterLessOrEqual": // TODO
+ return false, LDAPResultOperationsError
+ case "FilterApproxMatch": // TODO
+ return false, LDAPResultOperationsError
+ case "FilterExtensibleMatch": // TODO
+ return false, LDAPResultOperationsError
+ }
+
+ return false, LDAPResultSuccess
+}
+
+func GetFilterObjectClass(filter string) (string, error) {
+ f, err := CompileFilter(filter)
+ if err != nil {
+ return "", err
+ }
+ return parseFilterObjectClass(f)
+}
+func parseFilterObjectClass(f *ber.Packet) (string, error) {
+ objectClass := ""
+ switch FilterMap[f.Tag] {
+ case "Equality Match":
+ if len(f.Children) != 2 {
+ return "", errors.New("Equality match must have only two children")
+ }
+ attribute := strings.ToLower(f.Children[0].Value.(string))
+ value := f.Children[1].Value.(string)
+ if attribute == "objectclass" {
+ objectClass = strings.ToLower(value)
+ }
+ case "And":
+ for _, child := range f.Children {
+ subType, err := parseFilterObjectClass(child)
+ if err != nil {
+ return "", err
+ }
+ if len(subType) > 0 {
+ objectClass = subType
+ }
+ }
+ case "Or":
+ for _, child := range f.Children {
+ subType, err := parseFilterObjectClass(child)
+ if err != nil {
+ return "", err
+ }
+ if len(subType) > 0 {
+ objectClass = subType
+ }
+ }
+ case "Not":
+ if len(f.Children) != 1 {
+ return "", errors.New("Not filter must have only one child")
+ }
+ subType, err := parseFilterObjectClass(f.Children[0])
+ if err != nil {
+ return "", err
+ }
+ if len(subType) > 0 {
+ objectClass = subType
+ }
+
+ }
+ return strings.ToLower(objectClass), nil
+}
diff --git a/filter_test.go b/filter_test.go
new file mode 100644
index 0000000..2e62f25
--- /dev/null
+++ b/filter_test.go
@@ -0,0 +1,137 @@
+package ldap
+
+import (
+ "reflect"
+ "testing"
+
+ "github.com/nmcclain/asn1-ber"
+)
+
+type compileTest struct {
+ filterStr string
+ filterType uint8
+}
+
+var testFilters = []compileTest{
+ compileTest{filterStr: "(&(sn=Miller)(givenName=Bob))", filterType: FilterAnd},
+ compileTest{filterStr: "(|(sn=Miller)(givenName=Bob))", filterType: FilterOr},
+ compileTest{filterStr: "(!(sn=Miller))", filterType: FilterNot},
+ compileTest{filterStr: "(sn=Miller)", filterType: FilterEqualityMatch},
+ compileTest{filterStr: "(sn=Mill*)", filterType: FilterSubstrings},
+ compileTest{filterStr: "(sn=*Mill)", filterType: FilterSubstrings},
+ compileTest{filterStr: "(sn=*Mill*)", filterType: FilterSubstrings},
+ compileTest{filterStr: "(sn>=Miller)", filterType: FilterGreaterOrEqual},
+ compileTest{filterStr: "(sn<=Miller)", filterType: FilterLessOrEqual},
+ compileTest{filterStr: "(sn=*)", filterType: FilterPresent},
+ compileTest{filterStr: "(sn~=Miller)", filterType: FilterApproxMatch},
+ // compileTest{ filterStr: "()", filterType: FilterExtensibleMatch },
+}
+
+func TestFilter(t *testing.T) {
+ // Test Compiler and Decompiler
+ for _, i := range testFilters {
+ filter, err := CompileFilter(i.filterStr)
+ if err != nil {
+ t.Errorf("Problem compiling %s - %s", i.filterStr, err.Error())
+ } else if filter.Tag != uint8(i.filterType) {
+ t.Errorf("%q Expected %q got %q", i.filterStr, FilterMap[i.filterType], FilterMap[filter.Tag])
+ } else {
+ o, err := DecompileFilter(filter)
+ if err != nil {
+ t.Errorf("Problem compiling %s - %s", i.filterStr, err.Error())
+ } else if i.filterStr != o {
+ t.Errorf("%q expected, got %q", i.filterStr, o)
+ }
+ }
+ }
+}
+
+type binTestFilter struct {
+ bin []byte
+ str string
+}
+
+var binTestFilters = []binTestFilter{
+ {bin: []byte{0x87, 0x06, 0x6d, 0x65, 0x6d, 0x62, 0x65, 0x72}, str: "(member=*)"},
+}
+
+func TestFiltersDecode(t *testing.T) {
+ for i, test := range binTestFilters {
+ p := ber.DecodePacket(test.bin)
+ if filter, err := DecompileFilter(p); err != nil {
+ t.Errorf("binTestFilters[%d], DecompileFilter returned : %s", i, err)
+ } else if filter != test.str {
+ t.Errorf("binTestFilters[%d], %q expected, got %q", i, test.str, filter)
+ }
+ }
+}
+
+func TestFiltersEncode(t *testing.T) {
+ for i, test := range binTestFilters {
+ p, err := CompileFilter(test.str)
+ if err != nil {
+ t.Errorf("binTestFilters[%d], CompileFilter returned : %s", i, err)
+ continue
+ }
+ b := p.Bytes()
+ if !reflect.DeepEqual(b, test.bin) {
+ t.Errorf("binTestFilters[%d], %q expected for CompileFilter(%q), got %q", i, test.bin, test.str, b)
+ }
+ }
+}
+
+func BenchmarkFilterCompile(b *testing.B) {
+ b.StopTimer()
+ filters := make([]string, len(testFilters))
+
+ // Test Compiler and Decompiler
+ for idx, i := range testFilters {
+ filters[idx] = i.filterStr
+ }
+
+ maxIdx := len(filters)
+ b.StartTimer()
+ for i := 0; i < b.N; i++ {
+ CompileFilter(filters[i%maxIdx])
+ }
+}
+
+func BenchmarkFilterDecompile(b *testing.B) {
+ b.StopTimer()
+ filters := make([]*ber.Packet, len(testFilters))
+
+ // Test Compiler and Decompiler
+ for idx, i := range testFilters {
+ filters[idx], _ = CompileFilter(i.filterStr)
+ }
+
+ maxIdx := len(filters)
+ b.StartTimer()
+ for i := 0; i < b.N; i++ {
+ DecompileFilter(filters[i%maxIdx])
+ }
+}
+
+func TestGetFilterObjectClass(t *testing.T) {
+ c, err := GetFilterObjectClass("(objectClass=*)")
+ if err != nil {
+ t.Errorf("GetFilterObjectClass failed")
+ }
+ if c != "" {
+ t.Errorf("GetFilterObjectClass failed")
+ }
+ c, err = GetFilterObjectClass("(objectClass=posixAccount)")
+ if err != nil {
+ t.Errorf("GetFilterObjectClass failed")
+ }
+ if c != "posixaccount" {
+ t.Errorf("GetFilterObjectClass failed")
+ }
+ c, err = GetFilterObjectClass("(&(cn=awesome)(objectClass=posixGroup))")
+ if err != nil {
+ t.Errorf("GetFilterObjectClass failed")
+ }
+ if c != "posixgroup" {
+ t.Errorf("GetFilterObjectClass failed")
+ }
+}
diff --git a/ldap.go b/ldap.go
new file mode 100644
index 0000000..e6d6d52
--- /dev/null
+++ b/ldap.go
@@ -0,0 +1,340 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package ldap
+
+import (
+ "errors"
+ "fmt"
+ "io/ioutil"
+
+ "github.com/nmcclain/asn1-ber"
+)
+
+// LDAP Application Codes
+const (
+ ApplicationBindRequest = 0
+ ApplicationBindResponse = 1
+ ApplicationUnbindRequest = 2
+ ApplicationSearchRequest = 3
+ ApplicationSearchResultEntry = 4
+ ApplicationSearchResultDone = 5
+ ApplicationModifyRequest = 6
+ ApplicationModifyResponse = 7
+ ApplicationAddRequest = 8
+ ApplicationAddResponse = 9
+ ApplicationDelRequest = 10
+ ApplicationDelResponse = 11
+ ApplicationModifyDNRequest = 12
+ ApplicationModifyDNResponse = 13
+ ApplicationCompareRequest = 14
+ ApplicationCompareResponse = 15
+ ApplicationAbandonRequest = 16
+ ApplicationSearchResultReference = 19
+ ApplicationExtendedRequest = 23
+ ApplicationExtendedResponse = 24
+)
+
+var ApplicationMap = map[uint8]string{
+ ApplicationBindRequest: "Bind Request",
+ ApplicationBindResponse: "Bind Response",
+ ApplicationUnbindRequest: "Unbind Request",
+ ApplicationSearchRequest: "Search Request",
+ ApplicationSearchResultEntry: "Search Result Entry",
+ ApplicationSearchResultDone: "Search Result Done",
+ ApplicationModifyRequest: "Modify Request",
+ ApplicationModifyResponse: "Modify Response",
+ ApplicationAddRequest: "Add Request",
+ ApplicationAddResponse: "Add Response",
+ ApplicationDelRequest: "Del Request",
+ ApplicationDelResponse: "Del Response",
+ ApplicationModifyDNRequest: "Modify DN Request",
+ ApplicationModifyDNResponse: "Modify DN Response",
+ ApplicationCompareRequest: "Compare Request",
+ ApplicationCompareResponse: "Compare Response",
+ ApplicationAbandonRequest: "Abandon Request",
+ ApplicationSearchResultReference: "Search Result Reference",
+ ApplicationExtendedRequest: "Extended Request",
+ ApplicationExtendedResponse: "Extended Response",
+}
+
+// LDAP Result Codes
+const (
+ LDAPResultSuccess = 0
+ LDAPResultOperationsError = 1
+ LDAPResultProtocolError = 2
+ LDAPResultTimeLimitExceeded = 3
+ LDAPResultSizeLimitExceeded = 4
+ LDAPResultCompareFalse = 5
+ LDAPResultCompareTrue = 6
+ LDAPResultAuthMethodNotSupported = 7
+ LDAPResultStrongAuthRequired = 8
+ LDAPResultReferral = 10
+ LDAPResultAdminLimitExceeded = 11
+ LDAPResultUnavailableCriticalExtension = 12
+ LDAPResultConfidentialityRequired = 13
+ LDAPResultSaslBindInProgress = 14
+ LDAPResultNoSuchAttribute = 16
+ LDAPResultUndefinedAttributeType = 17
+ LDAPResultInappropriateMatching = 18
+ LDAPResultConstraintViolation = 19
+ LDAPResultAttributeOrValueExists = 20
+ LDAPResultInvalidAttributeSyntax = 21
+ LDAPResultNoSuchObject = 32
+ LDAPResultAliasProblem = 33
+ LDAPResultInvalidDNSyntax = 34
+ LDAPResultAliasDereferencingProblem = 36
+ LDAPResultInappropriateAuthentication = 48
+ LDAPResultInvalidCredentials = 49
+ LDAPResultInsufficientAccessRights = 50
+ LDAPResultBusy = 51
+ LDAPResultUnavailable = 52
+ LDAPResultUnwillingToPerform = 53
+ LDAPResultLoopDetect = 54
+ LDAPResultNamingViolation = 64
+ LDAPResultObjectClassViolation = 65
+ LDAPResultNotAllowedOnNonLeaf = 66
+ LDAPResultNotAllowedOnRDN = 67
+ LDAPResultEntryAlreadyExists = 68
+ LDAPResultObjectClassModsProhibited = 69
+ LDAPResultAffectsMultipleDSAs = 71
+ LDAPResultOther = 80
+
+ ErrorNetwork = 200
+ ErrorFilterCompile = 201
+ ErrorFilterDecompile = 202
+ ErrorDebugging = 203
+)
+
+var LDAPResultCodeMap = map[LDAPResultCode]string{
+ LDAPResultSuccess: "Success",
+ LDAPResultOperationsError: "Operations Error",
+ LDAPResultProtocolError: "Protocol Error",
+ LDAPResultTimeLimitExceeded: "Time Limit Exceeded",
+ LDAPResultSizeLimitExceeded: "Size Limit Exceeded",
+ LDAPResultCompareFalse: "Compare False",
+ LDAPResultCompareTrue: "Compare True",
+ LDAPResultAuthMethodNotSupported: "Auth Method Not Supported",
+ LDAPResultStrongAuthRequired: "Strong Auth Required",
+ LDAPResultReferral: "Referral",
+ LDAPResultAdminLimitExceeded: "Admin Limit Exceeded",
+ LDAPResultUnavailableCriticalExtension: "Unavailable Critical Extension",
+ LDAPResultConfidentialityRequired: "Confidentiality Required",
+ LDAPResultSaslBindInProgress: "Sasl Bind In Progress",
+ LDAPResultNoSuchAttribute: "No Such Attribute",
+ LDAPResultUndefinedAttributeType: "Undefined Attribute Type",
+ LDAPResultInappropriateMatching: "Inappropriate Matching",
+ LDAPResultConstraintViolation: "Constraint Violation",
+ LDAPResultAttributeOrValueExists: "Attribute Or Value Exists",
+ LDAPResultInvalidAttributeSyntax: "Invalid Attribute Syntax",
+ LDAPResultNoSuchObject: "No Such Object",
+ LDAPResultAliasProblem: "Alias Problem",
+ LDAPResultInvalidDNSyntax: "Invalid DN Syntax",
+ LDAPResultAliasDereferencingProblem: "Alias Dereferencing Problem",
+ LDAPResultInappropriateAuthentication: "Inappropriate Authentication",
+ LDAPResultInvalidCredentials: "Invalid Credentials",
+ LDAPResultInsufficientAccessRights: "Insufficient Access Rights",
+ LDAPResultBusy: "Busy",
+ LDAPResultUnavailable: "Unavailable",
+ LDAPResultUnwillingToPerform: "Unwilling To Perform",
+ LDAPResultLoopDetect: "Loop Detect",
+ LDAPResultNamingViolation: "Naming Violation",
+ LDAPResultObjectClassViolation: "Object Class Violation",
+ LDAPResultNotAllowedOnNonLeaf: "Not Allowed On Non Leaf",
+ LDAPResultNotAllowedOnRDN: "Not Allowed On RDN",
+ LDAPResultEntryAlreadyExists: "Entry Already Exists",
+ LDAPResultObjectClassModsProhibited: "Object Class Mods Prohibited",
+ LDAPResultAffectsMultipleDSAs: "Affects Multiple DSAs",
+ LDAPResultOther: "Other",
+}
+
+// Other LDAP constants
+const (
+ LDAPBindAuthSimple = 0
+ LDAPBindAuthSASL = 3
+)
+
+type LDAPResultCode uint8
+
+type Attribute struct {
+ attrType string
+ attrVals []string
+}
+type AddRequest struct {
+ dn string
+ attributes []Attribute
+}
+type DeleteRequest struct {
+ dn string
+}
+type ModifyDNRequest struct {
+ dn string
+ newrdn string
+ deleteoldrdn bool
+ newSuperior string
+}
+type AttributeValueAssertion struct {
+ attributeDesc string
+ assertionValue string
+}
+type CompareRequest struct {
+ dn string
+ ava []AttributeValueAssertion
+}
+type ExtendedRequest struct {
+ requestName string
+ requestValue string
+}
+
+// Adds descriptions to an LDAP Response packet for debugging
+func addLDAPDescriptions(packet *ber.Packet) (err error) {
+ defer func() {
+ if r := recover(); r != nil {
+ err = NewError(ErrorDebugging, errors.New("ldap: cannot process packet to add descriptions"))
+ }
+ }()
+ packet.Description = "LDAP Response"
+ packet.Children[0].Description = "Message ID"
+
+ application := packet.Children[1].Tag
+ packet.Children[1].Description = ApplicationMap[application]
+
+ switch application {
+ case ApplicationBindRequest:
+ addRequestDescriptions(packet)
+ case ApplicationBindResponse:
+ addDefaultLDAPResponseDescriptions(packet)
+ case ApplicationUnbindRequest:
+ addRequestDescriptions(packet)
+ case ApplicationSearchRequest:
+ addRequestDescriptions(packet)
+ case ApplicationSearchResultEntry:
+ packet.Children[1].Children[0].Description = "Object Name"
+ packet.Children[1].Children[1].Description = "Attributes"
+ for _, child := range packet.Children[1].Children[1].Children {
+ child.Description = "Attribute"
+ child.Children[0].Description = "Attribute Name"
+ child.Children[1].Description = "Attribute Values"
+ for _, grandchild := range child.Children[1].Children {
+ grandchild.Description = "Attribute Value"
+ }
+ }
+ if len(packet.Children) == 3 {
+ addControlDescriptions(packet.Children[2])
+ }
+ case ApplicationSearchResultDone:
+ addDefaultLDAPResponseDescriptions(packet)
+ case ApplicationModifyRequest:
+ addRequestDescriptions(packet)
+ case ApplicationModifyResponse:
+ case ApplicationAddRequest:
+ addRequestDescriptions(packet)
+ case ApplicationAddResponse:
+ case ApplicationDelRequest:
+ addRequestDescriptions(packet)
+ case ApplicationDelResponse:
+ case ApplicationModifyDNRequest:
+ addRequestDescriptions(packet)
+ case ApplicationModifyDNResponse:
+ case ApplicationCompareRequest:
+ addRequestDescriptions(packet)
+ case ApplicationCompareResponse:
+ case ApplicationAbandonRequest:
+ addRequestDescriptions(packet)
+ case ApplicationSearchResultReference:
+ case ApplicationExtendedRequest:
+ addRequestDescriptions(packet)
+ case ApplicationExtendedResponse:
+ }
+
+ return nil
+}
+
+func addControlDescriptions(packet *ber.Packet) {
+ packet.Description = "Controls"
+ for _, child := range packet.Children {
+ child.Description = "Control"
+ child.Children[0].Description = "Control Type (" + ControlTypeMap[child.Children[0].Value.(string)] + ")"
+ value := child.Children[1]
+ if len(child.Children) == 3 {
+ child.Children[1].Description = "Criticality"
+ value = child.Children[2]
+ }
+ value.Description = "Control Value"
+
+ switch child.Children[0].Value.(string) {
+ case ControlTypePaging:
+ value.Description += " (Paging)"
+ if value.Value != nil {
+ valueChildren := ber.DecodePacket(value.Data.Bytes())
+ value.Data.Truncate(0)
+ value.Value = nil
+ valueChildren.Children[1].Value = valueChildren.Children[1].Data.Bytes()
+ value.AppendChild(valueChildren)
+ }
+ value.Children[0].Description = "Real Search Control Value"
+ value.Children[0].Children[0].Description = "Paging Size"
+ value.Children[0].Children[1].Description = "Cookie"
+ }
+ }
+}
+
+func addRequestDescriptions(packet *ber.Packet) {
+ packet.Description = "LDAP Request"
+ packet.Children[0].Description = "Message ID"
+ packet.Children[1].Description = ApplicationMap[packet.Children[1].Tag]
+ if len(packet.Children) == 3 {
+ addControlDescriptions(packet.Children[2])
+ }
+}
+
+func addDefaultLDAPResponseDescriptions(packet *ber.Packet) {
+ resultCode := packet.Children[1].Children[0].Value.(uint64)
+ packet.Children[1].Children[0].Description = "Result Code (" + LDAPResultCodeMap[LDAPResultCode(resultCode)] + ")"
+ packet.Children[1].Children[1].Description = "Matched DN"
+ packet.Children[1].Children[2].Description = "Error Message"
+ if len(packet.Children[1].Children) > 3 {
+ packet.Children[1].Children[3].Description = "Referral"
+ }
+ if len(packet.Children) == 3 {
+ addControlDescriptions(packet.Children[2])
+ }
+}
+
+func DebugBinaryFile(fileName string) error {
+ file, err := ioutil.ReadFile(fileName)
+ if err != nil {
+ return NewError(ErrorDebugging, err)
+ }
+ ber.PrintBytes(file, "")
+ packet := ber.DecodePacket(file)
+ addLDAPDescriptions(packet)
+ ber.PrintPacket(packet)
+
+ return nil
+}
+
+type Error struct {
+ Err error
+ ResultCode LDAPResultCode
+}
+
+func (e *Error) Error() string {
+ return fmt.Sprintf("LDAP Result Code %d %q: %s", e.ResultCode, LDAPResultCodeMap[e.ResultCode], e.Err.Error())
+}
+
+func NewError(resultCode LDAPResultCode, err error) error {
+ return &Error{ResultCode: resultCode, Err: err}
+}
+
+func getLDAPResultCode(packet *ber.Packet) (code LDAPResultCode, description string) {
+ if len(packet.Children) >= 2 {
+ response := packet.Children[1]
+ if response.ClassType == ber.ClassApplication && response.TagType == ber.TypeConstructed && len(response.Children) == 3 {
+ return LDAPResultCode(response.Children[0].Value.(uint64)), response.Children[2].Value.(string)
+ }
+ }
+
+ return ErrorNetwork, "Invalid packet format"
+}
diff --git a/ldap_test.go b/ldap_test.go
new file mode 100644
index 0000000..31cfbf0
--- /dev/null
+++ b/ldap_test.go
@@ -0,0 +1,123 @@
+package ldap
+
+import (
+ "fmt"
+ "testing"
+)
+
+var ldapServer = "ldap.itd.umich.edu"
+var ldapPort = uint16(389)
+var baseDN = "dc=umich,dc=edu"
+var filter = []string{
+ "(cn=cis-fac)",
+ "(&(objectclass=rfc822mailgroup)(cn=*Computer*))",
+ "(&(objectclass=rfc822mailgroup)(cn=*Mathematics*))"}
+var attributes = []string{
+ "cn",
+ "description"}
+
+func TestConnect(t *testing.T) {
+ fmt.Printf("TestConnect: starting...\n")
+ l, err := Dial("tcp", fmt.Sprintf("%s:%d", ldapServer, ldapPort))
+ if err != nil {
+ t.Errorf(err.Error())
+ return
+ }
+ defer l.Close()
+ fmt.Printf("TestConnect: finished...\n")
+}
+
+func TestSearch(t *testing.T) {
+ fmt.Printf("TestSearch: starting...\n")
+ l, err := Dial("tcp", fmt.Sprintf("%s:%d", ldapServer, ldapPort))
+ if err != nil {
+ t.Errorf(err.Error())
+ return
+ }
+ defer l.Close()
+
+ searchRequest := NewSearchRequest(
+ baseDN,
+ ScopeWholeSubtree, DerefAlways, 0, 0, false,
+ filter[0],
+ attributes,
+ nil)
+
+ sr, err := l.Search(searchRequest)
+ if err != nil {
+ t.Errorf(err.Error())
+ return
+ }
+
+ fmt.Printf("TestSearch: %s -> num of entries = %d\n", searchRequest.Filter, len(sr.Entries))
+}
+
+func TestSearchWithPaging(t *testing.T) {
+ fmt.Printf("TestSearchWithPaging: starting...\n")
+ l, err := Dial("tcp", fmt.Sprintf("%s:%d", ldapServer, ldapPort))
+ if err != nil {
+ t.Errorf(err.Error())
+ return
+ }
+ defer l.Close()
+
+ err = l.Bind("", "")
+ if err != nil {
+ t.Errorf(err.Error())
+ return
+ }
+
+ searchRequest := NewSearchRequest(
+ baseDN,
+ ScopeWholeSubtree, DerefAlways, 0, 0, false,
+ filter[1],
+ attributes,
+ nil)
+ sr, err := l.SearchWithPaging(searchRequest, 5)
+ if err != nil {
+ t.Errorf(err.Error())
+ return
+ }
+
+ fmt.Printf("TestSearchWithPaging: %s -> num of entries = %d\n", searchRequest.Filter, len(sr.Entries))
+}
+
+func testMultiGoroutineSearch(t *testing.T, l *Conn, results chan *SearchResult, i int) {
+ searchRequest := NewSearchRequest(
+ baseDN,
+ ScopeWholeSubtree, DerefAlways, 0, 0, false,
+ filter[i],
+ attributes,
+ nil)
+ sr, err := l.Search(searchRequest)
+ if err != nil {
+ t.Errorf(err.Error())
+ results <- nil
+ return
+ }
+ results <- sr
+}
+
+func TestMultiGoroutineSearch(t *testing.T) {
+ fmt.Printf("TestMultiGoroutineSearch: starting...\n")
+ l, err := Dial("tcp", fmt.Sprintf("%s:%d", ldapServer, ldapPort))
+ if err != nil {
+ t.Errorf(err.Error())
+ return
+ }
+ defer l.Close()
+
+ results := make([]chan *SearchResult, len(filter))
+ for i := range filter {
+ results[i] = make(chan *SearchResult)
+ go testMultiGoroutineSearch(t, l, results[i], i)
+ }
+ for i := range filter {
+ sr := <-results[i]
+ if sr == nil {
+ t.Errorf("Did not receive results from goroutine for %q", filter[i])
+ } else {
+ fmt.Printf("TestMultiGoroutineSearch(%d): %s -> num of entries = %d\n", i, filter[i], len(sr.Entries))
+ }
+ }
+}
diff --git a/modify.go b/modify.go
new file mode 100644
index 0000000..6ffe314
--- /dev/null
+++ b/modify.go
@@ -0,0 +1,162 @@
+// Copyright 2014 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+//
+// File contains Modify functionality
+//
+// https://tools.ietf.org/html/rfc4511
+//
+// ModifyRequest ::= [APPLICATION 6] SEQUENCE {
+// object LDAPDN,
+// changes SEQUENCE OF change SEQUENCE {
+// operation ENUMERATED {
+// add (0),
+// delete (1),
+// replace (2),
+// ... },
+// modification PartialAttribute } }
+//
+// PartialAttribute ::= SEQUENCE {
+// type AttributeDescription,
+// vals SET OF value AttributeValue }
+//
+// AttributeDescription ::= LDAPString
+// -- Constrained to <attributedescription>
+// -- [RFC4512]
+//
+// AttributeValue ::= OCTET STRING
+//
+
+package ldap
+
+import (
+ "errors"
+ "log"
+
+ "github.com/nmcclain/asn1-ber"
+)
+
+const (
+ AddAttribute = 0
+ DeleteAttribute = 1
+ ReplaceAttribute = 2
+)
+
+var LDAPModifyAttributeMap = map[uint64]string{
+ AddAttribute: "Add",
+ DeleteAttribute: "Delete",
+ ReplaceAttribute: "Replace",
+}
+
+type PartialAttribute struct {
+ attrType string
+ attrVals []string
+}
+
+func (p *PartialAttribute) encode() *ber.Packet {
+ seq := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "PartialAttribute")
+ seq.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, p.attrType, "Type"))
+ set := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSet, nil, "AttributeValue")
+ for _, value := range p.attrVals {
+ set.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, value, "Vals"))
+ }
+ seq.AppendChild(set)
+ return seq
+}
+
+type ModifyRequest struct {
+ dn string
+ addAttributes []PartialAttribute
+ deleteAttributes []PartialAttribute
+ replaceAttributes []PartialAttribute
+}
+
+func (m *ModifyRequest) Add(attrType string, attrVals []string) {
+ m.addAttributes = append(m.addAttributes, PartialAttribute{attrType: attrType, attrVals: attrVals})
+}
+
+func (m *ModifyRequest) Delete(attrType string, attrVals []string) {
+ m.deleteAttributes = append(m.deleteAttributes, PartialAttribute{attrType: attrType, attrVals: attrVals})
+}
+
+func (m *ModifyRequest) Replace(attrType string, attrVals []string) {
+ m.replaceAttributes = append(m.replaceAttributes, PartialAttribute{attrType: attrType, attrVals: attrVals})
+}
+
+func (m ModifyRequest) encode() *ber.Packet {
+ request := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationModifyRequest, nil, "Modify Request")
+ request.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, m.dn, "DN"))
+ changes := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Changes")
+ for _, attribute := range m.addAttributes {
+ change := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Change")
+ change.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagEnumerated, uint64(AddAttribute), "Operation"))
+ change.AppendChild(attribute.encode())
+ changes.AppendChild(change)
+ }
+ for _, attribute := range m.deleteAttributes {
+ change := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Change")
+ change.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagEnumerated, uint64(DeleteAttribute), "Operation"))
+ change.AppendChild(attribute.encode())
+ changes.AppendChild(change)
+ }
+ for _, attribute := range m.replaceAttributes {
+ change := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Change")
+ change.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagEnumerated, uint64(ReplaceAttribute), "Operation"))
+ change.AppendChild(attribute.encode())
+ changes.AppendChild(change)
+ }
+ request.AppendChild(changes)
+ return request
+}
+
+func NewModifyRequest(
+ dn string,
+) *ModifyRequest {
+ return &ModifyRequest{
+ dn: dn,
+ }
+}
+
+func (l *Conn) Modify(modifyRequest *ModifyRequest) error {
+ messageID := l.nextMessageID()
+ packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request")
+ packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, messageID, "MessageID"))
+ packet.AppendChild(modifyRequest.encode())
+
+ l.Debug.PrintPacket(packet)
+
+ channel, err := l.sendMessage(packet)
+ if err != nil {
+ return err
+ }
+ if channel == nil {
+ return NewError(ErrorNetwork, errors.New("ldap: could not send message"))
+ }
+ defer l.finishMessage(messageID)
+
+ l.Debug.Printf("%d: waiting for response", messageID)
+ packet = <-channel
+ l.Debug.Printf("%d: got response %p", messageID, packet)
+ if packet == nil {
+ return NewError(ErrorNetwork, errors.New("ldap: could not retrieve message"))
+ }
+
+ if l.Debug {
+ if err := addLDAPDescriptions(packet); err != nil {
+ return err
+ }
+ ber.PrintPacket(packet)
+ }
+
+ if packet.Children[1].Tag == ApplicationModifyResponse {
+ resultCode, resultDescription := getLDAPResultCode(packet)
+ if resultCode != 0 {
+ return NewError(resultCode, errors.New(resultDescription))
+ }
+ } else {
+ log.Printf("Unexpected Response: %d", packet.Children[1].Tag)
+ }
+
+ l.Debug.Printf("%d: returning", messageID)
+ return nil
+}
diff --git a/search.go b/search.go
new file mode 100644
index 0000000..45b26b8
--- /dev/null
+++ b/search.go
@@ -0,0 +1,350 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+//
+// File contains Search functionality
+//
+// https://tools.ietf.org/html/rfc4511
+//
+// SearchRequest ::= [APPLICATION 3] SEQUENCE {
+// baseObject LDAPDN,
+// scope ENUMERATED {
+// baseObject (0),
+// singleLevel (1),
+// wholeSubtree (2),
+// ... },
+// derefAliases ENUMERATED {
+// neverDerefAliases (0),
+// derefInSearching (1),
+// derefFindingBaseObj (2),
+// derefAlways (3) },
+// sizeLimit INTEGER (0 .. maxInt),
+// timeLimit INTEGER (0 .. maxInt),
+// typesOnly BOOLEAN,
+// filter Filter,
+// attributes AttributeSelection }
+//
+// AttributeSelection ::= SEQUENCE OF selector LDAPString
+// -- The LDAPString is constrained to
+// -- <attributeSelector> in Section 4.5.1.8
+//
+// Filter ::= CHOICE {
+// and [0] SET SIZE (1..MAX) OF filter Filter,
+// or [1] SET SIZE (1..MAX) OF filter Filter,
+// not [2] Filter,
+// equalityMatch [3] AttributeValueAssertion,
+// substrings [4] SubstringFilter,
+// greaterOrEqual [5] AttributeValueAssertion,
+// lessOrEqual [6] AttributeValueAssertion,
+// present [7] AttributeDescription,
+// approxMatch [8] AttributeValueAssertion,
+// extensibleMatch [9] MatchingRuleAssertion,
+// ... }
+//
+// SubstringFilter ::= SEQUENCE {
+// type AttributeDescription,
+// substrings SEQUENCE SIZE (1..MAX) OF substring CHOICE {
+// initial [0] AssertionValue, -- can occur at most once
+// any [1] AssertionValue,
+// final [2] AssertionValue } -- can occur at most once
+// }
+//
+// MatchingRuleAssertion ::= SEQUENCE {
+// matchingRule [1] MatchingRuleId OPTIONAL,
+// type [2] AttributeDescription OPTIONAL,
+// matchValue [3] AssertionValue,
+// dnAttributes [4] BOOLEAN DEFAULT FALSE }
+//
+//
+
+package ldap
+
+import (
+ "errors"
+ "fmt"
+ "strings"
+
+ "github.com/nmcclain/asn1-ber"
+)
+
+const (
+ ScopeBaseObject = 0
+ ScopeSingleLevel = 1
+ ScopeWholeSubtree = 2
+)
+
+var ScopeMap = map[int]string{
+ ScopeBaseObject: "Base Object",
+ ScopeSingleLevel: "Single Level",
+ ScopeWholeSubtree: "Whole Subtree",
+}
+
+const (
+ NeverDerefAliases = 0
+ DerefInSearching = 1
+ DerefFindingBaseObj = 2
+ DerefAlways = 3
+)
+
+var DerefMap = map[int]string{
+ NeverDerefAliases: "NeverDerefAliases",
+ DerefInSearching: "DerefInSearching",
+ DerefFindingBaseObj: "DerefFindingBaseObj",
+ DerefAlways: "DerefAlways",
+}
+
+type Entry struct {
+ DN string
+ Attributes []*EntryAttribute
+}
+
+func (e *Entry) GetAttributeValues(attribute string) []string {
+ for _, attr := range e.Attributes {
+ if attr.Name == attribute {
+ return attr.Values
+ }
+ }
+ return []string{}
+}
+
+func (e *Entry) GetAttributeValue(attribute string) string {
+ values := e.GetAttributeValues(attribute)
+ if len(values) == 0 {
+ return ""
+ }
+ return values[0]
+}
+
+func (e *Entry) Print() {
+ fmt.Printf("DN: %s\n", e.DN)
+ for _, attr := range e.Attributes {
+ attr.Print()
+ }
+}
+
+func (e *Entry) PrettyPrint(indent int) {
+ fmt.Printf("%sDN: %s\n", strings.Repeat(" ", indent), e.DN)
+ for _, attr := range e.Attributes {
+ attr.PrettyPrint(indent + 2)
+ }
+}
+
+type EntryAttribute struct {
+ Name string
+ Values []string
+}
+
+func (e *EntryAttribute) Print() {
+ fmt.Printf("%s: %s\n", e.Name, e.Values)
+}
+
+func (e *EntryAttribute) PrettyPrint(indent int) {
+ fmt.Printf("%s%s: %s\n", strings.Repeat(" ", indent), e.Name, e.Values)
+}
+
+type SearchResult struct {
+ Entries []*Entry
+ Referrals []string
+ Controls []Control
+}
+
+func (s *SearchResult) Print() {
+ for _, entry := range s.Entries {
+ entry.Print()
+ }
+}
+
+func (s *SearchResult) PrettyPrint(indent int) {
+ for _, entry := range s.Entries {
+ entry.PrettyPrint(indent)
+ }
+}
+
+type SearchRequest struct {
+ BaseDN string
+ Scope int
+ DerefAliases int
+ SizeLimit int
+ TimeLimit int
+ TypesOnly bool
+ Filter string
+ Attributes []string
+ Controls []Control
+}
+
+func (s *SearchRequest) encode() (*ber.Packet, error) {
+ request := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationSearchRequest, nil, "Search Request")
+ request.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, s.BaseDN, "Base DN"))
+ request.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagEnumerated, uint64(s.Scope), "Scope"))
+ request.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagEnumerated, uint64(s.DerefAliases), "Deref Aliases"))
+ request.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, uint64(s.SizeLimit), "Size Limit"))
+ request.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, uint64(s.TimeLimit), "Time Limit"))
+ request.AppendChild(ber.NewBoolean(ber.ClassUniversal, ber.TypePrimitive, ber.TagBoolean, s.TypesOnly, "Types Only"))
+ // compile and encode filter
+ filterPacket, err := CompileFilter(s.Filter)
+ if err != nil {
+ return nil, err
+ }
+ request.AppendChild(filterPacket)
+ // encode attributes
+ attributesPacket := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Attributes")
+ for _, attribute := range s.Attributes {
+ attributesPacket.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, attribute, "Attribute"))
+ }
+ request.AppendChild(attributesPacket)
+ return request, nil
+}
+
+func NewSearchRequest(
+ BaseDN string,
+ Scope, DerefAliases, SizeLimit, TimeLimit int,
+ TypesOnly bool,
+ Filter string,
+ Attributes []string,
+ Controls []Control,
+) *SearchRequest {
+ return &SearchRequest{
+ BaseDN: BaseDN,
+ Scope: Scope,
+ DerefAliases: DerefAliases,
+ SizeLimit: SizeLimit,
+ TimeLimit: TimeLimit,
+ TypesOnly: TypesOnly,
+ Filter: Filter,
+ Attributes: Attributes,
+ Controls: Controls,
+ }
+}
+
+func (l *Conn) SearchWithPaging(searchRequest *SearchRequest, pagingSize uint32) (*SearchResult, error) {
+ if searchRequest.Controls == nil {
+ searchRequest.Controls = make([]Control, 0)
+ }
+
+ pagingControl := NewControlPaging(pagingSize)
+ searchRequest.Controls = append(searchRequest.Controls, pagingControl)
+ searchResult := new(SearchResult)
+ for {
+ result, err := l.Search(searchRequest)
+ l.Debug.Printf("Looking for Paging Control...")
+ if err != nil {
+ return searchResult, err
+ }
+ if result == nil {
+ return searchResult, NewError(ErrorNetwork, errors.New("ldap: packet not received"))
+ }
+
+ for _, entry := range result.Entries {
+ searchResult.Entries = append(searchResult.Entries, entry)
+ }
+ for _, referral := range result.Referrals {
+ searchResult.Referrals = append(searchResult.Referrals, referral)
+ }
+ for _, control := range result.Controls {
+ searchResult.Controls = append(searchResult.Controls, control)
+ }
+
+ l.Debug.Printf("Looking for Paging Control...")
+ pagingResult := FindControl(result.Controls, ControlTypePaging)
+ if pagingResult == nil {
+ pagingControl = nil
+ l.Debug.Printf("Could not find paging control. Breaking...")
+ break
+ }
+
+ cookie := pagingResult.(*ControlPaging).Cookie
+ if len(cookie) == 0 {
+ pagingControl = nil
+ l.Debug.Printf("Could not find cookie. Breaking...")
+ break
+ }
+ pagingControl.SetCookie(cookie)
+ }
+
+ if pagingControl != nil {
+ l.Debug.Printf("Abandoning Paging...")
+ pagingControl.PagingSize = 0
+ l.Search(searchRequest)
+ }
+
+ return searchResult, nil
+}
+
+func (l *Conn) Search(searchRequest *SearchRequest) (*SearchResult, error) {
+ messageID := l.nextMessageID()
+ packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request")
+ packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, messageID, "MessageID"))
+ // encode search request
+ encodedSearchRequest, err := searchRequest.encode()
+ if err != nil {
+ return nil, err
+ }
+ packet.AppendChild(encodedSearchRequest)
+ // encode search controls
+ if searchRequest.Controls != nil {
+ packet.AppendChild(encodeControls(searchRequest.Controls))
+ }
+
+ l.Debug.PrintPacket(packet)
+
+ channel, err := l.sendMessage(packet)
+ if err != nil {
+ return nil, err
+ }
+ if channel == nil {
+ return nil, NewError(ErrorNetwork, errors.New("ldap: could not send message"))
+ }
+ defer l.finishMessage(messageID)
+
+ result := &SearchResult{
+ Entries: make([]*Entry, 0),
+ Referrals: make([]string, 0),
+ Controls: make([]Control, 0)}
+
+ foundSearchResultDone := false
+ for !foundSearchResultDone {
+ l.Debug.Printf("%d: waiting for response", messageID)
+ packet = <-channel
+ l.Debug.Printf("%d: got response %p", messageID, packet)
+ if packet == nil {
+ return nil, NewError(ErrorNetwork, errors.New("ldap: could not retrieve message"))
+ }
+
+ if l.Debug {
+ if err := addLDAPDescriptions(packet); err != nil {
+ return nil, err
+ }
+ ber.PrintPacket(packet)
+ }
+
+ switch packet.Children[1].Tag {
+ case 4:
+ entry := new(Entry)
+ entry.DN = packet.Children[1].Children[0].Value.(string)
+ for _, child := range packet.Children[1].Children[1].Children {
+ attr := new(EntryAttribute)
+ attr.Name = child.Children[0].Value.(string)
+ for _, value := range child.Children[1].Children {
+ attr.Values = append(attr.Values, value.Value.(string))
+ }
+ entry.Attributes = append(entry.Attributes, attr)
+ }
+ result.Entries = append(result.Entries, entry)
+ case 5:
+ resultCode, resultDescription := getLDAPResultCode(packet)
+ if resultCode != 0 {
+ return result, NewError(resultCode, errors.New(resultDescription))
+ }
+ if len(packet.Children) == 3 {
+ for _, child := range packet.Children[2].Children {
+ result.Controls = append(result.Controls, DecodeControl(child))
+ }
+ }
+ foundSearchResultDone = true
+ case 19:
+ result.Referrals = append(result.Referrals, packet.Children[1].Children[0].Value.(string))
+ }
+ }
+ l.Debug.Printf("%d: returning", messageID)
+ return result, nil
+}
diff --git a/server.go b/server.go
new file mode 100644
index 0000000..dcb6406
--- /dev/null
+++ b/server.go
@@ -0,0 +1,473 @@
+package ldap
+
+import (
+ "crypto/tls"
+ "github.com/nmcclain/asn1-ber"
+ "io"
+ "log"
+ "net"
+ "strings"
+ "sync"
+)
+
+type Binder interface {
+ Bind(bindDN, bindSimplePw string, conn net.Conn) (LDAPResultCode, error)
+}
+type Searcher interface {
+ Search(boundDN string, req SearchRequest, conn net.Conn) (ServerSearchResult, error)
+}
+type Adder interface {
+ Add(boundDN string, req AddRequest, conn net.Conn) (LDAPResultCode, error)
+}
+type Modifier interface {
+ Modify(boundDN string, req ModifyRequest, conn net.Conn) (LDAPResultCode, error)
+}
+type Deleter interface {
+ Delete(boundDN, deleteDN string, conn net.Conn) (LDAPResultCode, error)
+}
+type ModifyDNr interface {
+ ModifyDN(boundDN string, req ModifyDNRequest, conn net.Conn) (LDAPResultCode, error)
+}
+type Comparer interface {
+ Compare(boundDN string, req CompareRequest, conn net.Conn) (LDAPResultCode, error)
+}
+type Abandoner interface {
+ Abandon(boundDN string, conn net.Conn) error
+}
+type Extender interface {
+ Extended(boundDN string, req ExtendedRequest, conn net.Conn) (LDAPResultCode, error)
+}
+type Unbinder interface {
+ Unbind(boundDN string, conn net.Conn) (LDAPResultCode, error)
+}
+type Closer interface {
+ Close(boundDN string, conn net.Conn) error
+}
+
+//
+type Server struct {
+ BindFns map[string]Binder
+ SearchFns map[string]Searcher
+ AddFns map[string]Adder
+ ModifyFns map[string]Modifier
+ DeleteFns map[string]Deleter
+ ModifyDNFns map[string]ModifyDNr
+ CompareFns map[string]Comparer
+ AbandonFns map[string]Abandoner
+ ExtendedFns map[string]Extender
+ UnbindFns map[string]Unbinder
+ CloseFns map[string]Closer
+ Quit chan bool
+ EnforceLDAP bool
+ Stats *Stats
+}
+
+type Stats struct {
+ Conns int
+ Binds int
+ Unbinds int
+ Searches int
+ statsMutex sync.Mutex
+}
+
+type ServerSearchResult struct {
+ Entries []*Entry
+ Referrals []string
+ Controls []Control
+ ResultCode LDAPResultCode
+}
+
+//
+func NewServer() *Server {
+ s := new(Server)
+ s.Quit = make(chan bool)
+
+ d := defaultHandler{}
+ s.BindFns = make(map[string]Binder)
+ s.SearchFns = make(map[string]Searcher)
+ s.AddFns = make(map[string]Adder)
+ s.ModifyFns = make(map[string]Modifier)
+ s.DeleteFns = make(map[string]Deleter)
+ s.ModifyDNFns = make(map[string]ModifyDNr)
+ s.CompareFns = make(map[string]Comparer)
+ s.AbandonFns = make(map[string]Abandoner)
+ s.ExtendedFns = make(map[string]Extender)
+ s.UnbindFns = make(map[string]Unbinder)
+ s.CloseFns = make(map[string]Closer)
+ s.BindFunc("", d)
+ s.SearchFunc("", d)
+ s.AddFunc("", d)
+ s.ModifyFunc("", d)
+ s.DeleteFunc("", d)
+ s.ModifyDNFunc("", d)
+ s.CompareFunc("", d)
+ s.AbandonFunc("", d)
+ s.ExtendedFunc("", d)
+ s.UnbindFunc("", d)
+ s.CloseFunc("", d)
+ s.Stats = nil
+ return s
+}
+func (server *Server) BindFunc(baseDN string, f Binder) {
+ server.BindFns[baseDN] = f
+}
+func (server *Server) SearchFunc(baseDN string, f Searcher) {
+ server.SearchFns[baseDN] = f
+}
+func (server *Server) AddFunc(baseDN string, f Adder) {
+ server.AddFns[baseDN] = f
+}
+func (server *Server) ModifyFunc(baseDN string, f Modifier) {
+ server.ModifyFns[baseDN] = f
+}
+func (server *Server) DeleteFunc(baseDN string, f Deleter) {
+ server.DeleteFns[baseDN] = f
+}
+func (server *Server) ModifyDNFunc(baseDN string, f ModifyDNr) {
+ server.ModifyDNFns[baseDN] = f
+}
+func (server *Server) CompareFunc(baseDN string, f Comparer) {
+ server.CompareFns[baseDN] = f
+}
+func (server *Server) AbandonFunc(baseDN string, f Abandoner) {
+ server.AbandonFns[baseDN] = f
+}
+func (server *Server) ExtendedFunc(baseDN string, f Extender) {
+ server.ExtendedFns[baseDN] = f
+}
+func (server *Server) UnbindFunc(baseDN string, f Unbinder) {
+ server.UnbindFns[baseDN] = f
+}
+func (server *Server) CloseFunc(baseDN string, f Closer) {
+ server.CloseFns[baseDN] = f
+}
+func (server *Server) QuitChannel(quit chan bool) {
+ server.Quit = quit
+}
+
+func (server *Server) ListenAndServeTLS(listenString string, certFile string, keyFile string) error {
+ cert, err := tls.LoadX509KeyPair(certFile, keyFile)
+ if err != nil {
+ return err
+ }
+ tlsConfig := tls.Config{Certificates: []tls.Certificate{cert}}
+ tlsConfig.ServerName = "localhost"
+ ln, err := tls.Listen("tcp", listenString, &tlsConfig)
+ if err != nil {
+ return err
+ }
+ err = server.serve(ln)
+ if err != nil {
+ return err
+ }
+ return nil
+}
+
+func (server *Server) SetStats(enable bool) {
+ if enable {
+ server.Stats = &Stats{}
+ } else {
+ server.Stats = nil
+ }
+}
+
+func (server *Server) GetStats() Stats {
+ defer func() {
+ server.Stats.statsMutex.Unlock()
+ }()
+ server.Stats.statsMutex.Lock()
+ return *server.Stats
+}
+
+func (server *Server) ListenAndServe(listenString string) error {
+ ln, err := net.Listen("tcp", listenString)
+ if err != nil {
+ return err
+ }
+ err = server.serve(ln)
+ if err != nil {
+ return err
+ }
+ return nil
+}
+
+func (server *Server) serve(ln net.Listener) error {
+ newConn := make(chan net.Conn)
+ go func() {
+ for {
+ conn, err := ln.Accept()
+ if err != nil {
+ if !strings.HasSuffix(err.Error(), "use of closed network connection") {
+ log.Printf("Error accepting network connection: %s", err.Error())
+ }
+ break
+ }
+ newConn <- conn
+ }
+ }()
+
+listener:
+ for {
+ select {
+ case c := <-newConn:
+ server.Stats.countConns(1)
+ go server.handleConnection(c)
+ case <-server.Quit:
+ ln.Close()
+ break listener
+ }
+ }
+ return nil
+}
+
+//
+func (server *Server) handleConnection(conn net.Conn) {
+ boundDN := "" // "" == anonymous
+
+handler:
+ for {
+ // read incoming LDAP packet
+ packet, err := ber.ReadPacket(conn)
+ if err == io.EOF { // Client closed connection
+ break
+ } else if err != nil {
+ log.Printf("handleConnection ber.ReadPacket ERROR: %s", err.Error())
+ break
+ }
+
+ // sanity check this packet
+ if len(packet.Children) < 2 {
+ log.Print("len(packet.Children) < 2")
+ break
+ }
+ // check the message ID and ClassType
+ messageID, ok := packet.Children[0].Value.(uint64)
+ if !ok {
+ log.Print("malformed messageID")
+ break
+ }
+ req := packet.Children[1]
+ if req.ClassType != ber.ClassApplication {
+ log.Print("req.ClassType != ber.ClassApplication")
+ break
+ }
+ // handle controls if present
+ controls := []Control{}
+ if len(packet.Children) > 2 {
+ for _, child := range packet.Children[2].Children {
+ controls = append(controls, DecodeControl(child))
+ }
+ }
+
+ //log.Printf("DEBUG: handling operation: %s [%d]", ApplicationMap[req.Tag], req.Tag)
+ //ber.PrintPacket(packet) // DEBUG
+
+ // dispatch the LDAP operation
+ switch req.Tag { // ldap op code
+ default:
+ responsePacket := encodeLDAPResponse(messageID, ApplicationAddResponse, LDAPResultOperationsError, "Unsupported operation: add")
+ if err = sendPacket(conn, responsePacket); err != nil {
+ log.Printf("sendPacket error %s", err.Error())
+ }
+ log.Printf("Unhandled operation: %s [%d]", ApplicationMap[req.Tag], req.Tag)
+ break handler
+
+ case ApplicationBindRequest:
+ server.Stats.countBinds(1)
+ ldapResultCode := HandleBindRequest(req, server.BindFns, conn)
+ if ldapResultCode == LDAPResultSuccess {
+ boundDN, ok = req.Children[1].Value.(string)
+ if !ok {
+ log.Printf("Malformed Bind DN")
+ break handler
+ }
+ }
+ responsePacket := encodeBindResponse(messageID, ldapResultCode)
+ if err = sendPacket(conn, responsePacket); err != nil {
+ log.Printf("sendPacket error %s", err.Error())
+ break handler
+ }
+ case ApplicationSearchRequest:
+ server.Stats.countSearches(1)
+ if err := HandleSearchRequest(req, &controls, messageID, boundDN, server, conn); err != nil {
+ log.Printf("handleSearchRequest error %s", err.Error()) // TODO: make this more testable/better err handling - stop using log, stop using breaks?
+ e := err.(*Error)
+ if err = sendPacket(conn, encodeSearchDone(messageID, e.ResultCode)); err != nil {
+ log.Printf("sendPacket error %s", err.Error())
+ break handler
+ }
+ break handler
+ } else {
+ if err = sendPacket(conn, encodeSearchDone(messageID, LDAPResultSuccess)); err != nil {
+ log.Printf("sendPacket error %s", err.Error())
+ break handler
+ }
+ }
+ case ApplicationUnbindRequest:
+ server.Stats.countUnbinds(1)
+ break handler // simply disconnect
+ case ApplicationExtendedRequest:
+ ldapResultCode := HandleExtendedRequest(req, boundDN, server.ExtendedFns, conn)
+ responsePacket := encodeLDAPResponse(messageID, ApplicationExtendedResponse, ldapResultCode, LDAPResultCodeMap[ldapResultCode])
+ if err = sendPacket(conn, responsePacket); err != nil {
+ log.Printf("sendPacket error %s", err.Error())
+ break handler
+ }
+ case ApplicationAbandonRequest:
+ HandleAbandonRequest(req, boundDN, server.AbandonFns, conn)
+ break handler
+
+ case ApplicationAddRequest:
+ ldapResultCode := HandleAddRequest(req, boundDN, server.AddFns, conn)
+ responsePacket := encodeLDAPResponse(messageID, ApplicationAddResponse, ldapResultCode, LDAPResultCodeMap[ldapResultCode])
+ if err = sendPacket(conn, responsePacket); err != nil {
+ log.Printf("sendPacket error %s", err.Error())
+ break handler
+ }
+ case ApplicationModifyRequest:
+ ldapResultCode := HandleModifyRequest(req, boundDN, server.ModifyFns, conn)
+ responsePacket := encodeLDAPResponse(messageID, ApplicationModifyResponse, ldapResultCode, LDAPResultCodeMap[ldapResultCode])
+ if err = sendPacket(conn, responsePacket); err != nil {
+ log.Printf("sendPacket error %s", err.Error())
+ break handler
+ }
+ case ApplicationDelRequest:
+ ldapResultCode := HandleDeleteRequest(req, boundDN, server.DeleteFns, conn)
+ responsePacket := encodeLDAPResponse(messageID, ApplicationDelResponse, ldapResultCode, LDAPResultCodeMap[ldapResultCode])
+ if err = sendPacket(conn, responsePacket); err != nil {
+ log.Printf("sendPacket error %s", err.Error())
+ break handler
+ }
+ case ApplicationModifyDNRequest:
+ ldapResultCode := HandleModifyDNRequest(req, boundDN, server.ModifyDNFns, conn)
+ responsePacket := encodeLDAPResponse(messageID, ApplicationModifyDNResponse, ldapResultCode, LDAPResultCodeMap[ldapResultCode])
+ if err = sendPacket(conn, responsePacket); err != nil {
+ log.Printf("sendPacket error %s", err.Error())
+ break handler
+ }
+ case ApplicationCompareRequest:
+ ldapResultCode := HandleCompareRequest(req, boundDN, server.CompareFns, conn)
+ responsePacket := encodeLDAPResponse(messageID, ApplicationCompareResponse, ldapResultCode, LDAPResultCodeMap[ldapResultCode])
+ if err = sendPacket(conn, responsePacket); err != nil {
+ log.Printf("sendPacket error %s", err.Error())
+ break handler
+ }
+ }
+ }
+
+ for _, c := range server.CloseFns {
+ c.Close(boundDN, conn)
+ }
+
+ conn.Close()
+}
+
+//
+func sendPacket(conn net.Conn, packet *ber.Packet) error {
+ _, err := conn.Write(packet.Bytes())
+ if err != nil {
+ log.Printf("Error Sending Message: %s", err.Error())
+ return err
+ }
+ return nil
+}
+
+//
+func routeFunc(dn string, funcNames []string) string {
+ bestPick := ""
+ for _, fn := range funcNames {
+ if strings.HasSuffix(dn, fn) {
+ l := len(strings.Split(bestPick, ","))
+ if bestPick == "" {
+ l = 0
+ }
+ if len(strings.Split(fn, ",")) > l {
+ bestPick = fn
+ }
+ }
+ }
+ return bestPick
+}
+
+//
+func encodeLDAPResponse(messageID uint64, responseType uint8, ldapResultCode LDAPResultCode, message string) *ber.Packet {
+ responsePacket := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Response")
+ responsePacket.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, messageID, "Message ID"))
+ reponse := ber.Encode(ber.ClassApplication, ber.TypeConstructed, responseType, nil, ApplicationMap[responseType])
+ reponse.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagEnumerated, uint64(ldapResultCode), "resultCode: "))
+ reponse.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, "", "matchedDN: "))
+ reponse.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, message, "errorMessage: "))
+ responsePacket.AppendChild(reponse)
+ return responsePacket
+}
+
+//
+type defaultHandler struct {
+}
+
+func (h defaultHandler) Bind(bindDN, bindSimplePw string, conn net.Conn) (LDAPResultCode, error) {
+ return LDAPResultInvalidCredentials, nil
+}
+func (h defaultHandler) Search(boundDN string, req SearchRequest, conn net.Conn) (ServerSearchResult, error) {
+ return ServerSearchResult{make([]*Entry, 0), []string{}, []Control{}, LDAPResultSuccess}, nil
+}
+func (h defaultHandler) Add(boundDN string, req AddRequest, conn net.Conn) (LDAPResultCode, error) {
+ return LDAPResultInsufficientAccessRights, nil
+}
+func (h defaultHandler) Modify(boundDN string, req ModifyRequest, conn net.Conn) (LDAPResultCode, error) {
+ return LDAPResultInsufficientAccessRights, nil
+}
+func (h defaultHandler) Delete(boundDN, deleteDN string, conn net.Conn) (LDAPResultCode, error) {
+ return LDAPResultInsufficientAccessRights, nil
+}
+func (h defaultHandler) ModifyDN(boundDN string, req ModifyDNRequest, conn net.Conn) (LDAPResultCode, error) {
+ return LDAPResultInsufficientAccessRights, nil
+}
+func (h defaultHandler) Compare(boundDN string, req CompareRequest, conn net.Conn) (LDAPResultCode, error) {
+ return LDAPResultInsufficientAccessRights, nil
+}
+func (h defaultHandler) Abandon(boundDN string, conn net.Conn) error {
+ return nil
+}
+func (h defaultHandler) Extended(boundDN string, req ExtendedRequest, conn net.Conn) (LDAPResultCode, error) {
+ return LDAPResultProtocolError, nil
+}
+func (h defaultHandler) Unbind(boundDN string, conn net.Conn) (LDAPResultCode, error) {
+ return LDAPResultSuccess, nil
+}
+func (h defaultHandler) Close(boundDN string, conn net.Conn) error {
+ conn.Close()
+ return nil
+}
+
+//
+func (stats *Stats) countConns(delta int) {
+ if stats != nil {
+ stats.statsMutex.Lock()
+ stats.Conns += delta
+ stats.statsMutex.Unlock()
+ }
+}
+func (stats *Stats) countBinds(delta int) {
+ if stats != nil {
+ stats.statsMutex.Lock()
+ stats.Binds += delta
+ stats.statsMutex.Unlock()
+ }
+}
+func (stats *Stats) countUnbinds(delta int) {
+ if stats != nil {
+ stats.statsMutex.Lock()
+ stats.Unbinds += delta
+ stats.statsMutex.Unlock()
+ }
+}
+func (stats *Stats) countSearches(delta int) {
+ if stats != nil {
+ stats.statsMutex.Lock()
+ stats.Searches += delta
+ stats.statsMutex.Unlock()
+ }
+}
+
+//
diff --git a/server_bind.go b/server_bind.go
new file mode 100644
index 0000000..5a80bf5
--- /dev/null
+++ b/server_bind.go
@@ -0,0 +1,73 @@
+package ldap
+
+import (
+ "github.com/nmcclain/asn1-ber"
+ "log"
+ "net"
+)
+
+func HandleBindRequest(req *ber.Packet, fns map[string]Binder, conn net.Conn) (resultCode LDAPResultCode) {
+ defer func() {
+ if r := recover(); r != nil {
+ resultCode = LDAPResultOperationsError
+ }
+ }()
+
+ // we only support ldapv3
+ ldapVersion, ok := req.Children[0].Value.(uint64)
+ if !ok {
+ return LDAPResultProtocolError
+ }
+ if ldapVersion != 3 {
+ log.Printf("Unsupported LDAP version: %d", ldapVersion)
+ return LDAPResultInappropriateAuthentication
+ }
+
+ // auth types
+ bindDN, ok := req.Children[1].Value.(string)
+ if !ok {
+ return LDAPResultProtocolError
+ }
+ bindAuth := req.Children[2]
+ switch bindAuth.Tag {
+ default:
+ log.Print("Unknown LDAP authentication method")
+ return LDAPResultInappropriateAuthentication
+ case LDAPBindAuthSimple:
+ if len(req.Children) == 3 {
+ fnNames := []string{}
+ for k := range fns {
+ fnNames = append(fnNames, k)
+ }
+ fn := routeFunc(bindDN, fnNames)
+ resultCode, err := fns[fn].Bind(bindDN, bindAuth.Data.String(), conn)
+ if err != nil {
+ log.Printf("BindFn Error %s", err.Error())
+ return LDAPResultOperationsError
+ }
+ return resultCode
+ } else {
+ log.Print("Simple bind request has wrong # children. len(req.Children) != 3")
+ return LDAPResultInappropriateAuthentication
+ }
+ case LDAPBindAuthSASL:
+ log.Print("SASL authentication is not supported")
+ return LDAPResultInappropriateAuthentication
+ }
+ return LDAPResultOperationsError
+}
+
+func encodeBindResponse(messageID uint64, ldapResultCode LDAPResultCode) *ber.Packet {
+ responsePacket := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Response")
+ responsePacket.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, messageID, "Message ID"))
+
+ bindReponse := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationBindResponse, nil, "Bind Response")
+ bindReponse.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagEnumerated, uint64(ldapResultCode), "resultCode: "))
+ bindReponse.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, "", "matchedDN: "))
+ bindReponse.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, "", "errorMessage: "))
+
+ responsePacket.AppendChild(bindReponse)
+
+ // ber.PrintPacket(responsePacket)
+ return responsePacket
+}
diff --git a/server_modify.go b/server_modify.go
new file mode 100644
index 0000000..0dca219
--- /dev/null
+++ b/server_modify.go
@@ -0,0 +1,231 @@
+package ldap
+
+import (
+ "github.com/nmcclain/asn1-ber"
+ "log"
+ "net"
+)
+
+func HandleAddRequest(req *ber.Packet, boundDN string, fns map[string]Adder, conn net.Conn) (resultCode LDAPResultCode) {
+ if len(req.Children) != 2 {
+ return LDAPResultProtocolError
+ }
+ var ok bool
+ addReq := AddRequest{}
+ addReq.dn, ok = req.Children[0].Value.(string)
+ if !ok {
+ return LDAPResultProtocolError
+ }
+ addReq.attributes = []Attribute{}
+ for _, attr := range req.Children[1].Children {
+ if len(attr.Children) != 2 {
+ return LDAPResultProtocolError
+ }
+
+ a := Attribute{}
+ a.attrType, ok = attr.Children[0].Value.(string)
+ if !ok {
+ return LDAPResultProtocolError
+ }
+ a.attrVals = []string{}
+ for _, val := range attr.Children[1].Children {
+ v, ok := val.Value.(string)
+ if !ok {
+ return LDAPResultProtocolError
+ }
+ a.attrVals = append(a.attrVals, v)
+ }
+ addReq.attributes = append(addReq.attributes, a)
+ }
+ fnNames := []string{}
+ for k := range fns {
+ fnNames = append(fnNames, k)
+ }
+ fn := routeFunc(boundDN, fnNames)
+ resultCode, err := fns[fn].Add(boundDN, addReq, conn)
+ if err != nil {
+ log.Printf("AddFn Error %s", err.Error())
+ return LDAPResultOperationsError
+ }
+ return resultCode
+}
+
+func HandleDeleteRequest(req *ber.Packet, boundDN string, fns map[string]Deleter, conn net.Conn) (resultCode LDAPResultCode) {
+ deleteDN := ber.DecodeString(req.Data.Bytes())
+ fnNames := []string{}
+ for k := range fns {
+ fnNames = append(fnNames, k)
+ }
+ fn := routeFunc(boundDN, fnNames)
+ resultCode, err := fns[fn].Delete(boundDN, deleteDN, conn)
+ if err != nil {
+ log.Printf("DeleteFn Error %s", err.Error())
+ return LDAPResultOperationsError
+ }
+ return resultCode
+}
+
+func HandleModifyRequest(req *ber.Packet, boundDN string, fns map[string]Modifier, conn net.Conn) (resultCode LDAPResultCode) {
+ if len(req.Children) != 2 {
+ return LDAPResultProtocolError
+ }
+ var ok bool
+ modReq := ModifyRequest{}
+ modReq.dn, ok = req.Children[0].Value.(string)
+ if !ok {
+ return LDAPResultProtocolError
+ }
+ for _, change := range req.Children[1].Children {
+ if len(change.Children) != 2 {
+ return LDAPResultProtocolError
+ }
+ attr := PartialAttribute{}
+ attrs := change.Children[1].Children
+ if len(attrs) != 2 {
+ return LDAPResultProtocolError
+ }
+ attr.attrType, ok = attrs[0].Value.(string)
+ if !ok {
+ return LDAPResultProtocolError
+ }
+ for _, val := range attrs[1].Children {
+ v, ok := val.Value.(string)
+ if !ok {
+ return LDAPResultProtocolError
+ }
+ attr.attrVals = append(attr.attrVals, v)
+ }
+ op, ok := change.Children[0].Value.(uint64)
+ if !ok {
+ return LDAPResultProtocolError
+ }
+ switch op {
+ default:
+ log.Printf("Unrecognized Modify attribute %d", op)
+ return LDAPResultProtocolError
+ case AddAttribute:
+ modReq.Add(attr.attrType, attr.attrVals)
+ case DeleteAttribute:
+ modReq.Delete(attr.attrType, attr.attrVals)
+ case ReplaceAttribute:
+ modReq.Replace(attr.attrType, attr.attrVals)
+ }
+ }
+ fnNames := []string{}
+ for k := range fns {
+ fnNames = append(fnNames, k)
+ }
+ fn := routeFunc(boundDN, fnNames)
+ resultCode, err := fns[fn].Modify(boundDN, modReq, conn)
+ if err != nil {
+ log.Printf("ModifyFn Error %s", err.Error())
+ return LDAPResultOperationsError
+ }
+ return resultCode
+}
+
+func HandleCompareRequest(req *ber.Packet, boundDN string, fns map[string]Comparer, conn net.Conn) (resultCode LDAPResultCode) {
+ if len(req.Children) != 2 {
+ return LDAPResultProtocolError
+ }
+ var ok bool
+ compReq := CompareRequest{}
+ compReq.dn, ok = req.Children[0].Value.(string)
+ if !ok {
+ return LDAPResultProtocolError
+ }
+ ava := req.Children[1]
+ if len(ava.Children) != 2 {
+ return LDAPResultProtocolError
+ }
+ attr, ok := ava.Children[0].Value.(string)
+ if !ok {
+ return LDAPResultProtocolError
+ }
+ val, ok := ava.Children[1].Value.(string)
+ if !ok {
+ return LDAPResultProtocolError
+ }
+ compReq.ava = []AttributeValueAssertion{AttributeValueAssertion{attr, val}}
+ fnNames := []string{}
+ for k := range fns {
+ fnNames = append(fnNames, k)
+ }
+ fn := routeFunc(boundDN, fnNames)
+ resultCode, err := fns[fn].Compare(boundDN, compReq, conn)
+ if err != nil {
+ log.Printf("CompareFn Error %s", err.Error())
+ return LDAPResultOperationsError
+ }
+ return resultCode
+}
+
+func HandleExtendedRequest(req *ber.Packet, boundDN string, fns map[string]Extender, conn net.Conn) (resultCode LDAPResultCode) {
+ if len(req.Children) != 1 && len(req.Children) != 2 {
+ return LDAPResultProtocolError
+ }
+ name := ber.DecodeString(req.Children[0].Data.Bytes())
+ var val string
+ if len(req.Children) == 2 {
+ val = ber.DecodeString(req.Children[1].Data.Bytes())
+ }
+ extReq := ExtendedRequest{name, val}
+ fnNames := []string{}
+ for k := range fns {
+ fnNames = append(fnNames, k)
+ }
+ fn := routeFunc(boundDN, fnNames)
+ resultCode, err := fns[fn].Extended(boundDN, extReq, conn)
+ if err != nil {
+ log.Printf("ExtendedFn Error %s", err.Error())
+ return LDAPResultOperationsError
+ }
+ return resultCode
+}
+
+func HandleAbandonRequest(req *ber.Packet, boundDN string, fns map[string]Abandoner, conn net.Conn) error {
+ fnNames := []string{}
+ for k := range fns {
+ fnNames = append(fnNames, k)
+ }
+ fn := routeFunc(boundDN, fnNames)
+ err := fns[fn].Abandon(boundDN, conn)
+ return err
+}
+
+func HandleModifyDNRequest(req *ber.Packet, boundDN string, fns map[string]ModifyDNr, conn net.Conn) (resultCode LDAPResultCode) {
+ if len(req.Children) != 3 && len(req.Children) != 4 {
+ return LDAPResultProtocolError
+ }
+ var ok bool
+ mdnReq := ModifyDNRequest{}
+ mdnReq.dn, ok = req.Children[0].Value.(string)
+ if !ok {
+ return LDAPResultProtocolError
+ }
+ mdnReq.newrdn, ok = req.Children[1].Value.(string)
+ if !ok {
+ return LDAPResultProtocolError
+ }
+ mdnReq.deleteoldrdn, ok = req.Children[2].Value.(bool)
+ if !ok {
+ return LDAPResultProtocolError
+ }
+ if len(req.Children) == 4 {
+ mdnReq.newSuperior, ok = req.Children[3].Value.(string)
+ if !ok {
+ return LDAPResultProtocolError
+ }
+ }
+ fnNames := []string{}
+ for k := range fns {
+ fnNames = append(fnNames, k)
+ }
+ fn := routeFunc(boundDN, fnNames)
+ resultCode, err := fns[fn].ModifyDN(boundDN, mdnReq, conn)
+ if err != nil {
+ log.Printf("ModifyDN Error %s", err.Error())
+ return LDAPResultOperationsError
+ }
+ return resultCode
+}
diff --git a/server_modify_test.go b/server_modify_test.go
new file mode 100644
index 0000000..d45b810
--- /dev/null
+++ b/server_modify_test.go
@@ -0,0 +1,191 @@
+package ldap
+
+import (
+ "net"
+ "os/exec"
+ "strings"
+ "testing"
+ "time"
+)
+
+//
+func TestAdd(t *testing.T) {
+ quit := make(chan bool)
+ done := make(chan bool)
+ go func() {
+ s := NewServer()
+ s.QuitChannel(quit)
+ s.BindFunc("", modifyTestHandler{})
+ s.AddFunc("", modifyTestHandler{})
+ if err := s.ListenAndServe(listenString); err != nil {
+ t.Errorf("s.ListenAndServe failed: %s", err.Error())
+ }
+ }()
+ go func() {
+ cmd := exec.Command("ldapadd", "-v", "-H", ldapURL, "-x", "-f", "tests/add.ldif")
+ out, _ := cmd.CombinedOutput()
+ if !strings.Contains(string(out), "modify complete") {
+ t.Errorf("ldapadd failed: %v", string(out))
+ }
+ cmd = exec.Command("ldapadd", "-v", "-H", ldapURL, "-x", "-f", "tests/add2.ldif")
+ out, _ = cmd.CombinedOutput()
+ if !strings.Contains(string(out), "ldap_add: Insufficient access") {
+ t.Errorf("ldapadd should have failed: %v", string(out))
+ }
+ if strings.Contains(string(out), "modify complete") {
+ t.Errorf("ldapadd should have failed: %v", string(out))
+ }
+ done <- true
+ }()
+ select {
+ case <-done:
+ case <-time.After(timeout):
+ t.Errorf("ldapadd command timed out")
+ }
+ quit <- true
+}
+
+//
+func TestDelete(t *testing.T) {
+ quit := make(chan bool)
+ done := make(chan bool)
+ go func() {
+ s := NewServer()
+ s.QuitChannel(quit)
+ s.BindFunc("", modifyTestHandler{})
+ s.DeleteFunc("", modifyTestHandler{})
+ if err := s.ListenAndServe(listenString); err != nil {
+ t.Errorf("s.ListenAndServe failed: %s", err.Error())
+ }
+ }()
+ go func() {
+ cmd := exec.Command("ldapdelete", "-v", "-H", ldapURL, "-x", "cn=Delete Me,dc=example,dc=com")
+ out, _ := cmd.CombinedOutput()
+ if !strings.Contains(string(out), "Delete Result: Success (0)") || !strings.Contains(string(out), "Additional info: Success") {
+ t.Errorf("ldapdelete failed: %v", string(out))
+ }
+ cmd = exec.Command("ldapdelete", "-v", "-H", ldapURL, "-x", "cn=Bob,dc=example,dc=com")
+ out, _ = cmd.CombinedOutput()
+ if strings.Contains(string(out), "Success") || !strings.Contains(string(out), "ldap_delete: Insufficient access") {
+ t.Errorf("ldapdelete should have failed: %v", string(out))
+ }
+ done <- true
+ }()
+ select {
+ case <-done:
+ case <-time.After(timeout):
+ t.Errorf("ldapdelete command timed out")
+ }
+ quit <- true
+}
+
+func TestModify(t *testing.T) {
+ quit := make(chan bool)
+ done := make(chan bool)
+ go func() {
+ s := NewServer()
+ s.QuitChannel(quit)
+ s.BindFunc("", modifyTestHandler{})
+ s.ModifyFunc("", modifyTestHandler{})
+ if err := s.ListenAndServe(listenString); err != nil {
+ t.Errorf("s.ListenAndServe failed: %s", err.Error())
+ }
+ }()
+ go func() {
+ cmd := exec.Command("ldapmodify", "-v", "-H", ldapURL, "-x", "-f", "tests/modify.ldif")
+ out, _ := cmd.CombinedOutput()
+ if !strings.Contains(string(out), "modify complete") {
+ t.Errorf("ldapmodify failed: %v", string(out))
+ }
+ cmd = exec.Command("ldapmodify", "-v", "-H", ldapURL, "-x", "-f", "tests/modify2.ldif")
+ out, _ = cmd.CombinedOutput()
+ if !strings.Contains(string(out), "ldap_modify: Insufficient access") || strings.Contains(string(out), "modify complete") {
+ t.Errorf("ldapmodify should have failed: %v", string(out))
+ }
+ done <- true
+ }()
+ select {
+ case <-done:
+ case <-time.After(timeout):
+ t.Errorf("ldapadd command timed out")
+ }
+ quit <- true
+}
+
+/*
+func TestModifyDN(t *testing.T) {
+ quit := make(chan bool)
+ done := make(chan bool)
+ go func() {
+ s := NewServer()
+ s.QuitChannel(quit)
+ s.BindFunc("", modifyTestHandler{})
+ s.AddFunc("", modifyTestHandler{})
+ if err := s.ListenAndServe(listenString); err != nil {
+ t.Errorf("s.ListenAndServe failed: %s", err.Error())
+ }
+ }()
+ go func() {
+ cmd := exec.Command("ldapadd", "-v", "-H", ldapURL, "-x", "-f", "tests/add.ldif")
+ //ldapmodrdn -H ldap://localhost:3389 -x "uid=babs,dc=example,dc=com" "uid=babsy,dc=example,dc=com"
+ out, _ := cmd.CombinedOutput()
+ if !strings.Contains(string(out), "modify complete") {
+ t.Errorf("ldapadd failed: %v", string(out))
+ }
+ cmd = exec.Command("ldapadd", "-v", "-H", ldapURL, "-x", "-f", "tests/add2.ldif")
+ out, _ = cmd.CombinedOutput()
+ if !strings.Contains(string(out), "ldap_add: Insufficient access") {
+ t.Errorf("ldapadd should have failed: %v", string(out))
+ }
+ if strings.Contains(string(out), "modify complete") {
+ t.Errorf("ldapadd should have failed: %v", string(out))
+ }
+ done <- true
+ }()
+ select {
+ case <-done:
+ case <-time.After(timeout):
+ t.Errorf("ldapadd command timed out")
+ }
+ quit <- true
+}
+*/
+
+//
+type modifyTestHandler struct {
+}
+
+func (h modifyTestHandler) Bind(bindDN, bindSimplePw string, conn net.Conn) (LDAPResultCode, error) {
+ if bindDN == "" && bindSimplePw == "" {
+ return LDAPResultSuccess, nil
+ }
+ return LDAPResultInvalidCredentials, nil
+}
+func (h modifyTestHandler) Add(boundDN string, req AddRequest, conn net.Conn) (LDAPResultCode, error) {
+ // only succeed on expected contents of add.ldif:
+ if len(req.attributes) == 5 && req.dn == "cn=Barbara Jensen,dc=example,dc=com" &&
+ req.attributes[2].attrType == "sn" && len(req.attributes[2].attrVals) == 1 &&
+ req.attributes[2].attrVals[0] == "Jensen" {
+ return LDAPResultSuccess, nil
+ }
+ return LDAPResultInsufficientAccessRights, nil
+}
+func (h modifyTestHandler) Delete(boundDN, deleteDN string, conn net.Conn) (LDAPResultCode, error) {
+ // only succeed on expected deleteDN
+ if deleteDN == "cn=Delete Me,dc=example,dc=com" {
+ return LDAPResultSuccess, nil
+ }
+ return LDAPResultInsufficientAccessRights, nil
+}
+func (h modifyTestHandler) Modify(boundDN string, req ModifyRequest, conn net.Conn) (LDAPResultCode, error) {
+ // only succeed on expected contents of modify.ldif:
+ if req.dn == "cn=testy,dc=example,dc=com" && len(req.addAttributes) == 1 &&
+ len(req.deleteAttributes) == 3 && len(req.replaceAttributes) == 2 &&
+ req.deleteAttributes[2].attrType == "details" && len(req.deleteAttributes[2].attrVals) == 0 {
+ return LDAPResultSuccess, nil
+ }
+ return LDAPResultInsufficientAccessRights, nil
+}
+func (h modifyTestHandler) ModifyDN(boundDN string, req ModifyDNRequest, conn net.Conn) (LDAPResultCode, error) {
+ return LDAPResultInsufficientAccessRights, nil
+}
diff --git a/server_search.go b/server_search.go
new file mode 100644
index 0000000..3fc91c5
--- /dev/null
+++ b/server_search.go
@@ -0,0 +1,217 @@
+package ldap
+
+import (
+ "errors"
+ "fmt"
+ "github.com/nmcclain/asn1-ber"
+ "net"
+ "strings"
+)
+
+func HandleSearchRequest(req *ber.Packet, controls *[]Control, messageID uint64, boundDN string, server *Server, conn net.Conn) (resultErr error) {
+ defer func() {
+ if r := recover(); r != nil {
+ resultErr = NewError(LDAPResultOperationsError, fmt.Errorf("Search function panic: %s", r))
+ }
+ }()
+
+ searchReq, err := parseSearchRequest(boundDN, req, controls)
+ if err != nil {
+ return NewError(LDAPResultOperationsError, err)
+ }
+
+ filterPacket, err := CompileFilter(searchReq.Filter)
+ if err != nil {
+ return NewError(LDAPResultOperationsError, err)
+ }
+
+ fnNames := []string{}
+ for k := range server.SearchFns {
+ fnNames = append(fnNames, k)
+ }
+ fn := routeFunc(searchReq.BaseDN, fnNames)
+ searchResp, err := server.SearchFns[fn].Search(boundDN, searchReq, conn)
+ if err != nil {
+ return NewError(searchResp.ResultCode, err)
+ }
+
+ if server.EnforceLDAP {
+ if searchReq.DerefAliases != NeverDerefAliases { // [-a {never|always|search|find}
+ // TODO: Server DerefAliases not supported: RFC4511 4.5.1.3
+ }
+ if searchReq.TimeLimit > 0 {
+ // TODO: Server TimeLimit not implemented
+ }
+ }
+
+ i := 0
+ for _, entry := range searchResp.Entries {
+ if server.EnforceLDAP {
+ // filter
+ keep, resultCode := ServerApplyFilter(filterPacket, entry)
+ if resultCode != LDAPResultSuccess {
+ return NewError(resultCode, errors.New("ServerApplyFilter error"))
+ }
+ if !keep {
+ continue
+ }
+
+ // constrained search scope
+ switch searchReq.Scope {
+ case ScopeWholeSubtree: // The scope is constrained to the entry named by baseObject and to all its subordinates.
+ case ScopeBaseObject: // The scope is constrained to the entry named by baseObject.
+ if entry.DN != searchReq.BaseDN {
+ continue
+ }
+ case ScopeSingleLevel: // The scope is constrained to the immediate subordinates of the entry named by baseObject.
+ parts := strings.Split(entry.DN, ",")
+ if len(parts) < 2 && entry.DN != searchReq.BaseDN {
+ continue
+ }
+ if dn := strings.Join(parts[1:], ","); dn != searchReq.BaseDN {
+ continue
+ }
+ }
+
+ // attributes
+ if len(searchReq.Attributes) > 1 || (len(searchReq.Attributes) == 1 && len(searchReq.Attributes[0]) > 0) {
+ entry, err = filterAttributes(entry, searchReq.Attributes)
+ if err != nil {
+ return NewError(LDAPResultOperationsError, err)
+ }
+ }
+
+ // size limit
+ if searchReq.SizeLimit > 0 && i >= searchReq.SizeLimit {
+ break
+ }
+ i++
+ }
+
+ // respond
+ responsePacket := encodeSearchResponse(messageID, searchReq, entry)
+ if err = sendPacket(conn, responsePacket); err != nil {
+ return NewError(LDAPResultOperationsError, err)
+ }
+ }
+ return nil
+}
+
+/////////////////////////
+func parseSearchRequest(boundDN string, req *ber.Packet, controls *[]Control) (SearchRequest, error) {
+ if len(req.Children) != 8 {
+ return SearchRequest{}, NewError(LDAPResultOperationsError, errors.New("Bad search request"))
+ }
+
+ // Parse the request
+ baseObject, ok := req.Children[0].Value.(string)
+ if !ok {
+ return SearchRequest{}, NewError(LDAPResultProtocolError, errors.New("Bad search request"))
+ }
+ s, ok := req.Children[1].Value.(uint64)
+ if !ok {
+ return SearchRequest{}, NewError(LDAPResultProtocolError, errors.New("Bad search request"))
+ }
+ scope := int(s)
+ d, ok := req.Children[2].Value.(uint64)
+ if !ok {
+ return SearchRequest{}, NewError(LDAPResultProtocolError, errors.New("Bad search request"))
+ }
+ derefAliases := int(d)
+ s, ok = req.Children[3].Value.(uint64)
+ if !ok {
+ return SearchRequest{}, NewError(LDAPResultProtocolError, errors.New("Bad search request"))
+ }
+ sizeLimit := int(s)
+ t, ok := req.Children[4].Value.(uint64)
+ if !ok {
+ return SearchRequest{}, NewError(LDAPResultProtocolError, errors.New("Bad search request"))
+ }
+ timeLimit := int(t)
+ typesOnly := false
+ if req.Children[5].Value != nil {
+ typesOnly, ok = req.Children[5].Value.(bool)
+ if !ok {
+ return SearchRequest{}, NewError(LDAPResultProtocolError, errors.New("Bad search request"))
+ }
+ }
+ filter, err := DecompileFilter(req.Children[6])
+ if err != nil {
+ return SearchRequest{}, err
+ }
+ attributes := []string{}
+ for _, attr := range req.Children[7].Children {
+ a, ok := attr.Value.(string)
+ if !ok {
+ return SearchRequest{}, NewError(LDAPResultProtocolError, errors.New("Bad search request"))
+ }
+ attributes = append(attributes, a)
+ }
+ searchReq := SearchRequest{baseObject, scope,
+ derefAliases, sizeLimit, timeLimit,
+ typesOnly, filter, attributes, *controls}
+
+ return searchReq, nil
+}
+
+/////////////////////////
+func filterAttributes(entry *Entry, attributes []string) (*Entry, error) {
+ // only return requested attributes
+ newAttributes := []*EntryAttribute{}
+
+ for _, attr := range entry.Attributes {
+ for _, requested := range attributes {
+ if strings.ToLower(attr.Name) == strings.ToLower(requested) {
+ newAttributes = append(newAttributes, attr)
+ }
+ }
+ }
+ entry.Attributes = newAttributes
+
+ return entry, nil
+}
+
+/////////////////////////
+func encodeSearchResponse(messageID uint64, req SearchRequest, res *Entry) *ber.Packet {
+ responsePacket := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Response")
+ responsePacket.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, messageID, "Message ID"))
+
+ searchEntry := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationSearchResultEntry, nil, "Search Result Entry")
+ searchEntry.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, res.DN, "Object Name"))
+
+ attrs := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Attributes:")
+ for _, attribute := range res.Attributes {
+ attrs.AppendChild(encodeSearchAttribute(attribute.Name, attribute.Values))
+ }
+
+ searchEntry.AppendChild(attrs)
+ responsePacket.AppendChild(searchEntry)
+
+ return responsePacket
+}
+
+func encodeSearchAttribute(name string, values []string) *ber.Packet {
+ packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Attribute")
+ packet.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, name, "Attribute Name"))
+
+ valuesPacket := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSet, nil, "Attribute Values")
+ for _, value := range values {
+ valuesPacket.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, value, "Attribute Value"))
+ }
+
+ packet.AppendChild(valuesPacket)
+
+ return packet
+}
+
+func encodeSearchDone(messageID uint64, ldapResultCode LDAPResultCode) *ber.Packet {
+ responsePacket := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Response")
+ responsePacket.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, messageID, "Message ID"))
+ donePacket := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationSearchResultDone, nil, "Search result done")
+ donePacket.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagEnumerated, uint64(ldapResultCode), "resultCode: "))
+ donePacket.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, "", "matchedDN: "))
+ donePacket.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, "", "errorMessage: "))
+ responsePacket.AppendChild(donePacket)
+
+ return responsePacket
+}
diff --git a/server_search_test.go b/server_search_test.go
new file mode 100644
index 0000000..8b8fa65
--- /dev/null
+++ b/server_search_test.go
@@ -0,0 +1,453 @@
+package ldap
+
+import (
+ "os/exec"
+ "strings"
+ "testing"
+ "time"
+)
+
+//
+func TestSearchSimpleOK(t *testing.T) {
+ quit := make(chan bool)
+ done := make(chan bool)
+ go func() {
+ s := NewServer()
+ s.QuitChannel(quit)
+ s.SearchFunc("", searchSimple{})
+ s.BindFunc("", bindSimple{})
+ if err := s.ListenAndServe(listenString); err != nil {
+ t.Errorf("s.ListenAndServe failed: %s", err.Error())
+ }
+ }()
+
+ serverBaseDN := "o=testers,c=test"
+
+ go func() {
+ cmd := exec.Command("ldapsearch", "-H", ldapURL, "-x",
+ "-b", serverBaseDN, "-D", "cn=testy,"+serverBaseDN, "-w", "iLike2test")
+ out, _ := cmd.CombinedOutput()
+ if !strings.Contains(string(out), "dn: cn=ned,o=testers,c=test") {
+ t.Errorf("ldapsearch failed: %v", string(out))
+ }
+ if !strings.Contains(string(out), "uidNumber: 5000") {
+ t.Errorf("ldapsearch failed: %v", string(out))
+ }
+ if !strings.Contains(string(out), "result: 0 Success") {
+ t.Errorf("ldapsearch failed: %v", string(out))
+ }
+ if !strings.Contains(string(out), "numResponses: 4") {
+ t.Errorf("ldapsearch failed: %v", string(out))
+ }
+ done <- true
+ }()
+
+ select {
+ case <-done:
+ case <-time.After(timeout):
+ t.Errorf("ldapsearch command timed out")
+ }
+ quit <- true
+}
+
+func TestSearchSizelimit(t *testing.T) {
+ quit := make(chan bool)
+ done := make(chan bool)
+ go func() {
+ s := NewServer()
+ s.EnforceLDAP = true
+ s.QuitChannel(quit)
+ s.SearchFunc("", searchSimple{})
+ s.BindFunc("", bindSimple{})
+ if err := s.ListenAndServe(listenString); err != nil {
+ t.Errorf("s.ListenAndServe failed: %s", err.Error())
+ }
+ }()
+
+ go func() {
+ cmd := exec.Command("ldapsearch", "-H", ldapURL, "-x",
+ "-b", serverBaseDN, "-D", "cn=testy,"+serverBaseDN, "-w", "iLike2test") // no limit for this test
+ out, _ := cmd.CombinedOutput()
+ if !strings.Contains(string(out), "result: 0 Success") {
+ t.Errorf("ldapsearch failed: %v", string(out))
+ }
+ if !strings.Contains(string(out), "numEntries: 3") {
+ t.Errorf("ldapsearch sizelimit unlimited failed - not enough entries: %v", string(out))
+ }
+
+ cmd = exec.Command("ldapsearch", "-H", ldapURL, "-x",
+ "-b", serverBaseDN, "-D", "cn=testy,"+serverBaseDN, "-w", "iLike2test", "-z", "9") // effectively no limit for this test
+ out, _ = cmd.CombinedOutput()
+ if !strings.Contains(string(out), "result: 0 Success") {
+ t.Errorf("ldapsearch failed: %v", string(out))
+ }
+ if !strings.Contains(string(out), "numEntries: 3") {
+ t.Errorf("ldapsearch sizelimit 9 failed - not enough entries: %v", string(out))
+ }
+
+ cmd = exec.Command("ldapsearch", "-H", ldapURL, "-x",
+ "-b", serverBaseDN, "-D", "cn=testy,"+serverBaseDN, "-w", "iLike2test", "-z", "2")
+ out, _ = cmd.CombinedOutput()
+ if !strings.Contains(string(out), "result: 0 Success") {
+ t.Errorf("ldapsearch failed: %v", string(out))
+ }
+ if !strings.Contains(string(out), "numEntries: 2") {
+ t.Errorf("ldapsearch sizelimit 2 failed - too many entries: %v", string(out))
+ }
+
+ cmd = exec.Command("ldapsearch", "-H", ldapURL, "-x",
+ "-b", serverBaseDN, "-D", "cn=testy,"+serverBaseDN, "-w", "iLike2test", "-z", "1")
+ out, _ = cmd.CombinedOutput()
+ if !strings.Contains(string(out), "result: 0 Success") {
+ t.Errorf("ldapsearch failed: %v", string(out))
+ }
+ if !strings.Contains(string(out), "numEntries: 1") {
+ t.Errorf("ldapsearch sizelimit 1 failed - too many entries: %v", string(out))
+ }
+
+ cmd = exec.Command("ldapsearch", "-H", ldapURL, "-x",
+ "-b", serverBaseDN, "-D", "cn=testy,"+serverBaseDN, "-w", "iLike2test", "-z", "0")
+ out, _ = cmd.CombinedOutput()
+ if !strings.Contains(string(out), "result: 0 Success") {
+ t.Errorf("ldapsearch failed: %v", string(out))
+ }
+ if !strings.Contains(string(out), "numEntries: 3") {
+ t.Errorf("ldapsearch sizelimit 0 failed - wrong number of entries: %v", string(out))
+ }
+
+ cmd = exec.Command("ldapsearch", "-H", ldapURL, "-x",
+ "-b", serverBaseDN, "-D", "cn=testy,"+serverBaseDN, "-w", "iLike2test", "-z", "1", "(uid=trent)")
+ out, _ = cmd.CombinedOutput()
+ if !strings.Contains(string(out), "result: 0 Success") {
+ t.Errorf("ldapsearch failed: %v", string(out))
+ }
+ if !strings.Contains(string(out), "numEntries: 1") {
+ t.Errorf("ldapsearch sizelimit 1 with filter failed - wrong number of entries: %v", string(out))
+ }
+
+ cmd = exec.Command("ldapsearch", "-H", ldapURL, "-x",
+ "-b", serverBaseDN, "-D", "cn=testy,"+serverBaseDN, "-w", "iLike2test", "-z", "0", "(uid=trent)")
+ out, _ = cmd.CombinedOutput()
+ if !strings.Contains(string(out), "result: 0 Success") {
+ t.Errorf("ldapsearch failed: %v", string(out))
+ }
+ if !strings.Contains(string(out), "numEntries: 1") {
+ t.Errorf("ldapsearch sizelimit 0 with filter failed - wrong number of entries: %v", string(out))
+ }
+ done <- true
+ }()
+
+ select {
+ case <-done:
+ case <-time.After(timeout):
+ t.Errorf("ldapsearch command timed out")
+ }
+ quit <- true
+}
+
+/////////////////////////
+func TestBindSearchMulti(t *testing.T) {
+ quit := make(chan bool)
+ done := make(chan bool)
+ go func() {
+ s := NewServer()
+ s.QuitChannel(quit)
+ s.BindFunc("", bindSimple{})
+ s.BindFunc("c=testz", bindSimple2{})
+ s.SearchFunc("", searchSimple{})
+ s.SearchFunc("c=testz", searchSimple2{})
+ if err := s.ListenAndServe(listenString); err != nil {
+ t.Errorf("s.ListenAndServe failed: %s", err.Error())
+ }
+ }()
+
+ go func() {
+ cmd := exec.Command("ldapsearch", "-H", ldapURL, "-x", "-b", "o=testers,c=test",
+ "-D", "cn=testy,o=testers,c=test", "-w", "iLike2test", "cn=ned")
+ out, _ := cmd.CombinedOutput()
+ if !strings.Contains(string(out), "result: 0 Success") {
+ t.Errorf("error routing default bind/search functions: %v", string(out))
+ }
+ if !strings.Contains(string(out), "dn: cn=ned,o=testers,c=test") {
+ t.Errorf("search default routing failed: %v", string(out))
+ }
+ cmd = exec.Command("ldapsearch", "-H", ldapURL, "-x", "-b", "o=testers,c=testz",
+ "-D", "cn=testy,o=testers,c=testz", "-w", "ZLike2test", "cn=hamburger")
+ out, _ = cmd.CombinedOutput()
+ if !strings.Contains(string(out), "result: 0 Success") {
+ t.Errorf("error routing custom bind/search functions: %v", string(out))
+ }
+ if !strings.Contains(string(out), "dn: cn=hamburger,o=testers,c=testz") {
+ t.Errorf("search custom routing failed: %v", string(out))
+ }
+ done <- true
+ }()
+
+ select {
+ case <-done:
+ case <-time.After(timeout):
+ t.Errorf("ldapsearch command timed out")
+ }
+
+ quit <- true
+}
+
+/////////////////////////
+func TestSearchPanic(t *testing.T) {
+ quit := make(chan bool)
+ done := make(chan bool)
+ go func() {
+ s := NewServer()
+ s.QuitChannel(quit)
+ s.SearchFunc("", searchPanic{})
+ s.BindFunc("", bindAnonOK{})
+ if err := s.ListenAndServe(listenString); err != nil {
+ t.Errorf("s.ListenAndServe failed: %s", err.Error())
+ }
+ }()
+
+ go func() {
+ cmd := exec.Command("ldapsearch", "-H", ldapURL, "-x", "-b", "o=testers,c=test")
+ out, _ := cmd.CombinedOutput()
+ if !strings.Contains(string(out), "result: 1 Operations error") {
+ t.Errorf("ldapsearch should have returned operations error due to panic: %v", string(out))
+ }
+ done <- true
+ }()
+
+ select {
+ case <-done:
+ case <-time.After(timeout):
+ t.Errorf("ldapsearch command timed out")
+ }
+ quit <- true
+}
+
+/////////////////////////
+type compileSearchFilterTest struct {
+ name string
+ filterStr string
+ numResponses string
+}
+
+var searchFilterTestFilters = []compileSearchFilterTest{
+ compileSearchFilterTest{name: "equalityOk", filterStr: "(uid=ned)", numResponses: "2"},
+ compileSearchFilterTest{name: "equalityNo", filterStr: "(uid=foo)", numResponses: "1"},
+ compileSearchFilterTest{name: "equalityOk", filterStr: "(objectclass=posixaccount)", numResponses: "4"},
+ compileSearchFilterTest{name: "presentEmptyOk", filterStr: "", numResponses: "4"},
+ compileSearchFilterTest{name: "presentOk", filterStr: "(objectclass=*)", numResponses: "4"},
+ compileSearchFilterTest{name: "presentOk", filterStr: "(description=*)", numResponses: "3"},
+ compileSearchFilterTest{name: "presentNo", filterStr: "(foo=*)", numResponses: "1"},
+ compileSearchFilterTest{name: "andOk", filterStr: "(&(uid=ned)(objectclass=posixaccount))", numResponses: "2"},
+ compileSearchFilterTest{name: "andNo", filterStr: "(&(uid=ned)(objectclass=posixgroup))", numResponses: "1"},
+ compileSearchFilterTest{name: "andNo", filterStr: "(&(uid=ned)(uid=trent))", numResponses: "1"},
+ compileSearchFilterTest{name: "orOk", filterStr: "(|(uid=ned)(uid=trent))", numResponses: "3"},
+ compileSearchFilterTest{name: "orOk", filterStr: "(|(uid=ned)(objectclass=posixaccount))", numResponses: "4"},
+ compileSearchFilterTest{name: "orNo", filterStr: "(|(uid=foo)(objectclass=foo))", numResponses: "1"},
+ compileSearchFilterTest{name: "andOrOk", filterStr: "(&(|(uid=ned)(uid=trent))(objectclass=posixaccount))", numResponses: "3"},
+ compileSearchFilterTest{name: "notOk", filterStr: "(!(uid=ned))", numResponses: "3"},
+ compileSearchFilterTest{name: "notOk", filterStr: "(!(uid=foo))", numResponses: "4"},
+ compileSearchFilterTest{name: "notAndOrOk", filterStr: "(&(|(uid=ned)(uid=trent))(!(objectclass=posixgroup)))", numResponses: "3"},
+ /*
+ compileSearchFilterTest{filterStr: "(sn=Mill*)", filterType: FilterSubstrings},
+ compileSearchFilterTest{filterStr: "(sn=*Mill)", filterType: FilterSubstrings},
+ compileSearchFilterTest{filterStr: "(sn=*Mill*)", filterType: FilterSubstrings},
+ compileSearchFilterTest{filterStr: "(sn>=Miller)", filterType: FilterGreaterOrEqual},
+ compileSearchFilterTest{filterStr: "(sn<=Miller)", filterType: FilterLessOrEqual},
+ compileSearchFilterTest{filterStr: "(sn~=Miller)", filterType: FilterApproxMatch},
+ */
+}
+
+/////////////////////////
+func TestSearchFiltering(t *testing.T) {
+ quit := make(chan bool)
+ done := make(chan bool)
+ go func() {
+ s := NewServer()
+ s.EnforceLDAP = true
+ s.QuitChannel(quit)
+ s.SearchFunc("", searchSimple{})
+ s.BindFunc("", bindSimple{})
+ if err := s.ListenAndServe(listenString); err != nil {
+ t.Errorf("s.ListenAndServe failed: %s", err.Error())
+ }
+ }()
+
+ for _, i := range searchFilterTestFilters {
+ t.Log(i.name)
+
+ go func() {
+ cmd := exec.Command("ldapsearch", "-H", ldapURL, "-x",
+ "-b", serverBaseDN, "-D", "cn=testy,"+serverBaseDN, "-w", "iLike2test", i.filterStr)
+ out, _ := cmd.CombinedOutput()
+ if !strings.Contains(string(out), "numResponses: "+i.numResponses) {
+ t.Errorf("ldapsearch failed - expected numResponses==%d: %v", i.numResponses, string(out))
+ }
+ done <- true
+ }()
+
+ select {
+ case <-done:
+ case <-time.After(timeout):
+ t.Errorf("ldapsearch command timed out")
+ }
+ }
+ quit <- true
+}
+
+/////////////////////////
+func TestSearchAttributes(t *testing.T) {
+ quit := make(chan bool)
+ done := make(chan bool)
+ go func() {
+ s := NewServer()
+ s.EnforceLDAP = true
+ s.QuitChannel(quit)
+ s.SearchFunc("", searchSimple{})
+ s.BindFunc("", bindSimple{})
+ if err := s.ListenAndServe(listenString); err != nil {
+ t.Errorf("s.ListenAndServe failed: %s", err.Error())
+ }
+ }()
+
+ go func() {
+ filterString := ""
+ cmd := exec.Command("ldapsearch", "-H", ldapURL, "-x",
+ "-b", serverBaseDN, "-D", "cn=testy,"+serverBaseDN, "-w", "iLike2test", filterString, "cn")
+ out, _ := cmd.CombinedOutput()
+
+ if !strings.Contains(string(out), "dn: cn=ned,o=testers,c=test") {
+ t.Errorf("ldapsearch failed - missing requested DN attribute: %v", string(out))
+ }
+ if !strings.Contains(string(out), "cn: ned") {
+ t.Errorf("ldapsearch failed - missing requested CN attribute: %v", string(out))
+ }
+ if strings.Contains(string(out), "uidNumber") {
+ t.Errorf("ldapsearch failed - uidNumber attr should not be displayed: %v", string(out))
+ }
+ if strings.Contains(string(out), "accountstatus") {
+ t.Errorf("ldapsearch failed - accountstatus attr should not be displayed: %v", string(out))
+ }
+ done <- true
+ }()
+
+ select {
+ case <-done:
+ case <-time.After(timeout):
+ t.Errorf("ldapsearch command timed out")
+ }
+ quit <- true
+}
+
+/////////////////////////
+func TestSearchScope(t *testing.T) {
+ quit := make(chan bool)
+ done := make(chan bool)
+ go func() {
+ s := NewServer()
+ s.EnforceLDAP = true
+ s.QuitChannel(quit)
+ s.SearchFunc("", searchSimple{})
+ s.BindFunc("", bindSimple{})
+ if err := s.ListenAndServe(listenString); err != nil {
+ t.Errorf("s.ListenAndServe failed: %s", err.Error())
+ }
+ }()
+
+ go func() {
+ cmd := exec.Command("ldapsearch", "-H", ldapURL, "-x",
+ "-b", "c=test", "-D", "cn=testy,o=testers,c=test", "-w", "iLike2test", "-s", "sub", "cn=trent")
+ out, _ := cmd.CombinedOutput()
+ if !strings.Contains(string(out), "dn: cn=trent,o=testers,c=test") {
+ t.Errorf("ldapsearch 'sub' scope failed - didn't find expected DN: %v", string(out))
+ }
+
+ cmd = exec.Command("ldapsearch", "-H", ldapURL, "-x",
+ "-b", "o=testers,c=test", "-D", "cn=testy,o=testers,c=test", "-w", "iLike2test", "-s", "one", "cn=trent")
+ out, _ = cmd.CombinedOutput()
+ if !strings.Contains(string(out), "dn: cn=trent,o=testers,c=test") {
+ t.Errorf("ldapsearch 'one' scope failed - didn't find expected DN: %v", string(out))
+ }
+ cmd = exec.Command("ldapsearch", "-H", ldapURL, "-x",
+ "-b", "c=test", "-D", "cn=testy,o=testers,c=test", "-w", "iLike2test", "-s", "one", "cn=trent")
+ out, _ = cmd.CombinedOutput()
+ if strings.Contains(string(out), "dn: cn=trent,o=testers,c=test") {
+ t.Errorf("ldapsearch 'one' scope failed - found unexpected DN: %v", string(out))
+ }
+
+ cmd = exec.Command("ldapsearch", "-H", ldapURL, "-x",
+ "-b", "cn=trent,o=testers,c=test", "-D", "cn=testy,o=testers,c=test", "-w", "iLike2test", "-s", "base", "cn=trent")
+ out, _ = cmd.CombinedOutput()
+ if !strings.Contains(string(out), "dn: cn=trent,o=testers,c=test") {
+ t.Errorf("ldapsearch 'base' scope failed - didn't find expected DN: %v", string(out))
+ }
+ cmd = exec.Command("ldapsearch", "-H", ldapURL, "-x",
+ "-b", "o=testers,c=test", "-D", "cn=testy,o=testers,c=test", "-w", "iLike2test", "-s", "base", "cn=trent")
+ out, _ = cmd.CombinedOutput()
+ if strings.Contains(string(out), "dn: cn=trent,o=testers,c=test") {
+ t.Errorf("ldapsearch 'base' scope failed - found unexpected DN: %v", string(out))
+ }
+
+ done <- true
+ }()
+
+ select {
+ case <-done:
+ case <-time.After(timeout):
+ t.Errorf("ldapsearch command timed out")
+ }
+ quit <- true
+}
+
+func TestSearchControls(t *testing.T) {
+ quit := make(chan bool)
+ done := make(chan bool)
+ go func() {
+ s := NewServer()
+ s.QuitChannel(quit)
+ s.SearchFunc("", searchControls{})
+ s.BindFunc("", bindSimple{})
+ if err := s.ListenAndServe(listenString); err != nil {
+ t.Errorf("s.ListenAndServe failed: %s", err.Error())
+ }
+ }()
+
+ serverBaseDN := "o=testers,c=test"
+
+ go func() {
+ cmd := exec.Command("ldapsearch", "-H", ldapURL, "-x",
+ "-b", serverBaseDN, "-D", "cn=testy,"+serverBaseDN, "-w", "iLike2test", "-e", "1.2.3.4.5")
+ out, _ := cmd.CombinedOutput()
+ if !strings.Contains(string(out), "dn: cn=hamburger,o=testers,c=testz") {
+ t.Errorf("ldapsearch with control failed: %v", string(out))
+ }
+ if !strings.Contains(string(out), "result: 0 Success") {
+ t.Errorf("ldapsearch with control failed: %v", string(out))
+ }
+ if !strings.Contains(string(out), "numResponses: 2") {
+ t.Errorf("ldapsearch with control failed: %v", string(out))
+ }
+
+ cmd = exec.Command("ldapsearch", "-H", ldapURL, "-x",
+ "-b", serverBaseDN, "-D", "cn=testy,"+serverBaseDN, "-w", "iLike2test")
+ out, _ = cmd.CombinedOutput()
+ if strings.Contains(string(out), "dn: cn=hamburger,o=testers,c=testz") {
+ t.Errorf("ldapsearch without control failed: %v", string(out))
+ }
+ if !strings.Contains(string(out), "result: 0 Success") {
+ t.Errorf("ldapsearch without control failed: %v", string(out))
+ }
+ if !strings.Contains(string(out), "numResponses: 1") {
+ t.Errorf("ldapsearch without control failed: %v", string(out))
+ }
+
+ done <- true
+ }()
+
+ select {
+ case <-done:
+ case <-time.After(timeout):
+ t.Errorf("ldapsearch command timed out")
+ }
+ quit <- true
+}
diff --git a/server_test.go b/server_test.go
new file mode 100644
index 0000000..88c47bf
--- /dev/null
+++ b/server_test.go
@@ -0,0 +1,410 @@
+package ldap
+
+import (
+ "bytes"
+ "log"
+ "net"
+ "os/exec"
+ "strings"
+ "testing"
+ "time"
+)
+
+var listenString = "localhost:3389"
+var ldapURL = "ldap://" + listenString
+var timeout = 400 * time.Millisecond
+var serverBaseDN = "o=testers,c=test"
+
+/////////////////////////
+func TestBindAnonOK(t *testing.T) {
+ quit := make(chan bool)
+ done := make(chan bool)
+ go func() {
+ s := NewServer()
+ s.QuitChannel(quit)
+ s.BindFunc("", bindAnonOK{})
+ if err := s.ListenAndServe(listenString); err != nil {
+ t.Errorf("s.ListenAndServe failed: %s", err.Error())
+ }
+ }()
+
+ go func() {
+ cmd := exec.Command("ldapsearch", "-H", ldapURL, "-x", "-b", "o=testers,c=test")
+ out, _ := cmd.CombinedOutput()
+ if !strings.Contains(string(out), "result: 0 Success") {
+ t.Errorf("ldapsearch failed: %v", string(out))
+ }
+ done <- true
+ }()
+
+ select {
+ case <-done:
+ case <-time.After(timeout):
+ t.Errorf("ldapsearch command timed out")
+ }
+ quit <- true
+}
+
+/////////////////////////
+func TestBindAnonFail(t *testing.T) {
+ quit := make(chan bool)
+ done := make(chan bool)
+ go func() {
+ s := NewServer()
+ s.QuitChannel(quit)
+ if err := s.ListenAndServe(listenString); err != nil {
+ t.Errorf("s.ListenAndServe failed: %s", err.Error())
+ }
+ }()
+
+ time.Sleep(timeout)
+ go func() {
+ cmd := exec.Command("ldapsearch", "-H", ldapURL, "-x", "-b", "o=testers,c=test")
+ out, _ := cmd.CombinedOutput()
+ if !strings.Contains(string(out), "ldap_bind: Invalid credentials (49)") {
+ t.Errorf("ldapsearch failed: %v", string(out))
+ }
+ done <- true
+ }()
+
+ select {
+ case <-done:
+ case <-time.After(timeout):
+ t.Errorf("ldapsearch command timed out")
+ }
+ time.Sleep(timeout)
+ quit <- true
+}
+
+/////////////////////////
+func TestBindSimpleOK(t *testing.T) {
+ quit := make(chan bool)
+ done := make(chan bool)
+ go func() {
+ s := NewServer()
+ s.QuitChannel(quit)
+ s.SearchFunc("", searchSimple{})
+ s.BindFunc("", bindSimple{})
+ if err := s.ListenAndServe(listenString); err != nil {
+ t.Errorf("s.ListenAndServe failed: %s", err.Error())
+ }
+ }()
+
+ serverBaseDN := "o=testers,c=test"
+
+ go func() {
+ cmd := exec.Command("ldapsearch", "-H", ldapURL, "-x",
+ "-b", serverBaseDN, "-D", "cn=testy,"+serverBaseDN, "-w", "iLike2test")
+ out, _ := cmd.CombinedOutput()
+ if !strings.Contains(string(out), "result: 0 Success") {
+ t.Errorf("ldapsearch failed: %v", string(out))
+ }
+ done <- true
+ }()
+
+ select {
+ case <-done:
+ case <-time.After(timeout):
+ t.Errorf("ldapsearch command timed out")
+ }
+ quit <- true
+}
+
+/////////////////////////
+func TestBindSimpleFailBadPw(t *testing.T) {
+ quit := make(chan bool)
+ done := make(chan bool)
+ go func() {
+ s := NewServer()
+ s.QuitChannel(quit)
+ s.BindFunc("", bindSimple{})
+ if err := s.ListenAndServe(listenString); err != nil {
+ t.Errorf("s.ListenAndServe failed: %s", err.Error())
+ }
+ }()
+
+ serverBaseDN := "o=testers,c=test"
+
+ go func() {
+ cmd := exec.Command("ldapsearch", "-H", ldapURL, "-x",
+ "-b", serverBaseDN, "-D", "cn=testy,"+serverBaseDN, "-w", "BADPassword")
+ out, _ := cmd.CombinedOutput()
+ if !strings.Contains(string(out), "ldap_bind: Invalid credentials (49)") {
+ t.Errorf("ldapsearch succeeded - should have failed: %v", string(out))
+ }
+ done <- true
+ }()
+
+ select {
+ case <-done:
+ case <-time.After(timeout):
+ t.Errorf("ldapsearch command timed out")
+ }
+ quit <- true
+}
+
+/////////////////////////
+func TestBindSimpleFailBadDn(t *testing.T) {
+ quit := make(chan bool)
+ done := make(chan bool)
+ go func() {
+ s := NewServer()
+ s.QuitChannel(quit)
+ s.BindFunc("", bindSimple{})
+ if err := s.ListenAndServe(listenString); err != nil {
+ t.Errorf("s.ListenAndServe failed: %s", err.Error())
+ }
+ }()
+
+ serverBaseDN := "o=testers,c=test"
+
+ go func() {
+ cmd := exec.Command("ldapsearch", "-H", ldapURL, "-x",
+ "-b", serverBaseDN, "-D", "cn=testoy,"+serverBaseDN, "-w", "iLike2test")
+ out, _ := cmd.CombinedOutput()
+ if string(out) != "ldap_bind: Invalid credentials (49)\n" {
+ t.Errorf("ldapsearch succeeded - should have failed: %v", string(out))
+ }
+ done <- true
+ }()
+
+ select {
+ case <-done:
+ case <-time.After(timeout):
+ t.Errorf("ldapsearch command timed out")
+ }
+ quit <- true
+}
+
+/////////////////////////
+func TestBindSSL(t *testing.T) {
+ ldapURLSSL := "ldaps://" + listenString
+ longerTimeout := 300 * time.Millisecond
+ quit := make(chan bool)
+ done := make(chan bool)
+ go func() {
+ s := NewServer()
+ s.QuitChannel(quit)
+ s.BindFunc("", bindAnonOK{})
+ if err := s.ListenAndServeTLS(listenString, "tests/cert_DONOTUSE.pem", "tests/key_DONOTUSE.pem"); err != nil {
+ t.Errorf("s.ListenAndServeTLS failed: %s", err.Error())
+ }
+ }()
+
+ go func() {
+ time.Sleep(longerTimeout * 2)
+ cmd := exec.Command("ldapsearch", "-H", ldapURLSSL, "-x", "-b", "o=testers,c=test")
+ out, _ := cmd.CombinedOutput()
+ if !strings.Contains(string(out), "result: 0 Success") {
+ t.Errorf("ldapsearch failed: %v", string(out))
+ }
+ done <- true
+ }()
+
+ select {
+ case <-done:
+ case <-time.After(longerTimeout * 2):
+ t.Errorf("ldapsearch command timed out")
+ }
+ quit <- true
+}
+
+/////////////////////////
+func TestBindPanic(t *testing.T) {
+ quit := make(chan bool)
+ done := make(chan bool)
+ go func() {
+ s := NewServer()
+ s.QuitChannel(quit)
+ s.BindFunc("", bindPanic{})
+ if err := s.ListenAndServe(listenString); err != nil {
+ t.Errorf("s.ListenAndServe failed: %s", err.Error())
+ }
+ }()
+
+ go func() {
+ cmd := exec.Command("ldapsearch", "-H", ldapURL, "-x", "-b", "o=testers,c=test")
+ out, _ := cmd.CombinedOutput()
+ if !strings.Contains(string(out), "ldap_bind: Operations error") {
+ t.Errorf("ldapsearch should have returned operations error due to panic: %v", string(out))
+ }
+ done <- true
+ }()
+
+ select {
+ case <-done:
+ case <-time.After(timeout):
+ t.Errorf("ldapsearch command timed out")
+ }
+ quit <- true
+}
+
+/////////////////////////
+type testStatsWriter struct {
+ buffer *bytes.Buffer
+}
+
+func (tsw testStatsWriter) Write(buf []byte) (int, error) {
+ tsw.buffer.Write(buf)
+ return len(buf), nil
+}
+
+func TestSearchStats(t *testing.T) {
+ w := testStatsWriter{&bytes.Buffer{}}
+ log.SetOutput(w)
+
+ quit := make(chan bool)
+ done := make(chan bool)
+ s := NewServer()
+
+ go func() {
+ s.QuitChannel(quit)
+ s.SearchFunc("", searchSimple{})
+ s.BindFunc("", bindAnonOK{})
+ s.SetStats(true)
+ if err := s.ListenAndServe(listenString); err != nil {
+ t.Errorf("s.ListenAndServe failed: %s", err.Error())
+ }
+ }()
+
+ go func() {
+ cmd := exec.Command("ldapsearch", "-H", ldapURL, "-x", "-b", "o=testers,c=test")
+ out, _ := cmd.CombinedOutput()
+ if !strings.Contains(string(out), "result: 0 Success") {
+ t.Errorf("ldapsearch failed: %v", string(out))
+ }
+ done <- true
+ }()
+
+ select {
+ case <-done:
+ case <-time.After(timeout):
+ t.Errorf("ldapsearch command timed out")
+ }
+
+ stats := s.GetStats()
+ log.Println(stats)
+ if stats.Conns != 1 || stats.Binds != 1 {
+ t.Errorf("Stats data missing or incorrect: %v", w.buffer.String())
+ }
+ quit <- true
+}
+
+/////////////////////////
+type bindAnonOK struct {
+}
+
+func (b bindAnonOK) Bind(bindDN, bindSimplePw string, conn net.Conn) (LDAPResultCode, error) {
+ if bindDN == "" && bindSimplePw == "" {
+ return LDAPResultSuccess, nil
+ }
+ return LDAPResultInvalidCredentials, nil
+}
+
+type bindSimple struct {
+}
+
+func (b bindSimple) Bind(bindDN, bindSimplePw string, conn net.Conn) (LDAPResultCode, error) {
+ if bindDN == "cn=testy,o=testers,c=test" && bindSimplePw == "iLike2test" {
+ return LDAPResultSuccess, nil
+ }
+ return LDAPResultInvalidCredentials, nil
+}
+
+type bindSimple2 struct {
+}
+
+func (b bindSimple2) Bind(bindDN, bindSimplePw string, conn net.Conn) (LDAPResultCode, error) {
+ if bindDN == "cn=testy,o=testers,c=testz" && bindSimplePw == "ZLike2test" {
+ return LDAPResultSuccess, nil
+ }
+ return LDAPResultInvalidCredentials, nil
+}
+
+type bindPanic struct {
+}
+
+func (b bindPanic) Bind(bindDN, bindSimplePw string, conn net.Conn) (LDAPResultCode, error) {
+ panic("test panic at the disco")
+ return LDAPResultInvalidCredentials, nil
+}
+
+type searchSimple struct {
+}
+
+func (s searchSimple) Search(boundDN string, searchReq SearchRequest, conn net.Conn) (ServerSearchResult, error) {
+ entries := []*Entry{
+ &Entry{"cn=ned,o=testers,c=test", []*EntryAttribute{
+ &EntryAttribute{"cn", []string{"ned"}},
+ &EntryAttribute{"o", []string{"ate"}},
+ &EntryAttribute{"uidNumber", []string{"5000"}},
+ &EntryAttribute{"accountstatus", []string{"active"}},
+ &EntryAttribute{"uid", []string{"ned"}},
+ &EntryAttribute{"description", []string{"ned via sa"}},
+ &EntryAttribute{"objectclass", []string{"posixaccount"}},
+ }},
+ &Entry{"cn=trent,o=testers,c=test", []*EntryAttribute{
+ &EntryAttribute{"cn", []string{"trent"}},
+ &EntryAttribute{"o", []string{"ate"}},
+ &EntryAttribute{"uidNumber", []string{"5005"}},
+ &EntryAttribute{"accountstatus", []string{"active"}},
+ &EntryAttribute{"uid", []string{"trent"}},
+ &EntryAttribute{"description", []string{"trent via sa"}},
+ &EntryAttribute{"objectclass", []string{"posixaccount"}},
+ }},
+ &Entry{"cn=randy,o=testers,c=test", []*EntryAttribute{
+ &EntryAttribute{"cn", []string{"randy"}},
+ &EntryAttribute{"o", []string{"ate"}},
+ &EntryAttribute{"uidNumber", []string{"5555"}},
+ &EntryAttribute{"accountstatus", []string{"active"}},
+ &EntryAttribute{"uid", []string{"randy"}},
+ &EntryAttribute{"objectclass", []string{"posixaccount"}},
+ }},
+ }
+ return ServerSearchResult{entries, []string{}, []Control{}, LDAPResultSuccess}, nil
+}
+
+type searchSimple2 struct {
+}
+
+func (s searchSimple2) Search(boundDN string, searchReq SearchRequest, conn net.Conn) (ServerSearchResult, error) {
+ entries := []*Entry{
+ &Entry{"cn=hamburger,o=testers,c=testz", []*EntryAttribute{
+ &EntryAttribute{"cn", []string{"hamburger"}},
+ &EntryAttribute{"o", []string{"testers"}},
+ &EntryAttribute{"uidNumber", []string{"5000"}},
+ &EntryAttribute{"accountstatus", []string{"active"}},
+ &EntryAttribute{"uid", []string{"hamburger"}},
+ &EntryAttribute{"objectclass", []string{"posixaccount"}},
+ }},
+ }
+ return ServerSearchResult{entries, []string{}, []Control{}, LDAPResultSuccess}, nil
+}
+
+type searchPanic struct {
+}
+
+func (s searchPanic) Search(boundDN string, searchReq SearchRequest, conn net.Conn) (ServerSearchResult, error) {
+ entries := []*Entry{}
+ panic("this is a test panic")
+ return ServerSearchResult{entries, []string{}, []Control{}, LDAPResultSuccess}, nil
+}
+
+type searchControls struct {
+}
+
+func (s searchControls) Search(boundDN string, searchReq SearchRequest, conn net.Conn) (ServerSearchResult, error) {
+ entries := []*Entry{}
+ if len(searchReq.Controls) == 1 && searchReq.Controls[0].GetControlType() == "1.2.3.4.5" {
+ newEntry := &Entry{"cn=hamburger,o=testers,c=testz", []*EntryAttribute{
+ &EntryAttribute{"cn", []string{"hamburger"}},
+ &EntryAttribute{"o", []string{"testers"}},
+ &EntryAttribute{"uidNumber", []string{"5000"}},
+ &EntryAttribute{"accountstatus", []string{"active"}},
+ &EntryAttribute{"uid", []string{"hamburger"}},
+ &EntryAttribute{"objectclass", []string{"posixaccount"}},
+ }}
+ entries = append(entries, newEntry)
+ }
+ return ServerSearchResult{entries, []string{}, []Control{}, LDAPResultSuccess}, nil
+}
diff --git a/tests/add.ldif b/tests/add.ldif
new file mode 100644
index 0000000..f8cdf71
--- /dev/null
+++ b/tests/add.ldif
@@ -0,0 +1,6 @@
+dn: cn=Barbara Jensen,dc=example,dc=com
+objectClass: person
+cn: Barbara Jensen
+sn: Jensen
+uid: bjensen
diff --git a/tests/add2.ldif b/tests/add2.ldif
new file mode 100644
index 0000000..ccb71ad
--- /dev/null
+++ b/tests/add2.ldif
@@ -0,0 +1,6 @@
+dn: cn=Big Bob,dc=example,dc=com
+objectClass: person
+cn: Big Bob
+sn: Bob
+uid: bob
diff --git a/tests/cert_DONOTUSE.pem b/tests/cert_DONOTUSE.pem
new file mode 100644
index 0000000..ee14324
--- /dev/null
+++ b/tests/cert_DONOTUSE.pem
@@ -0,0 +1,18 @@
+-----BEGIN CERTIFICATE-----
+MIIC9jCCAeCgAwIBAgIRAOG6xrSjAWQvJl9xFTIte/owCwYJKoZIhvcNAQELMBIx
+EDAOBgNVBAoTB0FjbWUgQ28wHhcNMTQwODIwMTY1MjQ4WhcNMTUwODIwMTY1MjQ4
+WjASMRAwDgYDVQQKEwdBY21lIENvMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIB
+CgKCAQEA30gcjawL7RXQ5B5IfAcCPsJkG3GBbfhkbBRI22VxktBoNvqh2TWyECG3
+WsB/N1WMmATnLxamBZ5mfouNbd120gbO1M06Ti57NP1YTmMp8AU18Dm4OjZ6IeQf
+ip1xYSSSb6UyucFN6zIt+5PY2o4DoGb6fSNKb1ybgu91LmC1O/TDlyYUWn2TtF73
+FOUwSt+A6t3/Jhjhlp4n5Oobw1rrAgf7DPhWFg0Thj1yknPzWALY2LPREOMWob0D
+EgR5C3WS2eYPyHkeMZWoSY6BiWTIU+hFqQUkdOvrWhflFoiZIsOl6iXmQpo2EQlg
+j3Oy2zyZk1ndAfHlFoAgPIIbnBc+2QIDAQABo0swSTAOBgNVHQ8BAf8EBAMCAKAw
+EwYDVR0lBAwwCgYIKwYBBQUHAwEwDAYDVR0TAQH/BAIwADAUBgNVHREEDTALggls
+b2NhbGhvc3QwCwYJKoZIhvcNAQELA4IBAQBB9xNt3rDrBA9tCLCjdlnIQuUu9Uf0
+tHsSH6keBkhEoAylzHjmkNlerhTaLkRgB0D8qjE5+1APz42TuRpHRunYHSTNN0aF
+N6zKlpXS0g+J/ViCh/Zw7xQI4mpSFqYzTgn4T733FqwLrmtKsj0IOOkDYSZc7qfh
+qwXp/SB1J0Kp8G8S3G73dCZZYuW8y/eYMEoSkjNwNLAXzEAmFkGd8f1xhWTvnOxz
+ZBbOOjggdRLxr7cMZ8GaVWFgEG93y3AYMhFxZYRwWTcWJvSTNP3xC/CWqxXkiKdO
+2BROqmTw8zdqjXCIbgX4B5G5njMq9fk0gc4SiTAQkCOF6Xo0wQUvBAbN
+-----END CERTIFICATE-----
diff --git a/tests/key_DONOTUSE.pem b/tests/key_DONOTUSE.pem
new file mode 100644
index 0000000..7feaa11
--- /dev/null
+++ b/tests/key_DONOTUSE.pem
@@ -0,0 +1,27 @@
+-----BEGIN RSA PRIVATE KEY-----
+MIIEoQIBAAKCAQEA30gcjawL7RXQ5B5IfAcCPsJkG3GBbfhkbBRI22VxktBoNvqh
+2TWyECG3WsB/N1WMmATnLxamBZ5mfouNbd120gbO1M06Ti57NP1YTmMp8AU18Dm4
+OjZ6IeQfip1xYSSSb6UyucFN6zIt+5PY2o4DoGb6fSNKb1ybgu91LmC1O/TDlyYU
+Wn2TtF73FOUwSt+A6t3/Jhjhlp4n5Oobw1rrAgf7DPhWFg0Thj1yknPzWALY2LPR
+EOMWob0DEgR5C3WS2eYPyHkeMZWoSY6BiWTIU+hFqQUkdOvrWhflFoiZIsOl6iXm
+Qpo2EQlgj3Oy2zyZk1ndAfHlFoAgPIIbnBc+2QIDAQABAoIBAGXECi+QEMd4QAMY
+wlS1JRLRqqrPavxiT/Lqs+I7NC6EClu0k/vZ+1Ra6aTVQ6ZGuZO3+F5/5h99eJ2I
+oWdHnxZOwApBl6d2i/U02wCvNbgNx+27gPoXRkcYIEAfTkPGVW/JTXtYXVkrP8YA
+NsA2JfT/un86jHyBKufcl/4RWcj/BddduEWZRAV7UH4iUVzr6kTHJHyIkrJrME4w
+0Njrd/+tvjX/AsBfj+4IqZo6yMSVQuGhnJYEWt+eU3+5OlLsWe8oy9PJN2hVPxvG
+bWaMx0/I88gSf9DQRnpIL2Kr715c4NUqy+82DP+Tg6h/lQ1oba68redtZoctDrmq
+IXox6PkCgYEA+fElqxWqdDTOw+yhD4CeZ6yxipI/iUQu2t7atRmTgi2WAgaT5aAP
+1lWhTmWPW7IzjYzFe/CK7OAe0P7JgmvBI2SMlSg/NWOx+pnTmOVO3buz46+B3VI9
+IFhMqAkVfnXukT0YgLsdvRTZb7irYeelVImZ0VVvn0+HzqZ1JUj8kwMCgYEA5LGK
+jZ0NmBQT/yxsp1GWlmIJnTizlxEGoc7ftL4SNGcshMKbSK6aYnUvjnnRdFwjQUu4
+mAN0LlDn7SalEL7PUnotUMNA278o4Zj88VcWQqjukkThDpOFVcreXuGkqIfxYyn6
+jJxLouF6L0zspEhFiEWjNYOVDzZw7Fh0tOCxkfMCf1L8voUPrIjo/74N02xSSEYk
+EM7xwCbTfLsvQ27eDxwqBqSlinWzr4564BQnpHHNuVBGbUu5kmcUAydhcYbcQESA
+Hi1oL5SKhY2vhZI+kPEOYaw3mebiZ2lV6B3i5kAW6B9RKdGUT0t4oLl3l2/qefqX
+tXrL40QCJBV5L2wxz6sCgYEAoDRXUTkSCtUV5Q3j15pqGVL4VTEhbdQ5hyR6xgzY
+h+k24JHLYjEeaZaaB/8CYbch413+JE9XFhMLRbBqtb5VUfvQvuDpEIdrRg58MzzE
+lVHuPn0OA74IC7+f42vCg2UoDkWcBOCAg8vcYkJLDBKs0veli5lv1EZY+NhGeWdm
+PU0CgYAhEHZnVC8DuKAUxuIXEpDil3F7iGYs1rAnGd3GkKofiait+9YlsCxFx/4F
+95VQjHm6Fdc+vwGUa2Z986wKmocWzVP3TbznMdbvr0/8LCOhQKDrtCkFWHtGsP/d
+PnCkwIdaTEen0E52PkMK8GNq6wjitINzRp5hpV23WFtGQxmmlA==
+-----END RSA PRIVATE KEY-----
diff --git a/tests/modify.ldif b/tests/modify.ldif
new file mode 100644
index 0000000..ac969cc
--- /dev/null
+++ b/tests/modify.ldif
@@ -0,0 +1,16 @@
+dn: cn=testy,dc=example,dc=com
+changetype: modify
+replace: mail
+-
+delete: manager
+-
+add: title
+title: Grand Poobah
+-
+delete: description
+-
+delete: details
+-
+replace: fullname
+fullname: Test Testerson
diff --git a/tests/modify2.ldif b/tests/modify2.ldif
new file mode 100644
index 0000000..794d7f4
--- /dev/null
+++ b/tests/modify2.ldif
@@ -0,0 +1,10 @@
+dn: cn=testo,dc=example,dc=com
+changetype: modify
+replace: mail
+-
+delete: manager
+-
+add: title
+title: Other Poobah
+-