diff --git a/handler.go b/handler.go new file mode 100644 index 0000000..624bdd7 --- /dev/null +++ b/handler.go @@ -0,0 +1,83 @@ +package ldap + +import ( + "context" + + "github.com/coredns/coredns/plugin" + "github.com/coredns/coredns/request" + + "github.com/miekg/dns" +) + + +// ServeDNS implements the plugin.Handler interface. +func (l Ldap) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { + opt := plugin.Options{} + state := request.Request{W: w, Req: r} + + zone := plugin.Zones(l.Zones).Matches(state.Name()) + if zone == "" { + return plugin.NextOrFailure(l.Name(), l.Next, ctx, w, r) + } + var ( + records []dns.RR + extra []dns.RR + err error + ) + + switch state.QType() { + case dns.TypeA: + records, err = plugin.A(ctx, l, zone, state, nil, opt) + case dns.TypeAAAA: + records, err = plugin.AAAA(ctx, l, zone, state, nil, opt) + case dns.TypeTXT: + records, err = plugin.TXT(ctx, l, zone, state, nil, opt) + case dns.TypeCNAME: + records, err = plugin.CNAME(ctx, l, zone, state, opt) + case dns.TypePTR: + records, err = plugin.PTR(ctx, l, zone, state, opt) + case dns.TypeMX: + records, extra, err = plugin.MX(ctx, l, zone, state, opt) + case dns.TypeSRV: + records, extra, err = plugin.SRV(ctx, l, zone, state, opt) + case dns.TypeSOA: + records, err = plugin.SOA(ctx, l, zone, state, opt) + case dns.TypeNS: + if state.Name() == zone { + records, extra, err = plugin.NS(ctx, l, zone, state, opt) + break + } + fallthrough + default: + // Do a fake A lookup, so we can distinguish between NODATA and NXDOMAIN + _, err = plugin.A(ctx, l, zone, state, nil, opt) + } + + if l.IsNameError(err) { + if l.Fall.Through(state.Name()) { + return plugin.NextOrFailure(l.Name(), l.Next, ctx, w, r) + } + // Make err nil when returning here, so we don't log spam for NXDOMAIN. + return plugin.BackendError(ctx, &l, zone, dns.RcodeNameError, state, nil /* err */, opt) + + } + if err != nil { + return plugin.BackendError(ctx, &l, zone, dns.RcodeServerFailure, state, err, opt) + } + + if len(records) == 0 { + return plugin.BackendError(ctx, &l, zone, dns.RcodeSuccess, state, err, opt) + } + + m := new(dns.Msg) + m.SetReply(r) + m.Authoritative = true + m.Answer = append(m.Answer, records...) + m.Extra = append(m.Extra, extra...) + + w.WriteMsg(m) + return dns.RcodeSuccess, nil +} + +// Name implements the Handler interface. +func (l Ldap) Name() string { return "ldap" } \ No newline at end of file diff --git a/ldap.go b/ldap.go index c891da0..5cb1b39 100644 --- a/ldap.go +++ b/ldap.go @@ -10,66 +10,71 @@ package ldap import ( "context" + "errors" "fmt" + "net" "io" "os" "github.com/coredns/coredns/plugin" "github.com/coredns/coredns/plugin/metrics" - clog "github.com/coredns/coredns/plugin/pkg/log" + "github.com/coredns/coredns/plugin/pkg/fall" "github.com/miekg/dns" "gopkg.in/ldap.v2" ) -// Define log to be a logger with the plugin name in it. This way we can just use log.Info and -// friends to log. -var log = clog.NewWithPlugin("ldap") - // Ldap is an ldap plugin to serve zone entries from a ldap backend. type Ldap struct { Next plugin.Handler + Fall fall.F + Zones []string + Client *ldap.Client + } -// ServeDNS implements the plugin.Handler interface. This method gets called when ldap is used -// in a Server. -func (l Ldap) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { - // This function could be simpler. I.e. just fmt.Println("ldap") here, but we want to show - // a slightly more complex ldap as to make this more interesting. - // Here we wrap the dns.ResponseWriter in a new ResponseWriter and call the next plugin, when the - // answer comes back, it will print "ldap". +var ( + errNoItems = errors.New("no items found") + errNsNotExposed = errors.New("namespace is not exposed") + errInvalidRequest = errors.New("invalid query name") +) - // Debug log that we've have seen the query. This will only be shown when the debug plugin is loaded. - log.Debug("Received response") +// Services implements the ServiceBackend interface. +func (l *Ldap) Services(ctx context.Context, state request.Request, exact bool, opt plugin.Options) (services []msg.Service, err error) { + services, err = l.Records(ctx, state, exact) + if err != nil { + return + } - // Wrap. - pw := NewResponsePrinter(w) - - // Export metric with the server label set to the current server handling the request. - requestCount.WithLabelValues(metrics.WithServer(ctx)).Inc() - - // Call next plugin (if any). - return plugin.NextOrFailure(e.Name(), e.Next, ctx, pw, r) + services = msg.Group(services) + return } -// Name implements the Handler interface. -func (l Ldap) Name() string { return "ldap" } - -// ResponsePrinter wrap a dns.ResponseWriter and will write ldap to standard output when WriteMsg is called. -type ResponsePrinter struct { - dns.ResponseWriter +// Reverse implements the ServiceBackend interface. +func (l *Ldap) Reverse(ctx context.Context, state request.Request, exact bool, opt plugin.Options) (services []msg.Service, err error) { + return l.Services(ctx, state, exact, opt) } -// NewResponsePrinter returns ResponseWriter. -func NewResponsePrinter(w dns.ResponseWriter) *ResponsePrinter { - return &ResponsePrinter{ResponseWriter: w} +// Lookup implements the ServiceBackend interface. +func (l *Ldap) Lookup(ctx context.Context, state request.Request, name string, typ uint16) (*dns.Msg, error) { + return l.Upstream.Lookup(ctx, state, name, typ) } -// WriteMsg calls the underlying ResponseWriter's WriteMsg method and prints "ldap" to standard output. -func (r *ResponsePrinter) WriteMsg(res *dns.Msg) error { - fmt.Fprintln(out, "ldap") - return r.ResponseWriter.WriteMsg(res) +// IsNameError implements the ServiceBackend interface. +func (l *Ldap) IsNameError(err error) bool { + return err == errNoItems || err == errNsNotExposed || err == errInvalidRequest } -// Make out a reference to os.Stdout so we can easily overwrite it for testing. -var out io.Writer = os.Stdout +// Records looks up records in ldap. If exact is true, it will lookup just this +// name. This is used when find matches when completing SRV lookups for instance. +func (l *Ldap) Records(ctx context.Context, state request.Request, exact bool) ([]msg.Service, error) { + name := state.Name() + + path, star := msg.PathWithWildcard(name, l.PathPrefix) + r, err := l.get(ctx, path, !exact) + if err != nil { + return nil, err + } + segments := strings.Split(msg.Path(name, l.PathPrefix), "/") + return l.loopNodes(r.Kvs, segments, star, state.QType()) +} \ No newline at end of file diff --git a/metrics.go b/metrics.go index eb898e4..eedaa7e 100644 --- a/metrics.go +++ b/metrics.go @@ -1,7 +1,6 @@ package ldap import ( - "sync" "github.com/coredns/coredns/plugin" @@ -16,4 +15,3 @@ var requestCount = prometheus.NewCounterVec(prometheus.CounterOpts{ Help: "Counter of requests made.", }, []string{"server"}) -var once sync.Once