Add SASL auth and CAP negotiation
This commit is contained in:
parent
be8b785813
commit
2f8dad2529
18 changed files with 563 additions and 127 deletions
129
pkg/irc/cap.go
Normal file
129
pkg/irc/cap.go
Normal file
|
@ -0,0 +1,129 @@
|
|||
package irc
|
||||
|
||||
import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
func (c *Client) HasCapability(name string, values ...string) bool {
|
||||
if capValues, ok := c.enabledCapabilities[name]; ok {
|
||||
if len(values) == 0 || capValues == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
for _, v := range values {
|
||||
for _, vCap := range capValues {
|
||||
if v == vCap {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
var clientWantedCaps = []string{}
|
||||
|
||||
func (c *Client) writeCAP() {
|
||||
c.write("CAP LS 302")
|
||||
}
|
||||
|
||||
func (c *Client) handleCAP(msg *Message) {
|
||||
if len(msg.Params) < 3 {
|
||||
c.write("CAP END")
|
||||
return
|
||||
}
|
||||
|
||||
caps := parseCaps(msg.LastParam())
|
||||
|
||||
c.lock.Lock()
|
||||
defer c.lock.Unlock()
|
||||
|
||||
switch msg.Params[1] {
|
||||
case "LS":
|
||||
for cap, values := range caps {
|
||||
for _, wanted := range c.wantedCapabilities {
|
||||
if cap == wanted {
|
||||
c.requestedCapabilities[cap] = values
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(msg.Params) == 3 {
|
||||
if len(c.requestedCapabilities) == 0 {
|
||||
c.write("CAP END")
|
||||
return
|
||||
}
|
||||
|
||||
reqCaps := []string{}
|
||||
for cap := range c.requestedCapabilities {
|
||||
reqCaps = append(reqCaps, cap)
|
||||
}
|
||||
|
||||
c.write("CAP REQ :" + strings.Join(reqCaps, " "))
|
||||
}
|
||||
|
||||
case "ACK":
|
||||
for cap := range caps {
|
||||
if v, ok := c.requestedCapabilities[cap]; ok {
|
||||
c.enabledCapabilities[cap] = v
|
||||
delete(c.requestedCapabilities, cap)
|
||||
}
|
||||
}
|
||||
|
||||
if len(c.requestedCapabilities) == 0 {
|
||||
if c.SASL != nil && c.HasCapability("sasl", c.SASL.Name()) {
|
||||
c.write("AUTHENTICATE " + c.SASL.Name())
|
||||
} else {
|
||||
c.write("CAP END")
|
||||
}
|
||||
}
|
||||
|
||||
case "NAK":
|
||||
for cap := range caps {
|
||||
delete(c.requestedCapabilities, cap)
|
||||
}
|
||||
|
||||
if len(c.requestedCapabilities) == 0 {
|
||||
c.write("CAP END")
|
||||
}
|
||||
|
||||
case "NEW":
|
||||
reqCaps := []string{}
|
||||
for cap, values := range caps {
|
||||
for _, wanted := range c.wantedCapabilities {
|
||||
if cap == wanted && !c.HasCapability(cap) {
|
||||
c.requestedCapabilities[cap] = values
|
||||
reqCaps = append(reqCaps, cap)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(reqCaps) > 0 {
|
||||
c.write("CAP REQ :" + strings.Join(reqCaps, " "))
|
||||
}
|
||||
|
||||
case "DEL":
|
||||
for cap := range caps {
|
||||
delete(c.enabledCapabilities, cap)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func parseCaps(caps string) map[string][]string {
|
||||
result := map[string][]string{}
|
||||
|
||||
parts := strings.Split(caps, " ")
|
||||
for _, part := range parts {
|
||||
capParts := strings.Split(part, "=")
|
||||
name := capParts[0]
|
||||
|
||||
if len(capParts) > 1 {
|
||||
result[name] = strings.Split(capParts[1], ",")
|
||||
} else {
|
||||
result[name] = nil
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
58
pkg/irc/cap_test.go
Normal file
58
pkg/irc/cap_test.go
Normal file
|
@ -0,0 +1,58 @@
|
|||
package irc
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestParseCaps(t *testing.T) {
|
||||
cases := []struct {
|
||||
input string
|
||||
expected map[string][]string
|
||||
}{
|
||||
{
|
||||
"sasl",
|
||||
map[string][]string{
|
||||
"sasl": nil,
|
||||
},
|
||||
}, {
|
||||
"sasl=PLAIN",
|
||||
map[string][]string{
|
||||
"sasl": {"PLAIN"},
|
||||
},
|
||||
}, {
|
||||
"cake sasl=PLAIN",
|
||||
map[string][]string{
|
||||
"cake": nil,
|
||||
"sasl": {"PLAIN"},
|
||||
},
|
||||
}, {
|
||||
"cake sasl=PLAIN pie",
|
||||
map[string][]string{
|
||||
"cake": nil,
|
||||
"sasl": {"PLAIN"},
|
||||
"pie": nil,
|
||||
},
|
||||
}, {
|
||||
"cake sasl=PLAIN pie=BLUEBERRY,RASPBERRY",
|
||||
map[string][]string{
|
||||
"cake": nil,
|
||||
"sasl": {"PLAIN"},
|
||||
"pie": {"BLUEBERRY", "RASPBERRY"},
|
||||
},
|
||||
}, {
|
||||
"cake sasl=PLAIN pie=BLUEBERRY,RASPBERRY cheesecake",
|
||||
map[string][]string{
|
||||
"cake": nil,
|
||||
"sasl": {"PLAIN"},
|
||||
"pie": {"BLUEBERRY", "RASPBERRY"},
|
||||
"cheesecake": nil,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
assert.Equal(t, tc.expected, parseCaps(tc.input))
|
||||
}
|
||||
}
|
|
@ -19,6 +19,7 @@ type Client struct {
|
|||
Password string
|
||||
Username string
|
||||
Realname string
|
||||
SASL SASL
|
||||
HandleNickInUse func(string) string
|
||||
|
||||
// Version is the reply to VERSION and FINGER CTCP messages
|
||||
|
@ -32,6 +33,10 @@ type Client struct {
|
|||
nick string
|
||||
channels []string
|
||||
|
||||
wantedCapabilities []string
|
||||
requestedCapabilities map[string][]string
|
||||
enabledCapabilities map[string][]string
|
||||
|
||||
conn net.Conn
|
||||
connected bool
|
||||
registered bool
|
||||
|
@ -49,17 +54,20 @@ type Client struct {
|
|||
|
||||
func NewClient(nick, username string) *Client {
|
||||
return &Client{
|
||||
nick: nick,
|
||||
Features: NewFeatures(),
|
||||
Username: username,
|
||||
Realname: nick,
|
||||
Messages: make(chan *Message, 32),
|
||||
ConnectionChanged: make(chan ConnectionState, 4),
|
||||
out: make(chan string, 32),
|
||||
quit: make(chan struct{}),
|
||||
reconnect: make(chan struct{}),
|
||||
dialer: &net.Dialer{Timeout: 10 * time.Second},
|
||||
recvBuf: make([]byte, 0, 4096),
|
||||
nick: nick,
|
||||
Features: NewFeatures(),
|
||||
Username: username,
|
||||
Realname: nick,
|
||||
Messages: make(chan *Message, 32),
|
||||
ConnectionChanged: make(chan ConnectionState, 4),
|
||||
out: make(chan string, 32),
|
||||
quit: make(chan struct{}),
|
||||
reconnect: make(chan struct{}),
|
||||
wantedCapabilities: clientWantedCaps,
|
||||
enabledCapabilities: map[string][]string{},
|
||||
requestedCapabilities: map[string][]string{},
|
||||
dialer: &net.Dialer{Timeout: 10 * time.Second},
|
||||
recvBuf: make([]byte, 0, 4096),
|
||||
backoff: &backoff.Backoff{
|
||||
Min: 500 * time.Millisecond,
|
||||
Max: 30 * time.Second,
|
||||
|
@ -191,6 +199,11 @@ func (c *Client) writeUser(username, realname string) {
|
|||
}
|
||||
|
||||
func (c *Client) register() {
|
||||
if c.SASL != nil {
|
||||
c.wantedCapabilities = append(c.wantedCapabilities, "sasl")
|
||||
}
|
||||
|
||||
c.writeCAP()
|
||||
if c.Password != "" {
|
||||
c.writePass(c.Password)
|
||||
}
|
||||
|
|
|
@ -152,11 +152,13 @@ func TestRegister(t *testing.T) {
|
|||
c.Username = "user"
|
||||
c.Realname = "rn"
|
||||
c.register()
|
||||
assert.Equal(t, "CAP LS 302\r\n", <-out)
|
||||
assert.Equal(t, "NICK nick\r\n", <-out)
|
||||
assert.Equal(t, "USER user 0 * :rn\r\n", <-out)
|
||||
|
||||
c.Password = "pass"
|
||||
c.register()
|
||||
assert.Equal(t, "CAP LS 302\r\n", <-out)
|
||||
assert.Equal(t, "PASS pass\r\n", <-out)
|
||||
assert.Equal(t, "NICK nick\r\n", <-out)
|
||||
assert.Equal(t, "USER user 0 * :rn\r\n", <-out)
|
||||
|
|
|
@ -226,6 +226,9 @@ func (c *Client) recv() {
|
|||
c.handleCTCP(ctcp, msg)
|
||||
}
|
||||
|
||||
case CAP:
|
||||
c.handleCAP(msg)
|
||||
|
||||
case RPL_WELCOME:
|
||||
c.setNick(msg.Params[0])
|
||||
c.setRegistered(true)
|
||||
|
@ -251,6 +254,8 @@ func (c *Client) recv() {
|
|||
return
|
||||
}
|
||||
|
||||
c.handleSASL(msg)
|
||||
|
||||
c.Messages <- msg
|
||||
}
|
||||
}
|
||||
|
|
64
pkg/irc/sasl.go
Normal file
64
pkg/irc/sasl.go
Normal file
|
@ -0,0 +1,64 @@
|
|||
package irc
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
)
|
||||
|
||||
type SASL interface {
|
||||
Name() string
|
||||
Encode() string
|
||||
}
|
||||
|
||||
type SASLPlain struct {
|
||||
Username string
|
||||
Password string
|
||||
}
|
||||
|
||||
func (s *SASLPlain) Name() string {
|
||||
return "PLAIN"
|
||||
}
|
||||
|
||||
func (s *SASLPlain) Encode() string {
|
||||
buf := bytes.Buffer{}
|
||||
buf.WriteString(s.Username)
|
||||
buf.WriteByte(0x0)
|
||||
buf.WriteString(s.Username)
|
||||
buf.WriteByte(0x0)
|
||||
buf.WriteString(s.Password)
|
||||
|
||||
return base64.StdEncoding.EncodeToString(buf.Bytes())
|
||||
}
|
||||
|
||||
type SASLExternal struct{}
|
||||
|
||||
func (s *SASLExternal) Name() string {
|
||||
return "EXTERNAL"
|
||||
}
|
||||
|
||||
func (s *SASLExternal) Encode() string {
|
||||
return "+"
|
||||
}
|
||||
|
||||
func (c *Client) handleSASL(msg *Message) {
|
||||
switch msg.Command {
|
||||
case AUTHENTICATE:
|
||||
auth := c.SASL.Encode()
|
||||
|
||||
for len(auth) >= 400 {
|
||||
c.write("AUTHENTICATE " + auth)
|
||||
auth = auth[400:]
|
||||
}
|
||||
if len(auth) > 0 {
|
||||
c.write("AUTHENTICATE " + auth)
|
||||
} else {
|
||||
c.write("AUTHENTICATE +")
|
||||
}
|
||||
|
||||
case RPL_SASLSUCCESS:
|
||||
c.write("CAP END")
|
||||
|
||||
case ERR_NICKLOCKED, ERR_SASLFAIL, ERR_SASLTOOLONG, ERR_SASLABORTED, RPL_SASLMECHS:
|
||||
c.write("CAP END")
|
||||
}
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue