Improve connection handling

This commit is contained in:
Ken-Håvard Lieng 2017-07-03 07:35:38 +02:00
parent 9dffb541b9
commit 0a96ebb428
5 changed files with 24 additions and 25 deletions

View File

@ -8,7 +8,6 @@ import (
"sync" "sync"
"github.com/jpillora/backoff" "github.com/jpillora/backoff"
"github.com/matryer/resync"
) )
type Client struct { type Client struct {
@ -35,9 +34,7 @@ type Client struct {
quit chan struct{} quit chan struct{}
reconnect chan struct{} reconnect chan struct{}
ready sync.WaitGroup
sendRecv sync.WaitGroup sendRecv sync.WaitGroup
once resync.Once
lock sync.Mutex lock sync.Mutex
} }

View File

@ -3,12 +3,17 @@ package irc
import ( import (
"bufio" "bufio"
"crypto/tls" "crypto/tls"
"errors"
"fmt" "fmt"
"net" "net"
"strings" "strings"
"time" "time"
) )
var (
ErrBadProtocol = errors.New("This server does not speak IRC")
)
func (c *Client) Connect(address string) { func (c *Client) Connect(address string) {
if idx := strings.Index(address, ":"); idx < 0 { if idx := strings.Index(address, ":"); idx < 0 {
c.Host = address c.Host = address
@ -60,10 +65,10 @@ func (c *Client) run() {
case <-c.reconnect: case <-c.reconnect:
c.disconnect() c.disconnect()
c.connChange(false, nil)
c.sendRecv.Wait() c.sendRecv.Wait()
c.reconnect = make(chan struct{}) c.reconnect = make(chan struct{})
c.once.Reset()
c.tryConnect() c.tryConnect()
} }
@ -83,12 +88,10 @@ func (c *Client) connChange(connected bool, err error) {
} }
func (c *Client) disconnect() { func (c *Client) disconnect() {
c.connChange(false, nil)
c.lock.Lock() c.lock.Lock()
c.connected = false c.connected = false
c.lock.Unlock() c.lock.Unlock()
c.once.Do(c.ready.Done)
c.conn.Close() c.conn.Close()
} }
@ -141,9 +144,7 @@ func (c *Client) connect() error {
c.register() c.register()
c.ready.Add(1) c.sendRecv.Add(1)
c.sendRecv.Add(2)
go c.send()
go c.recv() go c.recv()
return nil return nil
@ -152,8 +153,6 @@ func (c *Client) connect() error {
func (c *Client) send() { func (c *Client) send() {
defer c.sendRecv.Done() defer c.sendRecv.Done()
c.ready.Wait()
for { for {
select { select {
case <-c.quit: case <-c.quit:
@ -188,6 +187,11 @@ func (c *Client) recv() {
} }
msg := parseMessage(line) msg := parseMessage(line)
if msg == nil {
close(c.quit)
c.connChange(false, ErrBadProtocol)
return
}
switch msg.Command { switch msg.Command {
case Ping: case Ping:
@ -205,7 +209,8 @@ func (c *Client) recv() {
case ReplyWelcome: case ReplyWelcome:
c.setNick(msg.Params[0]) c.setNick(msg.Params[0])
c.once.Do(c.ready.Done) c.sendRecv.Add(1)
go c.send()
case ErrNicknameInUse: case ErrNicknameInUse:
if c.HandleNickInUse != nil { if c.HandleNickInUse != nil {

View File

@ -131,9 +131,7 @@ func TestRecv(t *testing.T) {
buf.WriteString("001 foo\r\n") buf.WriteString("001 foo\r\n")
c.reader = bufio.NewReader(buf) c.reader = bufio.NewReader(buf)
c.ready.Add(1)
c.sendRecv.Add(2) c.sendRecv.Add(2)
go c.send()
go c.recv() go c.recv()
assert.Equal(t, "PONG :test\r\n", <-conn.hook) assert.Equal(t, "PONG :test\r\n", <-conn.hook)
@ -143,7 +141,6 @@ func TestRecv(t *testing.T) {
func TestRecvTriggersReconnect(t *testing.T) { func TestRecvTriggersReconnect(t *testing.T) {
c := testClient() c := testClient()
c.conn = &mockConn{} c.conn = &mockConn{}
c.ready.Add(1)
c.reader = bufio.NewReader(&bytes.Buffer{}) c.reader = bufio.NewReader(&bytes.Buffer{})
done := make(chan struct{}) done := make(chan struct{})
ok := false ok := false

View File

@ -30,8 +30,7 @@ func parseMessage(line string) *Message {
if cmdStart > 0 { if cmdStart > 0 {
msg.Prefix = line[1 : cmdStart-1] msg.Prefix = line[1 : cmdStart-1]
} else { } else {
// Invalid message return nil
return &msg
} }
if i := strings.Index(msg.Prefix, "!"); i > 0 { if i := strings.Index(msg.Prefix, "!"); i > 0 {
@ -43,22 +42,24 @@ func parseMessage(line string) *Message {
} }
} }
var usesTrailing bool
var trailing string var trailing string
if i := strings.Index(line, " :"); i > 0 { if i := strings.Index(line, " :"); i > 0 {
cmdEnd = i cmdEnd = i
trailing = line[i+2:] trailing = line[i+2:]
usesTrailing = true
} }
cmd := strings.Split(line[cmdStart:cmdEnd], " ") cmd := strings.Split(line[cmdStart:cmdEnd], " ")
msg.Command = cmd[0] msg.Command = cmd[0]
if msg.Command == "" {
return nil
}
if len(cmd) > 1 { if len(cmd) > 1 {
msg.Params = cmd[1:] msg.Params = cmd[1:]
} }
if usesTrailing { if cmdEnd != len(line) {
msg.Params = append(msg.Params, trailing) msg.Params = append(msg.Params, trailing)
} }

View File

@ -117,14 +117,13 @@ func (h *wsHandler) quit(b []byte) {
var data Quit var data Quit
json.Unmarshal(b, &data) json.Unmarshal(b, &data)
log.Println(h.addr, "[IRC] Remove server", data.Server)
if i, ok := h.session.getIRC(data.Server); ok { if i, ok := h.session.getIRC(data.Server); ok {
log.Println(h.addr, "[IRC] Remove server", data.Server)
i.Quit()
h.session.deleteIRC(data.Server) h.session.deleteIRC(data.Server)
channelStore.RemoveUserAll(i.GetNick(), data.Server) i.Quit()
go h.session.user.RemoveServer(data.Server)
} }
go h.session.user.RemoveServer(data.Server)
} }
func (h *wsHandler) message(b []byte) { func (h *wsHandler) message(b []byte) {