diff options
| author | Mark Rushakoff <[email protected]> | 2018-02-23 12:06:02 -0800 |
|---|---|---|
| committer | Mark Rushakoff <[email protected]> | 2018-02-23 12:06:02 -0800 |
| commit | 2129d915cdf28ee389ae3f234136681c73bcd8be (patch) | |
| tree | d2cee01faf5bcc3e5c3d4df5356079db1d66bbfc | |
| parent | b7ccda6592ab6b7c821a5d9a5ab5a28cba8bb094 (diff) | |
Remove QuitChannel method, add Close method
| -rw-r--r-- | server.go | 25 | ||||
| -rw-r--r-- | server_modify_test.go | 18 | ||||
| -rw-r--r-- | server_search_test.go | 49 | ||||
| -rw-r--r-- | server_test.go | 47 |
4 files changed, 53 insertions, 86 deletions
@@ -58,9 +58,10 @@ type Server struct { ExtendedFns map[string]Extender UnbindFns map[string]Unbinder CloseFns map[string]Closer - Quit chan bool EnforceLDAP bool Stats *Stats + + closing chan struct{} } type Stats struct { @@ -81,7 +82,6 @@ type ServerSearchResult struct { // func NewServer() *Server { s := new(Server) - s.Quit = make(chan bool) d := defaultHandler{} s.BindFns = make(map[string]Binder) @@ -107,6 +107,8 @@ func NewServer() *Server { s.UnbindFunc("", d) s.CloseFunc("", d) s.Stats = nil + + s.closing = make(chan struct{}) return s } func (server *Server) BindFunc(baseDN string, f Binder) { @@ -142,9 +144,6 @@ func (server *Server) UnbindFunc(baseDN string, f Unbinder) { func (server *Server) CloseFunc(baseDN string, f Closer) { server.CloseFns[baseDN] = f } -func (server *Server) QuitChannel(quit chan bool) { - server.Quit = quit -} func (server *Server) ListenAndServeTLS(listenString string, certFile string, keyFile string) error { cert, err := tls.LoadX509KeyPair(certFile, keyFile) @@ -207,18 +206,26 @@ func (server *Server) Serve(ln net.Listener) error { } }() -listener: for { select { case c := <-newConn: server.Stats.countConns(1) go server.handleConnection(c) - case <-server.Quit: + case <-server.closing: ln.Close() - break listener + return nil } } - return nil +} + +// Close closes the underlying listener and exits the Serve method. +// Close is not safe for concurrent use. +func (server *Server) Close() { + select { + case <-server.closing: + default: + close(server.closing) + } } // diff --git a/server_modify_test.go b/server_modify_test.go index 378fbd1..7050a32 100644 --- a/server_modify_test.go +++ b/server_modify_test.go @@ -10,11 +10,10 @@ import ( // func TestAdd(t *testing.T) { - quit := make(chan bool) done := make(chan bool) + s := NewServer() + defer s.Close() go func() { - s := NewServer() - s.QuitChannel(quit) s.BindFunc("", modifyTestHandler{}) s.AddFunc("", modifyTestHandler{}) if err := s.ListenAndServe(listenString); err != nil { @@ -42,16 +41,14 @@ func TestAdd(t *testing.T) { case <-time.After(timeout): t.Errorf("ldapadd command timed out") } - quit <- true } // func TestDelete(t *testing.T) { - quit := make(chan bool) done := make(chan bool) + s := NewServer() + defer s.Close() go func() { - s := NewServer() - s.QuitChannel(quit) s.BindFunc("", modifyTestHandler{}) s.DeleteFunc("", modifyTestHandler{}) if err := s.ListenAndServe(listenString); err != nil { @@ -76,15 +73,13 @@ func TestDelete(t *testing.T) { case <-time.After(timeout): t.Errorf("ldapdelete command timed out") } - quit <- true } func TestModify(t *testing.T) { - quit := make(chan bool) done := make(chan bool) + s := NewServer() + defer s.Close() go func() { - s := NewServer() - s.QuitChannel(quit) s.BindFunc("", modifyTestHandler{}) s.ModifyFunc("", modifyTestHandler{}) if err := s.ListenAndServe(listenString); err != nil { @@ -109,7 +104,6 @@ func TestModify(t *testing.T) { case <-time.After(timeout): t.Errorf("ldapadd command timed out") } - quit <- true } /* diff --git a/server_search_test.go b/server_search_test.go index ec66e10..f0345f7 100644 --- a/server_search_test.go +++ b/server_search_test.go @@ -9,11 +9,10 @@ import ( // func TestSearchSimpleOK(t *testing.T) { - quit := make(chan bool) done := make(chan bool) + s := NewServer() + defer s.Close() go func() { - s := NewServer() - s.QuitChannel(quit) s.SearchFunc("", searchSimple{}) s.BindFunc("", bindSimple{}) if err := s.ListenAndServe(listenString); err != nil { @@ -47,16 +46,14 @@ func TestSearchSimpleOK(t *testing.T) { case <-time.After(timeout): t.Errorf("ldapsearch command timed out") } - quit <- true } func TestSearchSizelimit(t *testing.T) { - quit := make(chan bool) done := make(chan bool) + s := NewServer() + defer s.Close() go func() { - s := NewServer() s.EnforceLDAP = true - s.QuitChannel(quit) s.SearchFunc("", searchSimple{}) s.BindFunc("", bindSimple{}) if err := s.ListenAndServe(listenString); err != nil { @@ -142,16 +139,14 @@ func TestSearchSizelimit(t *testing.T) { case <-time.After(timeout): t.Errorf("ldapsearch command timed out") } - quit <- true } ///////////////////////// func TestBindSearchMulti(t *testing.T) { - quit := make(chan bool) done := make(chan bool) + s := NewServer() + defer s.Close() go func() { - s := NewServer() - s.QuitChannel(quit) s.BindFunc("", bindSimple{}) s.BindFunc("c=testz", bindSimple2{}) s.SearchFunc("", searchSimple{}) @@ -188,17 +183,14 @@ func TestBindSearchMulti(t *testing.T) { case <-time.After(timeout): t.Errorf("ldapsearch command timed out") } - - quit <- true } ///////////////////////// func TestSearchPanic(t *testing.T) { - quit := make(chan bool) done := make(chan bool) + s := NewServer() + defer s.Close() go func() { - s := NewServer() - s.QuitChannel(quit) s.SearchFunc("", searchPanic{}) s.BindFunc("", bindAnonOK{}) if err := s.ListenAndServe(listenString); err != nil { @@ -220,7 +212,6 @@ func TestSearchPanic(t *testing.T) { case <-time.After(timeout): t.Errorf("ldapsearch command timed out") } - quit <- true } ///////////////////////// @@ -260,12 +251,11 @@ var searchFilterTestFilters = []compileSearchFilterTest{ ///////////////////////// func TestSearchFiltering(t *testing.T) { - quit := make(chan bool) done := make(chan bool) + s := NewServer() + defer s.Close() go func() { - s := NewServer() s.EnforceLDAP = true - s.QuitChannel(quit) s.SearchFunc("", searchSimple{}) s.BindFunc("", bindSimple{}) if err := s.ListenAndServe(listenString); err != nil { @@ -292,17 +282,15 @@ func TestSearchFiltering(t *testing.T) { t.Errorf("ldapsearch command timed out") } } - quit <- true } ///////////////////////// func TestSearchAttributes(t *testing.T) { - quit := make(chan bool) done := make(chan bool) + s := NewServer() + defer s.Close() go func() { - s := NewServer() s.EnforceLDAP = true - s.QuitChannel(quit) s.SearchFunc("", searchSimple{}) s.BindFunc("", bindSimple{}) if err := s.ListenAndServe(listenString); err != nil { @@ -336,17 +324,15 @@ func TestSearchAttributes(t *testing.T) { case <-time.After(timeout): t.Errorf("ldapsearch command timed out") } - quit <- true } ///////////////////////// func TestSearchScope(t *testing.T) { - quit := make(chan bool) done := make(chan bool) + s := NewServer() + defer s.Close() go func() { - s := NewServer() s.EnforceLDAP = true - s.QuitChannel(quit) s.SearchFunc("", searchSimple{}) s.BindFunc("", bindSimple{}) if err := s.ListenAndServe(listenString); err != nil { @@ -396,15 +382,13 @@ func TestSearchScope(t *testing.T) { case <-time.After(timeout): t.Errorf("ldapsearch command timed out") } - quit <- true } func TestSearchControls(t *testing.T) { - quit := make(chan bool) done := make(chan bool) + s := NewServer() + defer s.Close() go func() { - s := NewServer() - s.QuitChannel(quit) s.SearchFunc("", searchControls{}) s.BindFunc("", bindSimple{}) if err := s.ListenAndServe(listenString); err != nil { @@ -449,5 +433,4 @@ func TestSearchControls(t *testing.T) { case <-time.After(timeout): t.Errorf("ldapsearch command timed out") } - quit <- true } diff --git a/server_test.go b/server_test.go index dafe9a8..87beb62 100644 --- a/server_test.go +++ b/server_test.go @@ -17,11 +17,10 @@ var serverBaseDN = "o=testers,c=test" ///////////////////////// func TestBindAnonOK(t *testing.T) { - quit := make(chan bool) done := make(chan bool) + s := NewServer() + defer s.Close() go func() { - s := NewServer() - s.QuitChannel(quit) s.BindFunc("", bindAnonOK{}) if err := s.ListenAndServe(listenString); err != nil { t.Errorf("s.ListenAndServe failed: %s", err.Error()) @@ -42,16 +41,14 @@ func TestBindAnonOK(t *testing.T) { case <-time.After(timeout): t.Errorf("ldapsearch command timed out") } - quit <- true } ///////////////////////// func TestBindAnonFail(t *testing.T) { - quit := make(chan bool) done := make(chan bool) + s := NewServer() + defer s.Close() go func() { - s := NewServer() - s.QuitChannel(quit) if err := s.ListenAndServe(listenString); err != nil { t.Errorf("s.ListenAndServe failed: %s", err.Error()) } @@ -72,17 +69,14 @@ func TestBindAnonFail(t *testing.T) { case <-time.After(timeout): t.Errorf("ldapsearch command timed out") } - time.Sleep(timeout) - quit <- true } ///////////////////////// func TestBindSimpleOK(t *testing.T) { - quit := make(chan bool) done := make(chan bool) + s := NewServer() + defer s.Close() go func() { - s := NewServer() - s.QuitChannel(quit) s.SearchFunc("", searchSimple{}) s.BindFunc("", bindSimple{}) if err := s.ListenAndServe(listenString); err != nil { @@ -107,16 +101,14 @@ func TestBindSimpleOK(t *testing.T) { case <-time.After(timeout): t.Errorf("ldapsearch command timed out") } - quit <- true } ///////////////////////// func TestBindSimpleFailBadPw(t *testing.T) { - quit := make(chan bool) done := make(chan bool) + s := NewServer() + defer s.Close() go func() { - s := NewServer() - s.QuitChannel(quit) s.BindFunc("", bindSimple{}) if err := s.ListenAndServe(listenString); err != nil { t.Errorf("s.ListenAndServe failed: %s", err.Error()) @@ -140,16 +132,14 @@ func TestBindSimpleFailBadPw(t *testing.T) { case <-time.After(timeout): t.Errorf("ldapsearch command timed out") } - quit <- true } ///////////////////////// func TestBindSimpleFailBadDn(t *testing.T) { - quit := make(chan bool) done := make(chan bool) + s := NewServer() + defer s.Close() go func() { - s := NewServer() - s.QuitChannel(quit) s.BindFunc("", bindSimple{}) if err := s.ListenAndServe(listenString); err != nil { t.Errorf("s.ListenAndServe failed: %s", err.Error()) @@ -173,7 +163,6 @@ func TestBindSimpleFailBadDn(t *testing.T) { case <-time.After(timeout): t.Errorf("ldapsearch command timed out") } - quit <- true } ///////////////////////// @@ -181,11 +170,10 @@ func TestBindSSL(t *testing.T) { t.Skip("unclear how to configure ldapsearch command to trust or skip verification of a custom SSL cert") ldapURLSSL := "ldaps://" + listenString longerTimeout := 300 * time.Millisecond - quit := make(chan bool) done := make(chan bool) + s := NewServer() + defer s.Close() go func() { - s := NewServer() - s.QuitChannel(quit) s.BindFunc("", bindAnonOK{}) if err := s.ListenAndServeTLS(listenString, "tests/cert_DONOTUSE.pem", "tests/key_DONOTUSE.pem"); err != nil { t.Errorf("s.ListenAndServeTLS failed: %s", err.Error()) @@ -211,16 +199,14 @@ func TestBindSSL(t *testing.T) { case <-time.After(longerTimeout * 2): t.Errorf("ldapsearch command timed out") } - quit <- true } ///////////////////////// func TestBindPanic(t *testing.T) { - quit := make(chan bool) done := make(chan bool) + s := NewServer() + defer s.Close() go func() { - s := NewServer() - s.QuitChannel(quit) s.BindFunc("", bindPanic{}) if err := s.ListenAndServe(listenString); err != nil { t.Errorf("s.ListenAndServe failed: %s", err.Error()) @@ -241,7 +227,6 @@ func TestBindPanic(t *testing.T) { case <-time.After(timeout): t.Errorf("ldapsearch command timed out") } - quit <- true } ///////////////////////// @@ -258,12 +243,11 @@ func TestSearchStats(t *testing.T) { w := testStatsWriter{&bytes.Buffer{}} log.SetOutput(w) - quit := make(chan bool) done := make(chan bool) s := NewServer() + defer s.Close() go func() { - s.QuitChannel(quit) s.SearchFunc("", searchSimple{}) s.BindFunc("", bindAnonOK{}) s.SetStats(true) @@ -292,7 +276,6 @@ func TestSearchStats(t *testing.T) { if stats.Conns != 1 || stats.Binds != 1 { t.Errorf("Stats data missing or incorrect: %v", w.buffer.String()) } - quit <- true } ///////////////////////// |
