imp: complete minimum unit test suite

components: tests
This commit is contained in:
David Arnold 2020-06-10 21:47:34 -05:00 committed by "David Arnold"
parent 0865a9015e
commit 75e79dedf2
No known key found for this signature in database
GPG Key ID: 6D6A936E69C59D08
5 changed files with 166 additions and 52 deletions

View File

@ -15,7 +15,7 @@ import (
) )
// nolint: gochecknoglobals // nolint: gochecknoglobals
var ldapTestCases = []test.Case{ var ldapHandlerTestCases = []test.Case{
{ {
// Simple case // Simple case
Qname: "a.example.org.", Qtype: dns.TypeA, Qname: "a.example.org.", Qtype: dns.TypeA,
@ -26,16 +26,16 @@ var ldapTestCases = []test.Case{
} }
// Create a new Ldap Plugin. Use the test.ErrorHandler as the next plugin. // Create a new Ldap Plugin. Use the test.ErrorHandler as the next plugin.
func newTestLdap() *Ldap { func newTestLdapHandler() *Ldap {
ldap := New([]string{"example.org.", "www.example.org.", "example.org.", "sample.example.org."}) ldap := New([]string{"example.org.", "www.example.org.", "example.org.", "sample.example.org."})
ldap.Zones.Z = newTestLdapZones() ldap.Zones.Z = newTestLdapHandlerZones()
ldap.Fall = fall.Zero ldap.Fall = fall.Zero
ldap.Next = test.ErrorHandler() ldap.Next = test.ErrorHandler()
return ldap return ldap
} }
func newTestLdapZones() map[string]*file.Zone { func newTestLdapHandlerZones() map[string]*file.Zone {
Zone := file.NewZone("example.org.", "") Zone := file.NewZone("example.org.", "")
if err := Zone.Insert(SOA("example.org.")); err != nil { if err := Zone.Insert(SOA("example.org.")); err != nil {
panic("omg") panic("omg")
@ -58,9 +58,9 @@ func newTestLdapZones() map[string]*file.Zone {
} }
func TestServeDNS(t *testing.T) { func TestServeDNS(t *testing.T) {
ldap := newTestLdap() ldap := newTestLdapHandler()
for i, tc := range ldapTestCases { for i, tc := range ldapHandlerTestCases {
req := tc.Msg() req := tc.Msg()
rec := dnstest.NewRecorder(&test.ResponseWriter{}) rec := dnstest.NewRecorder(&test.ResponseWriter{})

42
ldap.go
View File

@ -39,17 +39,19 @@ type Ldap struct {
Client ldap.Client Client ldap.Client
Zones file.Zones Zones file.Zones
searchRequest *ldap.SearchRequest // Exported for mocking in test
ldapURL string SearchRequest *ldap.SearchRequest
pagingLimit uint32 FqdnAttr string
syncInterval time.Duration Ip4Attr string
username string
password string ldapURL string
sasl bool pagingLimit uint32
fqdnAttr string syncInterval time.Duration
ip4Attr string username string
zMu sync.RWMutex password string
ttl time.Duration sasl bool
zMu sync.RWMutex
ttl time.Duration
} }
// New returns an initialized Ldap with defaults. // New returns an initialized Ldap with defaults.
@ -57,13 +59,19 @@ func New(zoneNames []string) *Ldap {
l := new(Ldap) l := new(Ldap)
l.Zones.Names = zoneNames l.Zones.Names = zoneNames
l.pagingLimit = 0 l.pagingLimit = 0
l.syncInterval = 60 * time.Second
// SearchRequest defaults // SearchRequest defaults
l.searchRequest = new(ldap.SearchRequest) l.SearchRequest = new(ldap.SearchRequest)
l.searchRequest.DerefAliases = ldap.NeverDerefAliases // TODO: Reason l.SearchRequest.DerefAliases = ldap.NeverDerefAliases // TODO: Reason
l.searchRequest.Scope = ldap.ScopeWholeSubtree // search whole subtree l.SearchRequest.Scope = ldap.ScopeWholeSubtree // search whole subtree
l.searchRequest.SizeLimit = 500 // TODO: Reason l.SearchRequest.SizeLimit = 500 // TODO: Reason
l.searchRequest.TimeLimit = 500 // TODO: Reason l.SearchRequest.TimeLimit = 500 // TODO: Reason
l.searchRequest.TypesOnly = false // TODO: Reason l.SearchRequest.TypesOnly = false // TODO: Reason
l.Zones.Z = make(map[string]*file.Zone, len(zoneNames))
for _, zn := range zoneNames {
l.Zones.Z[zn] = nil
}
return l return l
} }

View File

@ -125,13 +125,13 @@ func ParseStanza(c *caddy.Controller) (*Ldap, error) {
return nil, c.ArgErr() return nil, c.ArgErr()
} }
ldap.searchRequest.BaseDN = c.Val() // ou=ae-dir ldap.SearchRequest.BaseDN = c.Val() // ou=ae-dir
case "filter": case "filter":
if !c.NextArg() { if !c.NextArg() {
return nil, c.ArgErr() return nil, c.ArgErr()
} }
ldap.searchRequest.Filter = c.Val() // (objectClass=aeNwDevice) ldap.SearchRequest.Filter = c.Val() // (objectClass=aeNwDevice)
case "attributes": case "attributes":
c.Next() c.Next()
@ -142,15 +142,15 @@ func ParseStanza(c *caddy.Controller) (*Ldap, error) {
return nil, c.ArgErr() return nil, c.ArgErr()
} }
ldap.searchRequest.Attributes = append(ldap.searchRequest.Attributes, c.Val()) ldap.SearchRequest.Attributes = append(ldap.SearchRequest.Attributes, c.Val())
ldap.fqdnAttr = c.Val() // aeFqdn ldap.FqdnAttr = c.Val() // aeFqdn
case "ip4": case "ip4":
if !c.NextArg() { if !c.NextArg() {
return nil, c.ArgErr() return nil, c.ArgErr()
} }
ldap.searchRequest.Attributes = append(ldap.searchRequest.Attributes, c.Val()) ldap.SearchRequest.Attributes = append(ldap.SearchRequest.Attributes, c.Val())
ldap.ip4Attr = c.Val() // ipHostNumber ldap.Ip4Attr = c.Val() // ipHostNumber
default: default:
return nil, c.Errf("unknown attributes property '%s'", c.Val()) return nil, c.Errf("unknown attributes property '%s'", c.Val())
} }
@ -205,19 +205,19 @@ func ParseStanza(c *caddy.Controller) (*Ldap, error) {
return nil, c.Err("ldap_url cannot be empty") return nil, c.Err("ldap_url cannot be empty")
} }
if ldap.searchRequest.BaseDN == "" { if ldap.SearchRequest.BaseDN == "" {
return nil, c.Err("base_dn cannot be empty") return nil, c.Err("base_dn cannot be empty")
} }
if ldap.searchRequest.Filter == "" { if ldap.SearchRequest.Filter == "" {
return nil, c.Err("filter cannot be empty") return nil, c.Err("filter cannot be empty")
} }
if ldap.fqdnAttr == "" { if ldap.FqdnAttr == "" {
return nil, c.Err("fqdn attribute cannot be empty") return nil, c.Err("fqdn attribute cannot be empty")
} }
if ldap.ip4Attr == "" { if ldap.Ip4Attr == "" {
return nil, c.Err("ip4 attribute cannot be empty") return nil, c.Err("ip4 attribute cannot be empty")
} }

39
sync.go
View File

@ -3,6 +3,7 @@ package ldap
import ( import (
"context" "context"
"fmt" "fmt"
"strings"
"net" "net"
"time" "time"
@ -12,7 +13,7 @@ import (
// Run updates the zone from ldap. // Run updates the zone from ldap.
func (l *Ldap) Run(ctx context.Context) error { func (l *Ldap) Run(ctx context.Context) error {
if err := l.updateZones(); err != nil { if err := l.UpdateZones(); err != nil {
return err return err
} }
@ -23,7 +24,7 @@ func (l *Ldap) Run(ctx context.Context) error {
log.Infof("Breaking out of Ldap update loop: %v", ctx.Err()) log.Infof("Breaking out of Ldap update loop: %v", ctx.Err())
return return
case <-time.After(l.syncInterval): case <-time.After(l.syncInterval):
if err := l.updateZones(); err != nil && ctx.Err() == nil { if err := l.UpdateZones(); err != nil && ctx.Err() == nil {
log.Errorf("Failed to update zones: %v", err) log.Errorf("Failed to update zones: %v", err)
} }
} }
@ -34,10 +35,17 @@ func (l *Ldap) Run(ctx context.Context) error {
return nil return nil
} }
func (l *Ldap) updateZones() error { func (l *Ldap) UpdateZones() error {
zoneFileMap := make(map[string]*file.Zone, len(l.Zones.Names)) zoneFileMap := make(map[string]*file.Zone, len(l.Zones.Names))
for _, zn := range l.Zones.Names { for _, zn := range l.Zones.Names {
zoneFileMap[zn] = nil zoneFileMap[zn] = nil
zoneFileMap[zn] = file.NewZone(zn, "")
zoneFileMap[zn].Upstream = l.Upstream
err := zoneFileMap[zn].Insert(SOA(zn))
if err != nil {
return fmt.Errorf("updating zones: %w", err)
}
} }
ldapRecords, err := l.fetchLdapRecords() ldapRecords, err := l.fetchLdapRecords()
@ -46,20 +54,10 @@ func (l *Ldap) updateZones() error {
} }
for zn, lrpz := range l.mapLdapRecordsToZone(ldapRecords) { for zn, lrpz := range l.mapLdapRecordsToZone(ldapRecords) {
if lrpz == nil { if len(lrpz) == 0 {
continue continue
} }
if zoneFileMap[zn] == nil {
zoneFileMap[zn] = file.NewZone(zn, "")
zoneFileMap[zn].Upstream = l.Upstream
err = zoneFileMap[zn].Insert(SOA(zn))
if err != nil {
return fmt.Errorf("updating zones: %w", err)
}
}
for _, lr := range lrpz { for _, lr := range lrpz {
err = zoneFileMap[zn].Insert(lr.A()) err = zoneFileMap[zn].Insert(lr.A())
if err != nil { if err != nil {
@ -70,8 +68,7 @@ func (l *Ldap) updateZones() error {
l.zMu.Lock() l.zMu.Lock()
for zn, zf := range zoneFileMap { for zn, zf := range zoneFileMap {
// TODO: assignement copies lock value from file.Zone l.Zones.Z[zn] = zf
(*l.Zones.Z[zn]) = *zf
} }
l.zMu.Unlock() l.zMu.Unlock()
@ -95,16 +92,20 @@ func (l *Ldap) mapLdapRecordsToZone(ldapRecords []ldapRecord) (ldapRecordsPerZon
} }
func (l *Ldap) fetchLdapRecords() (ldapRecords []ldapRecord, err error) { func (l *Ldap) fetchLdapRecords() (ldapRecords []ldapRecord, err error) {
searchResult, err := l.Client.SearchWithPaging(l.searchRequest, l.pagingLimit) searchResult, err := l.Client.SearchWithPaging(l.SearchRequest, l.pagingLimit)
if err != nil { if err != nil {
return nil, fmt.Errorf("fetching data from server: %w", err) return nil, fmt.Errorf("fetching data from server: %w", err)
} }
ldapRecords = make([]ldapRecord, len(searchResult.Entries)) ldapRecords = make([]ldapRecord, len(searchResult.Entries))
for i := 0; i < len(ldapRecords); i++ { for i := 0; i < len(ldapRecords); i++ {
fqdn := searchResult.Entries[i].GetAttributeValue(l.FqdnAttr)
if !strings.HasSuffix(fqdn, ".") {
fqdn = fqdn + "."
}
ldapRecords[i] = ldapRecord{ ldapRecords[i] = ldapRecord{
fqdn: searchResult.Entries[i].GetAttributeValue(l.fqdnAttr), fqdn: fqdn,
ip: net.ParseIP(searchResult.Entries[i].GetAttributeValue(l.ip4Attr)), ip: net.ParseIP(searchResult.Entries[i].GetAttributeValue(l.Ip4Attr)),
} }
} }

105
sync_test.go Normal file
View File

@ -0,0 +1,105 @@
package ldap_test
import (
"crypto/tls"
"testing"
"time"
"gopkg.in/ldap.v3"
. "github.com/xoe-labs/ldap/v0"
)
type mockClient struct{}
func (m *mockClient) Start() {}
func (m *mockClient) StartTLS(*tls.Config) error { return nil }
func (m *mockClient) Close() {}
func (m *mockClient) SetTimeout(time.Duration) {}
func (m *mockClient) Bind(username, password string) error { return nil }
func (m *mockClient) UnauthenticatedBind(username string) error { return nil }
func (m *mockClient) SimpleBind(*ldap.SimpleBindRequest) (*ldap.SimpleBindResult, error) {
return nil, nil
}
func (m *mockClient) ExternalBind() error { return nil }
func (m *mockClient) Add(*ldap.AddRequest) error { return nil }
func (m *mockClient) Del(*ldap.DelRequest) error { return nil }
func (m *mockClient) Modify(*ldap.ModifyRequest) error { return nil }
func (m *mockClient) ModifyDN(*ldap.ModifyDNRequest) error { return nil }
func (m *mockClient) Compare(dn, attribute, value string) (bool, error) { return false, nil }
func (m *mockClient) PasswordModify(*ldap.PasswordModifyRequest) (*ldap.PasswordModifyResult, error) {
return nil, nil
}
func (m *mockClient) Search(*ldap.SearchRequest) (*ldap.SearchResult, error) { return nil, nil }
func (m *mockClient) SearchWithPaging(searchRequest *ldap.SearchRequest, pagingSize uint32) (*ldap.SearchResult, error) {
return &ldap.SearchResult{
Entries: []*ldap.Entry{{
DN: "ou=ae-dir, cn=host1",
Attributes: []*ldap.EntryAttribute{{
Name: "aeFqdn",
Values: []string{
"host1.example.org.",
},
}, {
Name: "ipHostNumber",
Values: []string{
"1.2.3.4",
},
}},
}, {
DN: "ou=ae-dir, cn=host2",
Attributes: []*ldap.EntryAttribute{{
Name: "aeFqdn",
// Without ending "."
Values: []string{
"host2.example.org",
},
}, {
Name: "ipHostNumber",
Values: []string{
"1.2.3.5",
},
}},
}, {
DN: "ou=ae-dir, cn=host3",
Attributes: []*ldap.EntryAttribute{{
Name: "aeFqdn",
Values: []string{
"host3.sample.example.org.",
},
}, {
Name: "ipHostNumber",
Values: []string{
"1.2.3.6",
},
}},
}},
}, nil
}
// Create a new Ldap Plugin. Uses a mocked client.
func newTestLdapSync() *Ldap {
ldap := New([]string{"example.org.", "www.example.org.", "example.org.", "sample.example.org."})
ldap.Client = &mockClient{}
ldap.SearchRequest.Attributes = []string{
"aeFqdn", "ipHostNumber",
}
ldap.SearchRequest.BaseDN = "ou=ae-dir"
ldap.SearchRequest.Filter = "(objectClass=aeNwDevice)"
ldap.FqdnAttr = "aeFqdn"
ldap.Ip4Attr = "ipHostNumber"
return ldap
}
// TestUpdateZone tests a zone update.
func TestUpdateZone(t *testing.T) {
ldap := newTestLdapSync()
if err := ldap.UpdateZones(); err != nil {
t.Fatalf("error updating zones: %v", err)
}
}