Bind identd to config address, set read/write deadlines in ident.Server, only reply to ident queries where the remote hosts match

This commit is contained in:
Ken-Håvard Lieng 2020-06-17 03:19:20 +02:00
parent 15ee5ce1c9
commit 04e6e8c7a2
4 changed files with 92 additions and 28 deletions

View file

@ -6,23 +6,37 @@ import (
"net"
"strings"
"sync"
"time"
)
var (
// DefaultAddr is the address a Server listens on when no Addr is specified
DefaultAddr = ":113"
// DefaultTimeout is the the time a Server will wait before failing
// reads and writes if no Timeout is specified
DefaultTimeout = 5 * time.Second
)
// Server implements the server-side of the Ident protocol
type Server struct {
// Addr is the host:port address to listen on
Addr string
// Timeout is the time to wait before failing reads and writes
Timeout time.Duration
idents map[string]string
entries map[string]entry
listener net.Listener
lock sync.Mutex
}
type entry struct {
remoteHost string
ident string
}
func NewServer() *Server {
return &Server{
idents: map[string]string{},
entries: map[string]entry{},
}
}
@ -54,35 +68,77 @@ func (s *Server) Stop() error {
return s.listener.Close()
}
func (s *Server) Add(local, remote, ident string) {
func (s *Server) Add(local, remote net.Addr, ident string) {
if local == nil || remote == nil {
return
}
_, localPort, err := net.SplitHostPort(local.String())
if err != nil {
return
}
remoteHost, remotePort, err := net.SplitHostPort(remote.String())
if err != nil {
return
}
s.lock.Lock()
s.idents[local+","+remote] = ident
s.entries[localPort+","+remotePort] = entry{
remoteHost: remoteHost,
ident: ident,
}
s.lock.Unlock()
}
func (s *Server) Remove(local, remote string) {
func (s *Server) Remove(local, remote net.Addr) {
if local == nil || remote == nil {
return
}
_, localPort, err := net.SplitHostPort(local.String())
if err != nil {
return
}
_, remotePort, err := net.SplitHostPort(remote.String())
if err != nil {
return
}
s.lock.Lock()
delete(s.idents, local+","+remote)
delete(s.entries, localPort+","+remotePort)
s.lock.Unlock()
}
func (s *Server) handle(conn net.Conn) {
defer conn.Close()
timeout := s.Timeout
if timeout == 0 {
timeout = DefaultTimeout
}
scan := bufio.NewScanner(conn)
scan.Buffer(make([]byte, 32), 32)
conn.SetReadDeadline(time.Now().Add(timeout))
if !scan.Scan() {
return
}
line := scan.Text()
ports := strings.ReplaceAll(line, " ", "")
query := scan.Text()
s.lock.Lock()
ident, ok := s.idents[ports]
entry, ok := s.entries[strings.ReplaceAll(query, " ", "")]
s.lock.Unlock()
if ok {
conn.Write([]byte(fmt.Sprintf("%s : USERID : Dispatch : %s\r\n", line, ident)))
remoteHost, _, err := net.SplitHostPort(conn.RemoteAddr().String())
if err != nil || remoteHost != entry.remoteHost {
return
}
conn.SetWriteDeadline(time.Now().Add(timeout))
conn.Write([]byte(fmt.Sprintf("%s : USERID : Dispatch : %s\r\n", query, entry.ident)))
}
}