Filter SASL mechanisms on RPL_SASLMECHS without trying next one
This commit is contained in:
parent
876d9ebdd0
commit
9d8d04fa7c
@ -4,11 +4,17 @@ import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
func (c *Client) HasCapability(name string, values ...string) bool {
|
||||
c.lock.Lock()
|
||||
defer c.lock.Unlock()
|
||||
var clientWantedCaps = []string{"cap-notify"}
|
||||
|
||||
if capValues, ok := c.enabledCapabilities[name]; ok {
|
||||
func (c *Client) GetCapability(name string) ([]string, bool) {
|
||||
c.lock.Lock()
|
||||
values, ok := c.enabledCapabilities[name]
|
||||
c.lock.Unlock()
|
||||
return values, ok
|
||||
}
|
||||
|
||||
func (c *Client) HasCapability(name string, values ...string) bool {
|
||||
if capValues, ok := c.GetCapability(name); ok {
|
||||
if len(values) == 0 || capValues == nil {
|
||||
return true
|
||||
}
|
||||
@ -25,20 +31,18 @@ func (c *Client) HasCapability(name string, values ...string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
var clientWantedCaps = []string{"cap-notify"}
|
||||
|
||||
func (c *Client) beginCAP() {
|
||||
c.write("CAP LS 302")
|
||||
}
|
||||
|
||||
func (c *Client) beginSASL() bool {
|
||||
if c.negotiating {
|
||||
for i, mech := range c.saslMechanisms {
|
||||
if c.HasCapability("sasl", mech.Name()) {
|
||||
c.currentSASLIndex = i
|
||||
c.authenticate(mech.Name())
|
||||
return true
|
||||
if mechs, ok := c.GetCapability("sasl"); ok {
|
||||
if mechs != nil {
|
||||
c.filterSASLMechanisms(mechs)
|
||||
}
|
||||
c.tryNextSASL()
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
|
@ -49,7 +49,7 @@ type Client struct {
|
||||
enabledCapabilities map[string][]string
|
||||
negotiating bool
|
||||
saslMechanisms []SASL
|
||||
currentSASLIndex int
|
||||
currentSASL SASL
|
||||
|
||||
conn net.Conn
|
||||
connected bool
|
||||
@ -137,11 +137,12 @@ func (c *Client) initSASL() {
|
||||
}
|
||||
|
||||
c.wantedCapabilities = append([]string{}, clientWantedCaps...)
|
||||
c.negotiating = false
|
||||
c.currentSASL = nil
|
||||
|
||||
if len(saslMechanisms) > 0 {
|
||||
c.wantedCapabilities = append(c.wantedCapabilities, "sasl")
|
||||
c.saslMechanisms = saslMechanisms
|
||||
c.currentSASLIndex = 0
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -112,24 +112,41 @@ func (s *SASLScram) Step(response string) (string, error) {
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) currentSASL() SASL {
|
||||
return c.saslMechanisms[c.currentSASLIndex]
|
||||
func (c *Client) tryNextSASL() {
|
||||
if len(c.saslMechanisms) > 0 {
|
||||
c.currentSASL, c.saslMechanisms = c.saslMechanisms[0], c.saslMechanisms[1:]
|
||||
c.authenticate(c.currentSASL.Name())
|
||||
} else {
|
||||
c.finishCAP()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) nextSASL() SASL {
|
||||
if c.currentSASLIndex < len(c.saslMechanisms)-1 {
|
||||
c.currentSASLIndex++
|
||||
return c.saslMechanisms[c.currentSASLIndex]
|
||||
func (c *Client) filterSASLMechanisms(supportedMechs []string) {
|
||||
saslMechanisms := []SASL{}
|
||||
|
||||
for _, mech := range c.saslMechanisms {
|
||||
for _, supported := range supportedMechs {
|
||||
if mech.Name() == supported {
|
||||
saslMechanisms = append(saslMechanisms, mech)
|
||||
break
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
c.saslMechanisms = saslMechanisms
|
||||
}
|
||||
|
||||
func (c *Client) handleSASL(msg *Message) {
|
||||
switch msg.Command {
|
||||
case AUTHENTICATE:
|
||||
auth, err := c.currentSASL().Step(msg.LastParam())
|
||||
if c.currentSASL == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// TODO: handle 400 chunking on incoming messages
|
||||
auth, err := c.currentSASL.Step(msg.LastParam())
|
||||
if err != nil {
|
||||
c.finishCAP()
|
||||
c.tryNextSASL()
|
||||
return
|
||||
}
|
||||
|
||||
@ -143,33 +160,20 @@ func (c *Client) handleSASL(msg *Message) {
|
||||
c.authenticate("+")
|
||||
}
|
||||
|
||||
case ERR_SASLFAIL, ERR_SASLTOOLONG, ERR_SASLABORTED:
|
||||
c.tryNextSASL()
|
||||
|
||||
case RPL_SASLMECHS:
|
||||
if len(msg.Params) > 1 {
|
||||
supportedMechs := strings.Split(msg.Params[1], ",")
|
||||
c.filterSASLMechanisms(supportedMechs)
|
||||
}
|
||||
|
||||
for i, mech := range c.saslMechanisms {
|
||||
for _, supported := range supportedMechs {
|
||||
if mech.Name() == supported && i > c.currentSASLIndex {
|
||||
c.currentSASLIndex = i
|
||||
c.authenticate(mech.Name())
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
c.finishCAP()
|
||||
|
||||
case ERR_SASLFAIL:
|
||||
if next := c.nextSASL(); next != nil {
|
||||
c.authenticate(next.Name())
|
||||
} else {
|
||||
if len(c.saslMechanisms) == 0 {
|
||||
c.finishCAP()
|
||||
}
|
||||
|
||||
case RPL_SASLSUCCESS, RPL_LOGGEDIN:
|
||||
c.finishCAP()
|
||||
|
||||
case ERR_NICKLOCKED, ERR_SASLTOOLONG, ERR_SASLABORTED:
|
||||
case RPL_SASLSUCCESS, RPL_LOGGEDIN, ERR_NICKLOCKED:
|
||||
c.finishCAP()
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user