diff --git a/handler_test.go b/handler_test.go index 93598a5..2b01acc 100644 --- a/handler_test.go +++ b/handler_test.go @@ -15,7 +15,7 @@ import ( ) // nolint: gochecknoglobals -var ldapTestCases = []test.Case{ +var ldapHandlerTestCases = []test.Case{ { // Simple case Qname: "a.example.org.", Qtype: dns.TypeA, @@ -26,16 +26,16 @@ var ldapTestCases = []test.Case{ } // Create a new Ldap Plugin. Use the test.ErrorHandler as the next plugin. -func newTestLdap() *Ldap { +func newTestLdapHandler() *Ldap { ldap := New([]string{"example.org.", "www.example.org.", "example.org.", "sample.example.org."}) - ldap.Zones.Z = newTestLdapZones() + ldap.Zones.Z = newTestLdapHandlerZones() ldap.Fall = fall.Zero ldap.Next = test.ErrorHandler() return ldap } -func newTestLdapZones() map[string]*file.Zone { +func newTestLdapHandlerZones() map[string]*file.Zone { Zone := file.NewZone("example.org.", "") if err := Zone.Insert(SOA("example.org.")); err != nil { panic("omg") @@ -58,9 +58,9 @@ func newTestLdapZones() map[string]*file.Zone { } func TestServeDNS(t *testing.T) { - ldap := newTestLdap() + ldap := newTestLdapHandler() - for i, tc := range ldapTestCases { + for i, tc := range ldapHandlerTestCases { req := tc.Msg() rec := dnstest.NewRecorder(&test.ResponseWriter{}) diff --git a/ldap.go b/ldap.go index 186ebce..4cec209 100644 --- a/ldap.go +++ b/ldap.go @@ -39,17 +39,19 @@ type Ldap struct { Client ldap.Client Zones file.Zones - searchRequest *ldap.SearchRequest - ldapURL string - pagingLimit uint32 - syncInterval time.Duration - username string - password string - sasl bool - fqdnAttr string - ip4Attr string - zMu sync.RWMutex - ttl time.Duration + // Exported for mocking in test + SearchRequest *ldap.SearchRequest + FqdnAttr string + Ip4Attr string + + ldapURL string + pagingLimit uint32 + syncInterval time.Duration + username string + password string + sasl bool + zMu sync.RWMutex + ttl time.Duration } // New returns an initialized Ldap with defaults. @@ -57,13 +59,19 @@ func New(zoneNames []string) *Ldap { l := new(Ldap) l.Zones.Names = zoneNames l.pagingLimit = 0 + l.syncInterval = 60 * time.Second // SearchRequest defaults - l.searchRequest = new(ldap.SearchRequest) - l.searchRequest.DerefAliases = ldap.NeverDerefAliases // TODO: Reason - l.searchRequest.Scope = ldap.ScopeWholeSubtree // search whole subtree - l.searchRequest.SizeLimit = 500 // TODO: Reason - l.searchRequest.TimeLimit = 500 // TODO: Reason - l.searchRequest.TypesOnly = false // TODO: Reason + l.SearchRequest = new(ldap.SearchRequest) + l.SearchRequest.DerefAliases = ldap.NeverDerefAliases // TODO: Reason + l.SearchRequest.Scope = ldap.ScopeWholeSubtree // search whole subtree + l.SearchRequest.SizeLimit = 500 // TODO: Reason + l.SearchRequest.TimeLimit = 500 // TODO: Reason + l.SearchRequest.TypesOnly = false // TODO: Reason + l.Zones.Z = make(map[string]*file.Zone, len(zoneNames)) + + for _, zn := range zoneNames { + l.Zones.Z[zn] = nil + } return l } diff --git a/setup.go b/setup.go index 4083726..402b539 100644 --- a/setup.go +++ b/setup.go @@ -125,13 +125,13 @@ func ParseStanza(c *caddy.Controller) (*Ldap, error) { return nil, c.ArgErr() } - ldap.searchRequest.BaseDN = c.Val() // ou=ae-dir + ldap.SearchRequest.BaseDN = c.Val() // ou=ae-dir case "filter": if !c.NextArg() { return nil, c.ArgErr() } - ldap.searchRequest.Filter = c.Val() // (objectClass=aeNwDevice) + ldap.SearchRequest.Filter = c.Val() // (objectClass=aeNwDevice) case "attributes": c.Next() @@ -142,15 +142,15 @@ func ParseStanza(c *caddy.Controller) (*Ldap, error) { return nil, c.ArgErr() } - ldap.searchRequest.Attributes = append(ldap.searchRequest.Attributes, c.Val()) - ldap.fqdnAttr = c.Val() // aeFqdn + ldap.SearchRequest.Attributes = append(ldap.SearchRequest.Attributes, c.Val()) + ldap.FqdnAttr = c.Val() // aeFqdn case "ip4": if !c.NextArg() { return nil, c.ArgErr() } - ldap.searchRequest.Attributes = append(ldap.searchRequest.Attributes, c.Val()) - ldap.ip4Attr = c.Val() // ipHostNumber + ldap.SearchRequest.Attributes = append(ldap.SearchRequest.Attributes, c.Val()) + ldap.Ip4Attr = c.Val() // ipHostNumber default: return nil, c.Errf("unknown attributes property '%s'", c.Val()) } @@ -205,19 +205,19 @@ func ParseStanza(c *caddy.Controller) (*Ldap, error) { return nil, c.Err("ldap_url cannot be empty") } - if ldap.searchRequest.BaseDN == "" { + if ldap.SearchRequest.BaseDN == "" { return nil, c.Err("base_dn cannot be empty") } - if ldap.searchRequest.Filter == "" { + if ldap.SearchRequest.Filter == "" { return nil, c.Err("filter cannot be empty") } - if ldap.fqdnAttr == "" { + if ldap.FqdnAttr == "" { return nil, c.Err("fqdn attribute cannot be empty") } - if ldap.ip4Attr == "" { + if ldap.Ip4Attr == "" { return nil, c.Err("ip4 attribute cannot be empty") } diff --git a/sync.go b/sync.go index d735e16..e17282b 100644 --- a/sync.go +++ b/sync.go @@ -3,6 +3,7 @@ package ldap import ( "context" "fmt" + "strings" "net" "time" @@ -12,7 +13,7 @@ import ( // Run updates the zone from ldap. func (l *Ldap) Run(ctx context.Context) error { - if err := l.updateZones(); err != nil { + if err := l.UpdateZones(); err != nil { return err } @@ -23,7 +24,7 @@ func (l *Ldap) Run(ctx context.Context) error { log.Infof("Breaking out of Ldap update loop: %v", ctx.Err()) return case <-time.After(l.syncInterval): - if err := l.updateZones(); err != nil && ctx.Err() == nil { + if err := l.UpdateZones(); err != nil && ctx.Err() == nil { log.Errorf("Failed to update zones: %v", err) } } @@ -34,10 +35,17 @@ func (l *Ldap) Run(ctx context.Context) error { return nil } -func (l *Ldap) updateZones() error { +func (l *Ldap) UpdateZones() error { zoneFileMap := make(map[string]*file.Zone, len(l.Zones.Names)) for _, zn := range l.Zones.Names { zoneFileMap[zn] = nil + zoneFileMap[zn] = file.NewZone(zn, "") + zoneFileMap[zn].Upstream = l.Upstream + + err := zoneFileMap[zn].Insert(SOA(zn)) + if err != nil { + return fmt.Errorf("updating zones: %w", err) + } } ldapRecords, err := l.fetchLdapRecords() @@ -46,20 +54,10 @@ func (l *Ldap) updateZones() error { } for zn, lrpz := range l.mapLdapRecordsToZone(ldapRecords) { - if lrpz == nil { + if len(lrpz) == 0 { continue } - if zoneFileMap[zn] == nil { - zoneFileMap[zn] = file.NewZone(zn, "") - zoneFileMap[zn].Upstream = l.Upstream - - err = zoneFileMap[zn].Insert(SOA(zn)) - if err != nil { - return fmt.Errorf("updating zones: %w", err) - } - } - for _, lr := range lrpz { err = zoneFileMap[zn].Insert(lr.A()) if err != nil { @@ -70,8 +68,7 @@ func (l *Ldap) updateZones() error { l.zMu.Lock() for zn, zf := range zoneFileMap { - // TODO: assignement copies lock value from file.Zone - (*l.Zones.Z[zn]) = *zf + l.Zones.Z[zn] = zf } l.zMu.Unlock() @@ -95,16 +92,20 @@ func (l *Ldap) mapLdapRecordsToZone(ldapRecords []ldapRecord) (ldapRecordsPerZon } func (l *Ldap) fetchLdapRecords() (ldapRecords []ldapRecord, err error) { - searchResult, err := l.Client.SearchWithPaging(l.searchRequest, l.pagingLimit) + searchResult, err := l.Client.SearchWithPaging(l.SearchRequest, l.pagingLimit) if err != nil { return nil, fmt.Errorf("fetching data from server: %w", err) } ldapRecords = make([]ldapRecord, len(searchResult.Entries)) for i := 0; i < len(ldapRecords); i++ { + fqdn := searchResult.Entries[i].GetAttributeValue(l.FqdnAttr) + if !strings.HasSuffix(fqdn, ".") { + fqdn = fqdn + "." + } ldapRecords[i] = ldapRecord{ - fqdn: searchResult.Entries[i].GetAttributeValue(l.fqdnAttr), - ip: net.ParseIP(searchResult.Entries[i].GetAttributeValue(l.ip4Attr)), + fqdn: fqdn, + ip: net.ParseIP(searchResult.Entries[i].GetAttributeValue(l.Ip4Attr)), } } diff --git a/sync_test.go b/sync_test.go new file mode 100644 index 0000000..6eea205 --- /dev/null +++ b/sync_test.go @@ -0,0 +1,105 @@ +package ldap_test + +import ( + "crypto/tls" + "testing" + "time" + + "gopkg.in/ldap.v3" + + . "github.com/xoe-labs/ldap/v0" +) + +type mockClient struct{} + +func (m *mockClient) Start() {} +func (m *mockClient) StartTLS(*tls.Config) error { return nil } +func (m *mockClient) Close() {} +func (m *mockClient) SetTimeout(time.Duration) {} + +func (m *mockClient) Bind(username, password string) error { return nil } +func (m *mockClient) UnauthenticatedBind(username string) error { return nil } +func (m *mockClient) SimpleBind(*ldap.SimpleBindRequest) (*ldap.SimpleBindResult, error) { + return nil, nil +} +func (m *mockClient) ExternalBind() error { return nil } + +func (m *mockClient) Add(*ldap.AddRequest) error { return nil } +func (m *mockClient) Del(*ldap.DelRequest) error { return nil } +func (m *mockClient) Modify(*ldap.ModifyRequest) error { return nil } +func (m *mockClient) ModifyDN(*ldap.ModifyDNRequest) error { return nil } + +func (m *mockClient) Compare(dn, attribute, value string) (bool, error) { return false, nil } +func (m *mockClient) PasswordModify(*ldap.PasswordModifyRequest) (*ldap.PasswordModifyResult, error) { + return nil, nil +} + +func (m *mockClient) Search(*ldap.SearchRequest) (*ldap.SearchResult, error) { return nil, nil } +func (m *mockClient) SearchWithPaging(searchRequest *ldap.SearchRequest, pagingSize uint32) (*ldap.SearchResult, error) { + return &ldap.SearchResult{ + Entries: []*ldap.Entry{{ + DN: "ou=ae-dir, cn=host1", + Attributes: []*ldap.EntryAttribute{{ + Name: "aeFqdn", + Values: []string{ + "host1.example.org.", + }, + }, { + Name: "ipHostNumber", + Values: []string{ + "1.2.3.4", + }, + }}, + }, { + DN: "ou=ae-dir, cn=host2", + Attributes: []*ldap.EntryAttribute{{ + Name: "aeFqdn", + // Without ending "." + Values: []string{ + "host2.example.org", + }, + }, { + Name: "ipHostNumber", + Values: []string{ + "1.2.3.5", + }, + }}, + }, { + DN: "ou=ae-dir, cn=host3", + Attributes: []*ldap.EntryAttribute{{ + Name: "aeFqdn", + Values: []string{ + "host3.sample.example.org.", + }, + }, { + Name: "ipHostNumber", + Values: []string{ + "1.2.3.6", + }, + }}, + }}, + }, nil +} + +// Create a new Ldap Plugin. Uses a mocked client. +func newTestLdapSync() *Ldap { + ldap := New([]string{"example.org.", "www.example.org.", "example.org.", "sample.example.org."}) + ldap.Client = &mockClient{} + ldap.SearchRequest.Attributes = []string{ + "aeFqdn", "ipHostNumber", + } + ldap.SearchRequest.BaseDN = "ou=ae-dir" + ldap.SearchRequest.Filter = "(objectClass=aeNwDevice)" + ldap.FqdnAttr = "aeFqdn" + ldap.Ip4Attr = "ipHostNumber" + + return ldap +} + +// TestUpdateZone tests a zone update. +func TestUpdateZone(t *testing.T) { + ldap := newTestLdapSync() + if err := ldap.UpdateZones(); err != nil { + t.Fatalf("error updating zones: %v", err) + } +}