Persist, renew and delete sessions, refactor storage package, move reusable packages to pkg

This commit is contained in:
Ken-Håvard Lieng 2018-05-31 23:24:59 +02:00
parent 121582f72a
commit 24f9553aa5
48 changed files with 1872 additions and 1171 deletions

173
pkg/irc/client.go Normal file
View file

@ -0,0 +1,173 @@
package irc
import (
"bufio"
"crypto/tls"
"net"
"strings"
"sync"
"github.com/jpillora/backoff"
)
type Client struct {
Server string
Host string
TLS bool
TLSConfig *tls.Config
Password string
Username string
Realname string
Messages chan *Message
ConnectionChanged chan ConnectionState
HandleNickInUse func(string) string
nick string
channels []string
Support *iSupport
conn net.Conn
connected bool
dialer *net.Dialer
reader *bufio.Reader
backoff *backoff.Backoff
out chan string
quit chan struct{}
reconnect chan struct{}
sendRecv sync.WaitGroup
lock sync.Mutex
}
func NewClient(nick, username string) *Client {
return &Client{
nick: nick,
Support: newISupport(),
Username: username,
Realname: nick,
Messages: make(chan *Message, 32),
ConnectionChanged: make(chan ConnectionState, 16),
out: make(chan string, 32),
quit: make(chan struct{}),
reconnect: make(chan struct{}),
backoff: &backoff.Backoff{
Jitter: true,
},
}
}
func (c *Client) GetNick() string {
c.lock.Lock()
nick := c.nick
c.lock.Unlock()
return nick
}
func (c *Client) setNick(nick string) {
c.lock.Lock()
c.nick = nick
c.lock.Unlock()
}
func (c *Client) Connected() bool {
c.lock.Lock()
connected := c.connected
c.lock.Unlock()
return connected
}
func (c *Client) Nick(nick string) {
c.Write("NICK " + nick)
}
func (c *Client) Oper(name, password string) {
c.Write("OPER " + name + " " + password)
}
func (c *Client) Mode(target, modes, params string) {
c.Write(strings.TrimRight("MODE "+target+" "+modes+" "+params, " "))
}
func (c *Client) Quit() {
go func() {
if c.Connected() {
c.write("QUIT")
}
close(c.quit)
}()
}
func (c *Client) Join(channels ...string) {
c.Write("JOIN " + strings.Join(channels, ","))
}
func (c *Client) Part(channels ...string) {
c.Write("PART " + strings.Join(channels, ","))
}
func (c *Client) Topic(channel string, topic ...string) {
msg := "TOPIC " + channel
if len(topic) > 0 {
msg += " :" + topic[0]
}
c.Write(msg)
}
func (c *Client) Invite(nick, channel string) {
c.Write("INVITE " + nick + " " + channel)
}
func (c *Client) Kick(channel string, users ...string) {
c.Write("KICK " + channel + " " + strings.Join(users, ","))
}
func (c *Client) Privmsg(target, msg string) {
c.Writef("PRIVMSG %s :%s", target, msg)
}
func (c *Client) Notice(target, msg string) {
c.Writef("NOTICE %s :%s", target, msg)
}
func (c *Client) Whois(nick string) {
c.Write("WHOIS " + nick)
}
func (c *Client) Away(message string) {
c.Write("AWAY :" + message)
}
func (c *Client) writePass(password string) {
c.write("PASS " + password)
}
func (c *Client) writeNick(nick string) {
c.write("NICK " + nick)
}
func (c *Client) writeUser(username, realname string) {
c.writef("USER %s 0 * :%s", username, realname)
}
func (c *Client) register() {
if c.Password != "" {
c.writePass(c.Password)
}
c.writeNick(c.nick)
c.writeUser(c.Username, c.Realname)
}
func (c *Client) addChannel(channel string) {
c.lock.Lock()
c.channels = append(c.channels, channel)
c.lock.Unlock()
}
func (c *Client) flushChannels() {
c.lock.Lock()
if len(c.channels) > 0 {
c.Join(c.channels...)
c.channels = []string{}
}
c.lock.Unlock()
}

168
pkg/irc/client_test.go Normal file
View file

