lint: wsl only

This commit is contained in:
David Arnold 2020-06-10 02:41:44 -05:00
parent 67f30f86ef
commit 5209d31929
No known key found for this signature in database
GPG Key ID: 6D6A936E69C59D08
6 changed files with 61 additions and 5 deletions

View File

@ -19,14 +19,18 @@ func (l *Ldap) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (
if zone == "" { if zone == "" {
return plugin.NextOrFailure(l.Name(), l.Next, ctx, w, r) return plugin.NextOrFailure(l.Name(), l.Next, ctx, w, r)
} }
Zone, ok := l.Zones.Z[zone] Zone, ok := l.Zones.Z[zone]
if !ok || Zone == nil { if !ok || Zone == nil {
return dns.RcodeServerFailure, nil return dns.RcodeServerFailure, nil
} }
var result file.Result var result file.Result
m := new(dns.Msg) m := new(dns.Msg)
m.SetReply(r) m.SetReply(r)
m.Authoritative = true m.Authoritative = true
l.zMu.RLock() l.zMu.RLock()
m.Answer, m.Ns, m.Extra, result = Zone.Lookup(ctx, state, state.Name()) m.Answer, m.Ns, m.Extra, result = Zone.Lookup(ctx, state, state.Name())
l.zMu.RUnlock() l.zMu.RUnlock()
@ -46,6 +50,7 @@ func (l *Ldap) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (
return dns.RcodeServerFailure, nil return dns.RcodeServerFailure, nil
} }
w.WriteMsg(m) w.WriteMsg(m)
return dns.RcodeSuccess, nil return dns.RcodeSuccess, nil
} }

View File

