diff --git a/pkg/irc/cap.go b/pkg/irc/cap.go index 58172030..96ddeeae 100644 --- a/pkg/irc/cap.go +++ b/pkg/irc/cap.go @@ -4,11 +4,17 @@ import ( "strings" ) -func (c *Client) HasCapability(name string, values ...string) bool { - c.lock.Lock() - defer c.lock.Unlock() +var clientWantedCaps = []string{"cap-notify"} - if capValues, ok := c.enabledCapabilities[name]; ok { +func (c *Client) GetCapability(name string) ([]string, bool) { + c.lock.Lock() + values, ok := c.enabledCapabilities[name] + c.lock.Unlock() + return values, ok +} + +func (c *Client) HasCapability(name string, values ...string) bool { + if capValues, ok := c.GetCapability(name); ok { if len(values) == 0 || capValues == nil { return true } @@ -25,20 +31,18 @@ func (c *Client) HasCapability(name string, values ...string) bool { return false } -var clientWantedCaps = []string{"cap-notify"} - func (c *Client) beginCAP() { c.write("CAP LS 302") } func (c *Client) beginSASL() bool { if c.negotiating { - for i, mech := range c.saslMechanisms { - if c.HasCapability("sasl", mech.Name()) { - c.currentSASLIndex = i - c.authenticate(mech.Name()) - return true + if mechs, ok := c.GetCapability("sasl"); ok { + if mechs != nil { + c.filterSASLMechanisms(mechs) } + c.tryNextSASL() + return true } } return false diff --git a/pkg/irc/client.go b/pkg/irc/client.go index 3f60faec..bb97ce3f 100644 --- a/pkg/irc/client.go +++ b/pkg/irc/client.go @@ -49,7 +49,7 @@ type Client struct { enabledCapabilities map[string][]string negotiating bool saslMechanisms []SASL - currentSASLIndex int + currentSASL SASL conn net.Conn connected bool @@ -137,11 +137,12 @@ func (c *Client) initSASL() { } c.wantedCapabilities = append([]string{}, clientWantedCaps...) + c.negotiating = false + c.currentSASL = nil if len(saslMechanisms) > 0 { c.wantedCapabilities = append(c.wantedCapabilities, "sasl") c.saslMechanisms = saslMechanisms - c.currentSASLIndex = 0 } } diff --git a/pkg/irc/sasl.go b/pkg/irc/sasl.go index ddeb1235..753670fc 100644 --- a/pkg/irc/sasl.go +++ b/pkg/irc/sasl.go @@ -112,24 +112,41 @@ func (s *SASLScram) Step(response string) (string, error) { } } -func (c *Client) currentSASL() SASL { - return c.saslMechanisms[c.currentSASLIndex] +func (c *Client) tryNextSASL() { + if len(c.saslMechanisms) > 0 { + c.currentSASL, c.saslMechanisms = c.saslMechanisms[0], c.saslMechanisms[1:] + c.authenticate(c.currentSASL.Name()) + } else { + c.finishCAP() + } } -func (c *Client) nextSASL() SASL { - if c.currentSASLIndex < len(c.saslMechanisms)-1 { - c.currentSASLIndex++ - return c.saslMechanisms[c.currentSASLIndex] +func (c *Client) filterSASLMechanisms(supportedMechs []string) { + saslMechanisms := []SASL{} + + for _, mech := range c.saslMechanisms { + for _, supported := range supportedMechs { + if mech.Name() == supported { + saslMechanisms = append(saslMechanisms, mech) + break + } + } } - return nil + + c.saslMechanisms = saslMechanisms } func (c *Client) handleSASL(msg *Message) { switch msg.Command { case AUTHENTICATE: - auth, err := c.currentSASL().Step(msg.LastParam()) + if c.currentSASL == nil { + return + } + + // TODO: handle 400 chunking on incoming messages + auth, err := c.currentSASL.Step(msg.LastParam()) if err != nil { - c.finishCAP() + c.tryNextSASL() return } @@ -143,33 +160,20 @@ func (c *Client) handleSASL(msg *Message) { c.authenticate("+") } + case ERR_SASLFAIL, ERR_SASLTOOLONG, ERR_SASLABORTED: + c.tryNextSASL() + case RPL_SASLMECHS: if len(msg.Params) > 1 { supportedMechs := strings.Split(msg.Params[1], ",") - - for i, mech := range c.saslMechanisms { - for _, supported := range supportedMechs { - if mech.Name() == supported && i > c.currentSASLIndex { - c.currentSASLIndex = i - c.authenticate(mech.Name()) - return - } - } - } + c.filterSASLMechanisms(supportedMechs) } - c.finishCAP() - case ERR_SASLFAIL: - if next := c.nextSASL(); next != nil { - c.authenticate(next.Name()) - } else { + if len(c.saslMechanisms) == 0 { c.finishCAP() } - case RPL_SASLSUCCESS, RPL_LOGGEDIN: - c.finishCAP() - - case ERR_NICKLOCKED, ERR_SASLTOOLONG, ERR_SASLABORTED: + case RPL_SASLSUCCESS, RPL_LOGGEDIN, ERR_NICKLOCKED: c.finishCAP() } }