@ -0,0 +1,168 @@
package irc
import (
"net"
"testing"
"github.com/stretchr/testify/assert"
)
func testClient() *Client {
return NewClient("test", "testing")
}
func testClientSend() (*Client, chan string) {
c := testClient()
conn := &mockConn{hook: make(chan string, 16)}
c.conn = conn
c.sendRecv.Add(1)
go c.send()
return c, conn.hook
}
type mockConn struct {
hook chan string
net.Conn
}
func (c *mockConn) Write(b []byte) (int, error) {
c.hook <- string(b)
return len(b), nil
}
func (c *mockConn) Close() error {
return nil
}
func TestPass(t *testing.T) {
c, out := testClientSend()
c.writePass("pass")
assert.Equal(t, "PASS pass\r\n", <-out)
}
func TestNick(t *testing.T) {
c, out := testClientSend()
c.Nick("test2")
assert.Equal(t, "NICK test2\r\n", <-out)
c.writeNick("nick")
assert.Equal(t, "NICK nick\r\n", <-out)
}
func TestUser(t *testing.T) {
c, out := testClientSend()
c.writeUser("user", "rn")
assert.Equal(t, "USER user 0 * :rn\r\n", <-out)
}
func TestOper(t *testing.T) {
c, out := testClientSend()
c.Oper("name", "pass")
assert.Equal(t, "OPER name pass\r\n", <-out)
}
func TestMode(t *testing.T) {
c, out := testClientSend()
c.Mode("#chan", "+o", "user")
assert.Equal(t, "MODE #chan +o user\r\n", <-out)
}
func TestQuit(t *testing.T) {
c, out := testClientSend()
c.connected = true
c.Quit()
assert.Equal(t, "QUIT\r\n", <-out)
_, ok := <-c.quit
assert.Equal(t, false, ok)
}
func TestJoin(t *testing.T) {
c, out := testClientSend()
c.Join("#a")
assert.Equal(t, "JOIN #a\r\n", <-out)
c.Join("#b", "#c")
assert.Equal(t, "JOIN #b,#c\r\n", <-out)
}
func TestPart(t *testing.T) {
c, out := testClientSend()
c.Part("#a")
assert.Equal(t, "PART #a\r\n", <-out)
c.Part("#b", "#c")
assert.Equal(t, "PART #b,#c\r\n", <-out)
}
func TestTopic(t *testing.T) {
c, out := testClientSend()
c.Topic("#chan")
assert.Equal(t, "TOPIC #chan\r\n", <-out)
c.Topic("#chan", "apple pie")
assert.Equal(t, "TOPIC #chan :apple pie\r\n", <-out)
c.Topic("#chan", "")
assert.Equal(t, "TOPIC #chan :\r\n", <-out)
}
func TestInvite(t *testing.T) {
c, out := testClientSend()
c.Invite("user", "#chan")
assert.Equal(t, "INVITE user #chan\r\n", <-out)
}
func TestKick(t *testing.T) {
c, out := testClientSend()
c.Kick("#chan", "user")
assert.Equal(t, "KICK #chan user\r\n", <-out)
c.Kick("#chan", "a", "b")
assert.Equal(t, "KICK #chan a,b\r\n", <-out)
}
func TestPrivmsg(t *testing.T) {
c, out := testClientSend()
c.Privmsg("user", "the message")
assert.Equal(t, "PRIVMSG user :the message\r\n", <-out)
}
func TestNotice(t *testing.T) {
c, out := testClientSend()
c.Notice("user", "the message")
assert.Equal(t, "NOTICE user :the message\r\n", <-out)
}
func TestWhois(t *testing.T) {
c, out := testClientSend()
c.Whois("user")
assert.Equal(t, "WHOIS user\r\n", <-out)
}
func TestAway(t *testing.T) {
c, out := testClientSend()
c.Away("not here")
assert.Equal(t, "AWAY :not here\r\n", <-out)
}
func TestRegister(t *testing.T) {
c, out := testClientSend()
c.nick = "nick"
c.Username = "user"
c.Realname = "rn"
c.register()
assert.Equal(t, "NICK nick\r\n", <-out)
assert.Equal(t, "USER user 0 * :rn\r\n", <-out)
c.Password = "pass"
c.register()
assert.Equal(t, "PASS pass\r\n", <-out)
assert.Equal(t, "NICK nick\r\n", <-out)
assert.Equal(t, "USER user 0 * :rn\r\n", <-out)
}
func TestFlushChannels(t *testing.T) {
c, out := testClientSend()
c.addChannel("#chan1")
c.flushChannels()
assert.Equal(t, <-out, "JOIN #chan1\r\n")
c.addChannel("#chan2")
c.addChannel("#chan3")
c.flushChannels()
assert.Equal(t, <-out, "JOIN #chan2,#chan3\r\n")
}

236
pkg/irc/conn.go Normal file
View file

@ -0,0 +1,236 @@
package irc
import (
"bufio"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"net"
"strings"
"time"
)
var (
ErrBadProtocol = errors.New("This server does not speak IRC")
)
func (c *Client) Connect(address string) {
if idx := strings.Index(address, ":"); idx < 0 {
c.Host = address
if c.TLS {
address += ":6697"
} else {
address += ":6667"
}
} else {
c.Host = address[:idx]
}
c.Server = address
c.dialer = &net.Dialer{Timeout: 10 * time.Second}
c.connChange(false, nil)
go c.run()
}
func (c *Client) Reconnect() {
close(c.reconnect)
}
func (c *Client) Write(data string) {
c.out <- data + "\r\n"
}
func (c *Client) Writef(format string, a ...interface{}) {
c.out <- fmt.Sprintf(format+"\r\n", a...)
}
func (c *Client) write(data string) {
c.conn.Write([]byte(data + "\r\n"))
}
func (c *Client) writef(format string, a ...interface{}) {
fmt.Fprintf(c.conn, format+"\r\n", a...)
}
func (c *Client) run() {
c.tryConnect()
for {
select {
case <-c.quit:
if c.Connected() {
c.disconnect()
}
c.sendRecv.Wait()
close(c.Messages)
return
case <-c.reconnect:
if c.Connected() {
c.disconnect()
}
c.sendRecv.Wait()
c.reconnect = make(chan struct{})
c.tryConnect()
}
}
}
type ConnectionState struct {
Connected bool
Error error
}
func (c *Client) connChange(connected bool, err error) {
c.ConnectionChanged <- ConnectionState{
Connected: connected,
Error: err,
}
}
func (c *Client) disconnect() {
c.lock.Lock()
c.connected = false
c.lock.Unlock()
c.conn.Close()
}
func (c *Client) tryConnect() {
for {
select {
case <-c.quit:
return
default:
}
err := c.connect()
if err != nil {
c.connChange(false, err)
if _, ok := err.(x509.UnknownAuthorityError); ok {
return
}
} else {
c.backoff.Reset()
c.flushChannels()
return
}
time.Sleep(c.backoff.Duration())
}
}
func (c *Client) connect() error {
c.lock.Lock()
defer c.lock.Unlock()
if c.TLS {
conn, err := tls.DialWithDialer(c.dialer, "tcp", c.Server, c.TLSConfig)
if err != nil {
return err
}
c.conn = conn
} else {
conn, err := c.dialer.Dial("tcp", c.Server)
if err != nil {
return err
}
c.conn = conn
}
c.connected = true
c.connChange(true, nil)
c.reader = bufio.NewReader(c.conn)
c.register()
c.sendRecv.Add(1)
go c.recv()
return nil
}
func (c *Client) send() {
defer c.sendRecv.Done()
for {
select {
case <-c.quit:
return
case <-c.reconnect:
return
case msg := <-c.out:
_, err := c.conn.Write([]byte(msg))
if err != nil {
return
}
}
}
}
func (c *Client) recv() {
defer c.sendRecv.Done()
for {
line, err := c.reader.ReadString('\n')
if err != nil {
select {
case <-c.quit:
return
default:
c.connChange(false, nil)
c.Reconnect()
return
}
}
msg := parseMessage(line)
if msg == nil {
close(c.quit)
c.connChange(false, ErrBadProtocol)
return
}
switch msg.Command {
case Ping:
go c.write("PONG :" + msg.LastParam())
case Join:
if msg.Nick == c.GetNick() {
c.addChannel(msg.Params[0])
}
case Nick:
if msg.Nick == c.GetNick() {
c.setNick(msg.LastParam())
}
case ReplyWelcome:
c.setNick(msg.Params[0])
c.sendRecv.Add(1)
go c.send()
case ReplyISupport:
c.Support.parse(msg.Params)
case ErrNicknameInUse:
if c.HandleNickInUse != nil {
go c.writeNick(c.HandleNickInUse(msg.Params[1]))
}
}
c.Messages <- msg
}
}