@ -1,4 +1,4 @@
package ldap package ldap_test
import ( import (
"context" "context"
@ -10,6 +10,8 @@ import (
"github.com/coredns/coredns/plugin/test" "github.com/coredns/coredns/plugin/test"
"github.com/miekg/dns" "github.com/miekg/dns"
. "github.com/xoe-labs/ldap/v0"
) )
var ldapTestCases = []test.Case{ var ldapTestCases = []test.Case{
@ -28,12 +30,14 @@ func newTestLdap() *Ldap {
ldap.Zones.Z = newTestLdapZones() ldap.Zones.Z = newTestLdapZones()
ldap.Fall = fall.Zero ldap.Fall = fall.Zero
ldap.Next = test.ErrorHandler() ldap.Next = test.ErrorHandler()
return ldap return ldap
} }
func newTestLdapZones() map[string]*file.Zone { func newTestLdapZones() map[string]*file.Zone {
Zone := file.NewZone("example.org.", "") Zone := file.NewZone("example.org.", "")
Zone.Insert(SOA("example.org.")) Zone.Insert(SOA("example.org."))
for _, rr := range []string{ for _, rr := range []string{
"example.org. " + defaultA, "example.org. " + defaultA,
"a.example.org. " + defaultA, "a.example.org. " + defaultA,
@ -41,25 +45,31 @@ func newTestLdapZones() map[string]*file.Zone {
r, _ := dns.NewRR(rr) r, _ := dns.NewRR(rr)
Zone.Insert(r) Zone.Insert(r)
} }
zones := make(map[string]*file.Zone) zones := make(map[string]*file.Zone)
zones["example.org."] = Zone zones["example.org."] = Zone
return zones return zones
} }
func TestServeDNS(t *testing.T) { func TestServeDNS(t *testing.T) {
ldap := newTestLdap() ldap := newTestLdap()
for i, tc := range ldapTestCases { for i, tc := range ldapTestCases {
req := tc.Msg() req := tc.Msg()
rec := dnstest.NewRecorder(&test.ResponseWriter{}) rec := dnstest.NewRecorder(&test.ResponseWriter{})
_, err := ldap.ServeDNS(context.Background(), rec, req) _, err := ldap.ServeDNS(context.Background(), rec, req)
if err != nil { if err != nil {
t.Errorf("Expected no error, got %v", err) t.Errorf("Expected no error, got %v", err)
continue continue
} }
resp := rec.Msg resp := rec.Msg
if resp == nil { if resp == nil {
t.Fatalf("Test %d, got nil message and no error for %q", i, req.Question[0].Name) t.Fatalf("Test %d, got nil message and no error for %q", i, req.Question[0].Name)
} }
if err := test.SortAndCheck(resp, tc); err != nil { if err := test.SortAndCheck(resp, tc); err != nil {
t.Error(err) t.Error(err)
} }

View File

@ -64,6 +64,7 @@ func New(zoneNames []string) *Ldap {
l.searchRequest.SizeLimit = 500 // TODO: Reason l.searchRequest.SizeLimit = 500 // TODO: Reason
l.searchRequest.TimeLimit = 500 // TODO: Reason l.searchRequest.TimeLimit = 500 // TODO: Reason
l.searchRequest.TypesOnly = false // TODO: Reason l.searchRequest.TypesOnly = false // TODO: Reason
return l return l
} }
@ -75,6 +76,7 @@ func (l *Ldap) InitClient() (err error) {
return err return err
} }
defer l.Client.Close() defer l.Client.Close()
return nil return nil
} }
@ -85,6 +87,7 @@ func SOA(zone string) dns.RR {
Mbox := hostmaster + "." Mbox := hostmaster + "."
Ns := "ns.dns." Ns := "ns.dns."
if zone[0] != '.' { if zone[0] != '.' {
Mbox += zone Mbox += zone
Ns += zone Ns += zone

View File

@ -75,6 +75,7 @@ func ldapParse(c *caddy.Controller) (*Ldap, error) {
return ldap, err return ldap, err
} }
} }
return ldap, nil return ldap, nil
} }
@ -96,63 +97,73 @@ func ParseStanza(c *caddy.Controller) (*Ldap, error) {
ldap.Upstream = upstream.New() ldap.Upstream = upstream.New()
for c.NextBlock() { for c.NextBlock() {
fmt.Printf("111 %#v\n", c.Val())
switch c.Val() { switch c.Val() {
// RFC 4516 URL // RFC 4516 URL
case "ldap_url": case "ldap_url":
if !c.NextArg() { if !c.NextArg() {
return nil, c.ArgErr() return nil, c.ArgErr()
} }
ldap.ldapURL = c.Val() ldap.ldapURL = c.Val()
case "paging_limit": case "paging_limit":
if !c.NextArg() { if !c.NextArg() {
return nil, c.ArgErr() return nil, c.ArgErr()
} }
pagingLimit, err := strconv.ParseUint(c.Val(), 10, 0) pagingLimit, err := strconv.ParseUint(c.Val(), 10, 0)
if err != nil { if err != nil {
return nil, c.Errf("paging_limit: %w", err) return nil, c.Errf("paging_limit: %w", err)
} }
ldap.pagingLimit = uint32(pagingLimit) ldap.pagingLimit = uint32(pagingLimit)
case "base_dn": case "base_dn":
if !c.NextArg() { if !c.NextArg() {
return nil, c.ArgErr() return nil, c.ArgErr()
} }
ldap.searchRequest.BaseDN = c.Val() // ou=ae-dir ldap.searchRequest.BaseDN = c.Val() // ou=ae-dir
case "filter": case "filter":
if !c.NextArg() { if !c.NextArg() {
return nil, c.ArgErr() return nil, c.ArgErr()
} }
ldap.searchRequest.Filter = c.Val() // (objectClass=aeNwDevice) ldap.searchRequest.Filter = c.Val() // (objectClass=aeNwDevice)
case "attributes": case "attributes":
c.Next() c.Next()
for c.NextBlock() { for c.NextBlock() {
switch c.Val() { switch c.Val() {
case "fqdn": case "fqdn":
if !c.NextArg() { if !c.NextArg() {
return nil, c.ArgErr() return nil, c.ArgErr()
} }
ldap.searchRequest.Attributes = append(ldap.searchRequest.Attributes, c.Val()) ldap.searchRequest.Attributes = append(ldap.searchRequest.Attributes, c.Val())
ldap.fqdnAttr = c.Val() // aeFqdn ldap.fqdnAttr = c.Val() // aeFqdn
case "ip4": case "ip4":
if !c.NextArg() { if !c.NextArg() {
return nil, c.ArgErr() return nil, c.ArgErr()
} }
ldap.searchRequest.Attributes = append(ldap.searchRequest.Attributes, c.Val()) ldap.searchRequest.Attributes = append(ldap.searchRequest.Attributes, c.Val())
ldap.ip4Attr = c.Val() // ipHostNumber ldap.ip4Attr = c.Val() // ipHostNumber
default: default:
return nil, c.Errf("unknown attributes property '%s'", c.Val()) return nil, c.Errf("unknown attributes property '%s'", c.Val())
} }
} }
continue continue
case "username": case "username":
if !c.NextArg() { if !c.NextArg() {
return nil, c.ArgErr() return nil, c.ArgErr()
} }
ldap.username = c.Val() ldap.username = c.Val()
case "password": case "password":
if !c.NextArg() { if !c.NextArg() {
return nil, c.ArgErr() return nil, c.ArgErr()
} }
ldap.password = c.Val() ldap.password = c.Val()
case "sasl": case "sasl":
ldap.sasl = true ldap.sasl = true
@ -160,19 +171,23 @@ func ParseStanza(c *caddy.Controller) (*Ldap, error) {
if !c.NextArg() { if !c.NextArg() {
return nil, c.ArgErr() return nil, c.ArgErr()
} }
ttl, err := time.ParseDuration(c.Val()) ttl, err := time.ParseDuration(c.Val())
if err != nil { if err != nil {
return nil, c.Errf("ttl: %w", err) return nil, c.Errf("ttl: %w", err)
} }
ldap.ttl = ttl ldap.ttl = ttl
case "sync_interval": case "sync_interval":
if !c.NextArg() { if !c.NextArg() {
return nil, c.ArgErr() return nil, c.ArgErr()
} }
syncInterval, err := time.ParseDuration(c.Val()) syncInterval, err := time.ParseDuration(c.Val())
if err != nil { if err != nil {
return nil, c.Errf("sync_interval: %w", err) return nil, c.Errf("sync_interval: %w", err)
} }
ldap.syncInterval = syncInterval ldap.syncInterval = syncInterval
case "fallthrough": case "fallthrough":
ldap.Fall.SetZonesFromArgs(c.RemainingArgs()) ldap.Fall.SetZonesFromArgs(c.RemainingArgs())
@ -180,31 +195,39 @@ func ParseStanza(c *caddy.Controller) (*Ldap, error) {
return nil, c.Errf("unknown property '%s'", c.Val()) return nil, c.Errf("unknown property '%s'", c.Val())
} }
} }
// validate non-default ldap values ... // validate non-default ldap values ...
if ldap.ldapURL == "" { if ldap.ldapURL == "" {
return nil, c.Err("ldap_url cannot be empty") return nil, c.Err("ldap_url cannot be empty")
} }
if ldap.searchRequest.BaseDN == "" { if ldap.searchRequest.BaseDN == "" {
return nil, c.Err("base_dn cannot be empty") return nil, c.Err("base_dn cannot be empty")
} }
if ldap.searchRequest.Filter == "" { if ldap.searchRequest.Filter == "" {
return nil, c.Err("filter cannot be empty") return nil, c.Err("filter cannot be empty")
} }
if ldap.fqdnAttr == "" { if ldap.fqdnAttr == "" {
return nil, c.Err("fqdn attribute cannot be empty") return nil, c.Err("fqdn attribute cannot be empty")
} }
if ldap.ip4Attr == "" { if ldap.ip4Attr == "" {
return nil, c.Err("ip4 attribute cannot be empty") return nil, c.Err("ip4 attribute cannot be empty")
} }
// if only one of password and username set // if only one of password and username set
if (ldap.username == "") != (ldap.password == "") { if (ldap.username == "") != (ldap.password == "") {
return nil, c.Err("if not using sasl, both, username and password must be set") return nil, c.Err("if not using sasl, both, username and password must be set")
} }
// if both username/password and sasl are set // if both username/password and sasl are set
if ldap.username != "" && ldap.sasl { if ldap.username != "" && ldap.sasl {
fmt.Printf("666 %#v\t%#v", ldap.username, ldap.sasl) fmt.Printf("666 %#v\t%#v", ldap.username, ldap.sasl)
return nil, c.Err("cannot use sasl and username based authentication at the same time") return nil, c.Err("cannot use sasl and username based authentication at the same time")
} }
// if neither username/password nor sasl are set // if neither username/password nor sasl are set
if ldap.username == "" && !ldap.sasl { if ldap.username == "" && !ldap.sasl {
return nil, c.Err("authenticate either via username/pwassword or sasl") return nil, c.Err("authenticate either via username/pwassword or sasl")

View File

@ -1,9 +1,11 @@
package ldap package ldap_test
import ( import (
"testing" "testing"
"github.com/caddyserver/caddy" "github.com/caddyserver/caddy"
. "github.com/xoe-labs/ldap/v0"
) )
// TestSetup tests the various things that should be parsed by setup. // TestSetup tests the various things that should be parsed by setup.

17
sync.go
View File

@ -15,7 +15,8 @@ func (l *Ldap) Run(ctx context.Context) error {
if err := l.updateZones(); err != nil { if err := l.updateZones(); err != nil {
return err return err
} }
go func() {
loop := func() {
for { for {
select { select {
case <-ctx.Done(): case <-ctx.Done():
@ -27,7 +28,9 @@ func (l *Ldap) Run(ctx context.Context) error {
} }
} }
} }
}() }
go loop()
return nil return nil
} }
@ -36,29 +39,35 @@ func (l *Ldap) updateZones() error {
for _, zn := range l.Zones.Names { for _, zn := range l.Zones.Names {
zoneFileMap[zn] = nil zoneFileMap[zn] = nil
} }
ldapRecords, err := l.fetchLdapRecords() ldapRecords, err := l.fetchLdapRecords()
if err != nil { if err != nil {
return fmt.Errorf("updating zones: %w", err) return fmt.Errorf("updating zones: %w", err)
} }
for zn, lrpz := range l.mapLdapRecordsToZone(ldapRecords) { for zn, lrpz := range l.mapLdapRecordsToZone(ldapRecords) {
if lrpz == nil { if lrpz == nil {
continue continue
} }
if zoneFileMap[zn] == nil { if zoneFileMap[zn] == nil {
zoneFileMap[zn] = file.NewZone(zn, "") zoneFileMap[zn] = file.NewZone(zn, "")
zoneFileMap[zn].Upstream = l.Upstream zoneFileMap[zn].Upstream = l.Upstream
zoneFileMap[zn].Insert(SOA(zn)) zoneFileMap[zn].Insert(SOA(zn))
} }
for _, lr := range lrpz { for _, lr := range lrpz {
zoneFileMap[zn].Insert(lr.A()) zoneFileMap[zn].Insert(lr.A())
} }
} }
l.zMu.Lock() l.zMu.Lock()
for zn, zf := range zoneFileMap { for zn, zf := range zoneFileMap {
// TODO: assignement copies lock value from file.Zone // TODO: assignement copies lock value from file.Zone
(*l.Zones.Z[zn]) = *zf (*l.Zones.Z[zn]) = *zf
} }
l.zMu.Unlock() l.zMu.Unlock()
return nil return nil
} }
@ -67,12 +76,14 @@ func (l *Ldap) mapLdapRecordsToZone(ldapRecords []ldapRecord) (ldapRecordsPerZon
for _, zn := range l.Zones.Names { for _, zn := range l.Zones.Names {
lrpz[zn] = nil lrpz[zn] = nil
} }
for _, lr := range ldapRecords { for _, lr := range ldapRecords {
zone := plugin.Zones(l.Zones.Names).Matches(lr.fqdn) zone := plugin.Zones(l.Zones.Names).Matches(lr.fqdn)
if zone != "" { if zone != "" {
lrpz[zone] = append(lrpz[zone], lr) lrpz[zone] = append(lrpz[zone], lr)
} }
} }
return lrpz return lrpz
} }
@ -81,6 +92,7 @@ func (l *Ldap) fetchLdapRecords() (ldapRecords []ldapRecord, err error) {
if err != nil { if err != nil {
return nil, fmt.Errorf("fetching data from server: %w", err) return nil, fmt.Errorf("fetching data from server: %w", err)
} }
ldapRecords = make([]ldapRecord, len(searchResult.Entries)) ldapRecords = make([]ldapRecord, len(searchResult.Entries))
for i := 0; i < len(ldapRecords); i++ { for i := 0; i < len(ldapRecords); i++ {
ldapRecords[i] = ldapRecord{ ldapRecords[i] = ldapRecord{
@ -88,5 +100,6 @@ func (l *Ldap) fetchLdapRecords() (ldapRecords []ldapRecord, err error) {
ip: net.ParseIP(searchResult.Entries[i].GetAttributeValue(l.ip4Attr)), ip: net.ParseIP(searchResult.Entries[i].GetAttributeValue(l.ip4Attr)),
} }
} }
return ldapRecords, nil return ldapRecords, nil
} }