diff --git a/pkg/ident/server.go b/pkg/ident/server.go index 58914879..baaa9d91 100644 --- a/pkg/ident/server.go +++ b/pkg/ident/server.go @@ -6,23 +6,37 @@ import ( "net" "strings" "sync" + "time" ) var ( + // DefaultAddr is the address a Server listens on when no Addr is specified DefaultAddr = ":113" + // DefaultTimeout is the the time a Server will wait before failing + // reads and writes if no Timeout is specified + DefaultTimeout = 5 * time.Second ) +// Server implements the server-side of the Ident protocol type Server struct { + // Addr is the host:port address to listen on Addr string + // Timeout is the time to wait before failing reads and writes + Timeout time.Duration - idents map[string]string + entries map[string]entry listener net.Listener lock sync.Mutex } +type entry struct { + remoteHost string + ident string +} + func NewServer() *Server { return &Server{ - idents: map[string]string{}, + entries: map[string]entry{}, } } @@ -54,35 +68,77 @@ func (s *Server) Stop() error { return s.listener.Close() } -func (s *Server) Add(local, remote, ident string) { +func (s *Server) Add(local, remote net.Addr, ident string) { + if local == nil || remote == nil { + return + } + + _, localPort, err := net.SplitHostPort(local.String()) + if err != nil { + return + } + + remoteHost, remotePort, err := net.SplitHostPort(remote.String()) + if err != nil { + return + } + s.lock.Lock() - s.idents[local+","+remote] = ident + s.entries[localPort+","+remotePort] = entry{ + remoteHost: remoteHost, + ident: ident, + } s.lock.Unlock() } -func (s *Server) Remove(local, remote string) { +func (s *Server) Remove(local, remote net.Addr) { + if local == nil || remote == nil { + return + } + + _, localPort, err := net.SplitHostPort(local.String()) + if err != nil { + return + } + + _, remotePort, err := net.SplitHostPort(remote.String()) + if err != nil { + return + } + s.lock.Lock() - delete(s.idents, local+","+remote) + delete(s.entries, localPort+","+remotePort) s.lock.Unlock() } func (s *Server) handle(conn net.Conn) { defer conn.Close() + timeout := s.Timeout + if timeout == 0 { + timeout = DefaultTimeout + } + scan := bufio.NewScanner(conn) scan.Buffer(make([]byte, 32), 32) + + conn.SetReadDeadline(time.Now().Add(timeout)) if !scan.Scan() { return } - - line := scan.Text() - ports := strings.ReplaceAll(line, " ", "") + query := scan.Text() s.lock.Lock() - ident, ok := s.idents[ports] + entry, ok := s.entries[strings.ReplaceAll(query, " ", "")] s.lock.Unlock() if ok { - conn.Write([]byte(fmt.Sprintf("%s : USERID : Dispatch : %s\r\n", line, ident))) + remoteHost, _, err := net.SplitHostPort(conn.RemoteAddr().String()) + if err != nil || remoteHost != entry.remoteHost { + return + } + + conn.SetWriteDeadline(time.Now().Add(timeout)) + conn.Write([]byte(fmt.Sprintf("%s : USERID : Dispatch : %s\r\n", query, entry.ident))) } } diff --git a/pkg/irc/client.go b/pkg/irc/client.go index c5bc1600..184c8b48 100644 --- a/pkg/irc/client.go +++ b/pkg/irc/client.go @@ -195,17 +195,24 @@ func (c *Client) Host() string { return c.Config.Host } -func (c *Client) LocalPort() string { +func (c *Client) LocalAddr() net.Addr { c.lock.Lock() defer c.lock.Unlock() if c.conn != nil { - _, local, err := net.SplitHostPort(c.conn.LocalAddr().String()) - if err == nil { - return local - } + return c.conn.LocalAddr() } - return "" + return nil +} + +func (c *Client) RemoteAddr() net.Addr { + c.lock.Lock() + defer c.lock.Unlock() + + if c.conn != nil { + return c.conn.RemoteAddr() + } + return nil } func (c *Client) MOTD() []string { diff --git a/server/irc_handler.go b/server/irc_handler.go index 687a52ca..dbac8f9e 100644 --- a/server/irc_handler.go +++ b/server/irc_handler.go @@ -3,6 +3,7 @@ package server import ( "fmt" "log" + "net" "os" "strconv" "strings" @@ -46,7 +47,8 @@ func newIRCHandler(client *irc.Client, state *State) *ircHandler { func (i *ircHandler) run() { var lastConnErr error - var localPort string + var localAddr net.Addr + var remoteAddr net.Addr for { select { @@ -61,11 +63,10 @@ func (i *ircHandler) run() { case state := <-i.client.ConnectionChanged: if identd := i.state.srv.identd; identd != nil { if state.Connected { - if localPort = i.client.LocalPort(); localPort != "" { - identd.Add(localPort, i.client.Config.Port, i.client.Config.Username) - } + localAddr, remoteAddr = i.client.LocalAddr(), i.client.RemoteAddr() + identd.Add(localAddr, remoteAddr, i.client.Config.Username) } else { - identd.Remove(localPort, i.client.Config.Port) + identd.Remove(localAddr, remoteAddr) } } @@ -310,9 +311,7 @@ func (i *ircHandler) info(msg *irc.Message) { } if identd := i.state.srv.identd; identd != nil { - if localPort := i.client.LocalPort(); localPort != "" { - identd.Remove(localPort, i.client.Config.Port) - } + identd.Remove(i.client.LocalAddr(), i.client.RemoteAddr()) } if network, ok := i.state.network(i.client.Host()); ok { diff --git a/server/server.go b/server/server.go index c0c4f87a..e6ac8ce8 100644 --- a/server/server.go +++ b/server/server.go @@ -2,6 +2,7 @@ package server import ( "log" + "net" "net/http" "os" "strconv" @@ -64,6 +65,7 @@ func (d *Dispatch) Run() { if cfg.Identd { d.identd = ident.NewServer() + d.identd.Addr = net.JoinHostPort(cfg.Address, "113") go d.identd.Listen() } @@ -74,7 +76,7 @@ func (d *Dispatch) Run() { d.loadUsers() d.initFileServer() - d.startHTTP() + d.serveHTTP() } func (d *Dispatch) loadUsers() { @@ -119,12 +121,12 @@ func (d *Dispatch) loadUser(user *storage.User) { } } -func (d *Dispatch) startHTTP() { +func (d *Dispatch) serveHTTP() { cfg := d.Config() port := cfg.Port if cfg.Dev { - // The node dev network will proxy index page requests and + // The node dev server will proxy index page requests and // websocket connections to this port port = "1337" }