233
pkg/irc/conn_test.go Normal file
View file

@ -0,0 +1,233 @@
package irc
import (
"bufio"
"bytes"
"crypto/tls"
"log"
"net"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
var ircd *mockIrcd
func init() {
initTestServer()
}
func initTestServer() {
ircd = &mockIrcd{
conn: make(chan bool, 1),
connClosed: make(chan bool, 1),
}
ircd.start()
}
type mockIrcd struct {
conn chan bool
connClosed chan bool
}
func (i *mockIrcd) start() {
ln, err := net.Listen("tcp", "127.0.0.1:45678")
if err != nil {
log.Fatal(err)
}
cert, err := tls.X509KeyPair(testCert, testKey)
if err != nil {
log.Fatal(err)
}
tlsConfig := &tls.Config{
Certificates: []tls.Certificate{cert},
}
lnTLS, err := tls.Listen("tcp", "127.0.0.1:45679", tlsConfig)
if err != nil {
log.Fatal(err)
}
go i.accept(ln)
go i.accept(lnTLS)
}
func (i *mockIrcd) accept(ln net.Listener) {
for {
conn, err := ln.Accept()
if err != nil {
return
}
go i.handle(conn)
i.conn <- true
}
}
func (i *mockIrcd) handle(conn net.Conn) {
buf := make([]byte, 1024)
for {
_, err := conn.Read(buf)
if err != nil {
i.connClosed <- true
return
}
}
}
func TestConnect(t *testing.T) {
c := testClient()
c.Connect("127.0.0.1:45678")
assert.Equal(t, c.Host, "127.0.0.1")
assert.Equal(t, c.Server, "127.0.0.1:45678")
waitConnAndClose(t, c)
}
func TestConnectTLS(t *testing.T) {
c := testClient()
c.TLS = true
c.TLSConfig = &tls.Config{
InsecureSkipVerify: true,
}
c.Connect("127.0.0.1:45679")
assert.Equal(t, c.Host, "127.0.0.1")
assert.Equal(t, c.Server, "127.0.0.1:45679")
waitConnAndClose(t, c)
}
func TestConnectDefaultPorts(t *testing.T) {
c := testClient()
c.Connect("127.0.0.1")
assert.Equal(t, "127.0.0.1:6667", c.Server)
c = testClient()
c.TLS = true
c.Connect("127.0.0.1")
assert.Equal(t, "127.0.0.1:6697", c.Server)
}
func TestWrite(t *testing.T) {
c, out := testClientSend()
c.write("test")
assert.Equal(t, "test\r\n", <-out)
c.Write("test")
assert.Equal(t, "test\r\n", <-out)
c.writef("test %d", 2)
assert.Equal(t, "test 2\r\n", <-out)
c.Writef("test %d", 2)
assert.Equal(t, "test 2\r\n", <-out)
}
func TestRecv(t *testing.T) {
c := testClient()
conn := &mockConn{hook: make(chan string, 16)}
c.conn = conn
buf := &bytes.Buffer{}
buf.WriteString("CMD\r\n")
buf.WriteString("PING :test\r\n")
buf.WriteString("001 foo\r\n")
c.reader = bufio.NewReader(buf)
c.sendRecv.Add(1)
go c.recv()
assert.Equal(t, "PONG :test\r\n", <-conn.hook)
assert.Equal(t, &Message{Command: "CMD"}, <-c.Messages)
}
func TestRecvTriggersReconnect(t *testing.T) {
c := testClient()
c.conn = &mockConn{}
c.reader = bufio.NewReader(&bytes.Buffer{})
done := make(chan struct{})
ok := false
go func() {
c.sendRecv.Add(1)
c.recv()
_, ok = <-c.reconnect
close(done)
}()
select {
case <-done:
assert.False(t, ok)
return
case <-time.After(100 * time.Millisecond):
t.Error("Reconnect not triggered")
}
}
func TestClose(t *testing.T) {
c := testClient()
close(c.quit)
ok := false
done := make(chan struct{})
go func() {
_, ok = <-c.Messages
close(done)
}()
c.run()
select {
case <-done:
assert.False(t, ok)
return
case <-time.After(100 * time.Millisecond):
t.Error("Channels not closed")
}
}
func waitConnAndClose(t *testing.T, c *Client) {
done := make(chan struct{})
quit := make(chan struct{})
go func() {
<-ircd.conn
quit <- struct{}{}
<-ircd.connClosed
close(done)
}()
for {
select {
case <-done:
return
case <-quit:
assert.True(t, c.Connected())
c.Quit()
case <-time.After(500 * time.Millisecond):
t.Error("Took too long")
return
}
}
}
var testCert = []byte(`-----BEGIN CERTIFICATE-----
MIIB0zCCAX2gAwIBAgIJAI/M7BYjwB+uMA0GCSqGSIb3DQEBBQUAMEUxCzAJBgNV
BAYTAkFVMRMwEQYDVQQIDApTb21lLVN0YXRlMSEwHwYDVQQKDBhJbnRlcm5ldCBX
aWRnaXRzIFB0eSBMdGQwHhcNMTIwOTEyMjE1MjAyWhcNMTUwOTEyMjE1MjAyWjBF
MQswCQYDVQQGEwJBVTETMBEGA1UECAwKU29tZS1TdGF0ZTEhMB8GA1UECgwYSW50
ZXJuZXQgV2lkZ2l0cyBQdHkgTHRkMFwwDQYJKoZIhvcNAQEBBQADSwAwSAJBANLJ
hPHhITqQbPklG3ibCVxwGMRfp/v4XqhfdQHdcVfHap6NQ5Wok/4xIA+ui35/MmNa
rtNuC+BdZ1tMuVCPFZcCAwEAAaNQME4wHQYDVR0OBBYEFJvKs8RfJaXTH08W+SGv
zQyKn0H8MB8GA1UdIwQYMBaAFJvKs8RfJaXTH08W+SGvzQyKn0H8MAwGA1UdEwQF
MAMBAf8wDQYJKoZIhvcNAQEFBQADQQBJlffJHybjDGxRMqaRmDhX0+6v02TUKZsW
r5QuVbpQhH6u+0UgcW0jp9QwpxoPTLTWGXEWBBBurxFwiCBhkQ+V
-----END CERTIFICATE-----`)
var testKey = []byte(`-----BEGIN RSA PRIVATE KEY-----
MIIBOwIBAAJBANLJhPHhITqQbPklG3ibCVxwGMRfp/v4XqhfdQHdcVfHap6NQ5Wo
k/4xIA+ui35/MmNartNuC+BdZ1tMuVCPFZcCAwEAAQJAEJ2N+zsR0Xn8/Q6twa4G
6OB1M1WO+k+ztnX/1SvNeWu8D6GImtupLTYgjZcHufykj09jiHmjHx8u8ZZB/o1N
MQIhAPW+eyZo7ay3lMz1V01WVjNKK9QSn1MJlb06h/LuYv9FAiEA25WPedKgVyCW
SmUwbPw8fnTcpqDWE3yTO3vKcebqMSsCIBF3UmVue8YU3jybC3NxuXq3wNm34R8T
xVLHwDXh/6NJAiEAl2oHGGLz64BuAfjKrqwz7qMYr9HCLIe/YsoWq/olzScCIQDi
D2lWusoe2/nEqfDVVWGWlyJ7yOmqaVm/iNUN9B2N2g==
-----END RSA PRIVATE KEY-----`)

40
pkg/irc/const.go Normal file
View file

@ -0,0 +1,40 @@
package irc
const (
Error = "ERROR"
Join = "JOIN"
Mode = "MODE"
Nick = "NICK"
Notice = "NOTICE"
Part = "PART"
Ping = "PING"
Privmsg = "PRIVMSG"
Quit = "QUIT"
Topic = "TOPIC"
ReplyWelcome = "001"
ReplyYourHost = "002"
ReplyCreated = "003"
ReplyISupport = "005"
ReplyLUserClient = "251"
ReplyLUserOp = "252"
ReplyLUserUnknown = "253"
ReplyLUserChannels = "254"
ReplyLUserMe = "255"
ReplyAway = "301"
ReplyWhoisUser = "311"
ReplyWhoisServer = "312"
ReplyWhoisOperator = "313"
ReplyWhoisIdle = "317"
ReplyEndOfWhois = "318"
ReplyWhoisChannels = "319"
ReplyNoTopic = "331"
ReplyTopic = "332"
ReplyNamReply = "353"
ReplyEndOfNames = "366"
ReplyMotd = "372"
ReplyMotdStart = "375"
ReplyEndOfMotd = "376"
ErrErroneousNickname = "432"
ErrNicknameInUse = "433"
)

161
pkg/irc/message.go Normal file
View file

@ -0,0 +1,161 @@
package irc
import (
"strings"
"sync"
"github.com/spf13/cast"
)
type Message struct {
Tags map[string]string
Prefix string
Nick string
Command string
Params []string
}
func (m *Message) LastParam() string {
if len(m.Params) > 0 {
return m.Params[len(m.Params)-1]
}
return ""
}
func parseMessage(line string) *Message {
line = strings.Trim(line, "\r\n ")
msg := Message{}
if strings.HasPrefix(line, "@") {
next := strings.Index(line, " ")
if next == -1 {
return nil
}
tags := strings.Split(line[1:next], ";")
if len(tags) > 0 {
msg.Tags = map[string]string{}
}
for _, tag := range tags {
key, val := splitParam(tag)
if key == "" {
continue
}
if val != "" {
msg.Tags[key] = unescapeTag(val)
} else {
msg.Tags[key] = ""
}
}
line = line[next+1:]
}
if strings.HasPrefix(line, ":") {
next := strings.Index(line, " ")
if next == -1 {
return nil
}
msg.Prefix = line[1:next]
if i := strings.Index(msg.Prefix, "!"); i > 0 {
msg.Nick = msg.Prefix[:i]
} else if i := strings.Index(msg.Prefix, "@"); i > 0 {
msg.Nick = msg.Prefix[:i]
} else {
msg.Nick = msg.Prefix
}
line = line[next+1:]
}
cmdEnd := len(line)
trailing := ""
if i := strings.Index(line, " :"); i > 0 {
cmdEnd = i
trailing = line[i+2:]
}
cmd := strings.Fields(line[:cmdEnd])
if len(cmd) == 0 {
return nil
}
msg.Command = cmd[0]
if len(cmd) > 1 {
msg.Params = cmd[1:]
}
if cmdEnd != len(line) {
msg.Params = append(msg.Params, trailing)
}
return &msg
}
type iSupport struct {
support map[string]string
lock sync.Mutex
}
func newISupport() *iSupport {
return &iSupport{
support: map[string]string{},
}
}
func (i *iSupport) parse(params []string) {
i.lock.Lock()
for _, param := range params[1 : len(params)-1] {
key, val := splitParam(param)
if key == "" {
continue
}
if key[0] == '-' {
delete(i.support, key[1:])
} else {
i.support[key] = val
}
}
i.lock.Unlock()
}
func (i *iSupport) Has(key string) bool {
i.lock.Lock()
_, has := i.support[key]
i.lock.Unlock()
return has
}
func (i *iSupport) Get(key string) string {
i.lock.Lock()
v := i.support[key]
i.lock.Unlock()
return v
}
func (i *iSupport) GetInt(key string) int {
i.lock.Lock()
v := cast.ToInt(i.support[key])
i.lock.Unlock()
return v
}
func splitParam(param string) (string, string) {
parts := strings.SplitN(param, "=", 2)
if len(parts) == 2 {
return parts[0], parts[1]
}
return parts[0], ""
}
func unescapeTag(s string) string {
s = strings.Replace(s, "\\:", ";", -1)
s = strings.Replace(s, "\\s", " ", -1)
s = strings.Replace(s, "\\\\", "\\", -1)
s = strings.Replace(s, "\\r", "\r", -1)
s = strings.Replace(s, "\\n", "\n", -1)
return s
}

199
pkg/irc/message_test.go Normal file
View file

@ -0,0 +1,199 @@
package irc
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestParseMessage(t *testing.T) {
cases := []struct {
input string
expected *Message
}{
{
":user CMD #chan :some message\r\n",
&Message{
Prefix: "user",
Nick: "user",
Command: "CMD",
Params: []string{"#chan", "some message"},
},
}, {
":nick!user@host.com CMD a b\r\n",
&Message{
Prefix: "nick!user@host.com",
Nick: "nick",
Command: "CMD",
Params: []string{"a", "b"},
},
}, {
"CMD a b :\r\n",
&Message{
Command: "CMD",
Params: []string{"a", "b", ""},
},
}, {
"CMD a b\r\n",
&Message{
Command: "CMD",
Params: []string{"a", "b"},
},
}, {
"CMD\r\n",
&Message{
Command: "CMD",
},
}, {
"CMD :tests and stuff\r\n",
&Message{
Command: "CMD",
Params: []string{"tests and stuff"},
},
}, {
":nick@host.com CMD\r\n",
&Message{
Prefix: "nick@host.com",
Nick: "nick",
Command: "CMD",
},
}, {
":ni@ck!user!name@host!.com CMD\r\n",
&Message{
Prefix: "ni@ck!user!name@host!.com",
Nick: "ni@ck",
Command: "CMD",
},
}, {
"CMD #cake pie \r\n",
&Message{
Command: "CMD",
Params: []string{"#cake", "pie"},
},
}, {
" CMD #cake pie\r\n",
&Message{
Command: "CMD",
Params: []string{"#cake", "pie"},
},
}, {
"CMD #cake ::pie\r\n",
&Message{
Command: "CMD",
Params: []string{"#cake", ":pie"},
},
}, {
"CMD #cake : pie\r\n",
&Message{
Command: "CMD",
Params: []string{"#cake", " pie"},
},
}, {
"CMD #cake :pie :P <3\r\n",
&Message{
Command: "CMD",
Params: []string{"#cake", "pie :P <3"},
},
}, {
"CMD #cake :pie!\r\n",
&Message{
Command: "CMD",
Params: []string{"#cake", "pie!"},
},
}, {
"@x=y CMD\r\n",
&Message{
Tags: map[string]string{
"x": "y",
},
Command: "CMD",
},
}, {
"@x=y :nick!user@host.com CMD\r\n",
&Message{
Tags: map[string]string{
"x": "y",
},
Prefix: "nick!user@host.com",
Nick: "nick",
Command: "CMD",
},
}, {
"@x=y :nick!user@host.com CMD :pie and cake\r\n",
&Message{
Tags: map[string]string{
"x": "y",
},
Prefix: "nick!user@host.com",
Nick: "nick",
Command: "CMD",
Params: []string{"pie and cake"},
},
}, {
"@x=y;a=b CMD\r\n",
&Message{
Tags: map[string]string{
"x": "y",
"a": "b",
},
Command: "CMD",
},
}, {
"@x=y;a=\\\\\\:\\s\\r\\n CMD\r\n",
&Message{
Tags: map[string]string{
"x": "y",
"a": "\\; \r\n",
},
Command: "CMD",
},
},
}
for _, tc := range cases {
assert.Equal(t, tc.expected, parseMessage(tc.input))
}
}
func TestLastParam(t *testing.T) {
assert.Equal(t, "some message", parseMessage(":user CMD #chan :some message\r\n").LastParam())
assert.Equal(t, "", parseMessage("NO_PARAMS").LastParam())
}
func TestBadMessagePanic(t *testing.T) {
parseMessage("@\r\n")
parseMessage("@ :\r\n")
parseMessage("@ :\r\n")
parseMessage(":user\r\n")
parseMessage(":\r\n")
parseMessage(":")
parseMessage("")
}
func TestParseISupport(t *testing.T) {
s := newISupport()
s.parse([]string{"bob", "CAKE=31", "PIE", ":durr"})
assert.Equal(t, 31, s.GetInt("CAKE"))
assert.Equal(t, "31", s.Get("CAKE"))
assert.True(t, s.Has("CAKE"))
assert.True(t, s.Has("PIE"))
assert.False(t, s.Has("APPLES"))
assert.Equal(t, "", s.Get("APPLES"))
assert.Equal(t, 0, s.GetInt("APPLES"))
s.parse([]string{"bob", "-PIE", ":hurr"})
assert.False(t, s.Has("PIE"))
s.parse([]string{"bob", "CAKE=1337", ":durr"})
assert.Equal(t, 1337, s.GetInt("CAKE"))
s.parse([]string{"bob", "CAKE=", ":durr"})
assert.Equal(t, "", s.Get("CAKE"))
assert.True(t, s.Has("CAKE"))
s.parse([]string{"bob", "CAKE===", ":durr"})
assert.Equal(t, "==", s.Get("CAKE"))
s.parse([]string{"bob", "-CAKE=31", ":durr"})
assert.False(t, s.Has("CAKE"))
}

View file

@ -0,0 +1,38 @@
package letsencrypt
import (
"path/filepath"
)
type Directory string
func (d Directory) Domain(domain string) string {
return filepath.Join(string(d), "certs", domain)
}
func (d Directory) Cert(domain string) string {
return filepath.Join(d.Domain(domain), "cert.pem")
}
func (d Directory) Key(domain string) string {
return filepath.Join(d.Domain(domain), "key.pem")
}
func (d Directory) Meta(domain string) string {
return filepath.Join(d.Domain(domain), "metadata.json")
}
func (d Directory) User(email string) string {
if email == "" {
email = defaultUser
}
return filepath.Join(string(d), "users", email)
}
func (d Directory) UserRegistration(email string) string {
return filepath.Join(d.User(email), "registration.json")
}
func (d Directory) UserKey(email string) string {
return filepath.Join(d.User(email), "key.pem")
}

View file

@ -0,0 +1,287 @@
package letsencrypt
import (
"crypto/tls"
"encoding/json"
"io/ioutil"
"os"
"sync"
"time"
"github.com/xenolf/lego/acme"
)
const URL = "https://acme-v01.api.letsencrypt.org/directory"
const KeySize = 2048
var directory Directory
func Run(dir, domain, email, port string) (*state, error) {
directory = Directory(dir)
user, err := getUser(email)
if err != nil {
return nil, err
}
client, err := acme.NewClient(URL, &user, acme.RSA2048)
if err != nil {
return nil, err
}
client.ExcludeChallenges([]acme.Challenge{acme.TLSSNI01})
client.SetHTTPAddress(port)
if user.Registration == nil {
user.Registration, err = client.Register()
if err != nil {
return nil, err
}
err = client.AgreeToTOS()
if err != nil {
return nil, err
}
err = saveUser(user)
if err != nil {
return nil, err
}
}
s := &state{
client: client,
domain: domain,
}
if certExists(domain) {
if !s.renew() {
err = s.loadCert()
if err != nil {
return nil, err
}
}
s.refreshOCSP()
} else {
err = s.obtain()
if err != nil {
return nil, err
}
}
go s.maintain()
return s, nil
}
type state struct {
client *acme.Client
domain string
cert *tls.Certificate
certPEM []byte
lock sync.Mutex
}
func (s *state) GetCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
s.lock.Lock()
cert := s.cert
s.lock.Unlock()
return cert, nil
}
func (s *state) getCertPEM() []byte {
s.lock.Lock()
certPEM := s.certPEM
s.lock.Unlock()
return certPEM
}
func (s *state) setCert(meta acme.CertificateResource) {
cert, err := tls.X509KeyPair(meta.Certificate, meta.PrivateKey)
if err == nil {
s.lock.Lock()
if s.cert != nil {
cert.OCSPStaple = s.cert.OCSPStaple
}
s.cert = &cert
s.certPEM = meta.Certificate
s.lock.Unlock()
}
}
func (s *state) setOCSP(ocsp []byte) {
cert := tls.Certificate{
OCSPStaple: ocsp,
}
s.lock.Lock()
if s.cert != nil {
cert.Certificate = s.cert.Certificate
cert.PrivateKey = s.cert.PrivateKey
}
s.cert = &cert
s.lock.Unlock()
}
func (s *state) obtain() error {
cert, errors := s.client.ObtainCertificate([]string{s.domain}, true, nil, false)
if err := errors[s.domain]; err != nil {
if _, ok := err.(acme.TOSError); ok {
err := s.client.AgreeToTOS()
if err != nil {
return err
}
return s.obtain()
}
return err
}
s.setCert(cert)
s.refreshOCSP()
err := saveCert(cert)
if err != nil {
return err
}
return nil
}
func (s *state) renew() bool {
cert, err := ioutil.ReadFile(directory.Cert(s.domain))
if err != nil {
return false
}
exp, err := acme.GetPEMCertExpiration(cert)
if err != nil {
return false
}
daysLeft := int(exp.Sub(time.Now().UTC()).Hours() / 24)
if daysLeft <= 30 {
metaBytes, err := ioutil.ReadFile(directory.Meta(s.domain))
if err != nil {
return false
}
key, err := ioutil.ReadFile(directory.Key(s.domain))
if err != nil {
return false
}
var meta acme.CertificateResource
err = json.Unmarshal(metaBytes, &meta)
if err != nil {
return false
}
meta.Certificate = cert
meta.PrivateKey = key
Renew:
newMeta, err := s.client.RenewCertificate(meta, true, false)
if err != nil {
if _, ok := err.(acme.TOSError); ok {
err := s.client.AgreeToTOS()
if err != nil {
return false
}
goto Renew
}
return false
}
s.setCert(newMeta)
err = saveCert(newMeta)
if err != nil {
return false
}
return true
}
return false
}
func (s *state) refreshOCSP() {
ocsp, resp, err := acme.GetOCSPForCert(s.getCertPEM())
if err == nil && resp.Status == acme.OCSPGood {
s.setOCSP(ocsp)
}
}
func (s *state) maintain() {
renew := time.Tick(24 * time.Hour)
ocsp := time.Tick(1 * time.Hour)
for {
select {
case <-renew:
s.renew()
case <-ocsp:
s.refreshOCSP()
}
}
}
func (s *state) loadCert() error {
cert, err := ioutil.ReadFile(directory.Cert(s.domain))
if err != nil {
return err
}
key, err := ioutil.ReadFile(directory.Key(s.domain))
if err != nil {
return err
}
s.setCert(acme.CertificateResource{
Certificate: cert,
PrivateKey: key,
})
return nil
}
func certExists(domain string) bool {
if _, err := os.Stat(directory.Cert(domain)); err != nil {
return false
}
if _, err := os.Stat(directory.Key(domain)); err != nil {
return false
}
return true
}
func saveCert(cert acme.CertificateResource) error {
err := os.MkdirAll(directory.Domain(cert.Domain), 0700)
if err != nil {
return err
}
err = ioutil.WriteFile(directory.Cert(cert.Domain), cert.Certificate, 0600)
if err != nil {
return err
}
err = ioutil.WriteFile(directory.Key(cert.Domain), cert.PrivateKey, 0600)
if err != nil {
return err
}
jsonBytes, err := json.MarshalIndent(&cert, "", " ")
if err != nil {
return err
}
err = ioutil.WriteFile(directory.Meta(cert.Domain), jsonBytes, 0600)
if err != nil {
return err
}
return nil
}

