diff options
| author | Mark Rushakoff <[email protected]> | 2018-02-23 15:35:58 -0800 |
|---|---|---|
| committer | Mark Rushakoff <[email protected]> | 2018-02-23 15:42:08 -0800 |
| commit | 82a8f44a2f4cf0686635d2a23ebb41a8f445194e (patch) | |
| tree | 50be4dd4cf5a89eb40f43393def27e14bc777034 | |
| parent | 0fce9cb1f0426d07ce0967ecf2ed82bb4834084c (diff) | |
Simplify server bind functions
For our purposes, it doesn't need to route multiple functions across
different DNs, so use a simple function instead.
| -rw-r--r-- | server.go | 18 | ||||
| -rw-r--r-- | server_bind.go | 12 | ||||
| -rw-r--r-- | server_modify_test.go | 12 | ||||
| -rw-r--r-- | server_search_test.go | 61 | ||||
| -rw-r--r-- | server_test.go | 36 |
5 files changed, 34 insertions, 105 deletions
@@ -11,9 +11,8 @@ import ( "github.com/mark-rushakoff/ldapserver/internal/asn1-ber" ) -type Binder interface { - Bind(bindDN, bindSimplePw string, conn net.Conn) (LDAPResultCode, error) -} +type BindFunc func(bindDN, bindSimplePw string, conn net.Conn) (LDAPResultCode, error) + type Searcher interface { Search(boundDN string, req SearchRequest, conn net.Conn) (ServerSearchResult, error) } @@ -47,7 +46,8 @@ type Closer interface { // type Server struct { - BindFns map[string]Binder + Bind BindFunc + SearchFns map[string]Searcher AddFns map[string]Adder ModifyFns map[string]Modifier @@ -84,7 +84,6 @@ func NewServer() *Server { s := new(Server) d := defaultHandler{} - s.BindFns = make(map[string]Binder) s.SearchFns = make(map[string]Searcher) s.AddFns = make(map[string]Adder) s.ModifyFns = make(map[string]Modifier) @@ -95,7 +94,9 @@ func NewServer() *Server { s.ExtendedFns = make(map[string]Extender) s.UnbindFns = make(map[string]Unbinder) s.CloseFns = make(map[string]Closer) - s.BindFunc("", d) + + s.Bind = d.Bind + s.SearchFunc("", d) s.AddFunc("", d) s.ModifyFunc("", d) @@ -111,9 +112,6 @@ func NewServer() *Server { s.closing = make(chan struct{}) return s } -func (server *Server) BindFunc(baseDN string, f Binder) { - server.BindFns[baseDN] = f -} func (server *Server) SearchFunc(baseDN string, f Searcher) { server.SearchFns[baseDN] = f } @@ -282,7 +280,7 @@ handler: case ApplicationBindRequest: server.Stats.countBinds(1) - ldapResultCode := HandleBindRequest(req, server.BindFns, conn) + ldapResultCode := HandleBindRequest(req, server.Bind, conn) if ldapResultCode == LDAPResultSuccess { boundDN, ok = req.Children[1].Value.(string) if !ok { diff --git a/server_bind.go b/server_bind.go index da8b062..e0c19f4 100644 --- a/server_bind.go +++ b/server_bind.go @@ -1,12 +1,13 @@ package ldapserver import ( - "github.com/mark-rushakoff/ldapserver/internal/asn1-ber" "log" "net" + + "github.com/mark-rushakoff/ldapserver/internal/asn1-ber" ) -func HandleBindRequest(req *ber.Packet, fns map[string]Binder, conn net.Conn) (resultCode LDAPResultCode) { +func HandleBindRequest(req *ber.Packet, fn BindFunc, conn net.Conn) (resultCode LDAPResultCode) { defer func() { if r := recover(); r != nil { resultCode = LDAPResultOperationsError @@ -35,12 +36,7 @@ func HandleBindRequest(req *ber.Packet, fns map[string]Binder, conn net.Conn) (r return LDAPResultInappropriateAuthentication case LDAPBindAuthSimple: if len(req.Children) == 3 { - fnNames := []string{} - for k := range fns { - fnNames = append(fnNames, k) - } - fn := routeFunc(bindDN, fnNames) - resultCode, err := fns[fn].Bind(bindDN, bindAuth.Data.String(), conn) + resultCode, err := fn(bindDN, bindAuth.Data.String(), conn) if err != nil { log.Printf("BindFn Error %s", err.Error()) return LDAPResultOperationsError diff --git a/server_modify_test.go b/server_modify_test.go index 6705343..78c1fde 100644 --- a/server_modify_test.go +++ b/server_modify_test.go @@ -15,7 +15,7 @@ func TestAdd(t *testing.T) { defer s.Close() ln, addr := mustListen() go func() { - s.BindFunc("", modifyTestHandler{}) + s.Bind = BindAnonOK s.AddFunc("", modifyTestHandler{}) if err := s.Serve(ln); err != nil { t.Errorf("s.Serve failed: %s", err.Error()) @@ -51,7 +51,7 @@ func TestDelete(t *testing.T) { defer s.Close() ln, addr := mustListen() go func() { - s.BindFunc("", modifyTestHandler{}) + s.Bind = BindAnonOK s.DeleteFunc("", modifyTestHandler{}) if err := s.Serve(ln); err != nil { t.Errorf("s.Serve failed: %s", err.Error()) @@ -83,7 +83,7 @@ func TestModify(t *testing.T) { defer s.Close() ln, addr := mustListen() go func() { - s.BindFunc("", modifyTestHandler{}) + s.Bind = BindAnonOK s.ModifyFunc("", modifyTestHandler{}) if err := s.Serve(ln); err != nil { t.Errorf("s.Serve failed: %s", err.Error()) @@ -152,12 +152,6 @@ func TestModifyDN(t *testing.T) { type modifyTestHandler struct { } -func (h modifyTestHandler) Bind(bindDN, bindSimplePw string, conn net.Conn) (LDAPResultCode, error) { - if bindDN == "" && bindSimplePw == "" { - return LDAPResultSuccess, nil - } - return LDAPResultInvalidCredentials, nil -} func (h modifyTestHandler) Add(boundDN string, req AddRequest, conn net.Conn) (LDAPResultCode, error) { // only succeed on expected contents of add.ldif: if len(req.attributes) == 5 && req.dn == "cn=Barbara Jensen,dc=example,dc=com" && diff --git a/server_search_test.go b/server_search_test.go index 6325984..1a59940 100644 --- a/server_search_test.go +++ b/server_search_test.go @@ -15,7 +15,7 @@ func TestSearchSimpleOK(t *testing.T) { ln, addr := mustListen() go func() { s.SearchFunc("", searchSimple{}) - s.BindFunc("", bindSimple{}) + s.Bind = BindSimple if err := s.Serve(ln); err != nil { t.Errorf("s.Serve failed: %s", err.Error()) } @@ -57,7 +57,7 @@ func TestSearchSizelimit(t *testing.T) { go func() { s.EnforceLDAP = true s.SearchFunc("", searchSimple{}) - s.BindFunc("", bindSimple{}) + s.Bind = BindSimple if err := s.Serve(ln); err != nil { t.Errorf("s.Serve failed: %s", err.Error()) } @@ -144,51 +144,6 @@ func TestSearchSizelimit(t *testing.T) { } ///////////////////////// -func TestBindSearchMulti(t *testing.T) { - done := make(chan bool) - s := NewServer() - defer s.Close() - ln, addr := mustListen() - go func() { - s.BindFunc("", bindSimple{}) - s.BindFunc("c=testz", bindSimple2{}) - s.SearchFunc("", searchSimple{}) - s.SearchFunc("c=testz", searchSimple2{}) - if err := s.Serve(ln); err != nil { - t.Errorf("s.Serve failed: %s", err.Error()) - } - }() - - go func() { - cmd := exec.Command("ldapsearch", "-H", "ldap://"+addr, "-x", "-b", "o=testers,c=test", - "-D", "cn=testy,o=testers,c=test", "-w", "iLike2test", "cn=ned") - out, _ := cmd.CombinedOutput() - if !strings.Contains(string(out), "result: 0 Success") { - t.Errorf("error routing default bind/search functions: %v", string(out)) - } - if !strings.Contains(string(out), "dn: cn=ned,o=testers,c=test") { - t.Errorf("search default routing failed: %v", string(out)) - } - cmd = exec.Command("ldapsearch", "-H", "ldap://"+addr, "-x", "-b", "o=testers,c=testz", - "-D", "cn=testy,o=testers,c=testz", "-w", "ZLike2test", "cn=hamburger") - out, _ = cmd.CombinedOutput() - if !strings.Contains(string(out), "result: 0 Success") { - t.Errorf("error routing custom bind/search functions: %v", string(out)) - } - if !strings.Contains(string(out), "dn: cn=hamburger,o=testers,c=testz") { - t.Errorf("search custom routing failed: %v", string(out)) - } - done <- true - }() - - select { - case <-done: - case <-time.After(timeout): - t.Errorf("ldapsearch command timed out") - } -} - -///////////////////////// func TestSearchPanic(t *testing.T) { done := make(chan bool) s := NewServer() @@ -196,7 +151,7 @@ func TestSearchPanic(t *testing.T) { ln, addr := mustListen() go func() { s.SearchFunc("", searchPanic{}) - s.BindFunc("", bindAnonOK{}) + s.Bind = BindAnonOK if err := s.Serve(ln); err != nil { t.Errorf("s.Serve failed: %s", err.Error()) } @@ -262,7 +217,7 @@ func TestSearchFiltering(t *testing.T) { go func() { s.EnforceLDAP = true s.SearchFunc("", searchSimple{}) - s.BindFunc("", bindSimple{}) + s.Bind = BindSimple if err := s.Serve(ln); err != nil { t.Errorf("s.Serve failed: %s", err.Error()) } @@ -298,7 +253,7 @@ func TestSearchAttributes(t *testing.T) { go func() { s.EnforceLDAP = true s.SearchFunc("", searchSimple{}) - s.BindFunc("", bindSimple{}) + s.Bind = BindSimple if err := s.Serve(ln); err != nil { t.Errorf("s.Serve failed: %s", err.Error()) } @@ -341,7 +296,7 @@ func TestSearchScope(t *testing.T) { go func() { s.EnforceLDAP = true s.SearchFunc("", searchSimple{}) - s.BindFunc("", bindSimple{}) + s.Bind = BindSimple if err := s.Serve(ln); err != nil { t.Errorf("s.Serve failed: %s", err.Error()) } @@ -386,7 +341,7 @@ func TestSearchScope(t *testing.T) { select { case <-done: - case <-time.After(2*timeout): + case <-time.After(2 * timeout): t.Errorf("ldapsearch command timed out") } } @@ -398,7 +353,7 @@ func TestSearchControls(t *testing.T) { ln, addr := mustListen() go func() { s.SearchFunc("", searchControls{}) - s.BindFunc("", bindSimple{}) + s.Bind = BindSimple if err := s.Serve(ln); err != nil { t.Errorf("s.Serve failed: %s", err.Error()) } diff --git a/server_test.go b/server_test.go index d7527f4..79a1c07 100644 --- a/server_test.go +++ b/server_test.go @@ -21,7 +21,7 @@ func TestBindAnonOK(t *testing.T) { defer s.Close() ln, addr := mustListen() go func() { - s.BindFunc("", bindAnonOK{}) + s.Bind = BindAnonOK if err := s.Serve(ln); err != nil { t.Errorf("s.Serve failed: %s", err.Error()) } @@ -80,7 +80,7 @@ func TestBindSimpleOK(t *testing.T) { ln, addr := mustListen() go func() { s.SearchFunc("", searchSimple{}) - s.BindFunc("", bindSimple{}) + s.Bind = BindSimple if err := s.Serve(ln); err != nil { t.Errorf("s.Serve failed: %s", err.Error()) } @@ -112,7 +112,7 @@ func TestBindSimpleFailBadPw(t *testing.T) { defer s.Close() ln, addr := mustListen() go func() { - s.BindFunc("", bindSimple{}) + s.Bind = BindSimple if err := s.Serve(ln); err != nil { t.Errorf("s.Serve failed: %s", err.Error()) } @@ -144,7 +144,7 @@ func TestBindSimpleFailBadDn(t *testing.T) { defer s.Close() ln, addr := mustListen() go func() { - s.BindFunc("", bindSimple{}) + s.Bind = BindSimple if err := s.Serve(ln); err != nil { t.Errorf("s.Serve failed: %s", err.Error()) } @@ -191,7 +191,7 @@ func TestBindSSL(t *testing.T) { ldapURLSSL := "ldaps://" + ln.Addr().String() go func() { - s.BindFunc("", bindAnonOK{}) + s.Bind = BindAnonOK if err := s.Serve(ln); err != nil { t.Errorf("s.Serve failed: %s", err.Error()) } @@ -225,7 +225,7 @@ func TestBindPanic(t *testing.T) { defer s.Close() ln, addr := mustListen() go func() { - s.BindFunc("", bindPanic{}) + s.Bind = BindPanic if err := s.Serve(ln); err != nil { t.Errorf("s.Serve failed: %s", err.Error()) } @@ -268,7 +268,7 @@ func TestSearchStats(t *testing.T) { go func() { s.SearchFunc("", searchSimple{}) - s.BindFunc("", bindAnonOK{}) + s.Bind = BindAnonOK s.SetStats(true) if err := s.Serve(ln); err != nil { t.Errorf("s.Serve failed: %s", err.Error()) @@ -297,43 +297,29 @@ func TestSearchStats(t *testing.T) { } } -///////////////////////// -type bindAnonOK struct { -} - -func (b bindAnonOK) Bind(bindDN, bindSimplePw string, conn net.Conn) (LDAPResultCode, error) { +func BindAnonOK(bindDN, bindSimplePw string, conn net.Conn) (LDAPResultCode, error) { if bindDN == "" && bindSimplePw == "" { return LDAPResultSuccess, nil } return LDAPResultInvalidCredentials, nil } -type bindSimple struct { -} - -func (b bindSimple) Bind(bindDN, bindSimplePw string, conn net.Conn) (LDAPResultCode, error) { +func BindSimple(bindDN, bindSimplePw string, conn net.Conn) (LDAPResultCode, error) { if bindDN == "cn=testy,o=testers,c=test" && bindSimplePw == "iLike2test" { return LDAPResultSuccess, nil } return LDAPResultInvalidCredentials, nil } -type bindSimple2 struct { -} - -func (b bindSimple2) Bind(bindDN, bindSimplePw string, conn net.Conn) (LDAPResultCode, error) { +func BindSimple2(bindDN, bindSimplePw string, conn net.Conn) (LDAPResultCode, error) { if bindDN == "cn=testy,o=testers,c=testz" && bindSimplePw == "ZLike2test" { return LDAPResultSuccess, nil } return LDAPResultInvalidCredentials, nil } -type bindPanic struct { -} - -func (b bindPanic) Bind(bindDN, bindSimplePw string, conn net.Conn) (LDAPResultCode, error) { +func BindPanic(bindDN, bindSimplePw string, conn net.Conn) (LDAPResultCode, error) { panic("test panic at the disco") - return LDAPResultInvalidCredentials, nil } type searchSimple struct { |
