aboutsummaryrefslogtreecommitdiff
path: root/server.go
diff options
context:
space:
mode:
Diffstat (limited to 'server.go')
-rw-r--r--server.go473
1 files changed, 473 insertions, 0 deletions
diff --git a/server.go b/server.go
new file mode 100644
index 0000000..dcb6406
--- /dev/null
+++ b/server.go
@@ -0,0 +1,473 @@
+package ldap
+
+import (
+ "crypto/tls"
+ "github.com/nmcclain/asn1-ber"
+ "io"
+ "log"
+ "net"
+ "strings"
+ "sync"
+)
+
+type Binder interface {
+ Bind(bindDN, bindSimplePw string, conn net.Conn) (LDAPResultCode, error)
+}
+type Searcher interface {
+ Search(boundDN string, req SearchRequest, conn net.Conn) (ServerSearchResult, error)
+}
+type Adder interface {
+ Add(boundDN string, req AddRequest, conn net.Conn) (LDAPResultCode, error)
+}
+type Modifier interface {
+ Modify(boundDN string, req ModifyRequest, conn net.Conn) (LDAPResultCode, error)
+}
+type Deleter interface {
+ Delete(boundDN, deleteDN string, conn net.Conn) (LDAPResultCode, error)
+}
+type ModifyDNr interface {
+ ModifyDN(boundDN string, req ModifyDNRequest, conn net.Conn) (LDAPResultCode, error)
+}
+type Comparer interface {
+ Compare(boundDN string, req CompareRequest, conn net.Conn) (LDAPResultCode, error)
+}
+type Abandoner interface {
+ Abandon(boundDN string, conn net.Conn) error
+}
+type Extender interface {
+ Extended(boundDN string, req ExtendedRequest, conn net.Conn) (LDAPResultCode, error)
+}
+type Unbinder interface {
+ Unbind(boundDN string, conn net.Conn) (LDAPResultCode, error)
+}
+type Closer interface {
+ Close(boundDN string, conn net.Conn) error
+}
+
+//
+type Server struct {
+ BindFns map[string]Binder
+ SearchFns map[string]Searcher
+ AddFns map[string]Adder
+ ModifyFns map[string]Modifier
+ DeleteFns map[string]Deleter
+ ModifyDNFns map[string]ModifyDNr
+ CompareFns map[string]Comparer
+ AbandonFns map[string]Abandoner
+ ExtendedFns map[string]Extender
+ UnbindFns map[string]Unbinder
+ CloseFns map[string]Closer
+ Quit chan bool
+ EnforceLDAP bool
+ Stats *Stats
+}
+
+type Stats struct {
+ Conns int
+ Binds int
+ Unbinds int
+ Searches int
+ statsMutex sync.Mutex
+}
+
+type ServerSearchResult struct {
+ Entries []*Entry
+ Referrals []string
+ Controls []Control
+ ResultCode LDAPResultCode
+}
+
+//
+func NewServer() *Server {
+ s := new(Server)
+ s.Quit = make(chan bool)
+
+ d := defaultHandler{}
+ s.BindFns = make(map[string]Binder)
+ s.SearchFns = make(map[string]Searcher)
+ s.AddFns = make(map[string]Adder)
+ s.ModifyFns = make(map[string]Modifier)
+ s.DeleteFns = make(map[string]Deleter)
+ s.ModifyDNFns = make(map[string]ModifyDNr)
+ s.CompareFns = make(map[string]Comparer)
+ s.AbandonFns = make(map[string]Abandoner)
+ s.ExtendedFns = make(map[string]Extender)
+ s.UnbindFns = make(map[string]Unbinder)
+ s.CloseFns = make(map[string]Closer)
+ s.BindFunc("", d)
+ s.SearchFunc("", d)
+ s.AddFunc("", d)
+ s.ModifyFunc("", d)
+ s.DeleteFunc("", d)
+ s.ModifyDNFunc("", d)
+ s.CompareFunc("", d)
+ s.AbandonFunc("", d)
+ s.ExtendedFunc("", d)
+ s.UnbindFunc("", d)
+ s.CloseFunc("", d)
+ s.Stats = nil
+ return s
+}
+func (server *Server) BindFunc(baseDN string, f Binder) {
+ server.BindFns[baseDN] = f
+}
+func (server *Server) SearchFunc(baseDN string, f Searcher) {
+ server.SearchFns[baseDN] = f
+}
+func (server *Server) AddFunc(baseDN string, f Adder) {
+ server.AddFns[baseDN] = f
+}
+func (server *Server) ModifyFunc(baseDN string, f Modifier) {
+ server.ModifyFns[baseDN] = f
+}
+func (server *Server) DeleteFunc(baseDN string, f Deleter) {
+ server.DeleteFns[baseDN] = f
+}
+func (server *Server) ModifyDNFunc(baseDN string, f ModifyDNr) {
+ server.ModifyDNFns[baseDN] = f
+}
+func (server *Server) CompareFunc(baseDN string, f Comparer) {
+ server.CompareFns[baseDN] = f
+}
+func (server *Server) AbandonFunc(baseDN string, f Abandoner) {
+ server.AbandonFns[baseDN] = f
+}
+func (server *Server) ExtendedFunc(baseDN string, f Extender) {
+ server.ExtendedFns[baseDN] = f
+}
+func (server *Server) UnbindFunc(baseDN string, f Unbinder) {
+ server.UnbindFns[baseDN] = f
+}
+func (server *Server) CloseFunc(baseDN string, f Closer) {
+ server.CloseFns[baseDN] = f
+}
+func (server *Server) QuitChannel(quit chan bool) {
+ server.Quit = quit
+}
+
+func (server *Server) ListenAndServeTLS(listenString string, certFile string, keyFile string) error {
+ cert, err := tls.LoadX509KeyPair(certFile, keyFile)
+ if err != nil {
+ return err
+ }
+ tlsConfig := tls.Config{Certificates: []tls.Certificate{cert}}
+ tlsConfig.ServerName = "localhost"
+ ln, err := tls.Listen("tcp", listenString, &tlsConfig)
+ if err != nil {
+ return err
+ }
+ err = server.serve(ln)
+ if err != nil {
+ return err
+ }
+ return nil
+}
+
+func (server *Server) SetStats(enable bool) {
+ if enable {
+ server.Stats = &Stats{}
+ } else {
+ server.Stats = nil
+ }
+}
+
+func (server *Server) GetStats() Stats {
+ defer func() {
+ server.Stats.statsMutex.Unlock()
+ }()
+ server.Stats.statsMutex.Lock()
+ return *server.Stats
+}
+
+func (server *Server) ListenAndServe(listenString string) error {
+ ln, err := net.Listen("tcp", listenString)
+ if err != nil {
+ return err
+ }
+ err = server.serve(ln)
+ if err != nil {
+ return err
+ }
+ return nil
+}
+
+func (server *Server) serve(ln net.Listener) error {
+ newConn := make(chan net.Conn)
+ go func() {
+ for {
+ conn, err := ln.Accept()
+ if err != nil {
+ if !strings.HasSuffix(err.Error(), "use of closed network connection") {
+ log.Printf("Error accepting network connection: %s", err.Error())
+ }
+ break
+ }
+ newConn <- conn
+ }
+ }()
+
+listener:
+ for {
+ select {
+ case c := <-newConn:
+ server.Stats.countConns(1)
+ go server.handleConnection(c)
+ case <-server.Quit:
+ ln.Close()
+ break listener
+ }
+ }
+ return nil
+}
+
+//
+func (server *Server) handleConnection(conn net.Conn) {
+ boundDN := "" // "" == anonymous
+
+handler:
+ for {
+ // read incoming LDAP packet
+ packet, err := ber.ReadPacket(conn)
+ if err == io.EOF { // Client closed connection
+ break
+ } else if err != nil {
+ log.Printf("handleConnection ber.ReadPacket ERROR: %s", err.Error())
+ break
+ }
+
+ // sanity check this packet
+ if len(packet.Children) < 2 {
+ log.Print("len(packet.Children) < 2")
+ break
+ }
+ // check the message ID and ClassType
+ messageID, ok := packet.Children[0].Value.(uint64)
+ if !ok {
+ log.Print("malformed messageID")
+ break
+ }
+ req := packet.Children[1]
+ if req.ClassType != ber.ClassApplication {
+ log.Print("req.ClassType != ber.ClassApplication")
+ break
+ }
+ // handle controls if present
+ controls := []Control{}
+ if len(packet.Children) > 2 {
+ for _, child := range packet.Children[2].Children {
+ controls = append(controls, DecodeControl(child))
+ }
+ }
+
+ //log.Printf("DEBUG: handling operation: %s [%d]", ApplicationMap[req.Tag], req.Tag)
+ //ber.PrintPacket(packet) // DEBUG
+
+ // dispatch the LDAP operation
+ switch req.Tag { // ldap op code
+ default:
+ responsePacket := encodeLDAPResponse(messageID, ApplicationAddResponse, LDAPResultOperationsError, "Unsupported operation: add")
+ if err = sendPacket(conn, responsePacket); err != nil {
+ log.Printf("sendPacket error %s", err.Error())
+ }
+ log.Printf("Unhandled operation: %s [%d]", ApplicationMap[req.Tag], req.Tag)
+ break handler
+
+ case ApplicationBindRequest:
+ server.Stats.countBinds(1)
+ ldapResultCode := HandleBindRequest(req, server.BindFns, conn)
+ if ldapResultCode == LDAPResultSuccess {
+ boundDN, ok = req.Children[1].Value.(string)
+ if !ok {
+ log.Printf("Malformed Bind DN")
+ break handler
+ }
+ }
+ responsePacket := encodeBindResponse(messageID, ldapResultCode)
+ if err = sendPacket(conn, responsePacket); err != nil {
+ log.Printf("sendPacket error %s", err.Error())
+ break handler
+ }
+ case ApplicationSearchRequest:
+ server.Stats.countSearches(1)
+ if err := HandleSearchRequest(req, &controls, messageID, boundDN, server, conn); err != nil {
+ log.Printf("handleSearchRequest error %s", err.Error()) // TODO: make this more testable/better err handling - stop using log, stop using breaks?
+ e := err.(*Error)
+ if err = sendPacket(conn, encodeSearchDone(messageID, e.ResultCode)); err != nil {
+ log.Printf("sendPacket error %s", err.Error())
+ break handler
+ }
+ break handler
+ } else {
+ if err = sendPacket(conn, encodeSearchDone(messageID, LDAPResultSuccess)); err != nil {
+ log.Printf("sendPacket error %s", err.Error())
+ break handler
+ }
+ }
+ case ApplicationUnbindRequest:
+ server.Stats.countUnbinds(1)
+ break handler // simply disconnect
+ case ApplicationExtendedRequest:
+ ldapResultCode := HandleExtendedRequest(req, boundDN, server.ExtendedFns, conn)
+ responsePacket := encodeLDAPResponse(messageID, ApplicationExtendedResponse, ldapResultCode, LDAPResultCodeMap[ldapResultCode])
+ if err = sendPacket(conn, responsePacket); err != nil {
+ log.Printf("sendPacket error %s", err.Error())
+ break handler
+ }
+ case ApplicationAbandonRequest:
+ HandleAbandonRequest(req, boundDN, server.AbandonFns, conn)
+ break handler
+
+ case ApplicationAddRequest:
+ ldapResultCode := HandleAddRequest(req, boundDN, server.AddFns, conn)
+ responsePacket := encodeLDAPResponse(messageID, ApplicationAddResponse, ldapResultCode, LDAPResultCodeMap[ldapResultCode])
+ if err = sendPacket(conn, responsePacket); err != nil {
+ log.Printf("sendPacket error %s", err.Error())
+ break handler
+ }
+ case ApplicationModifyRequest:
+ ldapResultCode := HandleModifyRequest(req, boundDN, server.ModifyFns, conn)
+ responsePacket := encodeLDAPResponse(messageID, ApplicationModifyResponse, ldapResultCode, LDAPResultCodeMap[ldapResultCode])
+ if err = sendPacket(conn, responsePacket); err != nil {
+ log.Printf("sendPacket error %s", err.Error())
+ break handler
+ }
+ case ApplicationDelRequest:
+ ldapResultCode := HandleDeleteRequest(req, boundDN, server.DeleteFns, conn)
+ responsePacket := encodeLDAPResponse(messageID, ApplicationDelResponse, ldapResultCode, LDAPResultCodeMap[ldapResultCode])
+ if err = sendPacket(conn, responsePacket); err != nil {
+ log.Printf("sendPacket error %s", err.Error())
+ break handler
+ }
+ case ApplicationModifyDNRequest:
+ ldapResultCode := HandleModifyDNRequest(req, boundDN, server.ModifyDNFns, conn)
+ responsePacket := encodeLDAPResponse(messageID, ApplicationModifyDNResponse, ldapResultCode, LDAPResultCodeMap[ldapResultCode])
+ if err = sendPacket(conn, responsePacket); err != nil {
+ log.Printf("sendPacket error %s", err.Error())
+ break handler
+ }
+ case ApplicationCompareRequest:
+ ldapResultCode := HandleCompareRequest(req, boundDN, server.CompareFns, conn)
+ responsePacket := encodeLDAPResponse(messageID, ApplicationCompareResponse, ldapResultCode, LDAPResultCodeMap[ldapResultCode])
+ if err = sendPacket(conn, responsePacket); err != nil {
+ log.Printf("sendPacket error %s", err.Error())
+ break handler
+ }
+ }
+ }
+
+ for _, c := range server.CloseFns {
+ c.Close(boundDN, conn)
+ }
+
+ conn.Close()
+}
+
+//
+func sendPacket(conn net.Conn, packet *ber.Packet) error {
+ _, err := conn.Write(packet.Bytes())
+ if err != nil {
+ log.Printf("Error Sending Message: %s", err.Error())
+ return err
+ }
+ return nil
+}
+
+//
+func routeFunc(dn string, funcNames []string) string {
+ bestPick := ""
+ for _, fn := range funcNames {
+ if strings.HasSuffix(dn, fn) {
+ l := len(strings.Split(bestPick, ","))
+ if bestPick == "" {
+ l = 0
+ }
+ if len(strings.Split(fn, ",")) > l {
+ bestPick = fn
+ }
+ }
+ }
+ return bestPick
+}
+
+//
+func encodeLDAPResponse(messageID uint64, responseType uint8, ldapResultCode LDAPResultCode, message string) *ber.Packet {
+ responsePacket := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Response")
+ responsePacket.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, messageID, "Message ID"))
+ reponse := ber.Encode(ber.ClassApplication, ber.TypeConstructed, responseType, nil, ApplicationMap[responseType])
+ reponse.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagEnumerated, uint64(ldapResultCode), "resultCode: "))
+ reponse.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, "", "matchedDN: "))
+ reponse.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, message, "errorMessage: "))
+ responsePacket.AppendChild(reponse)
+ return responsePacket
+}
+
+//
+type defaultHandler struct {
+}
+
+func (h defaultHandler) Bind(bindDN, bindSimplePw string, conn net.Conn) (LDAPResultCode, error) {
+ return LDAPResultInvalidCredentials, nil
+}
+func (h defaultHandler) Search(boundDN string, req SearchRequest, conn net.Conn) (ServerSearchResult, error) {
+ return ServerSearchResult{make([]*Entry, 0), []string{}, []Control{}, LDAPResultSuccess}, nil
+}
+func (h defaultHandler) Add(boundDN string, req AddRequest, conn net.Conn) (LDAPResultCode, error) {
+ return LDAPResultInsufficientAccessRights, nil
+}
+func (h defaultHandler) Modify(boundDN string, req ModifyRequest, conn net.Conn) (LDAPResultCode, error) {
+ return LDAPResultInsufficientAccessRights, nil
+}
+func (h defaultHandler) Delete(boundDN, deleteDN string, conn net.Conn) (LDAPResultCode, error) {
+ return LDAPResultInsufficientAccessRights, nil
+}
+func (h defaultHandler) ModifyDN(boundDN string, req ModifyDNRequest, conn net.Conn) (LDAPResultCode, error) {
+ return LDAPResultInsufficientAccessRights, nil
+}
+func (h defaultHandler) Compare(boundDN string, req CompareRequest, conn net.Conn) (LDAPResultCode, error) {
+ return LDAPResultInsufficientAccessRights, nil
+}
+func (h defaultHandler) Abandon(boundDN string, conn net.Conn) error {
+ return nil
+}
+func (h defaultHandler) Extended(boundDN string, req ExtendedRequest, conn net.Conn) (LDAPResultCode, error) {
+ return LDAPResultProtocolError, nil
+}
+func (h defaultHandler) Unbind(boundDN string, conn net.Conn) (LDAPResultCode, error) {
+ return LDAPResultSuccess, nil
+}
+func (h defaultHandler) Close(boundDN string, conn net.Conn) error {
+ conn.Close()
+ return nil
+}
+
+//
+func (stats *Stats) countConns(delta int) {
+ if stats != nil {
+ stats.statsMutex.Lock()
+ stats.Conns += delta
+ stats.statsMutex.Unlock()
+ }
+}
+func (stats *Stats) countBinds(delta int) {
+ if stats != nil {
+ stats.statsMutex.Lock()
+ stats.Binds += delta
+ stats.statsMutex.Unlock()
+ }
+}
+func (stats *Stats) countUnbinds(delta int) {
+ if stats != nil {
+ stats.statsMutex.Lock()
+ stats.Unbinds += delta
+ stats.statsMutex.Unlock()
+ }
+}
+func (stats *Stats) countSearches(delta int) {
+ if stats != nil {
+ stats.statsMutex.Lock()
+ stats.Searches += delta
+ stats.statsMutex.Unlock()
+ }
+}
+
+//