diff options
| author | Marin Ivanov <[email protected]> | 2019-02-13 07:20:14 +0200 |
|---|---|---|
| committer | Marin Ivanov <[email protected]> | 2019-02-13 07:26:13 +0200 |
| commit | a1a0a3aae7ef762250b9295985be1eee41d7a49e (patch) | |
| tree | 39accaebbd5c1fad072fc7ecb2360c0db6b21bef | |
| parent | cb4f041b8be79b49eb046466ceb1bea9cfcaeb87 (diff) | |
Revert the simplification of Searcher and Binder interfaces
* Revert "Simplify sever search functions" commit 9402a7d580c2dd929c68cf8b3038a1e6496f607f.
* Revert "Simplify server bind functions" commit 82a8f44a2f4cf0686635d2a23ebb41a8f445194e.
* Fix tests
| -rw-r--r-- | server.go | 30 | ||||
| -rw-r--r-- | server_bind.go | 9 | ||||
| -rw-r--r-- | server_modify_test.go | 12 | ||||
| -rw-r--r-- | server_search.go | 7 | ||||
| -rw-r--r-- | server_search_test.go | 75 | ||||
| -rw-r--r-- | server_test.go | 99 |
6 files changed, 175 insertions, 57 deletions
@@ -12,9 +12,12 @@ import ( "github.com/metala/ldap/internal/asn1-ber" ) -type BindFunc func(bindDN, bindSimplePw string, conn net.Conn) (LDAPResultCode, error) -type SearchFunc func(boundDN string, req SearchRequest, conn net.Conn) (ServerSearchResult, error) - +type Binder interface { + Bind(bindDN, bindSimplePw string, conn net.Conn) (LDAPResultCode, error) +} +type Searcher interface { + Search(boundDN string, req SearchRequest, conn net.Conn) (ServerSearchResult, error) +} type Adder interface { Add(boundDN string, req AddRequest, conn net.Conn) (LDAPResultCode, error) } @@ -45,9 +48,8 @@ type Closer interface { // type Server struct { - Bind BindFunc - Search SearchFunc - + BindFns map[string]Binder + SearchFns map[string]Searcher AddFns map[string]Adder ModifyFns map[string]Modifier DeleteFns map[string]Deleter @@ -87,6 +89,8 @@ func NewServer() *Server { s := new(Server) d := defaultHandler{} + s.BindFns = make(map[string]Binder) + s.SearchFns = make(map[string]Searcher) s.AddFns = make(map[string]Adder) s.ModifyFns = make(map[string]Modifier) s.DeleteFns = make(map[string]Deleter) @@ -96,10 +100,8 @@ func NewServer() *Server { s.ExtendedFns = make(map[string]Extender) s.UnbindFns = make(map[string]Unbinder) s.CloseFns = make(map[string]Closer) - - s.Bind = d.Bind - s.Search = d.Search - + s.BindFunc("", d) + s.SearchFunc("", d) s.AddFunc("", d) s.ModifyFunc("", d) s.DeleteFunc("", d) @@ -114,6 +116,12 @@ func NewServer() *Server { s.closing = make(chan struct{}) return s } +func (server *Server) BindFunc(baseDN string, f Binder) { + server.BindFns[baseDN] = f +} +func (server *Server) SearchFunc(baseDN string, f Searcher) { + server.SearchFns[baseDN] = f +} func (server *Server) AddFunc(baseDN string, f Adder) { server.AddFns[baseDN] = f } @@ -296,7 +304,7 @@ handler: case ApplicationBindRequest: server.Stats.countBinds(1) - ldapResultCode := HandleBindRequest(req, server.Bind, conn) + ldapResultCode := HandleBindRequest(req, server.BindFns, conn) if ldapResultCode == LDAPResultSuccess { boundDN, ok = req.Children[1].Value.(string) if !ok { diff --git a/server_bind.go b/server_bind.go index c094b13..1684823 100644 --- a/server_bind.go +++ b/server_bind.go @@ -7,7 +7,7 @@ import ( "github.com/metala/ldap/internal/asn1-ber" ) -func HandleBindRequest(req *ber.Packet, fn BindFunc, conn net.Conn) (resultCode LDAPResultCode) { +func HandleBindRequest(req *ber.Packet, fns map[string]Binder, conn net.Conn) (resultCode LDAPResultCode) { defer func() { if r := recover(); r != nil { resultCode = LDAPResultOperationsError @@ -36,7 +36,12 @@ func HandleBindRequest(req *ber.Packet, fn BindFunc, conn net.Conn) (resultCode return LDAPResultInappropriateAuthentication case LDAPBindAuthSimple: if len(req.Children) == 3 { - resultCode, err := fn(bindDN, bindAuth.Data.String(), conn) + fnNames := []string{} + for k := range fns { + fnNames = append(fnNames, k) + } + fn := routeFunc(bindDN, fnNames) + resultCode, err := fns[fn].Bind(bindDN, bindAuth.Data.String(), conn) if err != nil { log.Printf("BindFn Error %s", err.Error()) return LDAPResultOperationsError diff --git a/server_modify_test.go b/server_modify_test.go index f6f54cc..b13a98f 100644 --- a/server_modify_test.go +++ b/server_modify_test.go @@ -15,7 +15,7 @@ func TestAdd(t *testing.T) { defer s.Close() ln, addr := mustListen() go func() { - s.Bind = BindAnonOK + s.BindFunc("", modifyTestHandler{}) s.AddFunc("", modifyTestHandler{}) if err := s.Serve(ln); err != nil { t.Errorf("s.Serve failed: %s", err.Error()) @@ -51,7 +51,7 @@ func TestDelete(t *testing.T) { defer s.Close() ln, addr := mustListen() go func() { - s.Bind = BindAnonOK + s.BindFunc("", modifyTestHandler{}) s.DeleteFunc("", modifyTestHandler{}) if err := s.Serve(ln); err != nil { t.Errorf("s.Serve failed: %s", err.Error()) @@ -83,7 +83,7 @@ func TestModify(t *testing.T) { defer s.Close() ln, addr := mustListen() go func() { - s.Bind = BindAnonOK + s.BindFunc("", modifyTestHandler{}) s.ModifyFunc("", modifyTestHandler{}) if err := s.Serve(ln); err != nil { t.Errorf("s.Serve failed: %s", err.Error()) @@ -152,6 +152,12 @@ func TestModifyDN(t *testing.T) { type modifyTestHandler struct { } +func (h modifyTestHandler) Bind(bindDN, bindSimplePw string, conn net.Conn) (LDAPResultCode, error) { + if bindDN == "" && bindSimplePw == "" { + return LDAPResultSuccess, nil + } + return LDAPResultInvalidCredentials, nil +} func (h modifyTestHandler) Add(boundDN string, req AddRequest, conn net.Conn) (LDAPResultCode, error) { // only succeed on expected contents of add.ldif: if len(req.attributes) == 5 && req.dn == "cn=Barbara Jensen,dc=example,dc=com" && diff --git a/server_search.go b/server_search.go index 170b758..e9ce1d4 100644 --- a/server_search.go +++ b/server_search.go @@ -26,7 +26,12 @@ func HandleSearchRequest(req *ber.Packet, controls *[]Control, messageID uint64, return NewError(LDAPResultOperationsError, err) } - searchResp, err := server.Search(boundDN, searchReq, conn) + fnNames := []string{} + for k := range server.SearchFns { + fnNames = append(fnNames, k) + } + fn := routeFunc(searchReq.BaseDN, fnNames) + searchResp, err := server.SearchFns[fn].Search(boundDN, searchReq, conn) if err != nil { return NewError(searchResp.ResultCode, err) } diff --git a/server_search_test.go b/server_search_test.go index 4cf2089..86c5297 100644 --- a/server_search_test.go +++ b/server_search_test.go @@ -14,8 +14,8 @@ func TestSearchSimpleOK(t *testing.T) { defer s.Close() ln, addr := mustListen() go func() { - s.Search = SearchSimple - s.Bind = BindSimple + s.SearchFunc("", searchSimple{}) + s.BindFunc("", bindSimple{}) if err := s.Serve(ln); err != nil { t.Errorf("s.Serve failed: %s", err.Error()) } @@ -56,8 +56,8 @@ func TestSearchSizelimit(t *testing.T) { ln, addr := mustListen() go func() { s.EnforceLDAP = true - s.Search = SearchSimple - s.Bind = BindSimple + s.SearchFunc("", searchSimple{}) + s.BindFunc("", bindSimple{}) if err := s.Serve(ln); err != nil { t.Errorf("s.Serve failed: %s", err.Error()) } @@ -144,14 +144,59 @@ func TestSearchSizelimit(t *testing.T) { } ///////////////////////// +func TestBindSearchMulti(t *testing.T) { + done := make(chan bool) + s := NewServer() + defer s.Close() + ln, addr := mustListen() + go func() { + s.BindFunc("", bindSimple{}) + s.BindFunc("c=testz", bindSimple2{}) + s.SearchFunc("", searchSimple{}) + s.SearchFunc("c=testz", searchSimple2{}) + if err := s.Serve(ln); err != nil { + t.Errorf("s.Serve failed: %s", err.Error()) + } + }() + + go func() { + cmd := exec.Command("ldapsearch", "-H", "ldap://"+addr, "-x", "-b", "o=testers,c=test", + "-D", "cn=testy,o=testers,c=test", "-w", "iLike2test", "cn=ned") + out, _ := cmd.CombinedOutput() + if !strings.Contains(string(out), "result: 0 Success") { + t.Errorf("error routing default bind/search functions: %v", string(out)) + } + if !strings.Contains(string(out), "dn: cn=ned,o=testers,c=test") { + t.Errorf("search default routing failed: %v", string(out)) + } + cmd = exec.Command("ldapsearch", "-H", "ldap://"+addr, "-x", "-b", "o=testers,c=testz", + "-D", "cn=testy,o=testers,c=testz", "-w", "ZLike2test", "cn=hamburger") + out, _ = cmd.CombinedOutput() + if !strings.Contains(string(out), "result: 0 Success") { + t.Errorf("error routing custom bind/search functions: %v", string(out)) + } + if !strings.Contains(string(out), "dn: cn=hamburger,o=testers,c=testz") { + t.Errorf("search custom routing failed: %v", string(out)) + } + done <- true + }() + + select { + case <-done: + case <-time.After(timeout): + t.Errorf("ldapsearch command timed out") + } +} + +///////////////////////// func TestSearchPanic(t *testing.T) { done := make(chan bool) s := NewServer() defer s.Close() ln, addr := mustListen() go func() { - s.Search = SearchPanic - s.Bind = BindAnonOK + s.SearchFunc("", searchPanic{}) + s.BindFunc("", bindAnonOK{}) if err := s.Serve(ln); err != nil { t.Errorf("s.Serve failed: %s", err.Error()) } @@ -216,8 +261,8 @@ func TestSearchFiltering(t *testing.T) { ln, addr := mustListen() go func() { s.EnforceLDAP = true - s.Search = SearchSimple - s.Bind = BindSimple + s.SearchFunc("", searchSimple{}) + s.BindFunc("", bindSimple{}) if err := s.Serve(ln); err != nil { t.Errorf("s.Serve failed: %s", err.Error()) } @@ -252,8 +297,8 @@ func TestSearchAttributes(t *testing.T) { ln, addr := mustListen() go func() { s.EnforceLDAP = true - s.Search = SearchSimple - s.Bind = BindSimple + s.SearchFunc("", searchSimple{}) + s.BindFunc("", bindSimple{}) if err := s.Serve(ln); err != nil { t.Errorf("s.Serve failed: %s", err.Error()) } @@ -295,8 +340,8 @@ func TestSearchScope(t *testing.T) { ln, addr := mustListen() go func() { s.EnforceLDAP = true - s.Search = SearchSimple - s.Bind = BindSimple + s.SearchFunc("", searchSimple{}) + s.BindFunc("", bindSimple{}) if err := s.Serve(ln); err != nil { t.Errorf("s.Serve failed: %s", err.Error()) } @@ -341,7 +386,7 @@ func TestSearchScope(t *testing.T) { select { case <-done: - case <-time.After(2 * timeout): + case <-time.After(2*timeout): t.Errorf("ldapsearch command timed out") } } @@ -352,8 +397,8 @@ func TestSearchControls(t *testing.T) { defer s.Close() ln, addr := mustListen() go func() { - s.Search = SearchControls - s.Bind = BindSimple + s.SearchFunc("", searchControls{}) + s.BindFunc("", bindSimple{}) if err := s.Serve(ln); err != nil { t.Errorf("s.Serve failed: %s", err.Error()) } diff --git a/server_test.go b/server_test.go index 3ae15d4..cedfc54 100644 --- a/server_test.go +++ b/server_test.go @@ -199,8 +199,8 @@ func TestStartTLS(t *testing.T) { s := NewServer() defer s.Close() - s.Bind = BindAnonOK - s.Search = SearchSimple + s.BindFunc("", bindAnonOK{}) + s.SearchFunc("", searchSimple{}) s.TLSConfig = cert.ServerTLSConfig() ln, addr := mustListen() @@ -241,8 +241,8 @@ func TestStartTLSWithoutTLSConfigDoesNotPanic(t *testing.T) { s := NewServer() defer s.Close() - s.Bind = BindAnonOK - s.Search = SearchSimple + s.BindFunc("", bindAnonOK{}) + s.SearchFunc("", searchSimple{}) ln, addr := mustListen() go func() { @@ -278,8 +278,8 @@ func TestEnforcedTLSWithoutTLSConfig(t *testing.T) { s := NewServer() defer s.Close() s.EnforceTLS = true - s.Bind = BindAnonOK - s.Search = SearchSimple + s.BindFunc("", bindAnonOK{}) + s.SearchFunc("", searchSimple{}) ln, _ := mustListen() done := make(chan error) @@ -310,8 +310,8 @@ func TestEnforcedTLS(t *testing.T) { s := NewServer() defer s.Close() s.EnforceTLS = true - s.Bind = BindAnonOK - s.Search = SearchSimple + s.BindFunc("", bindAnonOK{}) + s.SearchFunc("", searchSimple{}) s.TLSConfig = cert.ServerTLSConfig() ln, addr := mustListen() @@ -356,8 +356,8 @@ func TestEnforcedTLSFail(t *testing.T) { s := NewServer() defer s.Close() s.EnforceTLS = true - s.Bind = BindAnonOK - s.Search = SearchSimple + s.BindFunc("", bindAnonOK{}) + s.SearchFunc("", searchSimple{}) s.TLSConfig = cert.ServerTLSConfig() ln, addr := mustListen() @@ -398,7 +398,7 @@ func TestBindAnonOK(t *testing.T) { defer s.Close() ln, addr := mustListen() go func() { - s.Bind = BindAnonOK + s.BindFunc("", bindAnonOK{}) if err := s.Serve(ln); err != nil { t.Errorf("s.Serve failed: %s", err.Error()) } @@ -456,8 +456,8 @@ func TestBindSimpleOK(t *testing.T) { defer s.Close() ln, addr := mustListen() go func() { - s.Search = SearchSimple - s.Bind = BindSimple + s.SearchFunc("", searchSimple{}) + s.BindFunc("", bindSimple{}) if err := s.Serve(ln); err != nil { t.Errorf("s.Serve failed: %s", err.Error()) } @@ -489,7 +489,7 @@ func TestBindSimpleFailBadPw(t *testing.T) { defer s.Close() ln, addr := mustListen() go func() { - s.Bind = BindSimple + s.BindFunc("", bindSimple{}) if err := s.Serve(ln); err != nil { t.Errorf("s.Serve failed: %s", err.Error()) } @@ -521,7 +521,7 @@ func TestBindSimpleFailBadDn(t *testing.T) { defer s.Close() ln, addr := mustListen() go func() { - s.Bind = BindSimple + s.BindFunc("", bindSimple{}) if err := s.Serve(ln); err != nil { t.Errorf("s.Serve failed: %s", err.Error()) } @@ -567,7 +567,7 @@ func TestBindSSL(t *testing.T) { ldapURLSSL := "ldaps://" + ln.Addr().String() go func() { - s.Bind = BindAnonOK + s.BindFunc("", bindAnonOK{}) if err := s.Serve(ln); err != nil { t.Errorf("s.Serve failed: %s", err.Error()) } @@ -602,7 +602,7 @@ func TestBindPanic(t *testing.T) { defer s.Close() ln, addr := mustListen() go func() { - s.Bind = BindPanic + s.BindFunc("", bindPanic{}) if err := s.Serve(ln); err != nil { t.Errorf("s.Serve failed: %s", err.Error()) } @@ -644,8 +644,8 @@ func TestSearchStats(t *testing.T) { ln, addr := mustListen() go func() { - s.Search = SearchSimple - s.Bind = BindAnonOK + s.SearchFunc("", searchSimple{}) + s.BindFunc("", bindAnonOK{}) s.SetStats(true) if err := s.Serve(ln); err != nil { t.Errorf("s.Serve failed: %s", err.Error()) @@ -674,25 +674,49 @@ func TestSearchStats(t *testing.T) { } } -func BindAnonOK(bindDN, bindSimplePw string, conn net.Conn) (LDAPResultCode, error) { +///////////////////////// +type bindAnonOK struct { +} + +func (b bindAnonOK) Bind(bindDN, bindSimplePw string, conn net.Conn) (LDAPResultCode, error) { if bindDN == "" && bindSimplePw == "" { return LDAPResultSuccess, nil } return LDAPResultInvalidCredentials, nil } -func BindSimple(bindDN, bindSimplePw string, conn net.Conn) (LDAPResultCode, error) { +type bindSimple struct { +} + +func (b bindSimple) Bind(bindDN, bindSimplePw string, conn net.Conn) (LDAPResultCode, error) { if bindDN == "cn=testy,o=testers,c=test" && bindSimplePw == "iLike2test" { return LDAPResultSuccess, nil } return LDAPResultInvalidCredentials, nil } -func BindPanic(bindDN, bindSimplePw string, conn net.Conn) (LDAPResultCode, error) { +type bindSimple2 struct { +} + +func (b bindSimple2) Bind(bindDN, bindSimplePw string, conn net.Conn) (LDAPResultCode, error) { + if bindDN == "cn=testy,o=testers,c=testz" && bindSimplePw == "ZLike2test" { + return LDAPResultSuccess, nil + } + return LDAPResultInvalidCredentials, nil +} + +type bindPanic struct { +} + +func (b bindPanic) Bind(bindDN, bindSimplePw string, conn net.Conn) (LDAPResultCode, error) { panic("test panic at the disco") + return LDAPResultInvalidCredentials, nil } -func SearchSimple(boundDN string, searchReq SearchRequest, conn net.Conn) (ServerSearchResult, error) { +type searchSimple struct { +} + +func (s searchSimple) Search(boundDN string, searchReq SearchRequest, conn net.Conn) (ServerSearchResult, error) { entries := []*Entry{ &Entry{"cn=ned,o=testers,c=test", []*EntryAttribute{ &EntryAttribute{"cn", []string{"ned"}}, @@ -724,11 +748,36 @@ func SearchSimple(boundDN string, searchReq SearchRequest, conn net.Conn) (Serve return ServerSearchResult{entries, []string{}, []Control{}, LDAPResultSuccess}, nil } -func SearchPanic(boundDN string, searchReq SearchRequest, conn net.Conn) (ServerSearchResult, error) { +type searchSimple2 struct { +} + +func (s searchSimple2) Search(boundDN string, searchReq SearchRequest, conn net.Conn) (ServerSearchResult, error) { + entries := []*Entry{ + &Entry{"cn=hamburger,o=testers,c=testz", []*EntryAttribute{ + &EntryAttribute{"cn", []string{"hamburger"}}, + &EntryAttribute{"o", []string{"testers"}}, + &EntryAttribute{"uidNumber", []string{"5000"}}, + &EntryAttribute{"accountstatus", []string{"active"}}, + &EntryAttribute{"uid", []string{"hamburger"}}, + &EntryAttribute{"objectclass", []string{"posixaccount"}}, + }}, + } + return ServerSearchResult{entries, []string{}, []Control{}, LDAPResultSuccess}, nil +} + +type searchPanic struct { +} + +func (s searchPanic) Search(boundDN string, searchReq SearchRequest, conn net.Conn) (ServerSearchResult, error) { + entries := []*Entry{} panic("this is a test panic") + return ServerSearchResult{entries, []string{}, []Control{}, LDAPResultSuccess}, nil +} + +type searchControls struct { } -func SearchControls(boundDN string, searchReq SearchRequest, conn net.Conn) (ServerSearchResult, error) { +func (s searchControls) Search(boundDN string, searchReq SearchRequest, conn net.Conn) (ServerSearchResult, error) { entries := []*Entry{} if len(searchReq.Controls) == 1 && searchReq.Controls[0].GetControlType() == "1.2.3.4.5" { newEntry := &Entry{"cn=hamburger,o=testers,c=testz", []*EntryAttribute{ |
