package irc import ( "bytes" "crypto/sha1" "crypto/sha256" "crypto/sha512" "encoding/base64" "errors" "hash" "strings" "github.com/xdg-go/scram" ) var DefaultSASLMechanisms = []string{ "EXTERNAL", //"SCRAM-SHA-512", "SCRAM-SHA-256", //"SCRAM-SHA-1", "PLAIN", } type SASL interface { Name() string Step(response string) (string, error) } type SASLPlain struct { Username string Password string } func (s *SASLPlain) Name() string { return "PLAIN" } func (s *SASLPlain) Step(string) (string, error) { buf := bytes.Buffer{} buf.WriteString(s.Username) buf.WriteByte(0x0) buf.WriteString(s.Username) buf.WriteByte(0x0) buf.WriteString(s.Password) return base64.StdEncoding.EncodeToString(buf.Bytes()), nil } type SASLExternal struct{} func (s *SASLExternal) Name() string { return "EXTERNAL" } func (s *SASLExternal) Step(string) (string, error) { return "+", nil } var ( scramHashes = map[string]scram.HashGeneratorFcn{ "SHA-512": func() hash.Hash { return sha512.New() }, "SHA-256": func() hash.Hash { return sha256.New() }, "SHA-1": func() hash.Hash { return sha1.New() }, } ErrUnsupportedHash = errors.New("unsupported hash algorithm") ) type SASLScram struct { Username string Password string Hash string conv *scram.ClientConversation } func (s *SASLScram) Name() string { return "SCRAM-" + s.Hash } func (s *SASLScram) Step(response string) (string, error) { if s.conv == nil { if hash, ok := scramHashes[s.Hash]; ok { client, err := hash.NewClient(s.Username, s.Password, "") if err != nil { return "", err } s.conv = client.NewConversation() } else { return "", ErrUnsupportedHash } } challenge := "" if response != "+" { b, err := base64.StdEncoding.DecodeString(response) if err != nil { return "", err } challenge = string(b) } res, err := s.conv.Step(challenge) if err != nil { return "", err } if s.conv.Done() { s.conv = nil return "+", nil } else { return base64.StdEncoding.EncodeToString([]byte(res)), nil } } func (c *Client) tryNextSASL() { if len(c.saslMechanisms) > 0 { c.currentSASL, c.saslMechanisms = c.saslMechanisms[0], c.saslMechanisms[1:] c.authenticate(c.currentSASL.Name()) } else { c.finishCAP() } } func (c *Client) filterSASLMechanisms(supportedMechs []string) { saslMechanisms := []SASL{} for _, mech := range c.saslMechanisms { for _, supported := range supportedMechs { if mech.Name() == supported { saslMechanisms = append(saslMechanisms, mech) break } } } c.saslMechanisms = saslMechanisms } func (c *Client) handleSASL(msg *Message) { switch msg.Command { case AUTHENTICATE: if c.currentSASL == nil { return } // TODO: handle 400 chunking on incoming messages auth, err := c.currentSASL.Step(msg.LastParam()) if err != nil { c.tryNextSASL() return } for len(auth) >= 400 { c.authenticate(auth) auth = auth[400:] } if len(auth) > 0 { c.authenticate(auth) } else { c.authenticate("+") } case ERR_SASLFAIL, ERR_SASLTOOLONG, ERR_SASLABORTED: c.tryNextSASL() case RPL_SASLMECHS: if len(msg.Params) > 1 { supportedMechs := strings.Split(msg.Params[1], ",") c.filterSASLMechanisms(supportedMechs) } if len(c.saslMechanisms) == 0 { c.finishCAP() } case RPL_SASLSUCCESS, RPL_LOGGEDIN, ERR_NICKLOCKED: c.finishCAP() } }