Filter SASL mechanisms on RPL_SASLMECHS without trying next one

This commit is contained in:
Ken-Håvard Lieng 2020-06-04 05:06:59 +02:00
parent 876d9ebdd0
commit 9d8d04fa7c
3 changed files with 50 additions and 41 deletions

View File

@ -4,11 +4,17 @@ import (
"strings"
)
func (c *Client) HasCapability(name string, values ...string) bool {
c.lock.Lock()
defer c.lock.Unlock()
var clientWantedCaps = []string{"cap-notify"}
if capValues, ok := c.enabledCapabilities[name]; ok {
func (c *Client) GetCapability(name string) ([]string, bool) {
c.lock.Lock()
values, ok := c.enabledCapabilities[name]
c.lock.Unlock()
return values, ok
}
func (c *Client) HasCapability(name string, values ...string) bool {
if capValues, ok := c.GetCapability(name); ok {
if len(values) == 0 || capValues == nil {
return true
}
@ -25,20 +31,18 @@ func (c *Client) HasCapability(name string, values ...string) bool {
return false
}
var clientWantedCaps = []string{"cap-notify"}
func (c *Client) beginCAP() {
c.write("CAP LS 302")
}
func (c *Client) beginSASL() bool {
if c.negotiating {
for i, mech := range c.saslMechanisms {
if c.HasCapability("sasl", mech.Name()) {
c.currentSASLIndex = i
c.authenticate(mech.Name())
return true
if mechs, ok := c.GetCapability("sasl"); ok {
if mechs != nil {
c.filterSASLMechanisms(mechs)
}
c.tryNextSASL()
return true
}
}
return false

View File

@ -49,7 +49,7 @@ type Client struct {
enabledCapabilities map[string][]string
negotiating bool
saslMechanisms []SASL
currentSASLIndex int
currentSASL SASL
conn net.Conn
connected bool
@ -137,11 +137,12 @@ func (c *Client) initSASL() {
}
c.wantedCapabilities = append([]string{}, clientWantedCaps...)
c.negotiating = false
c.currentSASL = nil
if len(saslMechanisms) > 0 {
c.wantedCapabilities = append(c.wantedCapabilities, "sasl")
c.saslMechanisms = saslMechanisms
c.currentSASLIndex = 0
}
}

View File

@ -112,24 +112,41 @@ func (s *SASLScram) Step(response string) (string, error) {
}
}
func (c *Client) currentSASL() SASL {
return c.saslMechanisms[c.currentSASLIndex]
func (c *Client) tryNextSASL() {
if len(c.saslMechanisms) > 0 {
c.currentSASL, c.saslMechanisms = c.saslMechanisms[0], c.saslMechanisms[1:]
c.authenticate(c.currentSASL.Name())
} else {
c.finishCAP()
}
}
func (c *Client) nextSASL() SASL {
if c.currentSASLIndex < len(c.saslMechanisms)-1 {
c.currentSASLIndex++
return c.saslMechanisms[c.currentSASLIndex]
func (c *Client) filterSASLMechanisms(supportedMechs []string) {
saslMechanisms := []SASL{}
for _, mech := range c.saslMechanisms {
for _, supported := range supportedMechs {
if mech.Name() == supported {
saslMechanisms = append(saslMechanisms, mech)
break
}
}
}
return nil
c.saslMechanisms = saslMechanisms
}
func (c *Client) handleSASL(msg *Message) {
switch msg.Command {
case AUTHENTICATE:
auth, err := c.currentSASL().Step(msg.LastParam())
if c.currentSASL == nil {
return
}
// TODO: handle 400 chunking on incoming messages
auth, err := c.currentSASL.Step(msg.LastParam())
if err != nil {
c.finishCAP()
c.tryNextSASL()
return
}
@ -143,33 +160,20 @@ func (c *Client) handleSASL(msg *Message) {
c.authenticate("+")
}
case ERR_SASLFAIL, ERR_SASLTOOLONG, ERR_SASLABORTED:
c.tryNextSASL()
case RPL_SASLMECHS:
if len(msg.Params) > 1 {
supportedMechs := strings.Split(msg.Params[1], ",")
for i, mech := range c.saslMechanisms {
for _, supported := range supportedMechs {
if mech.Name() == supported && i > c.currentSASLIndex {
c.currentSASLIndex = i
c.authenticate(mech.Name())
return
}
}
}
c.filterSASLMechanisms(supportedMechs)
}
c.finishCAP()
case ERR_SASLFAIL:
if next := c.nextSASL(); next != nil {
c.authenticate(next.Name())
} else {
if len(c.saslMechanisms) == 0 {
c.finishCAP()
}
case RPL_SASLSUCCESS, RPL_LOGGEDIN:
c.finishCAP()
case ERR_NICKLOCKED, ERR_SASLTOOLONG, ERR_SASLABORTED:
case RPL_SASLSUCCESS, RPL_LOGGEDIN, ERR_NICKLOCKED:
c.finishCAP()
}
}