aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMark Rushakoff <[email protected]>2018-02-23 15:35:58 -0800
committerMark Rushakoff <[email protected]>2018-02-23 15:42:08 -0800
commit82a8f44a2f4cf0686635d2a23ebb41a8f445194e (patch)
tree50be4dd4cf5a89eb40f43393def27e14bc777034
parent0fce9cb1f0426d07ce0967ecf2ed82bb4834084c (diff)
Simplify server bind functions
For our purposes, it doesn't need to route multiple functions across different DNs, so use a simple function instead.
-rw-r--r--server.go18
-rw-r--r--server_bind.go12
-rw-r--r--server_modify_test.go12
-rw-r--r--server_search_test.go61
-rw-r--r--server_test.go36
5 files changed, 34 insertions, 105 deletions
diff --git a/server.go b/server.go
index fd7bbf3..ffae8e7 100644
--- a/server.go
+++ b/server.go
@@ -11,9 +11,8 @@ import (
"github.com/mark-rushakoff/ldapserver/internal/asn1-ber"
)
-type Binder interface {
- Bind(bindDN, bindSimplePw string, conn net.Conn) (LDAPResultCode, error)
-}
+type BindFunc func(bindDN, bindSimplePw string, conn net.Conn) (LDAPResultCode, error)
+
type Searcher interface {
Search(boundDN string, req SearchRequest, conn net.Conn) (ServerSearchResult, error)
}
@@ -47,7 +46,8 @@ type Closer interface {
//
type Server struct {
- BindFns map[string]Binder
+ Bind BindFunc
+
SearchFns map[string]Searcher
AddFns map[string]Adder
ModifyFns map[string]Modifier
@@ -84,7 +84,6 @@ func NewServer() *Server {
s := new(Server)
d := defaultHandler{}
- s.BindFns = make(map[string]Binder)
s.SearchFns = make(map[string]Searcher)
s.AddFns = make(map[string]Adder)
s.ModifyFns = make(map[string]Modifier)
@@ -95,7 +94,9 @@ func NewServer() *Server {
s.ExtendedFns = make(map[string]Extender)
s.UnbindFns = make(map[string]Unbinder)
s.CloseFns = make(map[string]Closer)
- s.BindFunc("", d)
+
+ s.Bind = d.Bind
+
s.SearchFunc("", d)
s.AddFunc("", d)
s.ModifyFunc("", d)
@@ -111,9 +112,6 @@ func NewServer() *Server {
s.closing = make(chan struct{})
return s
}
-func (server *Server) BindFunc(baseDN string, f Binder) {
- server.BindFns[baseDN] = f
-}
func (server *Server) SearchFunc(baseDN string, f Searcher) {
server.SearchFns[baseDN] = f
}
@@ -282,7 +280,7 @@ handler:
case ApplicationBindRequest:
server.Stats.countBinds(1)
- ldapResultCode := HandleBindRequest(req, server.BindFns, conn)
+ ldapResultCode := HandleBindRequest(req, server.Bind, conn)
if ldapResultCode == LDAPResultSuccess {
boundDN, ok = req.Children[1].Value.(string)
if !ok {
diff --git a/server_bind.go b/server_bind.go
index da8b062..e0c19f4 100644
--- a/server_bind.go
+++ b/server_bind.go
@@ -1,12 +1,13 @@
package ldapserver
import (
- "github.com/mark-rushakoff/ldapserver/internal/asn1-ber"
"log"
"net"
+
+ "github.com/mark-rushakoff/ldapserver/internal/asn1-ber"
)
-func HandleBindRequest(req *ber.Packet, fns map[string]Binder, conn net.Conn) (resultCode LDAPResultCode) {
+func HandleBindRequest(req *ber.Packet, fn BindFunc, conn net.Conn) (resultCode LDAPResultCode) {
defer func() {
if r := recover(); r != nil {
resultCode = LDAPResultOperationsError
@@ -35,12 +36,7 @@ func HandleBindRequest(req *ber.Packet, fns map[string]Binder, conn net.Conn) (r
return LDAPResultInappropriateAuthentication
case LDAPBindAuthSimple:
if len(req.Children) == 3 {
- fnNames := []string{}
- for k := range fns {
- fnNames = append(fnNames, k)
- }
- fn := routeFunc(bindDN, fnNames)
- resultCode, err := fns[fn].Bind(bindDN, bindAuth.Data.String(), conn)
+ resultCode, err := fn(bindDN, bindAuth.Data.String(), conn)
if err != nil {
log.Printf("BindFn Error %s", err.Error())
return LDAPResultOperationsError
diff --git a/server_modify_test.go b/server_modify_test.go
index 6705343..78c1fde 100644
--- a/server_modify_test.go
+++ b/server_modify_test.go
@@ -15,7 +15,7 @@ func TestAdd(t *testing.T) {
defer s.Close()
ln, addr := mustListen()
go func() {
- s.BindFunc("", modifyTestHandler{})
+ s.Bind = BindAnonOK
s.AddFunc("", modifyTestHandler{})
if err := s.Serve(ln); err != nil {
t.Errorf("s.Serve failed: %s", err.Error())
@@ -51,7 +51,7 @@ func TestDelete(t *testing.T) {
defer s.Close()
ln, addr := mustListen()
go func() {
- s.BindFunc("", modifyTestHandler{})
+ s.Bind = BindAnonOK
s.DeleteFunc("", modifyTestHandler{})
if err := s.Serve(ln); err != nil {
t.Errorf("s.Serve failed: %s", err.Error())
@@ -83,7 +83,7 @@ func TestModify(t *testing.T) {
defer s.Close()
ln, addr := mustListen()
go func() {
- s.BindFunc("", modifyTestHandler{})
+ s.Bind = BindAnonOK
s.ModifyFunc("", modifyTestHandler{})
if err := s.Serve(ln); err != nil {
t.Errorf("s.Serve failed: %s", err.Error())
@@ -152,12 +152,6 @@ func TestModifyDN(t *testing.T) {
type modifyTestHandler struct {
}
-func (h modifyTestHandler) Bind(bindDN, bindSimplePw string, conn net.Conn) (LDAPResultCode, error) {
- if bindDN == "" && bindSimplePw == "" {
- return LDAPResultSuccess, nil
- }
- return LDAPResultInvalidCredentials, nil
-}
func (h modifyTestHandler) Add(boundDN string, req AddRequest, conn net.Conn) (LDAPResultCode, error) {
// only succeed on expected contents of add.ldif:
if len(req.attributes) == 5 && req.dn == "cn=Barbara Jensen,dc=example,dc=com" &&
diff --git a/server_search_test.go b/server_search_test.go
index 6325984..1a59940 100644
--- a/server_search_test.go
+++ b/server_search_test.go
@@ -15,7 +15,7 @@ func TestSearchSimpleOK(t *testing.T) {
ln, addr := mustListen()
go func() {
s.SearchFunc("", searchSimple{})
- s.BindFunc("", bindSimple{})
+ s.Bind = BindSimple
if err := s.Serve(ln); err != nil {
t.Errorf("s.Serve failed: %s", err.Error())
}
@@ -57,7 +57,7 @@ func TestSearchSizelimit(t *testing.T) {
go func() {
s.EnforceLDAP = true
s.SearchFunc("", searchSimple{})
- s.BindFunc("", bindSimple{})
+ s.Bind = BindSimple
if err := s.Serve(ln); err != nil {
t.Errorf("s.Serve failed: %s", err.Error())
}
@@ -144,51 +144,6 @@ func TestSearchSizelimit(t *testing.T) {
}
/////////////////////////
-func TestBindSearchMulti(t *testing.T) {
- done := make(chan bool)
- s := NewServer()
- defer s.Close()
- ln, addr := mustListen()
- go func() {
- s.BindFunc("", bindSimple{})
- s.BindFunc("c=testz", bindSimple2{})
- s.SearchFunc("", searchSimple{})
- s.SearchFunc("c=testz", searchSimple2{})
- if err := s.Serve(ln); err != nil {
- t.Errorf("s.Serve failed: %s", err.Error())
- }
- }()
-
- go func() {
- cmd := exec.Command("ldapsearch", "-H", "ldap://"+addr, "-x", "-b", "o=testers,c=test",
- "-D", "cn=testy,o=testers,c=test", "-w", "iLike2test", "cn=ned")
- out, _ := cmd.CombinedOutput()
- if !strings.Contains(string(out), "result: 0 Success") {
- t.Errorf("error routing default bind/search functions: %v", string(out))
- }
- if !strings.Contains(string(out), "dn: cn=ned,o=testers,c=test") {
- t.Errorf("search default routing failed: %v", string(out))
- }
- cmd = exec.Command("ldapsearch", "-H", "ldap://"+addr, "-x", "-b", "o=testers,c=testz",
- "-D", "cn=testy,o=testers,c=testz", "-w", "ZLike2test", "cn=hamburger")
- out, _ = cmd.CombinedOutput()
- if !strings.Contains(string(out), "result: 0 Success") {
- t.Errorf("error routing custom bind/search functions: %v", string(out))
- }
- if !strings.Contains(string(out), "dn: cn=hamburger,o=testers,c=testz") {
- t.Errorf("search custom routing failed: %v", string(out))
- }
- done <- true
- }()
-
- select {
- case <-done:
- case <-time.After(timeout):
- t.Errorf("ldapsearch command timed out")
- }
-}
-
-/////////////////////////
func TestSearchPanic(t *testing.T) {
done := make(chan bool)
s := NewServer()
@@ -196,7 +151,7 @@ func TestSearchPanic(t *testing.T) {
ln, addr := mustListen()
go func() {
s.SearchFunc("", searchPanic{})
- s.BindFunc("", bindAnonOK{})
+ s.Bind = BindAnonOK
if err := s.Serve(ln); err != nil {
t.Errorf("s.Serve failed: %s", err.Error())
}
@@ -262,7 +217,7 @@ func TestSearchFiltering(t *testing.T) {
go func() {
s.EnforceLDAP = true
s.SearchFunc("", searchSimple{})
- s.BindFunc("", bindSimple{})
+ s.Bind = BindSimple
if err := s.Serve(ln); err != nil {
t.Errorf("s.Serve failed: %s", err.Error())
}
@@ -298,7 +253,7 @@ func TestSearchAttributes(t *testing.T) {
go func() {
s.EnforceLDAP = true
s.SearchFunc("", searchSimple{})
- s.BindFunc("", bindSimple{})
+ s.Bind = BindSimple
if err := s.Serve(ln); err != nil {
t.Errorf("s.Serve failed: %s", err.Error())
}
@@ -341,7 +296,7 @@ func TestSearchScope(t *testing.T) {
go func() {
s.EnforceLDAP = true
s.SearchFunc("", searchSimple{})
- s.BindFunc("", bindSimple{})
+ s.Bind = BindSimple
if err := s.Serve(ln); err != nil {
t.Errorf("s.Serve failed: %s", err.Error())
}
@@ -386,7 +341,7 @@ func TestSearchScope(t *testing.T) {
select {
case <-done:
- case <-time.After(2*timeout):
+ case <-time.After(2 * timeout):
t.Errorf("ldapsearch command timed out")
}
}
@@ -398,7 +353,7 @@ func TestSearchControls(t *testing.T) {
ln, addr := mustListen()
go func() {
s.SearchFunc("", searchControls{})
- s.BindFunc("", bindSimple{})
+ s.Bind = BindSimple
if err := s.Serve(ln); err != nil {
t.Errorf("s.Serve failed: %s", err.Error())
}
diff --git a/server_test.go b/server_test.go
index d7527f4..79a1c07 100644
--- a/server_test.go
+++ b/server_test.go
@@ -21,7 +21,7 @@ func TestBindAnonOK(t *testing.T) {
defer s.Close()
ln, addr := mustListen()
go func() {
- s.BindFunc("", bindAnonOK{})
+ s.Bind = BindAnonOK
if err := s.Serve(ln); err != nil {
t.Errorf("s.Serve failed: %s", err.Error())
}
@@ -80,7 +80,7 @@ func TestBindSimpleOK(t *testing.T) {
ln, addr := mustListen()
go func() {
s.SearchFunc("", searchSimple{})
- s.BindFunc("", bindSimple{})
+ s.Bind = BindSimple
if err := s.Serve(ln); err != nil {
t.Errorf("s.Serve failed: %s", err.Error())
}
@@ -112,7 +112,7 @@ func TestBindSimpleFailBadPw(t *testing.T) {
defer s.Close()
ln, addr := mustListen()
go func() {
- s.BindFunc("", bindSimple{})
+ s.Bind = BindSimple
if err := s.Serve(ln); err != nil {
t.Errorf("s.Serve failed: %s", err.Error())
}
@@ -144,7 +144,7 @@ func TestBindSimpleFailBadDn(t *testing.T) {
defer s.Close()
ln, addr := mustListen()
go func() {
- s.BindFunc("", bindSimple{})
+ s.Bind = BindSimple
if err := s.Serve(ln); err != nil {
t.Errorf("s.Serve failed: %s", err.Error())
}
@@ -191,7 +191,7 @@ func TestBindSSL(t *testing.T) {
ldapURLSSL := "ldaps://" + ln.Addr().String()
go func() {
- s.BindFunc("", bindAnonOK{})
+ s.Bind = BindAnonOK
if err := s.Serve(ln); err != nil {
t.Errorf("s.Serve failed: %s", err.Error())
}
@@ -225,7 +225,7 @@ func TestBindPanic(t *testing.T) {
defer s.Close()
ln, addr := mustListen()
go func() {
- s.BindFunc("", bindPanic{})
+ s.Bind = BindPanic
if err := s.Serve(ln); err != nil {
t.Errorf("s.Serve failed: %s", err.Error())
}
@@ -268,7 +268,7 @@ func TestSearchStats(t *testing.T) {
go func() {
s.SearchFunc("", searchSimple{})
- s.BindFunc("", bindAnonOK{})
+ s.Bind = BindAnonOK
s.SetStats(true)
if err := s.Serve(ln); err != nil {
t.Errorf("s.Serve failed: %s", err.Error())
@@ -297,43 +297,29 @@ func TestSearchStats(t *testing.T) {
}
}
-/////////////////////////
-type bindAnonOK struct {
-}
-
-func (b bindAnonOK) Bind(bindDN, bindSimplePw string, conn net.Conn) (LDAPResultCode, error) {
+func BindAnonOK(bindDN, bindSimplePw string, conn net.Conn) (LDAPResultCode, error) {
if bindDN == "" && bindSimplePw == "" {
return LDAPResultSuccess, nil
}
return LDAPResultInvalidCredentials, nil
}
-type bindSimple struct {
-}
-
-func (b bindSimple) Bind(bindDN, bindSimplePw string, conn net.Conn) (LDAPResultCode, error) {
+func BindSimple(bindDN, bindSimplePw string, conn net.Conn) (LDAPResultCode, error) {
if bindDN == "cn=testy,o=testers,c=test" && bindSimplePw == "iLike2test" {
return LDAPResultSuccess, nil
}
return LDAPResultInvalidCredentials, nil
}
-type bindSimple2 struct {
-}
-
-func (b bindSimple2) Bind(bindDN, bindSimplePw string, conn net.Conn) (LDAPResultCode, error) {
+func BindSimple2(bindDN, bindSimplePw string, conn net.Conn) (LDAPResultCode, error) {
if bindDN == "cn=testy,o=testers,c=testz" && bindSimplePw == "ZLike2test" {
return LDAPResultSuccess, nil
}
return LDAPResultInvalidCredentials, nil
}
-type bindPanic struct {
-}
-
-func (b bindPanic) Bind(bindDN, bindSimplePw string, conn net.Conn) (LDAPResultCode, error) {
+func BindPanic(bindDN, bindSimplePw string, conn net.Conn) (LDAPResultCode, error) {
panic("test panic at the disco")
- return LDAPResultInvalidCredentials, nil
}
type searchSimple struct {