aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--server.go30
-rw-r--r--server_bind.go9
-rw-r--r--server_modify_test.go12
-rw-r--r--server_search.go7
-rw-r--r--server_search_test.go75
-rw-r--r--server_test.go99
6 files changed, 175 insertions, 57 deletions
diff --git a/server.go b/server.go
index d520414..28b5257 100644
--- a/server.go
+++ b/server.go
@@ -12,9 +12,12 @@ import (
"github.com/metala/ldap/internal/asn1-ber"
)
-type BindFunc func(bindDN, bindSimplePw string, conn net.Conn) (LDAPResultCode, error)
-type SearchFunc func(boundDN string, req SearchRequest, conn net.Conn) (ServerSearchResult, error)
-
+type Binder interface {
+ Bind(bindDN, bindSimplePw string, conn net.Conn) (LDAPResultCode, error)
+}
+type Searcher interface {
+ Search(boundDN string, req SearchRequest, conn net.Conn) (ServerSearchResult, error)
+}
type Adder interface {
Add(boundDN string, req AddRequest, conn net.Conn) (LDAPResultCode, error)
}
@@ -45,9 +48,8 @@ type Closer interface {
//
type Server struct {
- Bind BindFunc
- Search SearchFunc
-
+ BindFns map[string]Binder
+ SearchFns map[string]Searcher
AddFns map[string]Adder
ModifyFns map[string]Modifier
DeleteFns map[string]Deleter
@@ -87,6 +89,8 @@ func NewServer() *Server {
s := new(Server)
d := defaultHandler{}
+ s.BindFns = make(map[string]Binder)
+ s.SearchFns = make(map[string]Searcher)
s.AddFns = make(map[string]Adder)
s.ModifyFns = make(map[string]Modifier)
s.DeleteFns = make(map[string]Deleter)
@@ -96,10 +100,8 @@ func NewServer() *Server {
s.ExtendedFns = make(map[string]Extender)
s.UnbindFns = make(map[string]Unbinder)
s.CloseFns = make(map[string]Closer)
-
- s.Bind = d.Bind
- s.Search = d.Search
-
+ s.BindFunc("", d)
+ s.SearchFunc("", d)
s.AddFunc("", d)
s.ModifyFunc("", d)
s.DeleteFunc("", d)
@@ -114,6 +116,12 @@ func NewServer() *Server {
s.closing = make(chan struct{})
return s
}
+func (server *Server) BindFunc(baseDN string, f Binder) {
+ server.BindFns[baseDN] = f
+}
+func (server *Server) SearchFunc(baseDN string, f Searcher) {
+ server.SearchFns[baseDN] = f
+}
func (server *Server) AddFunc(baseDN string, f Adder) {
server.AddFns[baseDN] = f
}
@@ -296,7 +304,7 @@ handler:
case ApplicationBindRequest:
server.Stats.countBinds(1)
- ldapResultCode := HandleBindRequest(req, server.Bind, conn)
+ ldapResultCode := HandleBindRequest(req, server.BindFns, conn)
if ldapResultCode == LDAPResultSuccess {
boundDN, ok = req.Children[1].Value.(string)
if !ok {
diff --git a/server_bind.go b/server_bind.go
index c094b13..1684823 100644
--- a/server_bind.go
+++ b/server_bind.go
@@ -7,7 +7,7 @@ import (
"github.com/metala/ldap/internal/asn1-ber"
)
-func HandleBindRequest(req *ber.Packet, fn BindFunc, conn net.Conn) (resultCode LDAPResultCode) {
+func HandleBindRequest(req *ber.Packet, fns map[string]Binder, conn net.Conn) (resultCode LDAPResultCode) {
defer func() {
if r := recover(); r != nil {
resultCode = LDAPResultOperationsError
@@ -36,7 +36,12 @@ func HandleBindRequest(req *ber.Packet, fn BindFunc, conn net.Conn) (resultCode
return LDAPResultInappropriateAuthentication
case LDAPBindAuthSimple:
if len(req.Children) == 3 {
- resultCode, err := fn(bindDN, bindAuth.Data.String(), conn)
+ fnNames := []string{}
+ for k := range fns {
+ fnNames = append(fnNames, k)
+ }
+ fn := routeFunc(bindDN, fnNames)
+ resultCode, err := fns[fn].Bind(bindDN, bindAuth.Data.String(), conn)
if err != nil {
log.Printf("BindFn Error %s", err.Error())
return LDAPResultOperationsError
diff --git a/server_modify_test.go b/server_modify_test.go
index f6f54cc..b13a98f 100644
--- a/server_modify_test.go
+++ b/server_modify_test.go
@@ -15,7 +15,7 @@ func TestAdd(t *testing.T) {
defer s.Close()
ln, addr := mustListen()
go func() {
- s.Bind = BindAnonOK
+ s.BindFunc("", modifyTestHandler{})
s.AddFunc("", modifyTestHandler{})
if err := s.Serve(ln); err != nil {
t.Errorf("s.Serve failed: %s", err.Error())
@@ -51,7 +51,7 @@ func TestDelete(t *testing.T) {
defer s.Close()
ln, addr := mustListen()
go func() {
- s.Bind = BindAnonOK
+ s.BindFunc("", modifyTestHandler{})
s.DeleteFunc("", modifyTestHandler{})
if err := s.Serve(ln); err != nil {
t.Errorf("s.Serve failed: %s", err.Error())
@@ -83,7 +83,7 @@ func TestModify(t *testing.T) {
defer s.Close()
ln, addr := mustListen()
go func() {
- s.Bind = BindAnonOK
+ s.BindFunc("", modifyTestHandler{})
s.ModifyFunc("", modifyTestHandler{})
if err := s.Serve(ln); err != nil {
t.Errorf("s.Serve failed: %s", err.Error())
@@ -152,6 +152,12 @@ func TestModifyDN(t *testing.T) {
type modifyTestHandler struct {
}
+func (h modifyTestHandler) Bind(bindDN, bindSimplePw string, conn net.Conn) (LDAPResultCode, error) {
+ if bindDN == "" && bindSimplePw == "" {
+ return LDAPResultSuccess, nil
+ }
+ return LDAPResultInvalidCredentials, nil
+}
func (h modifyTestHandler) Add(boundDN string, req AddRequest, conn net.Conn) (LDAPResultCode, error) {
// only succeed on expected contents of add.ldif:
if len(req.attributes) == 5 && req.dn == "cn=Barbara Jensen,dc=example,dc=com" &&
diff --git a/server_search.go b/server_search.go
index 170b758..e9ce1d4 100644
--- a/server_search.go
+++ b/server_search.go
@@ -26,7 +26,12 @@ func HandleSearchRequest(req *ber.Packet, controls *[]Control, messageID uint64,
return NewError(LDAPResultOperationsError, err)
}
- searchResp, err := server.Search(boundDN, searchReq, conn)
+ fnNames := []string{}
+ for k := range server.SearchFns {
+ fnNames = append(fnNames, k)
+ }
+ fn := routeFunc(searchReq.BaseDN, fnNames)
+ searchResp, err := server.SearchFns[fn].Search(boundDN, searchReq, conn)
if err != nil {
return NewError(searchResp.ResultCode, err)
}
diff --git a/server_search_test.go b/server_search_test.go
index 4cf2089..86c5297 100644
--- a/server_search_test.go
+++ b/server_search_test.go
@@ -14,8 +14,8 @@ func TestSearchSimpleOK(t *testing.T) {
defer s.Close()
ln, addr := mustListen()
go func() {
- s.Search = SearchSimple
- s.Bind = BindSimple
+ s.SearchFunc("", searchSimple{})
+ s.BindFunc("", bindSimple{})
if err := s.Serve(ln); err != nil {
t.Errorf("s.Serve failed: %s", err.Error())
}
@@ -56,8 +56,8 @@ func TestSearchSizelimit(t *testing.T) {
ln, addr := mustListen()
go func() {
s.EnforceLDAP = true
- s.Search = SearchSimple
- s.Bind = BindSimple
+ s.SearchFunc("", searchSimple{})
+ s.BindFunc("", bindSimple{})
if err := s.Serve(ln); err != nil {
t.Errorf("s.Serve failed: %s", err.Error())
}
@@ -144,14 +144,59 @@ func TestSearchSizelimit(t *testing.T) {
}
/////////////////////////
+func TestBindSearchMulti(t *testing.T) {
+ done := make(chan bool)
+ s := NewServer()
+ defer s.Close()
+ ln, addr := mustListen()
+ go func() {
+ s.BindFunc("", bindSimple{})
+ s.BindFunc("c=testz", bindSimple2{})
+ s.SearchFunc("", searchSimple{})
+ s.SearchFunc("c=testz", searchSimple2{})
+ if err := s.Serve(ln); err != nil {
+ t.Errorf("s.Serve failed: %s", err.Error())
+ }
+ }()
+
+ go func() {
+ cmd := exec.Command("ldapsearch", "-H", "ldap://"+addr, "-x", "-b", "o=testers,c=test",
+ "-D", "cn=testy,o=testers,c=test", "-w", "iLike2test", "cn=ned")
+ out, _ := cmd.CombinedOutput()
+ if !strings.Contains(string(out), "result: 0 Success") {
+ t.Errorf("error routing default bind/search functions: %v", string(out))
+ }
+ if !strings.Contains(string(out), "dn: cn=ned,o=testers,c=test") {
+ t.Errorf("search default routing failed: %v", string(out))
+ }
+ cmd = exec.Command("ldapsearch", "-H", "ldap://"+addr, "-x", "-b", "o=testers,c=testz",
+ "-D", "cn=testy,o=testers,c=testz", "-w", "ZLike2test", "cn=hamburger")
+ out, _ = cmd.CombinedOutput()
+ if !strings.Contains(string(out), "result: 0 Success") {
+ t.Errorf("error routing custom bind/search functions: %v", string(out))
+ }
+ if !strings.Contains(string(out), "dn: cn=hamburger,o=testers,c=testz") {
+ t.Errorf("search custom routing failed: %v", string(out))
+ }
+ done <- true
+ }()
+
+ select {
+ case <-done:
+ case <-time.After(timeout):
+ t.Errorf("ldapsearch command timed out")
+ }
+}
+
+/////////////////////////
func TestSearchPanic(t *testing.T) {
done := make(chan bool)
s := NewServer()
defer s.Close()
ln, addr := mustListen()
go func() {
- s.Search = SearchPanic
- s.Bind = BindAnonOK
+ s.SearchFunc("", searchPanic{})
+ s.BindFunc("", bindAnonOK{})
if err := s.Serve(ln); err != nil {
t.Errorf("s.Serve failed: %s", err.Error())
}
@@ -216,8 +261,8 @@ func TestSearchFiltering(t *testing.T) {
ln, addr := mustListen()
go func() {
s.EnforceLDAP = true
- s.Search = SearchSimple
- s.Bind = BindSimple
+ s.SearchFunc("", searchSimple{})
+ s.BindFunc("", bindSimple{})
if err := s.Serve(ln); err != nil {
t.Errorf("s.Serve failed: %s", err.Error())
}
@@ -252,8 +297,8 @@ func TestSearchAttributes(t *testing.T) {
ln, addr := mustListen()
go func() {
s.EnforceLDAP = true
- s.Search = SearchSimple
- s.Bind = BindSimple
+ s.SearchFunc("", searchSimple{})
+ s.BindFunc("", bindSimple{})
if err := s.Serve(ln); err != nil {
t.Errorf("s.Serve failed: %s", err.Error())
}
@@ -295,8 +340,8 @@ func TestSearchScope(t *testing.T) {
ln, addr := mustListen()
go func() {
s.EnforceLDAP = true
- s.Search = SearchSimple
- s.Bind = BindSimple
+ s.SearchFunc("", searchSimple{})
+ s.BindFunc("", bindSimple{})
if err := s.Serve(ln); err != nil {
t.Errorf("s.Serve failed: %s", err.Error())
}
@@ -341,7 +386,7 @@ func TestSearchScope(t *testing.T) {
select {
case <-done:
- case <-time.After(2 * timeout):
+ case <-time.After(2*timeout):
t.Errorf("ldapsearch command timed out")
}
}
@@ -352,8 +397,8 @@ func TestSearchControls(t *testing.T) {
defer s.Close()
ln, addr := mustListen()
go func() {
- s.Search = SearchControls
- s.Bind = BindSimple
+ s.SearchFunc("", searchControls{})
+ s.BindFunc("", bindSimple{})
if err := s.Serve(ln); err != nil {
t.Errorf("s.Serve failed: %s", err.Error())
}
diff --git a/server_test.go b/server_test.go
index 3ae15d4..cedfc54 100644
--- a/server_test.go
+++ b/server_test.go
@@ -199,8 +199,8 @@ func TestStartTLS(t *testing.T) {
s := NewServer()
defer s.Close()
- s.Bind = BindAnonOK
- s.Search = SearchSimple
+ s.BindFunc("", bindAnonOK{})
+ s.SearchFunc("", searchSimple{})
s.TLSConfig = cert.ServerTLSConfig()
ln, addr := mustListen()
@@ -241,8 +241,8 @@ func TestStartTLSWithoutTLSConfigDoesNotPanic(t *testing.T) {
s := NewServer()
defer s.Close()
- s.Bind = BindAnonOK
- s.Search = SearchSimple
+ s.BindFunc("", bindAnonOK{})
+ s.SearchFunc("", searchSimple{})
ln, addr := mustListen()
go func() {
@@ -278,8 +278,8 @@ func TestEnforcedTLSWithoutTLSConfig(t *testing.T) {
s := NewServer()
defer s.Close()
s.EnforceTLS = true
- s.Bind = BindAnonOK
- s.Search = SearchSimple
+ s.BindFunc("", bindAnonOK{})
+ s.SearchFunc("", searchSimple{})
ln, _ := mustListen()
done := make(chan error)
@@ -310,8 +310,8 @@ func TestEnforcedTLS(t *testing.T) {
s := NewServer()
defer s.Close()
s.EnforceTLS = true
- s.Bind = BindAnonOK
- s.Search = SearchSimple
+ s.BindFunc("", bindAnonOK{})
+ s.SearchFunc("", searchSimple{})
s.TLSConfig = cert.ServerTLSConfig()
ln, addr := mustListen()
@@ -356,8 +356,8 @@ func TestEnforcedTLSFail(t *testing.T) {
s := NewServer()
defer s.Close()
s.EnforceTLS = true
- s.Bind = BindAnonOK
- s.Search = SearchSimple
+ s.BindFunc("", bindAnonOK{})
+ s.SearchFunc("", searchSimple{})
s.TLSConfig = cert.ServerTLSConfig()
ln, addr := mustListen()
@@ -398,7 +398,7 @@ func TestBindAnonOK(t *testing.T) {
defer s.Close()
ln, addr := mustListen()
go func() {
- s.Bind = BindAnonOK
+ s.BindFunc("", bindAnonOK{})
if err := s.Serve(ln); err != nil {
t.Errorf("s.Serve failed: %s", err.Error())
}
@@ -456,8 +456,8 @@ func TestBindSimpleOK(t *testing.T) {
defer s.Close()
ln, addr := mustListen()
go func() {
- s.Search = SearchSimple
- s.Bind = BindSimple
+ s.SearchFunc("", searchSimple{})
+ s.BindFunc("", bindSimple{})
if err := s.Serve(ln); err != nil {
t.Errorf("s.Serve failed: %s", err.Error())
}
@@ -489,7 +489,7 @@ func TestBindSimpleFailBadPw(t *testing.T) {
defer s.Close()
ln, addr := mustListen()
go func() {
- s.Bind = BindSimple
+ s.BindFunc("", bindSimple{})
if err := s.Serve(ln); err != nil {
t.Errorf("s.Serve failed: %s", err.Error())
}
@@ -521,7 +521,7 @@ func TestBindSimpleFailBadDn(t *testing.T) {
defer s.Close()
ln, addr := mustListen()
go func() {
- s.Bind = BindSimple
+ s.BindFunc("", bindSimple{})
if err := s.Serve(ln); err != nil {
t.Errorf("s.Serve failed: %s", err.Error())
}
@@ -567,7 +567,7 @@ func TestBindSSL(t *testing.T) {
ldapURLSSL := "ldaps://" + ln.Addr().String()
go func() {
- s.Bind = BindAnonOK
+ s.BindFunc("", bindAnonOK{})
if err := s.Serve(ln); err != nil {
t.Errorf("s.Serve failed: %s", err.Error())
}
@@ -602,7 +602,7 @@ func TestBindPanic(t *testing.T) {
defer s.Close()
ln, addr := mustListen()
go func() {
- s.Bind = BindPanic
+ s.BindFunc("", bindPanic{})
if err := s.Serve(ln); err != nil {
t.Errorf("s.Serve failed: %s", err.Error())
}
@@ -644,8 +644,8 @@ func TestSearchStats(t *testing.T) {
ln, addr := mustListen()
go func() {
- s.Search = SearchSimple
- s.Bind = BindAnonOK
+ s.SearchFunc("", searchSimple{})
+ s.BindFunc("", bindAnonOK{})
s.SetStats(true)
if err := s.Serve(ln); err != nil {
t.Errorf("s.Serve failed: %s", err.Error())
@@ -674,25 +674,49 @@ func TestSearchStats(t *testing.T) {
}
}
-func BindAnonOK(bindDN, bindSimplePw string, conn net.Conn) (LDAPResultCode, error) {
+/////////////////////////
+type bindAnonOK struct {
+}
+
+func (b bindAnonOK) Bind(bindDN, bindSimplePw string, conn net.Conn) (LDAPResultCode, error) {
if bindDN == "" && bindSimplePw == "" {
return LDAPResultSuccess, nil
}
return LDAPResultInvalidCredentials, nil
}
-func BindSimple(bindDN, bindSimplePw string, conn net.Conn) (LDAPResultCode, error) {
+type bindSimple struct {
+}
+
+func (b bindSimple) Bind(bindDN, bindSimplePw string, conn net.Conn) (LDAPResultCode, error) {
if bindDN == "cn=testy,o=testers,c=test" && bindSimplePw == "iLike2test" {
return LDAPResultSuccess, nil
}
return LDAPResultInvalidCredentials, nil
}
-func BindPanic(bindDN, bindSimplePw string, conn net.Conn) (LDAPResultCode, error) {
+type bindSimple2 struct {
+}
+
+func (b bindSimple2) Bind(bindDN, bindSimplePw string, conn net.Conn) (LDAPResultCode, error) {
+ if bindDN == "cn=testy,o=testers,c=testz" && bindSimplePw == "ZLike2test" {
+ return LDAPResultSuccess, nil
+ }
+ return LDAPResultInvalidCredentials, nil
+}
+
+type bindPanic struct {
+}
+
+func (b bindPanic) Bind(bindDN, bindSimplePw string, conn net.Conn) (LDAPResultCode, error) {
panic("test panic at the disco")
+ return LDAPResultInvalidCredentials, nil
}
-func SearchSimple(boundDN string, searchReq SearchRequest, conn net.Conn) (ServerSearchResult, error) {
+type searchSimple struct {
+}
+
+func (s searchSimple) Search(boundDN string, searchReq SearchRequest, conn net.Conn) (ServerSearchResult, error) {
entries := []*Entry{
&Entry{"cn=ned,o=testers,c=test", []*EntryAttribute{
&EntryAttribute{"cn", []string{"ned"}},
@@ -724,11 +748,36 @@ func SearchSimple(boundDN string, searchReq SearchRequest, conn net.Conn) (Serve
return ServerSearchResult{entries, []string{}, []Control{}, LDAPResultSuccess}, nil
}
-func SearchPanic(boundDN string, searchReq SearchRequest, conn net.Conn) (ServerSearchResult, error) {
+type searchSimple2 struct {
+}
+
+func (s searchSimple2) Search(boundDN string, searchReq SearchRequest, conn net.Conn) (ServerSearchResult, error) {
+ entries := []*Entry{
+ &Entry{"cn=hamburger,o=testers,c=testz", []*EntryAttribute{
+ &EntryAttribute{"cn", []string{"hamburger"}},
+ &EntryAttribute{"o", []string{"testers"}},
+ &EntryAttribute{"uidNumber", []string{"5000"}},
+ &EntryAttribute{"accountstatus", []string{"active"}},
+ &EntryAttribute{"uid", []string{"hamburger"}},
+ &EntryAttribute{"objectclass", []string{"posixaccount"}},
+ }},
+ }
+ return ServerSearchResult{entries, []string{}, []Control{}, LDAPResultSuccess}, nil
+}
+
+type searchPanic struct {
+}
+
+func (s searchPanic) Search(boundDN string, searchReq SearchRequest, conn net.Conn) (ServerSearchResult, error) {
+ entries := []*Entry{}
panic("this is a test panic")
+ return ServerSearchResult{entries, []string{}, []Control{}, LDAPResultSuccess}, nil
+}
+
+type searchControls struct {
}
-func SearchControls(boundDN string, searchReq SearchRequest, conn net.Conn) (ServerSearchResult, error) {
+func (s searchControls) Search(boundDN string, searchReq SearchRequest, conn net.Conn) (ServerSearchResult, error) {
entries := []*Entry{}
if len(searchReq.Controls) == 1 && searchReq.Controls[0].GetControlType() == "1.2.3.4.5" {
newEntry := &Entry{"cn=hamburger,o=testers,c=testz", []*EntryAttribute{