109
pkg/letsencrypt/user.go Normal file
View file

@ -0,0 +1,109 @@
package letsencrypt
import (
"crypto"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"encoding/json"
"encoding/pem"
"io/ioutil"
"os"
"github.com/xenolf/lego/acme"
)
const defaultUser = "default"
type User struct {
Email string
Registration *acme.RegistrationResource
key crypto.PrivateKey
}
func (u User) GetEmail() string {
return u.Email
}
func (u User) GetRegistration() *acme.RegistrationResource {
return u.Registration
}
func (u User) GetPrivateKey() crypto.PrivateKey {
return u.key
}
func newUser(email string) (User, error) {
var err error
user := User{Email: email}
user.key, err = rsa.GenerateKey(rand.Reader, KeySize)
if err != nil {
return user, err
}
return user, nil
}
func getUser(email string) (User, error) {
var user User
reg, err := os.Open(directory.UserRegistration(email))
if err != nil {
if os.IsNotExist(err) {
return newUser(email)
}
return user, err
}
defer reg.Close()
err = json.NewDecoder(reg).Decode(&user)
if err != nil {
return user, err
}
user.key, err = loadRSAPrivateKey(directory.UserKey(email))
if err != nil {
return user, err
}
return user, nil
}
func saveUser(user User) error {
err := os.MkdirAll(directory.User(user.Email), 0700)
if err != nil {
return err
}
err = saveRSAPrivateKey(user.key, directory.UserKey(user.Email))
if err != nil {
return err
}
jsonBytes, err := json.MarshalIndent(&user, "", " ")
if err != nil {
return err
}
return ioutil.WriteFile(directory.UserRegistration(user.Email), jsonBytes, 0600)
}
func loadRSAPrivateKey(file string) (crypto.PrivateKey, error) {
keyBytes, err := ioutil.ReadFile(file)
if err != nil {
return nil, err
}
keyBlock, _ := pem.Decode(keyBytes)
return x509.ParsePKCS1PrivateKey(keyBlock.Bytes)
}
func saveRSAPrivateKey(key crypto.PrivateKey, file string) error {
pemKey := pem.Block{
Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key.(*rsa.PrivateKey)),
}
keyOut, err := os.Create(file)
if err != nil {
return err
}
defer keyOut.Close()
return pem.Encode(keyOut, &pemKey)
}

