diff --git a/irc/client.go b/irc/client.go index 4229cda0..f42385a5 100644 --- a/irc/client.go +++ b/irc/client.go @@ -8,7 +8,6 @@ import ( "sync" "github.com/jpillora/backoff" - "github.com/matryer/resync" ) type Client struct { @@ -35,9 +34,7 @@ type Client struct { quit chan struct{} reconnect chan struct{} - ready sync.WaitGroup sendRecv sync.WaitGroup - once resync.Once lock sync.Mutex } diff --git a/irc/conn.go b/irc/conn.go index f7ff061b..a93c588f 100644 --- a/irc/conn.go +++ b/irc/conn.go @@ -3,12 +3,17 @@ package irc import ( "bufio" "crypto/tls" + "errors" "fmt" "net" "strings" "time" ) +var ( + ErrBadProtocol = errors.New("This server does not speak IRC") +) + func (c *Client) Connect(address string) { if idx := strings.Index(address, ":"); idx < 0 { c.Host = address @@ -60,10 +65,10 @@ func (c *Client) run() { case <-c.reconnect: c.disconnect() + c.connChange(false, nil) c.sendRecv.Wait() c.reconnect = make(chan struct{}) - c.once.Reset() c.tryConnect() } @@ -83,12 +88,10 @@ func (c *Client) connChange(connected bool, err error) { } func (c *Client) disconnect() { - c.connChange(false, nil) c.lock.Lock() c.connected = false c.lock.Unlock() - c.once.Do(c.ready.Done) c.conn.Close() } @@ -141,9 +144,7 @@ func (c *Client) connect() error { c.register() - c.ready.Add(1) - c.sendRecv.Add(2) - go c.send() + c.sendRecv.Add(1) go c.recv() return nil @@ -152,8 +153,6 @@ func (c *Client) connect() error { func (c *Client) send() { defer c.sendRecv.Done() - c.ready.Wait() - for { select { case <-c.quit: @@ -188,6 +187,11 @@ func (c *Client) recv() { } msg := parseMessage(line) + if msg == nil { + close(c.quit) + c.connChange(false, ErrBadProtocol) + return + } switch msg.Command { case Ping: @@ -205,7 +209,8 @@ func (c *Client) recv() { case ReplyWelcome: c.setNick(msg.Params[0]) - c.once.Do(c.ready.Done) + c.sendRecv.Add(1) + go c.send() case ErrNicknameInUse: if c.HandleNickInUse != nil { diff --git a/irc/conn_test.go b/irc/conn_test.go index c56a792a..6f781b9c 100644 --- a/irc/conn_test.go +++ b/irc/conn_test.go @@ -131,9 +131,7 @@ func TestRecv(t *testing.T) { buf.WriteString("001 foo\r\n") c.reader = bufio.NewReader(buf) - c.ready.Add(1) c.sendRecv.Add(2) - go c.send() go c.recv() assert.Equal(t, "PONG :test\r\n", <-conn.hook) @@ -143,7 +141,6 @@ func TestRecv(t *testing.T) { func TestRecvTriggersReconnect(t *testing.T) { c := testClient() c.conn = &mockConn{} - c.ready.Add(1) c.reader = bufio.NewReader(&bytes.Buffer{}) done := make(chan struct{}) ok := false diff --git a/irc/message.go b/irc/message.go index 3a309086..32a320e3 100644 --- a/irc/message.go +++ b/irc/message.go @@ -30,8 +30,7 @@ func parseMessage(line string) *Message { if cmdStart > 0 { msg.Prefix = line[1 : cmdStart-1] } else { - // Invalid message - return &msg + return nil } if i := strings.Index(msg.Prefix, "!"); i > 0 { @@ -43,22 +42,24 @@ func parseMessage(line string) *Message { } } - var usesTrailing bool var trailing string if i := strings.Index(line, " :"); i > 0 { cmdEnd = i trailing = line[i+2:] - usesTrailing = true } cmd := strings.Split(line[cmdStart:cmdEnd], " ") msg.Command = cmd[0] + if msg.Command == "" { + return nil + } + if len(cmd) > 1 { msg.Params = cmd[1:] } - if usesTrailing { + if cmdEnd != len(line) { msg.Params = append(msg.Params, trailing) } diff --git a/server/websocket_handler.go b/server/websocket_handler.go index 4ea652f7..565be204 100644 --- a/server/websocket_handler.go +++ b/server/websocket_handler.go @@ -117,14 +117,13 @@ func (h *wsHandler) quit(b []byte) { var data Quit json.Unmarshal(b, &data) + log.Println(h.addr, "[IRC] Remove server", data.Server) if i, ok := h.session.getIRC(data.Server); ok { - log.Println(h.addr, "[IRC] Remove server", data.Server) - - i.Quit() h.session.deleteIRC(data.Server) - channelStore.RemoveUserAll(i.GetNick(), data.Server) - go h.session.user.RemoveServer(data.Server) + i.Quit() } + + go h.session.user.RemoveServer(data.Server) } func (h *wsHandler) message(b []byte) {