From 5209d319299c9ac750669f6c5079b414194189ef Mon Sep 17 00:00:00 2001 From: David Arnold Date: Wed, 10 Jun 2020 02:41:44 -0500 Subject: [PATCH] lint: wsl only --- handler.go | 5 +++++ handler_test.go | 12 +++++++++++- ldap.go | 3 +++ setup.go | 25 ++++++++++++++++++++++++- setup_test.go | 4 +++- sync.go | 17 +++++++++++++++-- 6 files changed, 61 insertions(+), 5 deletions(-) diff --git a/handler.go b/handler.go index ecd4bca..fd391f0 100644 --- a/handler.go +++ b/handler.go @@ -19,14 +19,18 @@ func (l *Ldap) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) ( if zone == "" { return plugin.NextOrFailure(l.Name(), l.Next, ctx, w, r) } + Zone, ok := l.Zones.Z[zone] if !ok || Zone == nil { return dns.RcodeServerFailure, nil } + var result file.Result + m := new(dns.Msg) m.SetReply(r) m.Authoritative = true + l.zMu.RLock() m.Answer, m.Ns, m.Extra, result = Zone.Lookup(ctx, state, state.Name()) l.zMu.RUnlock() @@ -46,6 +50,7 @@ func (l *Ldap) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) ( return dns.RcodeServerFailure, nil } w.WriteMsg(m) + return dns.RcodeSuccess, nil } diff --git a/handler_test.go b/handler_test.go index f28c54a..3ea2e9b 100644 --- a/handler_test.go +++ b/handler_test.go @@ -1,4 +1,4 @@ -package ldap +package ldap_test import ( "context" @@ -10,6 +10,8 @@ import ( "github.com/coredns/coredns/plugin/test" "github.com/miekg/dns" + + . "github.com/xoe-labs/ldap/v0" ) var ldapTestCases = []test.Case{ @@ -28,12 +30,14 @@ func newTestLdap() *Ldap { ldap.Zones.Z = newTestLdapZones() ldap.Fall = fall.Zero ldap.Next = test.ErrorHandler() + return ldap } func newTestLdapZones() map[string]*file.Zone { Zone := file.NewZone("example.org.", "") Zone.Insert(SOA("example.org.")) + for _, rr := range []string{ "example.org. " + defaultA, "a.example.org. " + defaultA, @@ -41,25 +45,31 @@ func newTestLdapZones() map[string]*file.Zone { r, _ := dns.NewRR(rr) Zone.Insert(r) } + zones := make(map[string]*file.Zone) zones["example.org."] = Zone + return zones } func TestServeDNS(t *testing.T) { ldap := newTestLdap() + for i, tc := range ldapTestCases { req := tc.Msg() rec := dnstest.NewRecorder(&test.ResponseWriter{}) + _, err := ldap.ServeDNS(context.Background(), rec, req) if err != nil { t.Errorf("Expected no error, got %v", err) continue } + resp := rec.Msg if resp == nil { t.Fatalf("Test %d, got nil message and no error for %q", i, req.Question[0].Name) } + if err := test.SortAndCheck(resp, tc); err != nil { t.Error(err) } diff --git a/ldap.go b/ldap.go index 06cb56e..6839343 100644 --- a/ldap.go +++ b/ldap.go @@ -64,6 +64,7 @@ func New(zoneNames []string) *Ldap { l.searchRequest.SizeLimit = 500 // TODO: Reason l.searchRequest.TimeLimit = 500 // TODO: Reason l.searchRequest.TypesOnly = false // TODO: Reason + return l } @@ -75,6 +76,7 @@ func (l *Ldap) InitClient() (err error) { return err } defer l.Client.Close() + return nil } @@ -85,6 +87,7 @@ func SOA(zone string) dns.RR { Mbox := hostmaster + "." Ns := "ns.dns." + if zone[0] != '.' { Mbox += zone Ns += zone diff --git a/setup.go b/setup.go index a38b7f3..690c1e4 100644 --- a/setup.go +++ b/setup.go @@ -75,6 +75,7 @@ func ldapParse(c *caddy.Controller) (*Ldap, error) { return ldap, err } } + return ldap, nil } @@ -96,63 +97,73 @@ func ParseStanza(c *caddy.Controller) (*Ldap, error) { ldap.Upstream = upstream.New() for c.NextBlock() { - fmt.Printf("111 %#v\n", c.Val()) switch c.Val() { // RFC 4516 URL case "ldap_url": if !c.NextArg() { return nil, c.ArgErr() } + ldap.ldapURL = c.Val() case "paging_limit": if !c.NextArg() { return nil, c.ArgErr() } + pagingLimit, err := strconv.ParseUint(c.Val(), 10, 0) if err != nil { return nil, c.Errf("paging_limit: %w", err) } + ldap.pagingLimit = uint32(pagingLimit) case "base_dn": if !c.NextArg() { return nil, c.ArgErr() } + ldap.searchRequest.BaseDN = c.Val() // ou=ae-dir case "filter": if !c.NextArg() { return nil, c.ArgErr() } + ldap.searchRequest.Filter = c.Val() // (objectClass=aeNwDevice) case "attributes": c.Next() + for c.NextBlock() { switch c.Val() { case "fqdn": if !c.NextArg() { return nil, c.ArgErr() } + ldap.searchRequest.Attributes = append(ldap.searchRequest.Attributes, c.Val()) ldap.fqdnAttr = c.Val() // aeFqdn case "ip4": if !c.NextArg() { return nil, c.ArgErr() } + ldap.searchRequest.Attributes = append(ldap.searchRequest.Attributes, c.Val()) ldap.ip4Attr = c.Val() // ipHostNumber default: return nil, c.Errf("unknown attributes property '%s'", c.Val()) } } + continue case "username": if !c.NextArg() { return nil, c.ArgErr() } + ldap.username = c.Val() case "password": if !c.NextArg() { return nil, c.ArgErr() } + ldap.password = c.Val() case "sasl": ldap.sasl = true @@ -160,19 +171,23 @@ func ParseStanza(c *caddy.Controller) (*Ldap, error) { if !c.NextArg() { return nil, c.ArgErr() } + ttl, err := time.ParseDuration(c.Val()) if err != nil { return nil, c.Errf("ttl: %w", err) } + ldap.ttl = ttl case "sync_interval": if !c.NextArg() { return nil, c.ArgErr() } + syncInterval, err := time.ParseDuration(c.Val()) if err != nil { return nil, c.Errf("sync_interval: %w", err) } + ldap.syncInterval = syncInterval case "fallthrough": ldap.Fall.SetZonesFromArgs(c.RemainingArgs()) @@ -180,31 +195,39 @@ func ParseStanza(c *caddy.Controller) (*Ldap, error) { return nil, c.Errf("unknown property '%s'", c.Val()) } } + // validate non-default ldap values ... if ldap.ldapURL == "" { return nil, c.Err("ldap_url cannot be empty") } + if ldap.searchRequest.BaseDN == "" { return nil, c.Err("base_dn cannot be empty") } + if ldap.searchRequest.Filter == "" { return nil, c.Err("filter cannot be empty") } + if ldap.fqdnAttr == "" { return nil, c.Err("fqdn attribute cannot be empty") } + if ldap.ip4Attr == "" { return nil, c.Err("ip4 attribute cannot be empty") } + // if only one of password and username set if (ldap.username == "") != (ldap.password == "") { return nil, c.Err("if not using sasl, both, username and password must be set") } + // if both username/password and sasl are set if ldap.username != "" && ldap.sasl { fmt.Printf("666 %#v\t%#v", ldap.username, ldap.sasl) return nil, c.Err("cannot use sasl and username based authentication at the same time") } + // if neither username/password nor sasl are set if ldap.username == "" && !ldap.sasl { return nil, c.Err("authenticate either via username/pwassword or sasl") diff --git a/setup_test.go b/setup_test.go index ed484ee..00a8b01 100644 --- a/setup_test.go +++ b/setup_test.go @@ -1,9 +1,11 @@ -package ldap +package ldap_test import ( "testing" "github.com/caddyserver/caddy" + + . "github.com/xoe-labs/ldap/v0" ) // TestSetup tests the various things that should be parsed by setup. diff --git a/sync.go b/sync.go index ac1b99a..e0e0d06 100644 --- a/sync.go +++ b/sync.go @@ -15,7 +15,8 @@ func (l *Ldap) Run(ctx context.Context) error { if err := l.updateZones(); err != nil { return err } - go func() { + + loop := func() { for { select { case <-ctx.Done(): @@ -27,7 +28,9 @@ func (l *Ldap) Run(ctx context.Context) error { } } } - }() + } + go loop() + return nil } @@ -36,29 +39,35 @@ func (l *Ldap) updateZones() error { for _, zn := range l.Zones.Names { zoneFileMap[zn] = nil } + ldapRecords, err := l.fetchLdapRecords() if err != nil { return fmt.Errorf("updating zones: %w", err) } + for zn, lrpz := range l.mapLdapRecordsToZone(ldapRecords) { if lrpz == nil { continue } + if zoneFileMap[zn] == nil { zoneFileMap[zn] = file.NewZone(zn, "") zoneFileMap[zn].Upstream = l.Upstream zoneFileMap[zn].Insert(SOA(zn)) } + for _, lr := range lrpz { zoneFileMap[zn].Insert(lr.A()) } } + l.zMu.Lock() for zn, zf := range zoneFileMap { // TODO: assignement copies lock value from file.Zone (*l.Zones.Z[zn]) = *zf } l.zMu.Unlock() + return nil } @@ -67,12 +76,14 @@ func (l *Ldap) mapLdapRecordsToZone(ldapRecords []ldapRecord) (ldapRecordsPerZon for _, zn := range l.Zones.Names { lrpz[zn] = nil } + for _, lr := range ldapRecords { zone := plugin.Zones(l.Zones.Names).Matches(lr.fqdn) if zone != "" { lrpz[zone] = append(lrpz[zone], lr) } } + return lrpz } @@ -81,6 +92,7 @@ func (l *Ldap) fetchLdapRecords() (ldapRecords []ldapRecord, err error) { if err != nil { return nil, fmt.Errorf("fetching data from server: %w", err) } + ldapRecords = make([]ldapRecord, len(searchResult.Entries)) for i := 0; i < len(ldapRecords); i++ { ldapRecords[i] = ldapRecord{ @@ -88,5 +100,6 @@ func (l *Ldap) fetchLdapRecords() (ldapRecords []ldapRecord, err error) { ip: net.ParseIP(searchResult.Entries[i].GetAttributeValue(l.ip4Attr)), } } + return ldapRecords, nil }