View file

@ -0,0 +1,35 @@
package letsencrypt
import (
"io/ioutil"
"testing"
"github.com/stretchr/testify/assert"
)
func tempdir() string {
f, _ := ioutil.TempDir("", "")
return f
}
func testUser(t *testing.T, email string) {
user, err := newUser(email)
assert.Nil(t, err)
key := user.GetPrivateKey()
assert.NotNil(t, key)
err = saveUser(user)
assert.Nil(t, err)
user, err = getUser(email)
assert.Nil(t, err)
assert.Equal(t, email, user.GetEmail())
assert.Equal(t, key, user.GetPrivateKey())
}
func TestUser(t *testing.T) {
directory = Directory(tempdir())
testUser(t, "test@test.com")
testUser(t, "")
}

132
pkg/linkmeta/linkmeta.go Normal file
View file

@ -0,0 +1,132 @@
package linkmeta
import (
"errors"
"io"
"net/http"
"strings"
"time"
"golang.org/x/net/html"
"golang.org/x/net/html/atom"
)
var (
Client = &http.Client{
Timeout: 15 * time.Second,
}
ErrContentType = errors.New("Unsupported Content-Type")
)
type Meta struct {
URL string `json:"URL"`
SiteName string `json:"siteName,omitempty"`
Color string `json:"color,omitempty"`
Title string `json:"title"`
Description string `json:"description"`
ImageURL string `json:"imageURL,omitempty"`
VideoURL string `json:"videoURL,omitempty"`
}
func Fetch(url string) (*Meta, error) {
resp, err := http.Get(url)
if err != nil {
return nil, err
}
defer resp.Body.Close()
// TODO: Image links
if !strings.HasPrefix(resp.Header.Get("Content-Type"), "text/html") {
return nil, ErrContentType
}
return ExtractMeta(resp.Body, url)
}
func ExtractMeta(body io.Reader, url string) (*Meta, error) {
meta := Meta{URL: url}
var currentNode atom.Atom
z := html.NewTokenizer(body)
for {
tt := z.Next()
switch tt {
case html.ErrorToken:
if z.Err() == io.EOF {
return &meta, nil
}
return nil, z.Err()
case html.TextToken:
if currentNode == atom.Title && meta.Title == "" {
meta.Title = string(z.Text())
}
case html.StartTagToken, html.SelfClosingTagToken, html.EndTagToken:
name, hasAttr := z.TagName()
node := atom.Lookup(name)
if node == atom.Meta && hasAttr {
var key, val []byte
var name, content string
for hasAttr {
key, val, hasAttr = z.TagAttr()
switch atom.String(key) {
case "name":
name = string(val)
case "property":
name = string(val)
case "content":
content = string(val)
}
}
if content != "" {
switch name {
case "og:site_name":
meta.SiteName = content
case "theme-color", "msapplication-TileColor":
meta.Color = content
case "og:title", "twitter:title", "title":
meta.Title = content
case "og:description", "twitter:description":
meta.Description = content
case "description":
if meta.Description == "" {
meta.Description = content
}
case "og:image", "og:image:secure_url", "twitter:image":
if !strings.HasPrefix(meta.ImageURL, "https:") {
meta.ImageURL = content
}
case "og:video:url", "og:video:secure_url", "twitter:player":
if !strings.HasPrefix(meta.VideoURL, "https:") {
meta.VideoURL = content
}
}
}
continue
}
if tt == html.StartTagToken {
currentNode = node
} else {
currentNode = 0
}
if (node == atom.Head && tt == html.EndTagToken) || node == atom.Body {
return &meta, nil
}
}
}
}

