Improve connection handling
This commit is contained in:
parent
9dffb541b9
commit
0a96ebb428
@ -8,7 +8,6 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/jpillora/backoff"
|
"github.com/jpillora/backoff"
|
||||||
"github.com/matryer/resync"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type Client struct {
|
type Client struct {
|
||||||
@ -35,9 +34,7 @@ type Client struct {
|
|||||||
|
|
||||||
quit chan struct{}
|
quit chan struct{}
|
||||||
reconnect chan struct{}
|
reconnect chan struct{}
|
||||||
ready sync.WaitGroup
|
|
||||||
sendRecv sync.WaitGroup
|
sendRecv sync.WaitGroup
|
||||||
once resync.Once
|
|
||||||
lock sync.Mutex
|
lock sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
|
23
irc/conn.go
23
irc/conn.go
@ -3,12 +3,17 @@ package irc
|
|||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrBadProtocol = errors.New("This server does not speak IRC")
|
||||||
|
)
|
||||||
|
|
||||||
func (c *Client) Connect(address string) {
|
func (c *Client) Connect(address string) {
|
||||||
if idx := strings.Index(address, ":"); idx < 0 {
|
if idx := strings.Index(address, ":"); idx < 0 {
|
||||||
c.Host = address
|
c.Host = address
|
||||||
@ -60,10 +65,10 @@ func (c *Client) run() {
|
|||||||
|
|
||||||
case <-c.reconnect:
|
case <-c.reconnect:
|
||||||
c.disconnect()
|
c.disconnect()
|
||||||
|
c.connChange(false, nil)
|
||||||
|
|
||||||
c.sendRecv.Wait()
|
c.sendRecv.Wait()
|
||||||
c.reconnect = make(chan struct{})
|
c.reconnect = make(chan struct{})
|
||||||
c.once.Reset()
|
|
||||||
|
|
||||||
c.tryConnect()
|
c.tryConnect()
|
||||||
}
|
}
|
||||||
@ -83,12 +88,10 @@ func (c *Client) connChange(connected bool, err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) disconnect() {
|
func (c *Client) disconnect() {
|
||||||
c.connChange(false, nil)
|
|
||||||
c.lock.Lock()
|
c.lock.Lock()
|
||||||
c.connected = false
|
c.connected = false
|
||||||
c.lock.Unlock()
|
c.lock.Unlock()
|
||||||
|
|
||||||
c.once.Do(c.ready.Done)
|
|
||||||
c.conn.Close()
|
c.conn.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -141,9 +144,7 @@ func (c *Client) connect() error {
|
|||||||
|
|
||||||
c.register()
|
c.register()
|
||||||
|
|
||||||
c.ready.Add(1)
|
c.sendRecv.Add(1)
|
||||||
c.sendRecv.Add(2)
|
|
||||||
go c.send()
|
|
||||||
go c.recv()
|
go c.recv()
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@ -152,8 +153,6 @@ func (c *Client) connect() error {
|
|||||||
func (c *Client) send() {
|
func (c *Client) send() {
|
||||||
defer c.sendRecv.Done()
|
defer c.sendRecv.Done()
|
||||||
|
|
||||||
c.ready.Wait()
|
|
||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-c.quit:
|
case <-c.quit:
|
||||||
@ -188,6 +187,11 @@ func (c *Client) recv() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
msg := parseMessage(line)
|
msg := parseMessage(line)
|
||||||
|
if msg == nil {
|
||||||
|
close(c.quit)
|
||||||
|
c.connChange(false, ErrBadProtocol)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
switch msg.Command {
|
switch msg.Command {
|
||||||
case Ping:
|
case Ping:
|
||||||
@ -205,7 +209,8 @@ func (c *Client) recv() {
|
|||||||
|
|
||||||
case ReplyWelcome:
|
case ReplyWelcome:
|
||||||
c.setNick(msg.Params[0])
|
c.setNick(msg.Params[0])
|
||||||
c.once.Do(c.ready.Done)
|
c.sendRecv.Add(1)
|
||||||
|
go c.send()
|
||||||
|
|
||||||
case ErrNicknameInUse:
|
case ErrNicknameInUse:
|
||||||
if c.HandleNickInUse != nil {
|
if c.HandleNickInUse != nil {
|
||||||
|
@ -131,9 +131,7 @@ func TestRecv(t *testing.T) {
|
|||||||
buf.WriteString("001 foo\r\n")
|
buf.WriteString("001 foo\r\n")
|
||||||
c.reader = bufio.NewReader(buf)
|
c.reader = bufio.NewReader(buf)
|
||||||
|
|
||||||
c.ready.Add(1)
|
|
||||||
c.sendRecv.Add(2)
|
c.sendRecv.Add(2)
|
||||||
go c.send()
|
|
||||||
go c.recv()
|
go c.recv()
|
||||||
|
|
||||||
assert.Equal(t, "PONG :test\r\n", <-conn.hook)
|
assert.Equal(t, "PONG :test\r\n", <-conn.hook)
|
||||||
@ -143,7 +141,6 @@ func TestRecv(t *testing.T) {
|
|||||||
func TestRecvTriggersReconnect(t *testing.T) {
|
func TestRecvTriggersReconnect(t *testing.T) {
|
||||||
c := testClient()
|
c := testClient()
|
||||||
c.conn = &mockConn{}
|
c.conn = &mockConn{}
|
||||||
c.ready.Add(1)
|
|
||||||
c.reader = bufio.NewReader(&bytes.Buffer{})
|
c.reader = bufio.NewReader(&bytes.Buffer{})
|
||||||
done := make(chan struct{})
|
done := make(chan struct{})
|
||||||
ok := false
|
ok := false
|
||||||
|
@ -30,8 +30,7 @@ func parseMessage(line string) *Message {
|
|||||||
if cmdStart > 0 {
|
if cmdStart > 0 {
|
||||||
msg.Prefix = line[1 : cmdStart-1]
|
msg.Prefix = line[1 : cmdStart-1]
|
||||||
} else {
|
} else {
|
||||||
// Invalid message
|
return nil
|
||||||
return &msg
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if i := strings.Index(msg.Prefix, "!"); i > 0 {
|
if i := strings.Index(msg.Prefix, "!"); i > 0 {
|
||||||
@ -43,22 +42,24 @@ func parseMessage(line string) *Message {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var usesTrailing bool
|
|
||||||
var trailing string
|
var trailing string
|
||||||
|
|
||||||
if i := strings.Index(line, " :"); i > 0 {
|
if i := strings.Index(line, " :"); i > 0 {
|
||||||
cmdEnd = i
|
cmdEnd = i
|
||||||
trailing = line[i+2:]
|
trailing = line[i+2:]
|
||||||
usesTrailing = true
|
|
||||||
}
|
}
|
||||||
|
|
||||||
cmd := strings.Split(line[cmdStart:cmdEnd], " ")
|
cmd := strings.Split(line[cmdStart:cmdEnd], " ")
|
||||||
msg.Command = cmd[0]
|
msg.Command = cmd[0]
|
||||||
|
if msg.Command == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
if len(cmd) > 1 {
|
if len(cmd) > 1 {
|
||||||
msg.Params = cmd[1:]
|
msg.Params = cmd[1:]
|
||||||
}
|
}
|
||||||
|
|
||||||
if usesTrailing {
|
if cmdEnd != len(line) {
|
||||||
msg.Params = append(msg.Params, trailing)
|
msg.Params = append(msg.Params, trailing)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -117,14 +117,13 @@ func (h *wsHandler) quit(b []byte) {
|
|||||||
var data Quit
|
var data Quit
|
||||||
json.Unmarshal(b, &data)
|
json.Unmarshal(b, &data)
|
||||||
|
|
||||||
|
log.Println(h.addr, "[IRC] Remove server", data.Server)
|
||||||
if i, ok := h.session.getIRC(data.Server); ok {
|
if i, ok := h.session.getIRC(data.Server); ok {
|
||||||
log.Println(h.addr, "[IRC] Remove server", data.Server)
|
|
||||||
|
|
||||||
i.Quit()
|
|
||||||
h.session.deleteIRC(data.Server)
|
h.session.deleteIRC(data.Server)
|
||||||
channelStore.RemoveUserAll(i.GetNick(), data.Server)
|
i.Quit()
|
||||||
go h.session.user.RemoveServer(data.Server)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
go h.session.user.RemoveServer(data.Server)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *wsHandler) message(b []byte) {
|
func (h *wsHandler) message(b []byte) {
|
||||||
|
Loading…
Reference in New Issue
Block a user