Filter SASL mechanisms on RPL_SASLMECHS without trying next one

This commit is contained in:
Ken-Håvard Lieng 2020-06-04 05:06:59 +02:00
parent 876d9ebdd0
commit 9d8d04fa7c
3 changed files with 50 additions and 41 deletions

View File

@ -4,11 +4,17 @@ import (
"strings" "strings"
) )
func (c *Client) HasCapability(name string, values ...string) bool { var clientWantedCaps = []string{"cap-notify"}
c.lock.Lock()
defer c.lock.Unlock()
if capValues, ok := c.enabledCapabilities[name]; ok { func (c *Client) GetCapability(name string) ([]string, bool) {
c.lock.Lock()
values, ok := c.enabledCapabilities[name]
c.lock.Unlock()
return values, ok
}
func (c *Client) HasCapability(name string, values ...string) bool {
if capValues, ok := c.GetCapability(name); ok {
if len(values) == 0 || capValues == nil { if len(values) == 0 || capValues == nil {
return true return true
} }
@ -25,20 +31,18 @@ func (c *Client) HasCapability(name string, values ...string) bool {
return false return false
} }
var clientWantedCaps = []string{"cap-notify"}
func (c *Client) beginCAP() { func (c *Client) beginCAP() {
c.write("CAP LS 302") c.write("CAP LS 302")
} }
func (c *Client) beginSASL() bool { func (c *Client) beginSASL() bool {
if c.negotiating { if c.negotiating {
for i, mech := range c.saslMechanisms { if mechs, ok := c.GetCapability("sasl"); ok {
if c.HasCapability("sasl", mech.Name()) { if mechs != nil {
c.currentSASLIndex = i c.filterSASLMechanisms(mechs)
c.authenticate(mech.Name())
return true
} }
c.tryNextSASL()
return true
} }
} }
return false return false

View File

@ -49,7 +49,7 @@ type Client struct {
enabledCapabilities map[string][]string enabledCapabilities map[string][]string
negotiating bool negotiating bool
saslMechanisms []SASL saslMechanisms []SASL
currentSASLIndex int currentSASL SASL
conn net.Conn conn net.Conn
connected bool connected bool
@ -137,11 +137,12 @@ func (c *Client) initSASL() {
} }
c.wantedCapabilities = append([]string{}, clientWantedCaps...) c.wantedCapabilities = append([]string{}, clientWantedCaps...)
c.negotiating = false
c.currentSASL = nil
if len(saslMechanisms) > 0 { if len(saslMechanisms) > 0 {
c.wantedCapabilities = append(c.wantedCapabilities, "sasl") c.wantedCapabilities = append(c.wantedCapabilities, "sasl")
c.saslMechanisms = saslMechanisms c.saslMechanisms = saslMechanisms
c.currentSASLIndex = 0
} }
} }

View File

@ -112,24 +112,41 @@ func (s *SASLScram) Step(response string) (string, error) {
} }
} }
func (c *Client) currentSASL() SASL { func (c *Client) tryNextSASL() {
return c.saslMechanisms[c.currentSASLIndex] if len(c.saslMechanisms) > 0 {
c.currentSASL, c.saslMechanisms = c.saslMechanisms[0], c.saslMechanisms[1:]
c.authenticate(c.currentSASL.Name())
} else {
c.finishCAP()
}
} }
func (c *Client) nextSASL() SASL { func (c *Client) filterSASLMechanisms(supportedMechs []string) {
if c.currentSASLIndex < len(c.saslMechanisms)-1 { saslMechanisms := []SASL{}
c.currentSASLIndex++
return c.saslMechanisms[c.currentSASLIndex] for _, mech := range c.saslMechanisms {
for _, supported := range supportedMechs {
if mech.Name() == supported {
saslMechanisms = append(saslMechanisms, mech)
break
} }
return nil }
}
c.saslMechanisms = saslMechanisms
} }
func (c *Client) handleSASL(msg *Message) { func (c *Client) handleSASL(msg *Message) {
switch msg.Command { switch msg.Command {
case AUTHENTICATE: case AUTHENTICATE:
auth, err := c.currentSASL().Step(msg.LastParam()) if c.currentSASL == nil {
return
}
// TODO: handle 400 chunking on incoming messages
auth, err := c.currentSASL.Step(msg.LastParam())
if err != nil { if err != nil {
c.finishCAP() c.tryNextSASL()
return return
} }
@ -143,33 +160,20 @@ func (c *Client) handleSASL(msg *Message) {
c.authenticate("+") c.authenticate("+")
} }
case ERR_SASLFAIL, ERR_SASLTOOLONG, ERR_SASLABORTED:
c.tryNextSASL()
case RPL_SASLMECHS: case RPL_SASLMECHS:
if len(msg.Params) > 1 { if len(msg.Params) > 1 {
supportedMechs := strings.Split(msg.Params[1], ",") supportedMechs := strings.Split(msg.Params[1], ",")
c.filterSASLMechanisms(supportedMechs)
}
for i, mech := range c.saslMechanisms { if len(c.saslMechanisms) == 0 {
for _, supported := range supportedMechs {
if mech.Name() == supported && i > c.currentSASLIndex {
c.currentSASLIndex = i
c.authenticate(mech.Name())
return
}
}
}
}
c.finishCAP()
case ERR_SASLFAIL:
if next := c.nextSASL(); next != nil {
c.authenticate(next.Name())
} else {
c.finishCAP() c.finishCAP()
} }
case RPL_SASLSUCCESS, RPL_LOGGEDIN: case RPL_SASLSUCCESS, RPL_LOGGEDIN, ERR_NICKLOCKED:
c.finishCAP()
case ERR_NICKLOCKED, ERR_SASLTOOLONG, ERR_SASLABORTED:
c.finishCAP() c.finishCAP()
} }
} }