106
pkg/session/session.go Normal file
View file

@ -0,0 +1,106 @@
package session
import (
"crypto/rand"
"encoding/base64"
"net/http"
"sync"
"time"
)
var (
CookieName = "session"
Expiration = time.Hour * 24 * 7
RefreshInterval = time.Hour
)
type Session struct {
UserID uint64
key string
createdAt int64
expiration *time.Timer
lock sync.Mutex
}
func New(id uint64) (*Session, error) {
key, err := newSessionKey()
if err != nil {
return nil, err
}
return &Session{
key: key,
createdAt: time.Now().Unix(),
UserID: id,
expiration: time.NewTimer(Expiration),
}, nil
}
func (s *Session) Init() {
exp := time.Until(time.Unix(s.createdAt, 0).Add(Expiration))
s.expiration = time.NewTimer(exp)
}
func (s *Session) Key() string {
s.lock.Lock()
key := s.key
s.lock.Unlock()
return key
}
func (s *Session) SetCookie(w http.ResponseWriter, r *http.Request) {
http.SetCookie(w, &http.Cookie{
Name: CookieName,
Value: s.Key(),
Path: "/",
Expires: time.Now().Add(Expiration),
HttpOnly: true,
Secure: r.TLS != nil,
})
}
func (s *Session) Expired() bool {
s.lock.Lock()
created := time.Unix(s.createdAt, 0)
s.lock.Unlock()
return time.Since(created) > Expiration
}
func (s *Session) Refresh() (string, bool, error) {
s.lock.Lock()
created := time.Unix(s.createdAt, 0)
s.lock.Unlock()
if time.Since(created) > Expiration {
return "", true, nil
}
if time.Since(created) > RefreshInterval {
key, err := newSessionKey()
if err != nil {
return "", false, err
}
s.expiration.Reset(Expiration)
s.lock.Lock()
s.createdAt = time.Now().Unix()
s.key = key
s.lock.Unlock()
return key, false, nil
}
return "", false, nil
}
func (s *Session) WaitUntilExpiration() {
<-s.expiration.C
}
func newSessionKey() (string, error) {
key := make([]byte, 32)
_, err := rand.Read(key)
return base64.RawURLEncoding.EncodeToString(key), err
}

