Implement SCRAM-SHA-256
This commit is contained in:
parent
ead3b37cf9
commit
876d9ebdd0
36 changed files with 5089 additions and 72 deletions
|
@ -25,12 +25,32 @@ func (c *Client) HasCapability(name string, values ...string) bool {
|
|||
return false
|
||||
}
|
||||
|
||||
var clientWantedCaps = []string{}
|
||||
var clientWantedCaps = []string{"cap-notify"}
|
||||
|
||||
func (c *Client) writeCAP() {
|
||||
func (c *Client) beginCAP() {
|
||||
c.write("CAP LS 302")
|
||||
}
|
||||
|
||||
func (c *Client) beginSASL() bool {
|
||||
if c.negotiating {
|
||||
for i, mech := range c.saslMechanisms {
|
||||
if c.HasCapability("sasl", mech.Name()) {
|
||||
c.currentSASLIndex = i
|
||||
c.authenticate(mech.Name())
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (c *Client) finishCAP() {
|
||||
if c.negotiating {
|
||||
c.negotiating = false
|
||||
c.write("CAP END")
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) handleCAP(msg *Message) {
|
||||
if len(msg.Params) < 3 {
|
||||
c.write("CAP END")
|
||||
|
@ -39,9 +59,6 @@ func (c *Client) handleCAP(msg *Message) {
|
|||
|
||||
caps := parseCaps(msg.LastParam())
|
||||
|
||||
c.lock.Lock()
|
||||
defer c.lock.Unlock()
|
||||
|
||||
switch msg.Params[1] {
|
||||
case "LS":
|
||||
for cap, values := range caps {
|
||||
|
@ -58,6 +75,8 @@ func (c *Client) handleCAP(msg *Message) {
|
|||
return
|
||||
}
|
||||
|
||||
c.negotiating = true
|
||||
|
||||
reqCaps := []string{}
|
||||
for cap := range c.requestedCapabilities {
|
||||
reqCaps = append(reqCaps, cap)
|
||||
|
@ -67,19 +86,17 @@ func (c *Client) handleCAP(msg *Message) {
|
|||
}
|
||||
|
||||
case "ACK":
|
||||
c.lock.Lock()
|
||||
for cap := range caps {
|
||||
if v, ok := c.requestedCapabilities[cap]; ok {
|
||||
c.enabledCapabilities[cap] = v
|
||||
delete(c.requestedCapabilities, cap)
|
||||
}
|
||||
}
|
||||
c.lock.Unlock()
|
||||
|
||||
if len(c.requestedCapabilities) == 0 {
|
||||
if c.Config.SASL != nil && c.HasCapability("sasl", c.Config.SASL.Name()) {
|
||||
c.write("AUTHENTICATE " + c.Config.SASL.Name())
|
||||
} else {
|
||||
c.write("CAP END")
|
||||
}
|
||||
if len(c.requestedCapabilities) == 0 && !c.beginSASL() {
|
||||
c.finishCAP()
|
||||
}
|
||||
|
||||
case "NAK":
|
||||
|
@ -87,8 +104,8 @@ func (c *Client) handleCAP(msg *Message) {
|
|||
delete(c.requestedCapabilities, cap)
|
||||
}
|
||||
|
||||
if len(c.requestedCapabilities) == 0 {
|
||||
c.write("CAP END")
|
||||
if len(c.requestedCapabilities) == 0 && !c.beginSASL() {
|
||||
c.finishCAP()
|
||||
}
|
||||
|
||||
case "NEW":
|
||||
|
@ -107,9 +124,11 @@ func (c *Client) handleCAP(msg *Message) {
|
|||
}
|
||||
|
||||
case "DEL":
|
||||
c.lock.Lock()
|
||||
for cap := range caps {
|
||||
delete(c.enabledCapabilities, cap)
|
||||
}
|
||||
c.lock.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -12,15 +12,19 @@ import (
|
|||
)
|
||||
|
||||
type Config struct {
|
||||
Host string
|
||||
Port string
|
||||
TLS bool
|
||||
TLSConfig *tls.Config
|
||||
Nick string
|
||||
Password string
|
||||
Username string
|
||||
Realname string
|
||||
SASL SASL
|
||||
Host string
|
||||
Port string
|
||||
TLS bool
|
||||
TLSConfig *tls.Config
|
||||
ServerPassword string
|
||||
Nick string
|
||||
Username string
|
||||
Realname string
|
||||
|
||||
SASLMechanisms []string
|
||||
Account string
|
||||
Password string
|
||||
|
||||
// Version is the reply to VERSION and FINGER CTCP messages
|
||||
Version string
|
||||
// Source is the reply to SOURCE CTCP messages
|
||||
|
@ -43,6 +47,9 @@ type Client struct {
|
|||
wantedCapabilities []string
|
||||
requestedCapabilities map[string][]string
|
||||
enabledCapabilities map[string][]string
|
||||
negotiating bool
|
||||
saslMechanisms []SASL
|
||||
currentSASLIndex int
|
||||
|
||||
conn net.Conn
|
||||
connected bool
|
||||
|
@ -76,10 +83,8 @@ func NewClient(config *Config) *Client {
|
|||
config.Realname = config.Nick
|
||||
}
|
||||
|
||||
wantedCapabilities := append([]string{}, clientWantedCaps...)
|
||||
|
||||
if config.SASL != nil {
|
||||
wantedCapabilities = append(wantedCapabilities, "sasl")
|
||||
if config.SASLMechanisms == nil {
|
||||
config.SASLMechanisms = DefaultSASLMechanisms
|
||||
}
|
||||
|
||||
client := &Client{
|
||||
|
@ -88,7 +93,6 @@ func NewClient(config *Config) *Client {
|
|||
ConnectionChanged: make(chan ConnectionState, 4),
|
||||
Features: NewFeatures(),
|
||||
nick: config.Nick,
|
||||
wantedCapabilities: wantedCapabilities,
|
||||
requestedCapabilities: map[string][]string{},
|
||||
enabledCapabilities: map[string][]string{},
|
||||
dialer: &net.Dialer{Timeout: 10 * time.Second},
|
||||
|
@ -103,10 +107,44 @@ func NewClient(config *Config) *Client {
|
|||
reconnect: make(chan struct{}),
|
||||
}
|
||||
client.state = newState(client)
|
||||
client.initSASL()
|
||||
|
||||
return client
|
||||
}
|
||||
|
||||
func (c *Client) initSASL() {
|
||||
saslMechanisms := []SASL{}
|
||||
|
||||
for _, mech := range c.Config.SASLMechanisms {
|
||||
if mech == "EXTERNAL" {
|
||||
if c.Config.TLSConfig != nil && len(c.Config.TLSConfig.Certificates) > 0 {
|
||||
saslMechanisms = append(saslMechanisms, &SASLExternal{})
|
||||
}
|
||||
} else if c.Config.Account != "" && c.Config.Password != "" {
|
||||
if mech == "PLAIN" {
|
||||
saslMechanisms = append(saslMechanisms, &SASLPlain{
|
||||
Username: c.Config.Account,
|
||||
Password: c.Config.Password,
|
||||
})
|
||||
} else if strings.HasPrefix(mech, "SCRAM-") {
|
||||
saslMechanisms = append(saslMechanisms, &SASLScram{
|
||||
Username: c.Config.Account,
|
||||
Password: c.Config.Password,
|
||||
Hash: mech[6:],
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
c.wantedCapabilities = append([]string{}, clientWantedCaps...)
|
||||
|
||||
if len(saslMechanisms) > 0 {
|
||||
c.wantedCapabilities = append(c.wantedCapabilities, "sasl")
|
||||
c.saslMechanisms = saslMechanisms
|
||||
c.currentSASLIndex = 0
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) GetNick() string {
|
||||
c.lock.Lock()
|
||||
nick := c.nick
|
||||
|
@ -245,10 +283,14 @@ func (c *Client) writeUser(username, realname string) {
|
|||
c.writef("USER %s 0 * :%s", username, realname)
|
||||
}
|
||||
|
||||
func (c *Client) authenticate(response string) {
|
||||
c.write("AUTHENTICATE " + response)
|
||||
}
|
||||
|
||||
func (c *Client) register() {
|
||||
c.writeCAP()
|
||||
if c.Config.Password != "" {
|
||||
c.writePass(c.Config.Password)
|
||||
c.beginCAP()
|
||||
if c.Config.ServerPassword != "" {
|
||||
c.writePass(c.Config.ServerPassword)
|
||||
}
|
||||
c.writeNick(c.Config.Nick)
|
||||
c.writeUser(c.Config.Username, c.Config.Realname)
|
||||
|
|
|
@ -152,7 +152,7 @@ func TestRegister(t *testing.T) {
|
|||
assert.Equal(t, "NICK nick\r\n", <-out)
|
||||
assert.Equal(t, "USER user 0 * :rn\r\n", <-out)
|
||||
|
||||
c.Config.Password = "pass"
|
||||
c.Config.ServerPassword = "pass"
|
||||
c.register()
|
||||
assert.Equal(t, "CAP LS 302\r\n", <-out)
|
||||
assert.Equal(t, "PASS pass\r\n", <-out)
|
||||
|
|
|
@ -64,6 +64,7 @@ func (c *Client) run() {
|
|||
c.sendRecv.Wait()
|
||||
c.reconnect = make(chan struct{})
|
||||
c.state.reset()
|
||||
c.initSASL()
|
||||
|
||||
time.Sleep(c.backoff.Duration())
|
||||
c.tryConnect()
|
||||
|
|
|
@ -85,6 +85,7 @@ func (c *Client) handleMessage(msg *Message) {
|
|||
if len(msg.Params) > 0 {
|
||||
c.setNick(msg.Params[0])
|
||||
}
|
||||
c.negotiating = false
|
||||
c.setRegistered(true)
|
||||
c.flushChannels()
|
||||
|
||||
|
|
137
pkg/irc/sasl.go
137
pkg/irc/sasl.go
|
@ -2,12 +2,28 @@ package irc
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/sha1"
|
||||
"crypto/sha256"
|
||||
"crypto/sha512"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"hash"
|
||||
"strings"
|
||||
|
||||
"github.com/xdg-go/scram"
|
||||
)
|
||||
|
||||
var DefaultSASLMechanisms = []string{
|
||||
"EXTERNAL",
|
||||
//"SCRAM-SHA-512",
|
||||
"SCRAM-SHA-256",
|
||||
//"SCRAM-SHA-1",
|
||||
"PLAIN",
|
||||
}
|
||||
|
||||
type SASL interface {
|
||||
Name() string
|
||||
Encode() string
|
||||
Step(response string) (string, error)
|
||||
}
|
||||
|
||||
type SASLPlain struct {
|
||||
|
@ -19,7 +35,7 @@ func (s *SASLPlain) Name() string {
|
|||
return "PLAIN"
|
||||
}
|
||||
|
||||
func (s *SASLPlain) Encode() string {
|
||||
func (s *SASLPlain) Step(string) (string, error) {
|
||||
buf := bytes.Buffer{}
|
||||
buf.WriteString(s.Username)
|
||||
buf.WriteByte(0x0)
|
||||
|
@ -27,7 +43,7 @@ func (s *SASLPlain) Encode() string {
|
|||
buf.WriteByte(0x0)
|
||||
buf.WriteString(s.Password)
|
||||
|
||||
return base64.StdEncoding.EncodeToString(buf.Bytes())
|
||||
return base64.StdEncoding.EncodeToString(buf.Bytes()), nil
|
||||
}
|
||||
|
||||
type SASLExternal struct{}
|
||||
|
@ -36,29 +52,124 @@ func (s *SASLExternal) Name() string {
|
|||
return "EXTERNAL"
|
||||
}
|
||||
|
||||
func (s *SASLExternal) Encode() string {
|
||||
return "+"
|
||||
func (s *SASLExternal) Step(string) (string, error) {
|
||||
return "+", nil
|
||||
}
|
||||
|
||||
var (
|
||||
scramHashes = map[string]scram.HashGeneratorFcn{
|
||||
"SHA-512": func() hash.Hash { return sha512.New() },
|
||||
"SHA-256": func() hash.Hash { return sha256.New() },
|
||||
"SHA-1": func() hash.Hash { return sha1.New() },
|
||||
}
|
||||
|
||||
ErrUnsupportedHash = errors.New("unsupported hash algorithm")
|
||||
)
|
||||
|
||||
type SASLScram struct {
|
||||
Username string
|
||||
Password string
|
||||
Hash string
|
||||
conv *scram.ClientConversation
|
||||
}
|
||||
|
||||
func (s *SASLScram) Name() string {
|
||||
return "SCRAM-" + s.Hash
|
||||
}
|
||||
|
||||
func (s *SASLScram) Step(response string) (string, error) {
|
||||
if s.conv == nil {
|
||||
if hash, ok := scramHashes[s.Hash]; ok {
|
||||
client, err := hash.NewClient(s.Username, s.Password, "")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
s.conv = client.NewConversation()
|
||||
} else {
|
||||
return "", ErrUnsupportedHash
|
||||
}
|
||||
}
|
||||
|
||||
challenge := ""
|
||||
if response != "+" {
|
||||
b, err := base64.StdEncoding.DecodeString(response)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
challenge = string(b)
|
||||
}
|
||||
|
||||
res, err := s.conv.Step(challenge)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if s.conv.Done() {
|
||||
s.conv = nil
|
||||
return "+", nil
|
||||
} else {
|
||||
return base64.StdEncoding.EncodeToString([]byte(res)), nil
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) currentSASL() SASL {
|
||||
return c.saslMechanisms[c.currentSASLIndex]
|
||||
}
|
||||
|
||||
func (c *Client) nextSASL() SASL {
|
||||
if c.currentSASLIndex < len(c.saslMechanisms)-1 {
|
||||
c.currentSASLIndex++
|
||||
return c.saslMechanisms[c.currentSASLIndex]
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) handleSASL(msg *Message) {
|
||||
switch msg.Command {
|
||||
case AUTHENTICATE:
|
||||
auth := c.Config.SASL.Encode()
|
||||
auth, err := c.currentSASL().Step(msg.LastParam())
|
||||
if err != nil {
|
||||
c.finishCAP()
|
||||
return
|
||||
}
|
||||
|
||||
for len(auth) >= 400 {
|
||||
c.write("AUTHENTICATE " + auth)
|
||||
c.authenticate(auth)
|
||||
auth = auth[400:]
|
||||
}
|
||||
if len(auth) > 0 {
|
||||
c.write("AUTHENTICATE " + auth)
|
||||
c.authenticate(auth)
|
||||
} else {
|
||||
c.write("AUTHENTICATE +")
|
||||
c.authenticate("+")
|
||||
}
|
||||
|
||||
case RPL_SASLSUCCESS:
|
||||
c.write("CAP END")
|
||||
case RPL_SASLMECHS:
|
||||
if len(msg.Params) > 1 {
|
||||
supportedMechs := strings.Split(msg.Params[1], ",")
|
||||
|
||||
case ERR_NICKLOCKED, ERR_SASLFAIL, ERR_SASLTOOLONG, ERR_SASLABORTED, RPL_SASLMECHS:
|
||||
c.write("CAP END")
|
||||
for i, mech := range c.saslMechanisms {
|
||||
for _, supported := range supportedMechs {
|
||||
if mech.Name() == supported && i > c.currentSASLIndex {
|
||||
c.currentSASLIndex = i
|
||||
c.authenticate(mech.Name())
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
c.finishCAP()
|
||||
|
||||
case ERR_SASLFAIL:
|
||||
if next := c.nextSASL(); next != nil {
|
||||
c.authenticate(next.Name())
|
||||
} else {
|
||||
c.finishCAP()
|
||||
}
|
||||
|
||||
case RPL_SASLSUCCESS, RPL_LOGGEDIN:
|
||||
c.finishCAP()
|
||||
|
||||
case ERR_NICKLOCKED, ERR_SASLTOOLONG, ERR_SASLABORTED:
|
||||
c.finishCAP()
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue