diff --git a/bind.go b/bind.go index 171a2e9..2f0060a 100644 --- a/bind.go +++ b/bind.go @@ -7,7 +7,7 @@ package ldap import ( "errors" - "github.com/nmcclain/asn1-ber" + ber "github.com/go-asn1-ber/asn1-ber" ) func (l *Conn) Bind(username, password string) error { @@ -55,45 +55,44 @@ func (l *Conn) Bind(username, password string) error { } func (l *Conn) Unbind() error { - defer l.Close() - - messageID := l.nextMessageID() - - packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request") - packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, messageID, "MessageID")) - unbindRequest := ber.Encode(ber.ClassApplication, ber.TypePrimitive, ApplicationUnbindRequest, nil, "Unbind Request") - packet.AppendChild(unbindRequest) - - if l.Debug { - ber.PrintPacket(packet) - } - - channel, err := l.sendMessage(packet) - if err != nil { - return err - } - if channel == nil { - return NewError(ErrorNetwork, errors.New("ldap: could not send message")) - } - defer l.finishMessage(messageID) - - packet = <-channel - if packet == nil { - return NewError(ErrorNetwork, errors.New("ldap: could not retrieve response")) - } - - if l.Debug { - if err := addLDAPDescriptions(packet); err != nil { - return err - } - ber.PrintPacket(packet) - } - - resultCode, resultDescription := getLDAPResultCode(packet) - if resultCode != 0 { - return NewError(resultCode, errors.New(resultDescription)) - } - - return nil -} + defer l.Close() + + messageID := l.nextMessageID() + + packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request") + packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, messageID, "MessageID")) + unbindRequest := ber.Encode(ber.ClassApplication, ber.TypePrimitive, ApplicationUnbindRequest, nil, "Unbind Request") + packet.AppendChild(unbindRequest) + + if l.Debug { + ber.PrintPacket(packet) + } + + channel, err := l.sendMessage(packet) + if err != nil { + return err + } + if channel == nil { + return NewError(ErrorNetwork, errors.New("ldap: could not send message")) + } + defer l.finishMessage(messageID) + packet = <-channel + if packet == nil { + return NewError(ErrorNetwork, errors.New("ldap: could not retrieve response")) + } + + if l.Debug { + if err := addLDAPDescriptions(packet); err != nil { + return err + } + ber.PrintPacket(packet) + } + + resultCode, resultDescription := getLDAPResultCode(packet) + if resultCode != 0 { + return NewError(resultCode, errors.New(resultDescription)) + } + + return nil +} diff --git a/conn.go b/conn.go index 253e58e..a842b1e 100644 --- a/conn.go +++ b/conn.go @@ -12,7 +12,7 @@ import ( "sync" "time" - "github.com/nmcclain/asn1-ber" + ber "github.com/go-asn1-ber/asn1-ber" ) const ( @@ -296,7 +296,7 @@ func (l *Conn) reader() { addLDAPDescriptions(packet) message := &messagePacket{ Op: MessageResponse, - MessageID: packet.Children[0].Value.(uint64), + MessageID: uint64(packet.Children[0].Value.(int64)), //figure out if its really unsigned Packet: packet, } if !l.sendProcessMessage(message) { diff --git a/control.go b/control.go index 60fde91..a42b5bb 100644 --- a/control.go +++ b/control.go @@ -5,9 +5,10 @@ package ldap import ( - "strings" "fmt" - "github.com/nmcclain/asn1-ber" + "strings" + + ber "github.com/go-asn1-ber/asn1-ber" ) const ( @@ -129,7 +130,7 @@ func DecodeControl(packet *ber.Packet) Control { value.Description = "Search Control Value" value.Children[0].Description = "Paging Size" value.Children[1].Description = "Cookie" - c.PagingSize = uint32(value.Children[0].Value.(uint64)) + c.PagingSize = uint32(value.Children[0].Value.(int64)) c.Cookie = value.Children[1].Data.Bytes() value.Children[1].Value = c.Cookie return c diff --git a/debug.go b/debug.go index de9bc5a..77429fb 100644 --- a/debug.go +++ b/debug.go @@ -3,7 +3,7 @@ package ldap import ( "log" - "github.com/nmcclain/asn1-ber" + ber "github.com/go-asn1-ber/asn1-ber" ) // debbuging type diff --git a/filter.go b/filter.go index 05c4bb2..248b58d 100644 --- a/filter.go +++ b/filter.go @@ -10,7 +10,7 @@ import ( "strings" "unicode/utf8" - ber "github.com/nmcclain/asn1-ber" + ber "github.com/go-asn1-ber/asn1-ber" ) const ( @@ -26,7 +26,7 @@ const ( FilterExtensibleMatch = 9 ) -var FilterMap = map[uint8]string{ +var FilterMap = map[ber.Tag]string{ FilterAnd: "And", FilterOr: "Or", FilterNot: "Not", diff --git a/filter_test.go b/filter_test.go index 5244c8a..ea2bcc8 100644 --- a/filter_test.go +++ b/filter_test.go @@ -4,12 +4,12 @@ import ( "reflect" "testing" - ber "github.com/nmcclain/asn1-ber" + ber "github.com/go-asn1-ber/asn1-ber" ) type compileTest struct { filterStr string - filterType uint8 + filterType ber.Tag } var testFilters = []compileTest{ @@ -33,7 +33,7 @@ func TestFilter(t *testing.T) { filter, err := CompileFilter(i.filterStr) if err != nil { t.Errorf("Problem compiling %s - %s", i.filterStr, err.Error()) - } else if filter.Tag != uint8(i.filterType) { + } else if filter.Tag != i.filterType { t.Errorf("%q Expected %q got %q", i.filterStr, FilterMap[i.filterType], FilterMap[filter.Tag]) } else { o, err := DecompileFilter(filter) diff --git a/go.mod b/go.mod index 430dfa0..3e1240e 100644 --- a/go.mod +++ b/go.mod @@ -2,4 +2,4 @@ module github.com/nmcclain/ldap go 1.14 -require github.com/nmcclain/asn1-ber v0.0.0-20170104154839-2661553a0484 +require github.com/go-asn1-ber/asn1-ber v1.5.4 diff --git a/go.sum b/go.sum index f925ae6..20d3e3b 100644 --- a/go.sum +++ b/go.sum @@ -1,2 +1,2 @@ -github.com/nmcclain/asn1-ber v0.0.0-20170104154839-2661553a0484 h1:D9EvfGQvlkKaDr2CRKN++7HbSXbefUNDrPq60T+g24s= -github.com/nmcclain/asn1-ber v0.0.0-20170104154839-2661553a0484/go.mod h1:O1EljZ+oHprtxDDPHiMWVo/5dBT6PlvWX5PSwj80aBA= +github.com/go-asn1-ber/asn1-ber v1.5.4 h1:vXT6d/FNDiELJnLb6hGNa309LMsrCoYFvpwHDF0+Y1A= +github.com/go-asn1-ber/asn1-ber v1.5.4/go.mod h1:hEBeB/ic+5LoWskz+yKT7vGhhPYkProFKoKdwZRWMe0= diff --git a/ldap.go b/ldap.go index e6d6d52..aa8013a 100644 --- a/ldap.go +++ b/ldap.go @@ -8,8 +8,9 @@ import ( "errors" "fmt" "io/ioutil" + "os" - "github.com/nmcclain/asn1-ber" + ber "github.com/go-asn1-ber/asn1-ber" ) // LDAP Application Codes @@ -36,7 +37,7 @@ const ( ApplicationExtendedResponse = 24 ) -var ApplicationMap = map[uint8]string{ +var ApplicationMap = map[ber.Tag]string{ ApplicationBindRequest: "Bind Request", ApplicationBindResponse: "Bind Response", ApplicationUnbindRequest: "Unbind Request", @@ -307,7 +308,7 @@ func DebugBinaryFile(fileName string) error { if err != nil { return NewError(ErrorDebugging, err) } - ber.PrintBytes(file, "") + ber.PrintBytes(os.Stdout, file, "") packet := ber.DecodePacket(file) addLDAPDescriptions(packet) ber.PrintPacket(packet) @@ -332,7 +333,7 @@ func getLDAPResultCode(packet *ber.Packet) (code LDAPResultCode, description str if len(packet.Children) >= 2 { response := packet.Children[1] if response.ClassType == ber.ClassApplication && response.TagType == ber.TypeConstructed && len(response.Children) == 3 { - return LDAPResultCode(response.Children[0].Value.(uint64)), response.Children[2].Value.(string) + return LDAPResultCode(response.Children[0].Value.(int64)), response.Children[2].Value.(string) } } diff --git a/modify.go b/modify.go index 1137e2f..635a3c6 100644 --- a/modify.go +++ b/modify.go @@ -33,7 +33,7 @@ import ( "errors" "log" - "github.com/nmcclain/asn1-ber" + ber "github.com/go-asn1-ber/asn1-ber" ) const ( diff --git a/search.go b/search.go index 45b26b8..86a6e79 100644 --- a/search.go +++ b/search.go @@ -64,7 +64,7 @@ import ( "fmt" "strings" - "github.com/nmcclain/asn1-ber" + ber "github.com/go-asn1-ber/asn1-ber" ) const ( diff --git a/server.go b/server.go index 24b1e77..d716fe3 100644 --- a/server.go +++ b/server.go @@ -8,7 +8,7 @@ import ( "strings" "sync" - "github.com/nmcclain/asn1-ber" + ber "github.com/go-asn1-ber/asn1-ber" ) type Binder interface { @@ -157,11 +157,7 @@ func (server *Server) ListenAndServeTLS(listenString string, certFile string, ke if err != nil { return err } - err = server.Serve(ln) - if err != nil { - return err - } - return nil + return server.Serve(ln) } func (server *Server) SetStats(enable bool) { @@ -185,11 +181,7 @@ func (server *Server) ListenAndServe(listenString string) error { if err != nil { return err } - err = server.Serve(ln) - if err != nil { - return err - } - return nil + return server.Serve(ln) } func (server *Server) Serve(ln net.Listener) error { @@ -215,12 +207,19 @@ listener: go server.handleConnection(c) case <-server.Quit: ln.Close() + close(server.Quit) break listener } } return nil } +//Close closes the underlying net.Listener, and waits for confirmation +func (server *Server) Close() { + server.Quit <- true + <-server.Quit +} + // func (server *Server) handleConnection(conn net.Conn) { boundDN := "" // "" == anonymous @@ -229,7 +228,7 @@ handler: for { // read incoming LDAP packet packet, err := ber.ReadPacket(conn) - if err == io.EOF { // Client closed connection + if err == io.EOF || err == io.ErrUnexpectedEOF { // Client closed connection break } else if err != nil { log.Printf("handleConnection ber.ReadPacket ERROR: %s", err.Error()) @@ -242,11 +241,12 @@ handler: break } // check the message ID and ClassType - messageID, ok := packet.Children[0].Value.(uint64) + messageID64, ok := packet.Children[0].Value.(int64) if !ok { log.Print("malformed messageID") break } + messageID := uint64(messageID64) req := packet.Children[1] if req.ClassType != ber.ClassApplication { log.Print("req.ClassType != ber.ClassApplication") @@ -380,7 +380,7 @@ func routeFunc(dn string, funcNames []string) string { dnMatch := "," + strings.ToLower(dn) var weight int for _, fn := range funcNames { - if strings.HasSuffix(dnMatch, "," + fn) { + if strings.HasSuffix(dnMatch, ","+fn) { // empty string as 0, no-comma string 1 , etc if fn == "" { weight = 0 @@ -400,7 +400,7 @@ func routeFunc(dn string, funcNames []string) string { func encodeLDAPResponse(messageID uint64, responseType uint8, ldapResultCode LDAPResultCode, message string) *ber.Packet { responsePacket := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Response") responsePacket.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, messageID, "Message ID")) - reponse := ber.Encode(ber.ClassApplication, ber.TypeConstructed, responseType, nil, ApplicationMap[responseType]) + reponse := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ber.Tag(responseType), nil, ApplicationMap[ber.Tag(responseType)]) reponse.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagEnumerated, uint64(ldapResultCode), "resultCode: ")) reponse.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, "", "matchedDN: ")) reponse.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, message, "errorMessage: ")) diff --git a/server_bind.go b/server_bind.go index 5a80bf5..61aa691 100644 --- a/server_bind.go +++ b/server_bind.go @@ -1,9 +1,10 @@ package ldap import ( - "github.com/nmcclain/asn1-ber" "log" "net" + + ber "github.com/go-asn1-ber/asn1-ber" ) func HandleBindRequest(req *ber.Packet, fns map[string]Binder, conn net.Conn) (resultCode LDAPResultCode) { @@ -14,7 +15,7 @@ func HandleBindRequest(req *ber.Packet, fns map[string]Binder, conn net.Conn) (r }() // we only support ldapv3 - ldapVersion, ok := req.Children[0].Value.(uint64) + ldapVersion, ok := req.Children[0].Value.(int64) if !ok { return LDAPResultProtocolError } diff --git a/server_modify.go b/server_modify.go index ca68e40..5e5d249 100644 --- a/server_modify.go +++ b/server_modify.go @@ -4,7 +4,7 @@ import ( "log" "net" - "github.com/nmcclain/asn1-ber" + ber "github.com/go-asn1-ber/asn1-ber" ) func HandleAddRequest(req *ber.Packet, boundDN string, fns map[string]Adder, conn net.Conn) (resultCode LDAPResultCode) { @@ -96,7 +96,7 @@ func HandleModifyRequest(req *ber.Packet, boundDN string, fns map[string]Modifie } attr.AttrVals = append(attr.AttrVals, v) } - op, ok := change.Children[0].Value.(uint64) + op, ok := change.Children[0].Value.(int64) if !ok { return LDAPResultProtocolError } diff --git a/server_modify_test.go b/server_modify_test.go index 4501119..0818998 100644 --- a/server_modify_test.go +++ b/server_modify_test.go @@ -10,13 +10,11 @@ import ( // func TestAdd(t *testing.T) { - quit := make(chan bool) done := make(chan bool) + s := NewServer() + s.BindFunc("", modifyTestHandler{}) + s.AddFunc("", modifyTestHandler{}) go func() { - s := NewServer() - s.QuitChannel(quit) - s.BindFunc("", modifyTestHandler{}) - s.AddFunc("", modifyTestHandler{}) if err := s.ListenAndServe(listenString); err != nil { t.Errorf("s.ListenAndServe failed: %s", err.Error()) } @@ -42,18 +40,16 @@ func TestAdd(t *testing.T) { case <-time.After(timeout): t.Errorf("ldapadd command timed out") } - quit <- true + s.Close() } // func TestDelete(t *testing.T) { - quit := make(chan bool) done := make(chan bool) + s := NewServer() + s.BindFunc("", modifyTestHandler{}) + s.DeleteFunc("", modifyTestHandler{}) go func() { - s := NewServer() - s.QuitChannel(quit) - s.BindFunc("", modifyTestHandler{}) - s.DeleteFunc("", modifyTestHandler{}) if err := s.ListenAndServe(listenString); err != nil { t.Errorf("s.ListenAndServe failed: %s", err.Error()) } @@ -76,17 +72,15 @@ func TestDelete(t *testing.T) { case <-time.After(timeout): t.Errorf("ldapdelete command timed out") } - quit <- true + s.Close() } func TestModify(t *testing.T) { - quit := make(chan bool) done := make(chan bool) + s := NewServer() + s.BindFunc("", modifyTestHandler{}) + s.ModifyFunc("", modifyTestHandler{}) go func() { - s := NewServer() - s.QuitChannel(quit) - s.BindFunc("", modifyTestHandler{}) - s.ModifyFunc("", modifyTestHandler{}) if err := s.ListenAndServe(listenString); err != nil { t.Errorf("s.ListenAndServe failed: %s", err.Error()) } @@ -109,7 +103,7 @@ func TestModify(t *testing.T) { case <-time.After(timeout): t.Errorf("ldapadd command timed out") } - quit <- true + s.Close() } /* diff --git a/server_search.go b/server_search.go index 12a6caf..ff43a44 100644 --- a/server_search.go +++ b/server_search.go @@ -6,7 +6,7 @@ import ( "net" "strings" - ber "github.com/nmcclain/asn1-ber" + ber "github.com/go-asn1-ber/asn1-ber" ) func HandleSearchRequest(req *ber.Packet, controls *[]Control, messageID uint64, boundDN string, server *Server, conn net.Conn) (resultErr error) { @@ -109,22 +109,22 @@ func parseSearchRequest(boundDN string, req *ber.Packet, controls *[]Control) (S if !ok { return SearchRequest{}, NewError(LDAPResultProtocolError, errors.New("Bad search request")) } - s, ok := req.Children[1].Value.(uint64) + s, ok := req.Children[1].Value.(int64) if !ok { return SearchRequest{}, NewError(LDAPResultProtocolError, errors.New("Bad search request")) } scope := int(s) - d, ok := req.Children[2].Value.(uint64) + d, ok := req.Children[2].Value.(int64) if !ok { return SearchRequest{}, NewError(LDAPResultProtocolError, errors.New("Bad search request")) } derefAliases := int(d) - s, ok = req.Children[3].Value.(uint64) + s, ok = req.Children[3].Value.(int64) if !ok { return SearchRequest{}, NewError(LDAPResultProtocolError, errors.New("Bad search request")) } sizeLimit := int(s) - t, ok := req.Children[4].Value.(uint64) + t, ok := req.Children[4].Value.(int64) if !ok { return SearchRequest{}, NewError(LDAPResultProtocolError, errors.New("Bad search request")) } diff --git a/server_search_test.go b/server_search_test.go index 09e2b14..0580440 100644 --- a/server_search_test.go +++ b/server_search_test.go @@ -9,13 +9,11 @@ import ( // func TestSearchSimpleOK(t *testing.T) { - quit := make(chan bool) done := make(chan bool) + s := NewServer() + s.SearchFunc("", searchSimple{}) + s.BindFunc("", bindSimple{}) go func() { - s := NewServer() - s.QuitChannel(quit) - s.SearchFunc("", searchSimple{}) - s.BindFunc("", bindSimple{}) if err := s.ListenAndServe(listenString); err != nil { t.Errorf("s.ListenAndServe failed: %s", err.Error()) } @@ -47,18 +45,16 @@ func TestSearchSimpleOK(t *testing.T) { case <-time.After(timeout): t.Errorf("ldapsearch command timed out") } - quit <- true + s.Close() } func TestSearchSizelimit(t *testing.T) { - quit := make(chan bool) done := make(chan bool) + s := NewServer() + s.EnforceLDAP = true + s.SearchFunc("", searchSimple{}) + s.BindFunc("", bindSimple{}) go func() { - s := NewServer() - s.EnforceLDAP = true - s.QuitChannel(quit) - s.SearchFunc("", searchSimple{}) - s.BindFunc("", bindSimple{}) if err := s.ListenAndServe(listenString); err != nil { t.Errorf("s.ListenAndServe failed: %s", err.Error()) } @@ -142,20 +138,18 @@ func TestSearchSizelimit(t *testing.T) { case <-time.After(timeout): t.Errorf("ldapsearch command timed out") } - quit <- true + s.Close() } ///////////////////////// func TestBindSearchMulti(t *testing.T) { - quit := make(chan bool) done := make(chan bool) + s := NewServer() + s.BindFunc("", bindSimple{}) + s.BindFunc("c=testz", bindSimple2{}) + s.SearchFunc("", searchSimple{}) + s.SearchFunc("c=testz", searchSimple2{}) go func() { - s := NewServer() - s.QuitChannel(quit) - s.BindFunc("", bindSimple{}) - s.BindFunc("c=testz", bindSimple2{}) - s.SearchFunc("", searchSimple{}) - s.SearchFunc("c=testz", searchSimple2{}) if err := s.ListenAndServe(listenString); err != nil { t.Errorf("s.ListenAndServe failed: %s", err.Error()) } @@ -188,19 +182,16 @@ func TestBindSearchMulti(t *testing.T) { case <-time.After(timeout): t.Errorf("ldapsearch command timed out") } - - quit <- true + s.Close() } ///////////////////////// func TestSearchPanic(t *testing.T) { - quit := make(chan bool) done := make(chan bool) + s := NewServer() + s.SearchFunc("", searchPanic{}) + s.BindFunc("", bindAnonOK{}) go func() { - s := NewServer() - s.QuitChannel(quit) - s.SearchFunc("", searchPanic{}) - s.BindFunc("", bindAnonOK{}) if err := s.ListenAndServe(listenString); err != nil { t.Errorf("s.ListenAndServe failed: %s", err.Error()) } @@ -220,7 +211,7 @@ func TestSearchPanic(t *testing.T) { case <-time.After(timeout): t.Errorf("ldapsearch command timed out") } - quit <- true + s.Close() } ///////////////////////// @@ -260,14 +251,12 @@ var searchFilterTestFilters = []compileSearchFilterTest{ ///////////////////////// func TestSearchFiltering(t *testing.T) { - quit := make(chan bool) done := make(chan bool) + s := NewServer() + s.EnforceLDAP = true + s.SearchFunc("", searchSimple{}) + s.BindFunc("", bindSimple{}) go func() { - s := NewServer() - s.EnforceLDAP = true - s.QuitChannel(quit) - s.SearchFunc("", searchSimple{}) - s.BindFunc("", bindSimple{}) if err := s.ListenAndServe(listenString); err != nil { t.Errorf("s.ListenAndServe failed: %s", err.Error()) } @@ -292,19 +281,17 @@ func TestSearchFiltering(t *testing.T) { t.Errorf("ldapsearch command timed out") } } - quit <- true + s.Close() } ///////////////////////// func TestSearchAttributes(t *testing.T) { - quit := make(chan bool) done := make(chan bool) + s := NewServer() + s.EnforceLDAP = true + s.SearchFunc("", searchSimple{}) + s.BindFunc("", bindSimple{}) go func() { - s := NewServer() - s.EnforceLDAP = true - s.QuitChannel(quit) - s.SearchFunc("", searchSimple{}) - s.BindFunc("", bindSimple{}) if err := s.ListenAndServe(listenString); err != nil { t.Errorf("s.ListenAndServe failed: %s", err.Error()) } @@ -336,18 +323,16 @@ func TestSearchAttributes(t *testing.T) { case <-time.After(timeout): t.Errorf("ldapsearch command timed out") } - quit <- true + s.Close() } func TestSearchAllUserAttributes(t *testing.T) { - quit := make(chan bool) done := make(chan bool) + s := NewServer() + s.EnforceLDAP = true + s.SearchFunc("", searchSimple{}) + s.BindFunc("", bindSimple{}) go func() { - s := NewServer() - s.EnforceLDAP = true - s.QuitChannel(quit) - s.SearchFunc("", searchSimple{}) - s.BindFunc("", bindSimple{}) if err := s.ListenAndServe(listenString); err != nil { t.Errorf("s.ListenAndServe failed: %s", err.Error()) } @@ -388,19 +373,17 @@ func TestSearchAllUserAttributes(t *testing.T) { case <-time.After(timeout): t.Errorf("ldapsearch command timed out") } - quit <- true + s.Close() } ///////////////////////// func TestSearchScope(t *testing.T) { - quit := make(chan bool) done := make(chan bool) + s := NewServer() + s.EnforceLDAP = true + s.SearchFunc("", searchSimple{}) + s.BindFunc("", bindSimple{}) go func() { - s := NewServer() - s.EnforceLDAP = true - s.QuitChannel(quit) - s.SearchFunc("", searchSimple{}) - s.BindFunc("", bindSimple{}) if err := s.ListenAndServe(listenString); err != nil { t.Errorf("s.ListenAndServe failed: %s", err.Error()) } @@ -448,20 +431,17 @@ func TestSearchScope(t *testing.T) { case <-time.After(timeout): t.Errorf("ldapsearch command timed out") } - quit <- true + s.Close() } - ///////////////////////// func TestSearchScopeCaseInsensitive(t *testing.T) { - quit := make(chan bool) done := make(chan bool) + s := NewServer() + s.EnforceLDAP = true + s.SearchFunc("", searchCaseInsensitive{}) + s.BindFunc("", bindCaseInsensitive{}) go func() { - s := NewServer() - s.EnforceLDAP = true - s.QuitChannel(quit) - s.SearchFunc("", searchCaseInsensitive{}) - s.BindFunc("", bindCaseInsensitive{}) if err := s.ListenAndServe(listenString); err != nil { t.Errorf("s.ListenAndServe failed: %s", err.Error()) } @@ -483,18 +463,15 @@ func TestSearchScopeCaseInsensitive(t *testing.T) { case <-time.After(timeout): t.Errorf("ldapsearch command timed out") } - quit <- true + s.Close() } - func TestSearchControls(t *testing.T) { - quit := make(chan bool) done := make(chan bool) + s := NewServer() + s.SearchFunc("", searchControls{}) + s.BindFunc("", bindSimple{}) go func() { - s := NewServer() - s.QuitChannel(quit) - s.SearchFunc("", searchControls{}) - s.BindFunc("", bindSimple{}) if err := s.ListenAndServe(listenString); err != nil { t.Errorf("s.ListenAndServe failed: %s", err.Error()) } @@ -537,5 +514,5 @@ func TestSearchControls(t *testing.T) { case <-time.After(timeout): t.Errorf("ldapsearch command timed out") } - quit <- true + s.Close() } diff --git a/server_test.go b/server_test.go index 233f7ea..636d041 100644 --- a/server_test.go +++ b/server_test.go @@ -17,12 +17,10 @@ var serverBaseDN = "o=testers,c=test" ///////////////////////// func TestBindAnonOK(t *testing.T) { - quit := make(chan bool) done := make(chan bool) + s := NewServer() + s.BindFunc("", bindAnonOK{}) go func() { - s := NewServer() - s.QuitChannel(quit) - s.BindFunc("", bindAnonOK{}) if err := s.ListenAndServe(listenString); err != nil { t.Errorf("s.ListenAndServe failed: %s", err.Error()) } @@ -42,16 +40,14 @@ func TestBindAnonOK(t *testing.T) { case <-time.After(timeout): t.Errorf("ldapsearch command timed out") } - quit <- true + s.Close() } ///////////////////////// func TestBindAnonFail(t *testing.T) { - quit := make(chan bool) done := make(chan bool) + s := NewServer() go func() { - s := NewServer() - s.QuitChannel(quit) if err := s.ListenAndServe(listenString); err != nil { t.Errorf("s.ListenAndServe failed: %s", err.Error()) } @@ -72,19 +68,16 @@ func TestBindAnonFail(t *testing.T) { case <-time.After(timeout): t.Errorf("ldapsearch command timed out") } - time.Sleep(timeout) - quit <- true + s.Close() } ///////////////////////// func TestBindSimpleOK(t *testing.T) { - quit := make(chan bool) done := make(chan bool) + s := NewServer() + s.SearchFunc("", searchSimple{}) + s.BindFunc("", bindSimple{}) go func() { - s := NewServer() - s.QuitChannel(quit) - s.SearchFunc("", searchSimple{}) - s.BindFunc("", bindSimple{}) if err := s.ListenAndServe(listenString); err != nil { t.Errorf("s.ListenAndServe failed: %s", err.Error()) } @@ -107,17 +100,15 @@ func TestBindSimpleOK(t *testing.T) { case <-time.After(timeout): t.Errorf("ldapsearch command timed out") } - quit <- true + s.Close() } ///////////////////////// func TestBindSimpleFailBadPw(t *testing.T) { - quit := make(chan bool) done := make(chan bool) + s := NewServer() + s.BindFunc("", bindSimple{}) go func() { - s := NewServer() - s.QuitChannel(quit) - s.BindFunc("", bindSimple{}) if err := s.ListenAndServe(listenString); err != nil { t.Errorf("s.ListenAndServe failed: %s", err.Error()) } @@ -140,17 +131,15 @@ func TestBindSimpleFailBadPw(t *testing.T) { case <-time.After(timeout): t.Errorf("ldapsearch command timed out") } - quit <- true + s.Close() } ///////////////////////// func TestBindSimpleFailBadDn(t *testing.T) { - quit := make(chan bool) done := make(chan bool) + s := NewServer() + s.BindFunc("", bindSimple{}) go func() { - s := NewServer() - s.QuitChannel(quit) - s.BindFunc("", bindSimple{}) if err := s.ListenAndServe(listenString); err != nil { t.Errorf("s.ListenAndServe failed: %s", err.Error()) } @@ -173,27 +162,25 @@ func TestBindSimpleFailBadDn(t *testing.T) { case <-time.After(timeout): t.Errorf("ldapsearch command timed out") } - quit <- true + s.Close() } ///////////////////////// func TestBindSSL(t *testing.T) { ldapURLSSL := "ldaps://" + listenString longerTimeout := 300 * time.Millisecond - quit := make(chan bool) done := make(chan bool) + s := NewServer() + s.BindFunc("", bindAnonOK{}) go func() { - s := NewServer() - s.QuitChannel(quit) - s.BindFunc("", bindAnonOK{}) if err := s.ListenAndServeTLS(listenString, "tests/cert_DONOTUSE.pem", "tests/key_DONOTUSE.pem"); err != nil { t.Errorf("s.ListenAndServeTLS failed: %s", err.Error()) } }() go func() { - time.Sleep(longerTimeout * 2) cmd := exec.Command("ldapsearch", "-H", ldapURLSSL, "-x", "-b", "o=testers,c=test") + cmd.Env = []string{"LDAPTLS_REQCERT=ALLOW"} out, _ := cmd.CombinedOutput() if !strings.Contains(string(out), "result: 0 Success") { t.Errorf("ldapsearch failed: %v", string(out)) @@ -206,17 +193,15 @@ func TestBindSSL(t *testing.T) { case <-time.After(longerTimeout * 2): t.Errorf("ldapsearch command timed out") } - quit <- true + s.Close() } ///////////////////////// func TestBindPanic(t *testing.T) { - quit := make(chan bool) done := make(chan bool) + s := NewServer() + s.BindFunc("", bindPanic{}) go func() { - s := NewServer() - s.QuitChannel(quit) - s.BindFunc("", bindPanic{}) if err := s.ListenAndServe(listenString); err != nil { t.Errorf("s.ListenAndServe failed: %s", err.Error()) } @@ -236,7 +221,7 @@ func TestBindPanic(t *testing.T) { case <-time.After(timeout): t.Errorf("ldapsearch command timed out") } - quit <- true + s.Close() } ///////////////////////// @@ -253,15 +238,13 @@ func TestSearchStats(t *testing.T) { w := testStatsWriter{&bytes.Buffer{}} log.SetOutput(w) - quit := make(chan bool) done := make(chan bool) s := NewServer() + s.SearchFunc("", searchSimple{}) + s.BindFunc("", bindAnonOK{}) + s.SetStats(true) go func() { - s.QuitChannel(quit) - s.SearchFunc("", searchSimple{}) - s.BindFunc("", bindAnonOK{}) - s.SetStats(true) if err := s.ListenAndServe(listenString); err != nil { t.Errorf("s.ListenAndServe failed: %s", err.Error()) } @@ -287,7 +270,7 @@ func TestSearchStats(t *testing.T) { if stats.Conns != 1 || stats.Binds != 1 { t.Errorf("Stats data missing or incorrect: %v", w.buffer.String()) } - quit <- true + s.Close() } ///////////////////////// @@ -339,7 +322,6 @@ func (b bindCaseInsensitive) Bind(bindDN, bindSimplePw string, conn net.Conn) (L return LDAPResultInvalidCredentials, nil } - type searchSimple struct { } @@ -420,7 +402,6 @@ func (s searchControls) Search(boundDN string, searchReq SearchRequest, conn net return ServerSearchResult{entries, []string{}, []Control{}, LDAPResultSuccess}, nil } - type searchCaseInsensitive struct { } @@ -439,7 +420,6 @@ func (s searchCaseInsensitive) Search(boundDN string, searchReq SearchRequest, c return ServerSearchResult{entries, []string{}, []Control{}, LDAPResultSuccess}, nil } - func TestRouteFunc(t *testing.T) { if routeFunc("", []string{"a", "xyz", "tt"}) != "" { t.Error("routeFunc failed")