View file

@ -0,0 +1,5 @@
struct Session {
UserID uint64
key string
createdAt int64
}

View file

@ -0,0 +1,112 @@
package session
import (
"io"
"time"
"unsafe"
)
var (
_ = unsafe.Sizeof(0)
_ = io.ReadFull
_ = time.Now()
)
func (d *Session) Size() (s uint64) {
{
l := uint64(len(d.key))
{
t := l
for t >= 0x80 {
t >>= 7
s++
}
s++
}
s += l
}
s += 16
return
}
func (d *Session) Marshal(buf []byte) ([]byte, error) {
size := d.Size()
{
if uint64(cap(buf)) >= size {
buf = buf[:size]
} else {
buf = make([]byte, size)
}
}
i := uint64(0)
{
*(*uint64)(unsafe.Pointer(&buf[0])) = d.UserID
}
{
l := uint64(len(d.key))
{
t := uint64(l)
for t >= 0x80 {
buf[i+8] = byte(t) | 0x80
t >>= 7
i++
}
buf[i+8] = byte(t)
i++
}
copy(buf[i+8:], d.key)
i += l
}
{
*(*int64)(unsafe.Pointer(&buf[i+8])) = d.createdAt
}
return buf[:i+16], nil
}
func (d *Session) Unmarshal(buf []byte) (uint64, error) {
i := uint64(0)
{
d.UserID = *(*uint64)(unsafe.Pointer(&buf[i+0]))
}
{
l := uint64(0)
{
bs := uint8(7)
t := uint64(buf[i+8] & 0x7F)
for buf[i+8]&0x80 == 0x80 {
i++
t |= uint64(buf[i+8]&0x7F) << bs
bs += 7
}
i++
l = t
}
d.key = string(buf[i+8 : i+8+l])
i += l
}
{
d.createdAt = *(*int64)(unsafe.Pointer(&buf[i+8]))
}
return i + 16, nil
}