Persist, renew and delete sessions, refactor storage package, move reusable packages to pkg
This commit is contained in:
parent
121582f72a
commit
24f9553aa5
48 changed files with 1872 additions and 1171 deletions
173
pkg/irc/client.go
Normal file
173
pkg/irc/client.go
Normal file
|
@ -0,0 +1,173 @@
|
|||
package irc
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"crypto/tls"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/jpillora/backoff"
|
||||
)
|
||||
|
||||
type Client struct {
|
||||
Server string
|
||||
Host string
|
||||
TLS bool
|
||||
TLSConfig *tls.Config
|
||||
Password string
|
||||
Username string
|
||||
Realname string
|
||||
Messages chan *Message
|
||||
ConnectionChanged chan ConnectionState
|
||||
HandleNickInUse func(string) string
|
||||
|
||||
nick string
|
||||
channels []string
|
||||
Support *iSupport
|
||||
|
||||
conn net.Conn
|
||||
connected bool
|
||||
dialer *net.Dialer
|
||||
reader *bufio.Reader
|
||||
backoff *backoff.Backoff
|
||||
out chan string
|
||||
|
||||
quit chan struct{}
|
||||
reconnect chan struct{}
|
||||
sendRecv sync.WaitGroup
|
||||
lock sync.Mutex
|
||||
}
|
||||
|
||||
func NewClient(nick, username string) *Client {
|
||||
return &Client{
|
||||
nick: nick,
|
||||
Support: newISupport(),
|
||||
Username: username,
|
||||
Realname: nick,
|
||||
Messages: make(chan *Message, 32),
|
||||
ConnectionChanged: make(chan ConnectionState, 16),
|
||||
out: make(chan string, 32),
|
||||
quit: make(chan struct{}),
|
||||
reconnect: make(chan struct{}),
|
||||
backoff: &backoff.Backoff{
|
||||
Jitter: true,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) GetNick() string {
|
||||
c.lock.Lock()
|
||||
nick := c.nick
|
||||
c.lock.Unlock()
|
||||
return nick
|
||||
}
|
||||
|
||||
func (c *Client) setNick(nick string) {
|
||||
c.lock.Lock()
|
||||
c.nick = nick
|
||||
c.lock.Unlock()
|
||||
}
|
||||
|
||||
func (c *Client) Connected() bool {
|
||||
c.lock.Lock()
|
||||
connected := c.connected
|
||||
c.lock.Unlock()
|
||||
return connected
|
||||
}
|
||||
|
||||
func (c *Client) Nick(nick string) {
|
||||
c.Write("NICK " + nick)
|
||||
}
|
||||
|
||||
func (c *Client) Oper(name, password string) {
|
||||
c.Write("OPER " + name + " " + password)
|
||||
}
|
||||
|
||||
func (c *Client) Mode(target, modes, params string) {
|
||||
c.Write(strings.TrimRight("MODE "+target+" "+modes+" "+params, " "))
|
||||
}
|
||||
|
||||
func (c *Client) Quit() {
|
||||
go func() {
|
||||
if c.Connected() {
|
||||
c.write("QUIT")
|
||||
}
|
||||
close(c.quit)
|
||||
}()
|
||||
}
|
||||
|
||||
func (c *Client) Join(channels ...string) {
|
||||
c.Write("JOIN " + strings.Join(channels, ","))
|
||||
}
|
||||
|
||||
func (c *Client) Part(channels ...string) {
|
||||
c.Write("PART " + strings.Join(channels, ","))
|
||||
}
|
||||
|
||||
func (c *Client) Topic(channel string, topic ...string) {
|
||||
msg := "TOPIC " + channel
|
||||
if len(topic) > 0 {
|
||||
msg += " :" + topic[0]
|
||||
}
|
||||
c.Write(msg)
|
||||
}
|
||||
|
||||
func (c *Client) Invite(nick, channel string) {
|
||||
c.Write("INVITE " + nick + " " + channel)
|
||||
}
|
||||
|
||||
func (c *Client) Kick(channel string, users ...string) {
|
||||
c.Write("KICK " + channel + " " + strings.Join(users, ","))
|
||||
}
|
||||
|
||||
func (c *Client) Privmsg(target, msg string) {
|
||||
c.Writef("PRIVMSG %s :%s", target, msg)
|
||||
}
|
||||
|
||||
func (c *Client) Notice(target, msg string) {
|
||||
c.Writef("NOTICE %s :%s", target, msg)
|
||||
}
|
||||
|
||||
func (c *Client) Whois(nick string) {
|
||||
c.Write("WHOIS " + nick)
|
||||
}
|
||||
|
||||
func (c *Client) Away(message string) {
|
||||
c.Write("AWAY :" + message)
|
||||
}
|
||||
|
||||
func (c *Client) writePass(password string) {
|
||||
c.write("PASS " + password)
|
||||
}
|
||||
|
||||
func (c *Client) writeNick(nick string) {
|
||||
c.write("NICK " + nick)
|
||||
}
|
||||
|
||||
func (c *Client) writeUser(username, realname string) {
|
||||
c.writef("USER %s 0 * :%s", username, realname)
|
||||
}
|
||||
|
||||
func (c *Client) register() {
|
||||
if c.Password != "" {
|
||||
c.writePass(c.Password)
|
||||
}
|
||||
c.writeNick(c.nick)
|
||||
c.writeUser(c.Username, c.Realname)
|
||||
}
|
||||
|
||||
func (c *Client) addChannel(channel string) {
|
||||
c.lock.Lock()
|
||||
c.channels = append(c.channels, channel)
|
||||
c.lock.Unlock()
|
||||
}
|
||||
|
||||
func (c *Client) flushChannels() {
|
||||
c.lock.Lock()
|
||||
if len(c.channels) > 0 {
|
||||
c.Join(c.channels...)
|
||||
c.channels = []string{}
|
||||
}
|
||||
c.lock.Unlock()
|
||||
}
|
168
pkg/irc/client_test.go
Normal file
168
pkg/irc/client_test.go
Normal file
|
@ -0,0 +1,168 @@
|
|||
package irc
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func testClient() *Client {
|
||||
return NewClient("test", "testing")
|
||||
}
|
||||
|
||||
func testClientSend() (*Client, chan string) {
|
||||
c := testClient()
|
||||
conn := &mockConn{hook: make(chan string, 16)}
|
||||
c.conn = conn
|
||||
c.sendRecv.Add(1)
|
||||
go c.send()
|
||||
return c, conn.hook
|
||||
}
|
||||
|
||||
type mockConn struct {
|
||||
hook chan string
|
||||
net.Conn
|
||||
}
|
||||
|
||||
func (c *mockConn) Write(b []byte) (int, error) {
|
||||
c.hook <- string(b)
|
||||
return len(b), nil
|
||||
}
|
||||
|
||||
func (c *mockConn) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestPass(t *testing.T) {
|
||||
c, out := testClientSend()
|
||||
c.writePass("pass")
|
||||
assert.Equal(t, "PASS pass\r\n", <-out)
|
||||
}
|
||||
|
||||
func TestNick(t *testing.T) {
|
||||
c, out := testClientSend()
|
||||
c.Nick("test2")
|
||||
assert.Equal(t, "NICK test2\r\n", <-out)
|
||||
|
||||
c.writeNick("nick")
|
||||
assert.Equal(t, "NICK nick\r\n", <-out)
|
||||
}
|
||||
|
||||
func TestUser(t *testing.T) {
|
||||
c, out := testClientSend()
|
||||
c.writeUser("user", "rn")
|
||||
assert.Equal(t, "USER user 0 * :rn\r\n", <-out)
|
||||
}
|
||||
|
||||
func TestOper(t *testing.T) {
|
||||
c, out := testClientSend()
|
||||
c.Oper("name", "pass")
|
||||
assert.Equal(t, "OPER name pass\r\n", <-out)
|
||||
}
|
||||
|
||||
func TestMode(t *testing.T) {
|
||||
c, out := testClientSend()
|
||||
c.Mode("#chan", "+o", "user")
|
||||
assert.Equal(t, "MODE #chan +o user\r\n", <-out)
|
||||
}
|
||||
|
||||
func TestQuit(t *testing.T) {
|
||||
c, out := testClientSend()
|
||||
c.connected = true
|
||||
c.Quit()
|
||||
assert.Equal(t, "QUIT\r\n", <-out)
|
||||
_, ok := <-c.quit
|
||||
assert.Equal(t, false, ok)
|
||||
}
|
||||
|
||||
func TestJoin(t *testing.T) {
|
||||
c, out := testClientSend()
|
||||
c.Join("#a")
|
||||
assert.Equal(t, "JOIN #a\r\n", <-out)
|
||||
c.Join("#b", "#c")
|
||||
assert.Equal(t, "JOIN #b,#c\r\n", <-out)
|
||||
}
|
||||
|
||||
func TestPart(t *testing.T) {
|
||||
c, out := testClientSend()
|
||||
c.Part("#a")
|
||||
assert.Equal(t, "PART #a\r\n", <-out)
|
||||
c.Part("#b", "#c")
|
||||
assert.Equal(t, "PART #b,#c\r\n", <-out)
|
||||
}
|
||||
|
||||
func TestTopic(t *testing.T) {
|
||||
c, out := testClientSend()
|
||||
c.Topic("#chan")
|
||||
assert.Equal(t, "TOPIC #chan\r\n", <-out)
|
||||
c.Topic("#chan", "apple pie")
|
||||
assert.Equal(t, "TOPIC #chan :apple pie\r\n", <-out)
|
||||
c.Topic("#chan", "")
|
||||
assert.Equal(t, "TOPIC #chan :\r\n", <-out)
|
||||
}
|
||||
|
||||
func TestInvite(t *testing.T) {
|
||||
c, out := testClientSend()
|
||||
c.Invite("user", "#chan")
|
||||
assert.Equal(t, "INVITE user #chan\r\n", <-out)
|
||||
}
|
||||
|
||||
func TestKick(t *testing.T) {
|
||||
c, out := testClientSend()
|
||||
c.Kick("#chan", "user")
|
||||
assert.Equal(t, "KICK #chan user\r\n", <-out)
|
||||
c.Kick("#chan", "a", "b")
|
||||
assert.Equal(t, "KICK #chan a,b\r\n", <-out)
|
||||
}
|
||||
|
||||
func TestPrivmsg(t *testing.T) {
|
||||
c, out := testClientSend()
|
||||
c.Privmsg("user", "the message")
|
||||
assert.Equal(t, "PRIVMSG user :the message\r\n", <-out)
|
||||
}
|
||||
|
||||
func TestNotice(t *testing.T) {
|
||||
c, out := testClientSend()
|
||||
c.Notice("user", "the message")
|
||||
assert.Equal(t, "NOTICE user :the message\r\n", <-out)
|
||||
}
|
||||
|
||||
func TestWhois(t *testing.T) {
|
||||
c, out := testClientSend()
|
||||
c.Whois("user")
|
||||
assert.Equal(t, "WHOIS user\r\n", <-out)
|
||||
}
|
||||
|
||||
func TestAway(t *testing.T) {
|
||||
c, out := testClientSend()
|
||||
c.Away("not here")
|
||||
assert.Equal(t, "AWAY :not here\r\n", <-out)
|
||||
}
|
||||
|
||||
func TestRegister(t *testing.T) {
|
||||
c, out := testClientSend()
|
||||
c.nick = "nick"
|
||||
c.Username = "user"
|
||||
c.Realname = "rn"
|
||||
c.register()
|
||||
assert.Equal(t, "NICK nick\r\n", <-out)
|
||||
assert.Equal(t, "USER user 0 * :rn\r\n", <-out)
|
||||
|
||||
c.Password = "pass"
|
||||
c.register()
|
||||
assert.Equal(t, "PASS pass\r\n", <-out)
|
||||
assert.Equal(t, "NICK nick\r\n", <-out)
|
||||
assert.Equal(t, "USER user 0 * :rn\r\n", <-out)
|
||||
}
|
||||
|
||||
func TestFlushChannels(t *testing.T) {
|
||||
c, out := testClientSend()
|
||||
c.addChannel("#chan1")
|
||||
c.flushChannels()
|
||||
assert.Equal(t, <-out, "JOIN #chan1\r\n")
|
||||
c.addChannel("#chan2")
|
||||
c.addChannel("#chan3")
|
||||
c.flushChannels()
|
||||
assert.Equal(t, <-out, "JOIN #chan2,#chan3\r\n")
|
||||
}
|
236
pkg/irc/conn.go
Normal file
236
pkg/irc/conn.go
Normal file
|
@ -0,0 +1,236 @@
|
|||
package irc
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrBadProtocol = errors.New("This server does not speak IRC")
|
||||
)
|
||||
|
||||
func (c *Client) Connect(address string) {
|
||||
if idx := strings.Index(address, ":"); idx < 0 {
|
||||
c.Host = address
|
||||
|
||||
if c.TLS {
|
||||
address += ":6697"
|
||||
} else {
|
||||
address += ":6667"
|
||||
}
|
||||
} else {
|
||||
c.Host = address[:idx]
|
||||
}
|
||||
c.Server = address
|
||||
c.dialer = &net.Dialer{Timeout: 10 * time.Second}
|
||||
|
||||
c.connChange(false, nil)
|
||||
go c.run()
|
||||
}
|
||||
|
||||
func (c *Client) Reconnect() {
|
||||
close(c.reconnect)
|
||||
}
|
||||
|
||||
func (c *Client) Write(data string) {
|
||||
c.out <- data + "\r\n"
|
||||
}
|
||||
|
||||
func (c *Client) Writef(format string, a ...interface{}) {
|
||||
c.out <- fmt.Sprintf(format+"\r\n", a...)
|
||||
}
|
||||
|
||||
func (c *Client) write(data string) {
|
||||
c.conn.Write([]byte(data + "\r\n"))
|
||||
}
|
||||
|
||||
func (c *Client) writef(format string, a ...interface{}) {
|
||||
fmt.Fprintf(c.conn, format+"\r\n", a...)
|
||||
}
|
||||
|
||||
func (c *Client) run() {
|
||||
c.tryConnect()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-c.quit:
|
||||
if c.Connected() {
|
||||
c.disconnect()
|
||||
}
|
||||
|
||||
c.sendRecv.Wait()
|
||||
close(c.Messages)
|
||||
return
|
||||
|
||||
case <-c.reconnect:
|
||||
if c.Connected() {
|
||||
c.disconnect()
|
||||
}
|
||||
|
||||
c.sendRecv.Wait()
|
||||
c.reconnect = make(chan struct{})
|
||||
|
||||
c.tryConnect()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type ConnectionState struct {
|
||||
Connected bool
|
||||
Error error
|
||||
}
|
||||
|
||||
func (c *Client) connChange(connected bool, err error) {
|
||||
c.ConnectionChanged <- ConnectionState{
|
||||
Connected: connected,
|
||||
Error: err,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) disconnect() {
|
||||
c.lock.Lock()
|
||||
c.connected = false
|
||||
c.lock.Unlock()
|
||||
|
||||
c.conn.Close()
|
||||
}
|
||||
|
||||
func (c *Client) tryConnect() {
|
||||
for {
|
||||
select {
|
||||
case <-c.quit:
|
||||
return
|
||||
|
||||
default:
|
||||
}
|
||||
|
||||
err := c.connect()
|
||||
if err != nil {
|
||||
c.connChange(false, err)
|
||||
if _, ok := err.(x509.UnknownAuthorityError); ok {
|
||||
return
|
||||
}
|
||||
} else {
|
||||
c.backoff.Reset()
|
||||
|
||||
c.flushChannels()
|
||||
return
|
||||
}
|
||||
|
||||
time.Sleep(c.backoff.Duration())
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) connect() error {
|
||||
c.lock.Lock()
|
||||
defer c.lock.Unlock()
|
||||
|
||||
if c.TLS {
|
||||
conn, err := tls.DialWithDialer(c.dialer, "tcp", c.Server, c.TLSConfig)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.conn = conn
|
||||
} else {
|
||||
conn, err := c.dialer.Dial("tcp", c.Server)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.conn = conn
|
||||
}
|
||||
|
||||
c.connected = true
|
||||
c.connChange(true, nil)
|
||||
c.reader = bufio.NewReader(c.conn)
|
||||
|
||||
c.register()
|
||||
|
||||
c.sendRecv.Add(1)
|
||||
go c.recv()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) send() {
|
||||
defer c.sendRecv.Done()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-c.quit:
|
||||
return
|
||||
|
||||
case <-c.reconnect:
|
||||
return
|
||||
|
||||
case msg := <-c.out:
|
||||
_, err := c.conn.Write([]byte(msg))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) recv() {
|
||||
defer c.sendRecv.Done()
|
||||
|
||||
for {
|
||||
line, err := c.reader.ReadString('\n')
|
||||
if err != nil {
|
||||
select {
|
||||
case <-c.quit:
|
||||
return
|
||||
|
||||
default:
|
||||
c.connChange(false, nil)
|
||||
c.Reconnect()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
msg := parseMessage(line)
|
||||
if msg == nil {
|
||||
close(c.quit)
|
||||
c.connChange(false, ErrBadProtocol)
|
||||
return
|
||||
}
|
||||
|
||||
switch msg.Command {
|
||||
case Ping:
|
||||
go c.write("PONG :" + msg.LastParam())
|
||||
|
||||
case Join:
|
||||
if msg.Nick == c.GetNick() {
|
||||
c.addChannel(msg.Params[0])
|
||||
}
|
||||
|
||||
case Nick:
|
||||
if msg.Nick == c.GetNick() {
|
||||
c.setNick(msg.LastParam())
|
||||
}
|
||||
|
||||
case ReplyWelcome:
|
||||
c.setNick(msg.Params[0])
|
||||
c.sendRecv.Add(1)
|
||||
go c.send()
|
||||
|
||||
case ReplyISupport:
|
||||
c.Support.parse(msg.Params)
|
||||
|
||||
case ErrNicknameInUse:
|
||||
if c.HandleNickInUse != nil {
|
||||
go c.writeNick(c.HandleNickInUse(msg.Params[1]))
|
||||
}
|
||||
}
|
||||
|
||||
c.Messages <- msg
|
||||
}
|
||||
}
|
233
pkg/irc/conn_test.go
Normal file
233
pkg/irc/conn_test.go
Normal file
|
@ -0,0 +1,233 @@
|
|||
package irc
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"crypto/tls"
|
||||
"log"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
var ircd *mockIrcd
|
||||
|
||||
func init() {
|
||||
initTestServer()
|
||||
}
|
||||
|
||||
func initTestServer() {
|
||||
ircd = &mockIrcd{
|
||||
conn: make(chan bool, 1),
|
||||
connClosed: make(chan bool, 1),
|
||||
}
|
||||
ircd.start()
|
||||
}
|
||||
|
||||
type mockIrcd struct {
|
||||
conn chan bool
|
||||
connClosed chan bool
|
||||
}
|
||||
|
||||
func (i *mockIrcd) start() {
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:45678")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
cert, err := tls.X509KeyPair(testCert, testKey)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
tlsConfig := &tls.Config{
|
||||
Certificates: []tls.Certificate{cert},
|
||||
}
|
||||
|
||||
lnTLS, err := tls.Listen("tcp", "127.0.0.1:45679", tlsConfig)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
go i.accept(ln)
|
||||
go i.accept(lnTLS)
|
||||
}
|
||||
|
||||
func (i *mockIrcd) accept(ln net.Listener) {
|
||||
for {
|
||||
conn, err := ln.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
go i.handle(conn)
|
||||
i.conn <- true
|
||||
}
|
||||
}
|
||||
|
||||
func (i *mockIrcd) handle(conn net.Conn) {
|
||||
buf := make([]byte, 1024)
|
||||
for {
|
||||
_, err := conn.Read(buf)
|
||||
if err != nil {
|
||||
i.connClosed <- true
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnect(t *testing.T) {
|
||||
c := testClient()
|
||||
c.Connect("127.0.0.1:45678")
|
||||
assert.Equal(t, c.Host, "127.0.0.1")
|
||||
assert.Equal(t, c.Server, "127.0.0.1:45678")
|
||||
waitConnAndClose(t, c)
|
||||
}
|
||||
|
||||
func TestConnectTLS(t *testing.T) {
|
||||
c := testClient()
|
||||
c.TLS = true
|
||||
c.TLSConfig = &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
}
|
||||
c.Connect("127.0.0.1:45679")
|
||||
assert.Equal(t, c.Host, "127.0.0.1")
|
||||
assert.Equal(t, c.Server, "127.0.0.1:45679")
|
||||
waitConnAndClose(t, c)
|
||||
}
|
||||
|
||||
func TestConnectDefaultPorts(t *testing.T) {
|
||||
c := testClient()
|
||||
c.Connect("127.0.0.1")
|
||||
assert.Equal(t, "127.0.0.1:6667", c.Server)
|
||||
|
||||
c = testClient()
|
||||
c.TLS = true
|
||||
c.Connect("127.0.0.1")
|
||||
assert.Equal(t, "127.0.0.1:6697", c.Server)
|
||||
}
|
||||
|
||||
func TestWrite(t *testing.T) {
|
||||
c, out := testClientSend()
|
||||
c.write("test")
|
||||
assert.Equal(t, "test\r\n", <-out)
|
||||
c.Write("test")
|
||||
assert.Equal(t, "test\r\n", <-out)
|
||||
c.writef("test %d", 2)
|
||||
assert.Equal(t, "test 2\r\n", <-out)
|
||||
c.Writef("test %d", 2)
|
||||
assert.Equal(t, "test 2\r\n", <-out)
|
||||
}
|
||||
|
||||
func TestRecv(t *testing.T) {
|
||||
c := testClient()
|
||||
conn := &mockConn{hook: make(chan string, 16)}
|
||||
c.conn = conn
|
||||
|
||||
buf := &bytes.Buffer{}
|
||||
buf.WriteString("CMD\r\n")
|
||||
buf.WriteString("PING :test\r\n")
|
||||
buf.WriteString("001 foo\r\n")
|
||||
c.reader = bufio.NewReader(buf)
|
||||
|
||||
c.sendRecv.Add(1)
|
||||
go c.recv()
|
||||
|
||||
assert.Equal(t, "PONG :test\r\n", <-conn.hook)
|
||||
assert.Equal(t, &Message{Command: "CMD"}, <-c.Messages)
|
||||
}
|
||||
|
||||
func TestRecvTriggersReconnect(t *testing.T) {
|
||||
c := testClient()
|
||||
c.conn = &mockConn{}
|
||||
c.reader = bufio.NewReader(&bytes.Buffer{})
|
||||
done := make(chan struct{})
|
||||
ok := false
|
||||
go func() {
|
||||
c.sendRecv.Add(1)
|
||||
c.recv()
|
||||
_, ok = <-c.reconnect
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
assert.False(t, ok)
|
||||
return
|
||||
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Error("Reconnect not triggered")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClose(t *testing.T) {
|
||||
c := testClient()
|
||||
close(c.quit)
|
||||
ok := false
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
_, ok = <-c.Messages
|
||||
close(done)
|
||||
}()
|
||||
|
||||
c.run()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
assert.False(t, ok)
|
||||
return
|
||||
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Error("Channels not closed")
|
||||
}
|
||||
}
|
||||
|
||||
func waitConnAndClose(t *testing.T, c *Client) {
|
||||
done := make(chan struct{})
|
||||
quit := make(chan struct{})
|
||||
go func() {
|
||||
<-ircd.conn
|
||||
quit <- struct{}{}
|
||||
<-ircd.connClosed
|
||||
close(done)
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-done:
|
||||
return
|
||||
|
||||
case <-quit:
|
||||
assert.True(t, c.Connected())
|
||||
c.Quit()
|
||||
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
t.Error("Took too long")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var testCert = []byte(`-----BEGIN CERTIFICATE-----
|
||||
MIIB0zCCAX2gAwIBAgIJAI/M7BYjwB+uMA0GCSqGSIb3DQEBBQUAMEUxCzAJBgNV
|
||||
BAYTAkFVMRMwEQYDVQQIDApTb21lLVN0YXRlMSEwHwYDVQQKDBhJbnRlcm5ldCBX
|
||||
aWRnaXRzIFB0eSBMdGQwHhcNMTIwOTEyMjE1MjAyWhcNMTUwOTEyMjE1MjAyWjBF
|
||||
MQswCQYDVQQGEwJBVTETMBEGA1UECAwKU29tZS1TdGF0ZTEhMB8GA1UECgwYSW50
|
||||
ZXJuZXQgV2lkZ2l0cyBQdHkgTHRkMFwwDQYJKoZIhvcNAQEBBQADSwAwSAJBANLJ
|
||||
hPHhITqQbPklG3ibCVxwGMRfp/v4XqhfdQHdcVfHap6NQ5Wok/4xIA+ui35/MmNa
|
||||
rtNuC+BdZ1tMuVCPFZcCAwEAAaNQME4wHQYDVR0OBBYEFJvKs8RfJaXTH08W+SGv
|
||||
zQyKn0H8MB8GA1UdIwQYMBaAFJvKs8RfJaXTH08W+SGvzQyKn0H8MAwGA1UdEwQF
|
||||
MAMBAf8wDQYJKoZIhvcNAQEFBQADQQBJlffJHybjDGxRMqaRmDhX0+6v02TUKZsW
|
||||
r5QuVbpQhH6u+0UgcW0jp9QwpxoPTLTWGXEWBBBurxFwiCBhkQ+V
|
||||
-----END CERTIFICATE-----`)
|
||||
|
||||
var testKey = []byte(`-----BEGIN RSA PRIVATE KEY-----
|
||||
MIIBOwIBAAJBANLJhPHhITqQbPklG3ibCVxwGMRfp/v4XqhfdQHdcVfHap6NQ5Wo
|
||||
k/4xIA+ui35/MmNartNuC+BdZ1tMuVCPFZcCAwEAAQJAEJ2N+zsR0Xn8/Q6twa4G
|
||||
6OB1M1WO+k+ztnX/1SvNeWu8D6GImtupLTYgjZcHufykj09jiHmjHx8u8ZZB/o1N
|
||||
MQIhAPW+eyZo7ay3lMz1V01WVjNKK9QSn1MJlb06h/LuYv9FAiEA25WPedKgVyCW
|
||||
SmUwbPw8fnTcpqDWE3yTO3vKcebqMSsCIBF3UmVue8YU3jybC3NxuXq3wNm34R8T
|
||||
xVLHwDXh/6NJAiEAl2oHGGLz64BuAfjKrqwz7qMYr9HCLIe/YsoWq/olzScCIQDi
|
||||
D2lWusoe2/nEqfDVVWGWlyJ7yOmqaVm/iNUN9B2N2g==
|
||||
-----END RSA PRIVATE KEY-----`)
|
40
pkg/irc/const.go
Normal file
40
pkg/irc/const.go
Normal file
|
@ -0,0 +1,40 @@
|
|||
package irc
|
||||
|
||||
const (
|
||||
Error = "ERROR"
|
||||
Join = "JOIN"
|
||||
Mode = "MODE"
|
||||
Nick = "NICK"
|
||||
Notice = "NOTICE"
|
||||
Part = "PART"
|
||||
Ping = "PING"
|
||||
Privmsg = "PRIVMSG"
|
||||
Quit = "QUIT"
|
||||
Topic = "TOPIC"
|
||||
|
||||
ReplyWelcome = "001"
|
||||
ReplyYourHost = "002"
|
||||
ReplyCreated = "003"
|
||||
ReplyISupport = "005"
|
||||
ReplyLUserClient = "251"
|
||||
ReplyLUserOp = "252"
|
||||
ReplyLUserUnknown = "253"
|
||||
ReplyLUserChannels = "254"
|
||||
ReplyLUserMe = "255"
|
||||
ReplyAway = "301"
|
||||
ReplyWhoisUser = "311"
|
||||
ReplyWhoisServer = "312"
|
||||
ReplyWhoisOperator = "313"
|
||||
ReplyWhoisIdle = "317"
|
||||
ReplyEndOfWhois = "318"
|
||||
ReplyWhoisChannels = "319"
|
||||
ReplyNoTopic = "331"
|
||||
ReplyTopic = "332"
|
||||
ReplyNamReply = "353"
|
||||
ReplyEndOfNames = "366"
|
||||
ReplyMotd = "372"
|
||||
ReplyMotdStart = "375"
|
||||
ReplyEndOfMotd = "376"
|
||||
ErrErroneousNickname = "432"
|
||||
ErrNicknameInUse = "433"
|
||||
)
|
161
pkg/irc/message.go
Normal file
161
pkg/irc/message.go
Normal file
|
@ -0,0 +1,161 @@
|
|||
package irc
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/spf13/cast"
|
||||
)
|
||||
|
||||
type Message struct {
|
||||
Tags map[string]string
|
||||
Prefix string
|
||||
Nick string
|
||||
Command string
|
||||
Params []string
|
||||
}
|
||||
|
||||
func (m *Message) LastParam() string {
|
||||
if len(m.Params) > 0 {
|
||||
return m.Params[len(m.Params)-1]
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func parseMessage(line string) *Message {
|
||||
line = strings.Trim(line, "\r\n ")
|
||||
msg := Message{}
|
||||
|
||||
if strings.HasPrefix(line, "@") {
|
||||
next := strings.Index(line, " ")
|
||||
if next == -1 {
|
||||
return nil
|
||||
}
|
||||
tags := strings.Split(line[1:next], ";")
|
||||
|
||||
if len(tags) > 0 {
|
||||
msg.Tags = map[string]string{}
|
||||
}
|
||||
|
||||
for _, tag := range tags {
|
||||
key, val := splitParam(tag)
|
||||
if key == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
if val != "" {
|
||||
msg.Tags[key] = unescapeTag(val)
|
||||
} else {
|
||||
msg.Tags[key] = ""
|
||||
}
|
||||
}
|
||||
|
||||
line = line[next+1:]
|
||||
}
|
||||
|
||||
if strings.HasPrefix(line, ":") {
|
||||
next := strings.Index(line, " ")
|
||||
if next == -1 {
|
||||
return nil
|
||||
}
|
||||
msg.Prefix = line[1:next]
|
||||
|
||||
if i := strings.Index(msg.Prefix, "!"); i > 0 {
|
||||
msg.Nick = msg.Prefix[:i]
|
||||
} else if i := strings.Index(msg.Prefix, "@"); i > 0 {
|
||||
msg.Nick = msg.Prefix[:i]
|
||||
} else {
|
||||
msg.Nick = msg.Prefix
|
||||
}
|
||||
|
||||
line = line[next+1:]
|
||||
}
|
||||
|
||||
cmdEnd := len(line)
|
||||
trailing := ""
|
||||
if i := strings.Index(line, " :"); i > 0 {
|
||||
cmdEnd = i
|
||||
trailing = line[i+2:]
|
||||
}
|
||||
|
||||
cmd := strings.Fields(line[:cmdEnd])
|
||||
if len(cmd) == 0 {
|
||||
return nil
|
||||
}
|
||||
msg.Command = cmd[0]
|
||||
|
||||
if len(cmd) > 1 {
|
||||
msg.Params = cmd[1:]
|
||||
}
|
||||
if cmdEnd != len(line) {
|
||||
msg.Params = append(msg.Params, trailing)
|
||||
}
|
||||
|
||||
return &msg
|
||||
}
|
||||
|
||||
type iSupport struct {
|
||||
support map[string]string
|
||||
lock sync.Mutex
|
||||
}
|
||||
|
||||
func newISupport() *iSupport {
|
||||
return &iSupport{
|
||||
support: map[string]string{},
|
||||
}
|
||||
}
|
||||
|
||||
func (i *iSupport) parse(params []string) {
|
||||
i.lock.Lock()
|
||||
for _, param := range params[1 : len(params)-1] {
|
||||
key, val := splitParam(param)
|
||||
if key == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
if key[0] == '-' {
|
||||
delete(i.support, key[1:])
|
||||
} else {
|
||||
i.support[key] = val
|
||||
}
|
||||
}
|
||||
i.lock.Unlock()
|
||||
}
|
||||
|
||||
func (i *iSupport) Has(key string) bool {
|
||||
i.lock.Lock()
|
||||
_, has := i.support[key]
|
||||
i.lock.Unlock()
|
||||
return has
|
||||
}
|
||||
|
||||
func (i *iSupport) Get(key string) string {
|
||||
i.lock.Lock()
|
||||
v := i.support[key]
|
||||
i.lock.Unlock()
|
||||
return v
|
||||
}
|
||||
|
||||
func (i *iSupport) GetInt(key string) int {
|
||||
i.lock.Lock()
|
||||
v := cast.ToInt(i.support[key])
|
||||
i.lock.Unlock()
|
||||
return v
|
||||
}
|
||||
|
||||
func splitParam(param string) (string, string) {
|
||||
parts := strings.SplitN(param, "=", 2)
|
||||
if len(parts) == 2 {
|
||||
return parts[0], parts[1]
|
||||
}
|
||||
return parts[0], ""
|
||||
}
|
||||
|
||||
func unescapeTag(s string) string {
|
||||
s = strings.Replace(s, "\\:", ";", -1)
|
||||
s = strings.Replace(s, "\\s", " ", -1)
|
||||
s = strings.Replace(s, "\\\\", "\\", -1)
|
||||
s = strings.Replace(s, "\\r", "\r", -1)
|
||||
s = strings.Replace(s, "\\n", "\n", -1)
|
||||
return s
|
||||
}
|
199
pkg/irc/message_test.go
Normal file
199
pkg/irc/message_test.go
Normal file
|
@ -0,0 +1,199 @@
|
|||
package irc
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestParseMessage(t *testing.T) {
|
||||
cases := []struct {
|
||||
input string
|
||||
expected *Message
|
||||
}{
|
||||
{
|
||||
":user CMD #chan :some message\r\n",
|
||||
&Message{
|
||||
Prefix: "user",
|
||||
Nick: "user",
|
||||
Command: "CMD",
|
||||
Params: []string{"#chan", "some message"},
|
||||
},
|
||||
}, {
|
||||
":nick!user@host.com CMD a b\r\n",
|
||||
&Message{
|
||||
Prefix: "nick!user@host.com",
|
||||
Nick: "nick",
|
||||
Command: "CMD",
|
||||
Params: []string{"a", "b"},
|
||||
},
|
||||
}, {
|
||||
"CMD a b :\r\n",
|
||||
&Message{
|
||||
Command: "CMD",
|
||||
Params: []string{"a", "b", ""},
|
||||
},
|
||||
}, {
|
||||
"CMD a b\r\n",
|
||||
&Message{
|
||||
Command: "CMD",
|
||||
Params: []string{"a", "b"},
|
||||
},
|
||||
}, {
|
||||
"CMD\r\n",
|
||||
&Message{
|
||||
Command: "CMD",
|
||||
},
|
||||
}, {
|
||||
"CMD :tests and stuff\r\n",
|
||||
&Message{
|
||||
Command: "CMD",
|
||||
Params: []string{"tests and stuff"},
|
||||
},
|
||||
}, {
|
||||
":nick@host.com CMD\r\n",
|
||||
&Message{
|
||||
Prefix: "nick@host.com",
|
||||
Nick: "nick",
|
||||
Command: "CMD",
|
||||
},
|
||||
}, {
|
||||
":ni@ck!user!name@host!.com CMD\r\n",
|
||||
&Message{
|
||||
Prefix: "ni@ck!user!name@host!.com",
|
||||
Nick: "ni@ck",
|
||||
Command: "CMD",
|
||||
},
|
||||
}, {
|
||||
"CMD #cake pie \r\n",
|
||||
&Message{
|
||||
Command: "CMD",
|
||||
Params: []string{"#cake", "pie"},
|
||||
},
|
||||
}, {
|
||||
" CMD #cake pie\r\n",
|
||||
&Message{
|
||||
Command: "CMD",
|
||||
Params: []string{"#cake", "pie"},
|
||||
},
|
||||
}, {
|
||||
"CMD #cake ::pie\r\n",
|
||||
&Message{
|
||||
Command: "CMD",
|
||||
Params: []string{"#cake", ":pie"},
|
||||
},
|
||||
}, {
|
||||
"CMD #cake : pie\r\n",
|
||||
&Message{
|
||||
Command: "CMD",
|
||||
Params: []string{"#cake", " pie"},
|
||||
},
|
||||
}, {
|
||||
"CMD #cake :pie :P <3\r\n",
|
||||
&Message{
|
||||
Command: "CMD",
|
||||
Params: []string{"#cake", "pie :P <3"},
|
||||
},
|
||||
}, {
|
||||
"CMD #cake :pie!\r\n",
|
||||
&Message{
|
||||
Command: "CMD",
|
||||
Params: []string{"#cake", "pie!"},
|
||||
},
|
||||
}, {
|
||||
"@x=y CMD\r\n",
|
||||
&Message{
|
||||
Tags: map[string]string{
|
||||
"x": "y",
|
||||
},
|
||||
Command: "CMD",
|
||||
},
|
||||
}, {
|
||||
"@x=y :nick!user@host.com CMD\r\n",
|
||||
&Message{
|
||||
Tags: map[string]string{
|
||||
"x": "y",
|
||||
},
|
||||
Prefix: "nick!user@host.com",
|
||||
Nick: "nick",
|
||||
Command: "CMD",
|
||||
},
|
||||
}, {
|
||||
"@x=y :nick!user@host.com CMD :pie and cake\r\n",
|
||||
&Message{
|
||||
Tags: map[string]string{
|
||||
"x": "y",
|
||||
},
|
||||
Prefix: "nick!user@host.com",
|
||||
Nick: "nick",
|
||||
Command: "CMD",
|
||||
Params: []string{"pie and cake"},
|
||||
},
|
||||
}, {
|
||||
"@x=y;a=b CMD\r\n",
|
||||
&Message{
|
||||
Tags: map[string]string{
|
||||
"x": "y",
|
||||
"a": "b",
|
||||
},
|
||||
Command: "CMD",
|
||||
},
|
||||
}, {
|
||||
"@x=y;a=\\\\\\:\\s\\r\\n CMD\r\n",
|
||||
&Message{
|
||||
Tags: map[string]string{
|
||||
"x": "y",
|
||||
"a": "\\; \r\n",
|
||||
},
|
||||
Command: "CMD",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
assert.Equal(t, tc.expected, parseMessage(tc.input))
|
||||
}
|
||||
}
|
||||
|
||||
func TestLastParam(t *testing.T) {
|
||||
assert.Equal(t, "some message", parseMessage(":user CMD #chan :some message\r\n").LastParam())
|
||||
assert.Equal(t, "", parseMessage("NO_PARAMS").LastParam())
|
||||
}
|
||||
|
||||
func TestBadMessagePanic(t *testing.T) {
|
||||
parseMessage("@\r\n")
|
||||
parseMessage("@ :\r\n")
|
||||
parseMessage("@ :\r\n")
|
||||
parseMessage(":user\r\n")
|
||||
parseMessage(":\r\n")
|
||||
parseMessage(":")
|
||||
parseMessage("")
|
||||
}
|
||||
|
||||
func TestParseISupport(t *testing.T) {
|
||||
s := newISupport()
|
||||
s.parse([]string{"bob", "CAKE=31", "PIE", ":durr"})
|
||||
assert.Equal(t, 31, s.GetInt("CAKE"))
|
||||
assert.Equal(t, "31", s.Get("CAKE"))
|
||||
assert.True(t, s.Has("CAKE"))
|
||||
assert.True(t, s.Has("PIE"))
|
||||
assert.False(t, s.Has("APPLES"))
|
||||
assert.Equal(t, "", s.Get("APPLES"))
|
||||
assert.Equal(t, 0, s.GetInt("APPLES"))
|
||||
|
||||
s.parse([]string{"bob", "-PIE", ":hurr"})
|
||||
assert.False(t, s.Has("PIE"))
|
||||
|
||||
s.parse([]string{"bob", "CAKE=1337", ":durr"})
|
||||
assert.Equal(t, 1337, s.GetInt("CAKE"))
|
||||
|
||||
s.parse([]string{"bob", "CAKE=", ":durr"})
|
||||
assert.Equal(t, "", s.Get("CAKE"))
|
||||
assert.True(t, s.Has("CAKE"))
|
||||
|
||||
s.parse([]string{"bob", "CAKE===", ":durr"})
|
||||
assert.Equal(t, "==", s.Get("CAKE"))
|
||||
|
||||
s.parse([]string{"bob", "-CAKE=31", ":durr"})
|
||||
assert.False(t, s.Has("CAKE"))
|
||||
}
|
38
pkg/letsencrypt/directory.go
Normal file
38
pkg/letsencrypt/directory.go
Normal file
|
@ -0,0 +1,38 @@
|
|||
package letsencrypt
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
)
|
||||
|
||||
type Directory string
|
||||
|
||||
func (d Directory) Domain(domain string) string {
|
||||
return filepath.Join(string(d), "certs", domain)
|
||||
}
|
||||
|
||||
func (d Directory) Cert(domain string) string {
|
||||
return filepath.Join(d.Domain(domain), "cert.pem")
|
||||
}
|
||||
|
||||
func (d Directory) Key(domain string) string {
|
||||
return filepath.Join(d.Domain(domain), "key.pem")
|
||||
}
|
||||
|
||||
func (d Directory) Meta(domain string) string {
|
||||
return filepath.Join(d.Domain(domain), "metadata.json")
|
||||
}
|
||||
|
||||
func (d Directory) User(email string) string {
|
||||
if email == "" {
|
||||
email = defaultUser
|
||||
}
|
||||
return filepath.Join(string(d), "users", email)
|
||||
}
|
||||
|
||||
func (d Directory) UserRegistration(email string) string {
|
||||
return filepath.Join(d.User(email), "registration.json")
|
||||
}
|
||||
|
||||
func (d Directory) UserKey(email string) string {
|
||||
return filepath.Join(d.User(email), "key.pem")
|
||||
}
|
287
pkg/letsencrypt/letsencrypt.go
Normal file
287
pkg/letsencrypt/letsencrypt.go
Normal file
|
@ -0,0 +1,287 @@
|
|||
package letsencrypt
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/xenolf/lego/acme"
|
||||
)
|
||||
|
||||
const URL = "https://acme-v01.api.letsencrypt.org/directory"
|
||||
const KeySize = 2048
|
||||
|
||||
var directory Directory
|
||||
|
||||
func Run(dir, domain, email, port string) (*state, error) {
|
||||
directory = Directory(dir)
|
||||
|
||||
user, err := getUser(email)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
client, err := acme.NewClient(URL, &user, acme.RSA2048)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
client.ExcludeChallenges([]acme.Challenge{acme.TLSSNI01})
|
||||
client.SetHTTPAddress(port)
|
||||
|
||||
if user.Registration == nil {
|
||||
user.Registration, err = client.Register()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = client.AgreeToTOS()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = saveUser(user)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
s := &state{
|
||||
client: client,
|
||||
domain: domain,
|
||||
}
|
||||
|
||||
if certExists(domain) {
|
||||
if !s.renew() {
|
||||
err = s.loadCert()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
s.refreshOCSP()
|
||||
} else {
|
||||
err = s.obtain()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
go s.maintain()
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
type state struct {
|
||||
client *acme.Client
|
||||
domain string
|
||||
cert *tls.Certificate
|
||||
certPEM []byte
|
||||
lock sync.Mutex
|
||||
}
|
||||
|
||||
func (s *state) GetCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
s.lock.Lock()
|
||||
cert := s.cert
|
||||
s.lock.Unlock()
|
||||
|
||||
return cert, nil
|
||||
}
|
||||
|
||||
func (s *state) getCertPEM() []byte {
|
||||
s.lock.Lock()
|
||||
certPEM := s.certPEM
|
||||
s.lock.Unlock()
|
||||
|
||||
return certPEM
|
||||
}
|
||||
|
||||
func (s *state) setCert(meta acme.CertificateResource) {
|
||||
cert, err := tls.X509KeyPair(meta.Certificate, meta.PrivateKey)
|
||||
if err == nil {
|
||||
s.lock.Lock()
|
||||
if s.cert != nil {
|
||||
cert.OCSPStaple = s.cert.OCSPStaple
|
||||
}
|
||||
|
||||
s.cert = &cert
|
||||
s.certPEM = meta.Certificate
|
||||
s.lock.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *state) setOCSP(ocsp []byte) {
|
||||
cert := tls.Certificate{
|
||||
OCSPStaple: ocsp,
|
||||
}
|
||||
|
||||
s.lock.Lock()
|
||||
if s.cert != nil {
|
||||
cert.Certificate = s.cert.Certificate
|
||||
cert.PrivateKey = s.cert.PrivateKey
|
||||
}
|
||||
s.cert = &cert
|
||||
s.lock.Unlock()
|
||||
}
|
||||
|
||||
func (s *state) obtain() error {
|
||||
cert, errors := s.client.ObtainCertificate([]string{s.domain}, true, nil, false)
|
||||
if err := errors[s.domain]; err != nil {
|
||||
if _, ok := err.(acme.TOSError); ok {
|
||||
err := s.client.AgreeToTOS()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return s.obtain()
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
s.setCert(cert)
|
||||
s.refreshOCSP()
|
||||
|
||||
err := saveCert(cert)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *state) renew() bool {
|
||||
cert, err := ioutil.ReadFile(directory.Cert(s.domain))
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
exp, err := acme.GetPEMCertExpiration(cert)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
daysLeft := int(exp.Sub(time.Now().UTC()).Hours() / 24)
|
||||
|
||||
if daysLeft <= 30 {
|
||||
metaBytes, err := ioutil.ReadFile(directory.Meta(s.domain))
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
key, err := ioutil.ReadFile(directory.Key(s.domain))
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
var meta acme.CertificateResource
|
||||
err = json.Unmarshal(metaBytes, &meta)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
meta.Certificate = cert
|
||||
meta.PrivateKey = key
|
||||
|
||||
Renew:
|
||||
newMeta, err := s.client.RenewCertificate(meta, true, false)
|
||||
if err != nil {
|
||||
if _, ok := err.(acme.TOSError); ok {
|
||||
err := s.client.AgreeToTOS()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
goto Renew
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
s.setCert(newMeta)
|
||||
|
||||
err = saveCert(newMeta)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *state) refreshOCSP() {
|
||||
ocsp, resp, err := acme.GetOCSPForCert(s.getCertPEM())
|
||||
if err == nil && resp.Status == acme.OCSPGood {
|
||||
s.setOCSP(ocsp)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *state) maintain() {
|
||||
renew := time.Tick(24 * time.Hour)
|
||||
ocsp := time.Tick(1 * time.Hour)
|
||||
for {
|
||||
select {
|
||||
case <-renew:
|
||||
s.renew()
|
||||
|
||||
case <-ocsp:
|
||||
s.refreshOCSP()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *state) loadCert() error {
|
||||
cert, err := ioutil.ReadFile(directory.Cert(s.domain))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
key, err := ioutil.ReadFile(directory.Key(s.domain))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.setCert(acme.CertificateResource{
|
||||
Certificate: cert,
|
||||
PrivateKey: key,
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func certExists(domain string) bool {
|
||||
if _, err := os.Stat(directory.Cert(domain)); err != nil {
|
||||
return false
|
||||
}
|
||||
if _, err := os.Stat(directory.Key(domain)); err != nil {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func saveCert(cert acme.CertificateResource) error {
|
||||
err := os.MkdirAll(directory.Domain(cert.Domain), 0700)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = ioutil.WriteFile(directory.Cert(cert.Domain), cert.Certificate, 0600)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = ioutil.WriteFile(directory.Key(cert.Domain), cert.PrivateKey, 0600)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
jsonBytes, err := json.MarshalIndent(&cert, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = ioutil.WriteFile(directory.Meta(cert.Domain), jsonBytes, 0600)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
109
pkg/letsencrypt/user.go
Normal file
109
pkg/letsencrypt/user.go
Normal file
|
@ -0,0 +1,109 @@
|
|||
package letsencrypt
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
|
||||
"github.com/xenolf/lego/acme"
|
||||
)
|
||||
|
||||
const defaultUser = "default"
|
||||
|
||||
type User struct {
|
||||
Email string
|
||||
Registration *acme.RegistrationResource
|
||||
key crypto.PrivateKey
|
||||
}
|
||||
|
||||
func (u User) GetEmail() string {
|
||||
return u.Email
|
||||
}
|
||||
|
||||
func (u User) GetRegistration() *acme.RegistrationResource {
|
||||
return u.Registration
|
||||
}
|
||||
|
||||
func (u User) GetPrivateKey() crypto.PrivateKey {
|
||||
return u.key
|
||||
}
|
||||
|
||||
func newUser(email string) (User, error) {
|
||||
var err error
|
||||
user := User{Email: email}
|
||||
user.key, err = rsa.GenerateKey(rand.Reader, KeySize)
|
||||
if err != nil {
|
||||
return user, err
|
||||
}
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func getUser(email string) (User, error) {
|
||||
var user User
|
||||
|
||||
reg, err := os.Open(directory.UserRegistration(email))
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return newUser(email)
|
||||
}
|
||||
return user, err
|
||||
}
|
||||
defer reg.Close()
|
||||
|
||||
err = json.NewDecoder(reg).Decode(&user)
|
||||
if err != nil {
|
||||
return user, err
|
||||
}
|
||||
|
||||
user.key, err = loadRSAPrivateKey(directory.UserKey(email))
|
||||
if err != nil {
|
||||
return user, err
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func saveUser(user User) error {
|
||||
err := os.MkdirAll(directory.User(user.Email), 0700)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = saveRSAPrivateKey(user.key, directory.UserKey(user.Email))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
jsonBytes, err := json.MarshalIndent(&user, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return ioutil.WriteFile(directory.UserRegistration(user.Email), jsonBytes, 0600)
|
||||
}
|
||||
|
||||
func loadRSAPrivateKey(file string) (crypto.PrivateKey, error) {
|
||||
keyBytes, err := ioutil.ReadFile(file)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
keyBlock, _ := pem.Decode(keyBytes)
|
||||
return x509.ParsePKCS1PrivateKey(keyBlock.Bytes)
|
||||
}
|
||||
|
||||
func saveRSAPrivateKey(key crypto.PrivateKey, file string) error {
|
||||
pemKey := pem.Block{
|
||||
Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key.(*rsa.PrivateKey)),
|
||||
}
|
||||
keyOut, err := os.Create(file)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer keyOut.Close()
|
||||
return pem.Encode(keyOut, &pemKey)
|
||||
}
|
35
pkg/letsencrypt/user_test.go
Normal file
35
pkg/letsencrypt/user_test.go
Normal file
|
@ -0,0 +1,35 @@
|
|||
package letsencrypt
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func tempdir() string {
|
||||
f, _ := ioutil.TempDir("", "")
|
||||
return f
|
||||
}
|
||||
|
||||
func testUser(t *testing.T, email string) {
|
||||
user, err := newUser(email)
|
||||
assert.Nil(t, err)
|
||||
key := user.GetPrivateKey()
|
||||
assert.NotNil(t, key)
|
||||
|
||||
err = saveUser(user)
|
||||
assert.Nil(t, err)
|
||||
|
||||
user, err = getUser(email)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, email, user.GetEmail())
|
||||
assert.Equal(t, key, user.GetPrivateKey())
|
||||
}
|
||||
|
||||
func TestUser(t *testing.T) {
|
||||
directory = Directory(tempdir())
|
||||
|
||||
testUser(t, "test@test.com")
|
||||
testUser(t, "")
|
||||
}
|
132
pkg/linkmeta/linkmeta.go
Normal file
132
pkg/linkmeta/linkmeta.go
Normal file
|
@ -0,0 +1,132 @@
|
|||
package linkmeta
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"golang.org/x/net/html"
|
||||
"golang.org/x/net/html/atom"
|
||||
)
|
||||
|
||||
var (
|
||||
Client = &http.Client{
|
||||
Timeout: 15 * time.Second,
|
||||
}
|
||||
|
||||
ErrContentType = errors.New("Unsupported Content-Type")
|
||||
)
|
||||
|
||||
type Meta struct {
|
||||
URL string `json:"URL"`
|
||||
SiteName string `json:"siteName,omitempty"`
|
||||
Color string `json:"color,omitempty"`
|
||||
Title string `json:"title"`
|
||||
Description string `json:"description"`
|
||||
ImageURL string `json:"imageURL,omitempty"`
|
||||
VideoURL string `json:"videoURL,omitempty"`
|
||||
}
|
||||
|
||||
func Fetch(url string) (*Meta, error) {
|
||||
resp, err := http.Get(url)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// TODO: Image links
|
||||
if !strings.HasPrefix(resp.Header.Get("Content-Type"), "text/html") {
|
||||
return nil, ErrContentType
|
||||
}
|
||||
|
||||
return ExtractMeta(resp.Body, url)
|
||||
}
|
||||
|
||||
func ExtractMeta(body io.Reader, url string) (*Meta, error) {
|
||||
meta := Meta{URL: url}
|
||||
var currentNode atom.Atom
|
||||
|
||||
z := html.NewTokenizer(body)
|
||||
for {
|
||||
tt := z.Next()
|
||||
switch tt {
|
||||
case html.ErrorToken:
|
||||
if z.Err() == io.EOF {
|
||||
return &meta, nil
|
||||
}
|
||||
return nil, z.Err()
|
||||
|
||||
case html.TextToken:
|
||||
if currentNode == atom.Title && meta.Title == "" {
|
||||
meta.Title = string(z.Text())
|
||||
}
|
||||
|
||||
case html.StartTagToken, html.SelfClosingTagToken, html.EndTagToken:
|
||||
name, hasAttr := z.TagName()
|
||||
node := atom.Lookup(name)
|
||||
|
||||
if node == atom.Meta && hasAttr {
|
||||
var key, val []byte
|
||||
var name, content string
|
||||
for hasAttr {
|
||||
key, val, hasAttr = z.TagAttr()
|
||||
switch atom.String(key) {
|
||||
case "name":
|
||||
name = string(val)
|
||||
|
||||
case "property":
|
||||
name = string(val)
|
||||
|
||||
case "content":
|
||||
content = string(val)
|
||||
}
|
||||
}
|
||||
|
||||
if content != "" {
|
||||
switch name {
|
||||
case "og:site_name":
|
||||
meta.SiteName = content
|
||||
|
||||
case "theme-color", "msapplication-TileColor":
|
||||
meta.Color = content
|
||||
|
||||
case "og:title", "twitter:title", "title":
|
||||
meta.Title = content
|
||||
|
||||
case "og:description", "twitter:description":
|
||||
meta.Description = content
|
||||
|
||||
case "description":
|
||||
if meta.Description == "" {
|
||||
meta.Description = content
|
||||
}
|
||||
|
||||
case "og:image", "og:image:secure_url", "twitter:image":
|
||||
if !strings.HasPrefix(meta.ImageURL, "https:") {
|
||||
meta.ImageURL = content
|
||||
}
|
||||
|
||||
case "og:video:url", "og:video:secure_url", "twitter:player":
|
||||
if !strings.HasPrefix(meta.VideoURL, "https:") {
|
||||
meta.VideoURL = content
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
if tt == html.StartTagToken {
|
||||
currentNode = node
|
||||
} else {
|
||||
currentNode = 0
|
||||
}
|
||||
|
||||
if (node == atom.Head && tt == html.EndTagToken) || node == atom.Body {
|
||||
return &meta, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
106
pkg/session/session.go
Normal file
106
pkg/session/session.go
Normal file
|
@ -0,0 +1,106 @@
|
|||
package session
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
CookieName = "session"
|
||||
|
||||
Expiration = time.Hour * 24 * 7
|
||||
RefreshInterval = time.Hour
|
||||
)
|
||||
|
||||
type Session struct {
|
||||
UserID uint64
|
||||
|
||||
key string
|
||||
createdAt int64
|
||||
expiration *time.Timer
|
||||
lock sync.Mutex
|
||||
}
|
||||
|
||||
func New(id uint64) (*Session, error) {
|
||||
key, err := newSessionKey()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Session{
|
||||
key: key,
|
||||
createdAt: time.Now().Unix(),
|
||||
UserID: id,
|
||||
expiration: time.NewTimer(Expiration),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Session) Init() {
|
||||
exp := time.Until(time.Unix(s.createdAt, 0).Add(Expiration))
|
||||
s.expiration = time.NewTimer(exp)
|
||||
}
|
||||
|
||||
func (s *Session) Key() string {
|
||||
s.lock.Lock()
|
||||
key := s.key
|
||||
s.lock.Unlock()
|
||||
return key
|
||||
}
|
||||
|
||||
func (s *Session) SetCookie(w http.ResponseWriter, r *http.Request) {
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: CookieName,
|
||||
Value: s.Key(),
|
||||
Path: "/",
|
||||
Expires: time.Now().Add(Expiration),
|
||||
HttpOnly: true,
|
||||
Secure: r.TLS != nil,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Session) Expired() bool {
|
||||
s.lock.Lock()
|
||||
created := time.Unix(s.createdAt, 0)
|
||||
s.lock.Unlock()
|
||||
return time.Since(created) > Expiration
|
||||
}
|
||||
|
||||
func (s *Session) Refresh() (string, bool, error) {
|
||||
s.lock.Lock()
|
||||
created := time.Unix(s.createdAt, 0)
|
||||
s.lock.Unlock()
|
||||
|
||||
if time.Since(created) > Expiration {
|
||||
return "", true, nil
|
||||
}
|
||||
|
||||
if time.Since(created) > RefreshInterval {
|
||||
key, err := newSessionKey()
|
||||
if err != nil {
|
||||
return "", false, err
|
||||
}
|
||||
|
||||
s.expiration.Reset(Expiration)
|
||||
|
||||
s.lock.Lock()
|
||||
s.createdAt = time.Now().Unix()
|
||||
s.key = key
|
||||
s.lock.Unlock()
|
||||
return key, false, nil
|
||||
}
|
||||
|
||||
return "", false, nil
|
||||
}
|
||||
|
||||
func (s *Session) WaitUntilExpiration() {
|
||||
<-s.expiration.C
|
||||
}
|
||||
|
||||
func newSessionKey() (string, error) {
|
||||
key := make([]byte, 32)
|
||||
_, err := rand.Read(key)
|
||||
return base64.RawURLEncoding.EncodeToString(key), err
|
||||
}
|
5
pkg/session/session.schema
Normal file
5
pkg/session/session.schema
Normal file
|
@ -0,0 +1,5 @@
|
|||
struct Session {
|
||||
UserID uint64
|
||||
key string
|
||||
createdAt int64
|
||||
}
|
112
pkg/session/session.schema.gen.go
Normal file
112
pkg/session/session.schema.gen.go
Normal file
|
@ -0,0 +1,112 @@
|
|||
package session
|
||||
|
||||
import (
|
||||
"io"
|
||||
"time"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
var (
|
||||
_ = unsafe.Sizeof(0)
|
||||
_ = io.ReadFull
|
||||
_ = time.Now()
|
||||
)
|
||||
|
||||
func (d *Session) Size() (s uint64) {
|
||||
|
||||
{
|
||||
l := uint64(len(d.key))
|
||||
|
||||
{
|
||||
|
||||
t := l
|
||||
for t >= 0x80 {
|
||||
t >>= 7
|
||||
s++
|
||||
}
|
||||
s++
|
||||
|
||||
}
|
||||
s += l
|
||||
}
|
||||
s += 16
|
||||
return
|
||||
}
|
||||
func (d *Session) Marshal(buf []byte) ([]byte, error) {
|
||||
size := d.Size()
|
||||
{
|
||||
if uint64(cap(buf)) >= size {
|
||||
buf = buf[:size]
|
||||
} else {
|
||||
buf = make([]byte, size)
|
||||
}
|
||||
}
|
||||
i := uint64(0)
|
||||
|
||||
{
|
||||
|
||||
*(*uint64)(unsafe.Pointer(&buf[0])) = d.UserID
|
||||
|
||||
}
|
||||
{
|
||||
l := uint64(len(d.key))
|
||||
|
||||
{
|
||||
|
||||
t := uint64(l)
|
||||
|
||||
for t >= 0x80 {
|
||||
buf[i+8] = byte(t) | 0x80
|
||||
t >>= 7
|
||||
i++
|
||||
}
|
||||
buf[i+8] = byte(t)
|
||||
i++
|
||||
|
||||
}
|
||||
copy(buf[i+8:], d.key)
|
||||
i += l
|
||||
}
|
||||
{
|
||||
|
||||
*(*int64)(unsafe.Pointer(&buf[i+8])) = d.createdAt
|
||||
|
||||
}
|
||||
return buf[:i+16], nil
|
||||
}
|
||||
|
||||
func (d *Session) Unmarshal(buf []byte) (uint64, error) {
|
||||
i := uint64(0)
|
||||
|
||||
{
|
||||
|
||||
d.UserID = *(*uint64)(unsafe.Pointer(&buf[i+0]))
|
||||
|
||||
}
|
||||
{
|
||||
l := uint64(0)
|
||||
|
||||
{
|
||||
|
||||
bs := uint8(7)
|
||||
t := uint64(buf[i+8] & 0x7F)
|
||||
for buf[i+8]&0x80 == 0x80 {
|
||||
i++
|
||||
t |= uint64(buf[i+8]&0x7F) << bs
|
||||
bs += 7
|
||||
}
|
||||
i++
|
||||
|
||||
l = t
|
||||
|
||||
}
|
||||
d.key = string(buf[i+8 : i+8+l])
|
||||
i += l
|
||||
}
|
||||
{
|
||||
|
||||
d.createdAt = *(*int64)(unsafe.Pointer(&buf[i+8]))
|
||||
|
||||
}
|
||||
return i + 16, nil
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue