Store auth info in a JWT token in a cookie

This commit is contained in:
Ken-Håvard Lieng 2016-01-15 02:27:30 +01:00
parent 3e0a1be6bc
commit fb54d4966c
18 changed files with 499 additions and 331 deletions

File diff suppressed because one or more lines are too long

View File

@ -1,35 +1,22 @@
import React from 'react';
import { render } from 'react-dom';
import { syncReduxAndRouter, replacePath } from 'redux-simple-router';
import { syncReduxAndRouter } from 'redux-simple-router';
import createBrowserHistory from 'history/lib/createBrowserHistory';
import configureStore from './store';
import createRoutes from './routes';
import Socket from './util/Socket';
import handleSocket from './socket';
import { createUUID } from './util';
import Root from './containers/Root';
const host = __DEV__ ? `${window.location.hostname}:1337` : window.location.host;
let uuid = localStorage.uuid;
let newUser = false;
if (!uuid) {
uuid = createUUID();
newUser = true;
}
const socket = new Socket(host, uuid);
const socket = new Socket(host);
const store = configureStore(socket);
handleSocket(socket, store);
const history = createBrowserHistory();
syncReduxAndRouter(history, store);
if (newUser) {
store.dispatch(replacePath('/connect'));
localStorage.uuid = uuid;
}
const routes = createRoutes();
render(<Root store={store} routes={routes} history={history} />, document.getElementById('root'));

View File

@ -2,16 +2,12 @@ import EventEmitter2 from 'eventemitter2';
import Backoff from 'backo';
export default class Socket extends EventEmitter2 {
constructor(host, uuid) {
constructor(host) {
super();
const protocol = window.location.protocol === 'https:' ? 'wss' : 'ws';
this.url = `${protocol}://${host}/ws`;
if (uuid) {
this.url += `?uuid=${uuid}`;
}
this.connectTimeout = 20000;
this.pingTimeout = 30000;
this.backoff = new Backoff({

View File

@ -10,14 +10,6 @@ export function normalizeChannel(channel) {
return channel.split('#').join('').toLowerCase();
}
export function createUUID() {
return 'xxxxxxxx-xxxx-4xxx-yxxx-xxxxxxxxxxxx'.replace(/[xy]/g, c => {
const r = Math.random() * 16 | 0;
const v = c === 'x' ? r : (r & 0x3 | 0x8);
return v.toString(16);
});
}
export function timestamp(date = new Date()) {
const h = padLeft(date.getHours(), 2, '0');
const m = padLeft(date.getMinutes(), 2, '0');

View File

@ -11,7 +11,7 @@ import (
var clearCmd = &cobra.Command{
Use: "clear",
Short: "Clear database and message logs",
Short: "Clear all user data",
Run: func(cmd *cobra.Command, args []string) {
err := os.Remove(storage.Path.Database())
if err == nil || os.IsNotExist(err) {
@ -20,9 +20,16 @@ var clearCmd = &cobra.Command{
log.Println(err)
}
err = os.RemoveAll(storage.Path.Logs())
err = os.RemoveAll(storage.Path.HMACKey())
if err == nil || os.IsNotExist(err) {
log.Println("HMAC key cleared")
} else {
log.Println(err)
}
err = os.RemoveAll(storage.Path.Users())
if err == nil {
log.Println("Logs cleared")
log.Println("User data cleared")
} else {
log.Println(err)
}

102
server/auth.go Normal file
View File

@ -0,0 +1,102 @@
package server
import (
"crypto/rand"
"fmt"
"io/ioutil"
"log"
"net/http"
"github.com/dgrijalva/jwt-go"
"github.com/khlieng/dispatch/storage"
)
func handleAuth(w http.ResponseWriter, r *http.Request) *Session {
var session *Session
cookie, err := r.Cookie(cookieName)
if err != nil {
authLog(r, "No cookie set")
session = newUser(w, r)
} else {
token, err := jwt.Parse(cookie.Value, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
return hmacKey, nil
})
if err == nil && token.Valid {
userID := uint64(token.Claims["UserID"].(float64))
log.Println(r.RemoteAddr, "[Auth] GET", r.URL.Path, "| Valid token | User ID:", userID)
sessionLock.Lock()
session = sessions[userID]
sessionLock.Unlock()
} else {
if err != nil {
authLog(r, "Invalid token: "+err.Error())
} else {
authLog(r, "Invalid token")
}
session = newUser(w, r)
}
}
return session
}
func newUser(w http.ResponseWriter, r *http.Request) *Session {
user := storage.NewUser()
if user == nil {
return nil
}
log.Println(r.RemoteAddr, "[Auth] Create session | User ID:", user.ID)
session := NewSession(user)
sessionLock.Lock()
sessions[user.ID] = session
sessionLock.Unlock()
go session.write()
token := jwt.New(jwt.SigningMethodHS256)
token.Claims["UserID"] = user.ID
tokenString, err := token.SignedString(hmacKey)
if err != nil {
return nil
}
http.SetCookie(w, &http.Cookie{
Name: cookieName,
Value: tokenString,
Path: "/",
HttpOnly: true,
Secure: r.TLS != nil,
})
return session
}
func getHMACKey() ([]byte, error) {
key, err := ioutil.ReadFile(storage.Path.HMACKey())
if err != nil {
key = make([]byte, 32)
rand.Read(key)
err = ioutil.WriteFile(storage.Path.HMACKey(), key, 0600)
if err != nil {
return nil, err
}
}
return key, nil
}
func authLog(r *http.Request, s string) {
log.Println(r.RemoteAddr, "[Auth] GET", r.URL.Path, "|", s)
}

View File

@ -10,9 +10,8 @@ import (
func reconnectIRC() {
for _, user := range storage.LoadUsers() {
session := NewSession()
session.user = user
sessions[user.UUID] = session
session := NewSession(user)
sessions[user.ID] = session
go session.write()
channels := user.GetChannels()
@ -30,7 +29,13 @@ func reconnectIRC() {
}
session.setIRC(server.Host, i)
if server.Port != "" {
i.Connect(net.JoinHostPort(server.Host, server.Port))
} else {
i.Connect(server.Host)
}
go newIRCHandler(i, session).run()
var joining []string

View File

@ -22,7 +22,7 @@ func TestMain(m *testing.M) {
storage.Initialize(tempdir)
storage.Open()
user = storage.NewUser("uuid")
user = storage.NewUser()
channelStore = storage.NewChannelStore()
code := m.Run()
@ -34,8 +34,7 @@ func TestMain(m *testing.M) {
func dispatchMessage(msg *irc.Message) WSResponse {
c := irc.NewClient("nick", "user")
c.Host = "host.com"
s := NewSession()
s.user = user
s := NewSession(user)
newIRCHandler(c, s).dispatchMessage(msg)
@ -168,7 +167,7 @@ func TestHandleIRCWelcome(t *testing.T) {
func TestHandleIRCWhois(t *testing.T) {
c := irc.NewClient("nick", "user")
c.Host = "host.com"
s := NewSession()
s := NewSession(nil)
i := newIRCHandler(c, s)
i.dispatchMessage(&irc.Message{
@ -212,7 +211,7 @@ func TestHandleIRCTopic(t *testing.T) {
func TestHandleIRCNames(t *testing.T) {
c := irc.NewClient("nick", "user")
c.Host = "host.com"
s := NewSession()
s := NewSession(nil)
i := newIRCHandler(c, s)
i.dispatchMessage(&irc.Message{
@ -240,7 +239,7 @@ func TestHandleIRCNames(t *testing.T) {
func TestHandleIRCMotd(t *testing.T) {
c := irc.NewClient("nick", "user")
c.Host = "host.com"
s := NewSession()
s := NewSession(nil)
i := newIRCHandler(c, s)
i.dispatchMessage(&irc.Message{

View File

@ -29,10 +29,16 @@ type File struct {
func serveFiles(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/" {
handleAuth(w, r)
serveFile(w, r, "index.html.gz", "text/html")
return
}
if strings.HasSuffix(r.URL.Path, "favicon.ico") {
w.WriteHeader(404)
return
}
for _, file := range files {
if strings.HasSuffix(r.URL.Path, file.Path) {
serveFile(w, r, file.Path+".gz", file.ContentType)
@ -40,6 +46,7 @@ func serveFiles(w http.ResponseWriter, r *http.Request) {
}
}
handleAuth(w, r)
serveFile(w, r, "index.html.gz", "text/html")
}

View File

@ -17,11 +17,17 @@ import (
"github.com/khlieng/dispatch/storage"
)
const (
cookieName = "dispatch"
)
var (
channelStore *storage.ChannelStore
sessions map[string]*Session
sessions map[uint64]*Session
sessionLock sync.Mutex
hmacKey []byte
upgrader = websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
@ -35,7 +41,13 @@ func Run() {
defer storage.Close()
channelStore = storage.NewChannelStore()
sessions = make(map[string]*Session)
sessions = make(map[uint64]*Session)
var err error
hmacKey, err = getHMACKey()
if err != nil {
log.Fatal(err)
}
reconnectIRC()
startHTTP()
@ -95,27 +107,32 @@ func startHTTP() {
func serve(w http.ResponseWriter, r *http.Request) {
if r.Method != "GET" {
w.WriteHeader(404)
return
}
if r.URL.Path == "/ws" {
upgradeWS(w, r)
session := handleAuth(w, r)
if session == nil {
log.Println("[Auth] No session")
w.WriteHeader(500)
return
}
upgradeWS(w, r, session)
} else {
serveFiles(w, r)
}
}
func upgradeWS(w http.ResponseWriter, r *http.Request) {
conn, err := upgrader.Upgrade(w, r, nil)
func upgradeWS(w http.ResponseWriter, r *http.Request, session *Session) {
conn, err := upgrader.Upgrade(w, r, w.Header())
if err != nil {
log.Println(err)
return
}
uuid := r.URL.Query().Get("uuid")
if uuid != "" {
newWSHandler(conn, uuid).run()
}
newWSHandler(conn, session).run()
}
func createHTTPSRedirect(portHTTPS string) http.HandlerFunc {

View File

@ -19,12 +19,13 @@ type Session struct {
user *storage.User
}
func NewSession() *Session {
func NewSession(user *storage.User) *Session {
return &Session{
irc: make(map[string]*irc.Client),
connectionState: make(map[string]bool),
ws: make(map[string]*wsConn),
out: make(chan WSResponse, 32),
user: user,
}
}
@ -88,6 +89,14 @@ func (s *Session) deleteWS(addr string) {
s.wsLock.Unlock()
}
func (s *Session) numWS() int {
s.ircLock.Lock()
n := len(s.ws)
s.ircLock.Unlock()
return n
}
func (s *Session) sendJSON(t string, v interface{}) {
s.out <- WSResponse{t, v}
}

View File

@ -20,13 +20,14 @@ type wsHandler struct {
handlers map[string]func([]byte)
}
func newWSHandler(conn *websocket.Conn, uuid string) *wsHandler {
func newWSHandler(conn *websocket.Conn, session *Session) *wsHandler {
h := &wsHandler{
ws: newWSConn(conn),
session: session,
addr: conn.RemoteAddr().String(),
}
h.init(uuid)
h.initHandlers()
h.init()
return h
}
@ -54,16 +55,12 @@ func (h *wsHandler) dispatchRequest(req WSRequest) {
}
}
func (h *wsHandler) init(uuid string) {
log.Println(h.addr, "set UUID", uuid)
sessionLock.Lock()
if storedSession, exists := sessions[uuid]; exists {
sessionLock.Unlock()
h.session = storedSession
func (h *wsHandler) init() {
h.session.setWS(h.addr, h.ws)
log.Println(h.addr, "attached to", h.session.numIRC(), "existing IRC connections")
log.Println(h.addr, "[Session] User ID:", h.session.user.ID, "|",
h.session.numIRC(), "IRC connections |",
h.session.numWS(), "WebSocket connections")
channels := h.session.user.GetChannels()
for i, channel := range channels {
@ -81,18 +78,6 @@ func (h *wsHandler) init(uuid string) {
Users: channelStore.GetUsers(channel.Server, channel.Name),
})
}
} else {
h.session = NewSession()
h.session.user = storage.NewUser(uuid)
sessions[uuid] = h.session
sessionLock.Unlock()
h.session.setWS(h.addr, h.ws)
h.session.sendJSON("servers", nil)
go h.session.write()
}
}
func (h *wsHandler) connect(b []byte) {
@ -105,7 +90,7 @@ func (h *wsHandler) connect(b []byte) {
}
if _, ok := h.session.getIRC(host); !ok {
log.Println(h.addr, "connecting to", data.Server)
log.Println(h.addr, "[IRC] Add server", data.Server)
i := irc.NewClient(data.Nick, data.Username)
i.TLS = data.TLS
@ -134,7 +119,7 @@ func (h *wsHandler) connect(b []byte) {
Realname: data.Realname,
})
} else {
log.Println(h.addr, "already connected to", data.Server)
log.Println(h.addr, "[IRC]", data.Server, "already added")
}
}
@ -161,6 +146,8 @@ func (h *wsHandler) quit(b []byte) {
json.Unmarshal(b, &data)
if i, ok := h.session.getIRC(data.Server); ok {
log.Println(h.addr, "[IRC] Remove server", data.Server)
i.Quit()
h.session.deleteIRC(data.Server)
channelStore.RemoveUserAll(i.GetNick(), data.Server)

View File

@ -21,32 +21,28 @@ func (d directory) LetsEncrypt() string {
return filepath.Join(d.Root(), "letsencrypt")
}
func (d directory) Logs() string {
return filepath.Join(d.Root(), "logs")
}
func (d directory) Log(userID string) string {
return filepath.Join(d.Logs(), userID+".log")
}
func (d directory) Index(userID string) string {
return filepath.Join(d.Logs(), userID+".idx")
}
func (d directory) Users() string {
return filepath.Join(d.Root(), "users")
}
func (d directory) User(userID string) string {
return filepath.Join(d.Users(), userID)
func (d directory) User(username string) string {
return filepath.Join(d.Users(), username)
}
func (d directory) Certificate(userID string) string {
return filepath.Join(d.User(userID), "cert.pem")
func (d directory) Log(username string) string {
return filepath.Join(d.User(username), "log")
}
func (d directory) Key(userID string) string {
return filepath.Join(d.User(userID), "key.pem")
func (d directory) Index(username string) string {
return filepath.Join(d.User(username), "index")
}
func (d directory) Certificate(username string) string {
return filepath.Join(d.User(username), "cert.pem")
}
func (d directory) Key(username string) string {
return filepath.Join(d.User(username), "key.pem")
}
func (d directory) Config() string {
@ -56,3 +52,7 @@ func (d directory) Config() string {
func (d directory) Database() string {
return filepath.Join(d.Root(), "dispatch.db")
}
func (d directory) HMACKey() string {
return filepath.Join(d.Root(), "hmac.key")
}

View File

@ -1,8 +1,8 @@
package storage
import (
"encoding/binary"
"log"
"os"
"github.com/khlieng/dispatch/Godeps/_workspace/src/github.com/boltdb/bolt"
)
@ -20,11 +20,6 @@ var (
func Initialize(dir string) {
Path = directory(dir)
err := os.MkdirAll(Path.Logs(), 0700)
if err != nil {
log.Fatal(err)
}
}
func Open() {
@ -46,3 +41,13 @@ func Open() {
func Close() {
db.Close()
}
func idToBytes(i uint64) []byte {
b := make([]byte, 8)
binary.BigEndian.PutUint64(b, i)
return b
}
func idFromBytes(b []byte) uint64 {
return binary.BigEndian.Uint64(b)
}

View File

@ -6,14 +6,23 @@ import (
"encoding/json"
"log"
"strconv"
"strings"
"sync"
"time"
"github.com/khlieng/dispatch/Godeps/_workspace/src/github.com/blevesearch/bleve"
"github.com/khlieng/dispatch/Godeps/_workspace/src/github.com/boltdb/bolt"
)
type User struct {
ID uint64
Username string
id []byte
messageLog *bolt.DB
messageIndex bleve.Index
certificate *tls.Certificate
lock sync.Mutex
}
type Server struct {
Name string `json:"name"`
Host string `json:"host"`
@ -32,39 +41,37 @@ type Channel struct {
Topic string `json:"topic,omitempty"`
}
type Message struct {
ID uint64 `json:"id"`
Server string `json:"server"`
From string `json:"from"`
To string `json:"to"`
Content string `json:"content"`
Time int64 `json:"time"`
}
func NewUser() *User {
user := &User{}
type User struct {
UUID string
err := db.Update(func(tx *bolt.Tx) error {
b := tx.Bucket(bucketUsers)
messageLog *bolt.DB
messageIndex bleve.Index
certificate *tls.Certificate
lock sync.Mutex
}
func NewUser(uuid string) *User {
user := &User{
UUID: uuid,
var err error
user.ID, err = b.NextSequence()
if err != nil {
return err
}
db.Update(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("Users"))
data, _ := json.Marshal(user)
user.Username = strconv.FormatUint(user.ID, 10)
user.id = idToBytes(user.ID)
b.Put([]byte(uuid), data)
data, err := json.Marshal(user)
if err != nil {
return err
}
return nil
return b.Put(user.id, data)
})
user.openMessageLog()
if err != nil {
return nil
}
err = user.openMessageLog()
if err != nil {
log.Println(err)
}
return user
}
@ -73,20 +80,29 @@ func LoadUsers() []*User {
var users []*User
db.View(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("Users"))
b := tx.Bucket(bucketUsers)
b.ForEach(func(k, v []byte) error {
user := User{UUID: string(k)}
b.ForEach(func(k, _ []byte) error {
id := idFromBytes(k)
user := &User{
ID: id,
Username: strconv.FormatUint(id, 10),
id: make([]byte, 8),
}
copy(user.id, k)
users = append(users, user)
return nil
})
return nil
})
for _, user := range users {
user.openMessageLog()
user.loadCertificate()
users = append(users, &user)
return nil
})
return nil
})
}
return users
}
@ -95,10 +111,9 @@ func (u *User) GetServers() []Server {
var servers []Server
db.View(func(tx *bolt.Tx) error {
c := tx.Bucket([]byte("Servers")).Cursor()
prefix := []byte(u.UUID)
c := tx.Bucket(bucketServers).Cursor()
for k, v := c.Seek(prefix); bytes.HasPrefix(k, prefix); k, v = c.Next() {
for k, v := c.Seek(u.id); bytes.HasPrefix(k, u.id); k, v = c.Next() {
var server Server
json.Unmarshal(v, &server)
servers = append(servers, server)
@ -114,11 +129,9 @@ func (u *User) GetChannels() []Channel {
var channels []Channel
db.View(func(tx *bolt.Tx) error {
c := tx.Bucket([]byte("Channels")).Cursor()
c := tx.Bucket(bucketChannels).Cursor()
prefix := []byte(u.UUID)
for k, v := c.Seek(prefix); bytes.HasPrefix(k, prefix); k, v = c.Next() {
for k, v := c.Seek(u.id); bytes.HasPrefix(k, u.id); k, v = c.Next() {
var channel Channel
json.Unmarshal(v, &channel)
channels = append(channels, channel)
@ -132,10 +145,10 @@ func (u *User) GetChannels() []Channel {
func (u *User) AddServer(server Server) {
db.Update(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("Servers"))
b := tx.Bucket(bucketServers)
data, _ := json.Marshal(server)
b.Put([]byte(u.UUID+":"+server.Host), data)
b.Put(u.serverID(server.Host), data)
return nil
})
@ -143,10 +156,10 @@ func (u *User) AddServer(server Server) {
func (u *User) AddChannel(channel Channel) {
db.Update(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("Channels"))
b := tx.Bucket(bucketChannels)
data, _ := json.Marshal(channel)
b.Put([]byte(u.UUID+":"+channel.Server+":"+channel.Name), data)
b.Put(u.channelID(channel.Server, channel.Name), data)
return nil
})
@ -154,8 +167,8 @@ func (u *User) AddChannel(channel Channel) {
func (u *User) SetNick(nick, address string) {
db.Update(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("Servers"))
id := []byte(u.UUID + ":" + address)
b := tx.Bucket(bucketServers)
id := u.serverID(address)
var server Server
json.Unmarshal(b.Get(id), &server)
@ -170,11 +183,10 @@ func (u *User) SetNick(nick, address string) {
func (u *User) RemoveServer(address string) {
db.Update(func(tx *bolt.Tx) error {
serverID := []byte(u.UUID + ":" + address)
serverID := u.serverID(address)
tx.Bucket(bucketServers).Delete(serverID)
tx.Bucket([]byte("Servers")).Delete(serverID)
b := tx.Bucket([]byte("Channels"))
b := tx.Bucket(bucketChannels)
c := b.Cursor()
for k, _ := c.Seek(serverID); bytes.HasPrefix(k, serverID); k, _ = c.Next() {
@ -187,157 +199,31 @@ func (u *User) RemoveServer(address string) {
func (u *User) RemoveChannel(server, channel string) {
db.Update(func(tx *bolt.Tx) error {
tx.Bucket([]byte("Channels")).Delete([]byte(u.UUID + ":" + server + ":" + channel))
b := tx.Bucket(bucketChannels)
id := u.channelID(server, channel)
b.Delete(id)
return nil
})
}
func (u *User) LogMessage(server, from, to, content string) {
bucketKey := server + ":" + to
var id uint64
var idStr string
var message Message
u.messageLog.Update(func(tx *bolt.Tx) error {
b, _ := tx.Bucket(bucketMessages).CreateBucketIfNotExists([]byte(bucketKey))
id, _ = b.NextSequence()
idStr = strconv.FormatUint(id, 10)
message = Message{
ID: id,
Content: content,
Server: server,
From: from,
To: to,
Time: time.Now().Unix(),
}
data, _ := json.Marshal(message)
b.Put([]byte(idStr), data)
return nil
})
u.messageIndex.Index(bucketKey+":"+idStr, message)
}
func (u *User) GetLastMessages(server, channel string, count int) ([]Message, error) {
messages := make([]Message, count)
u.messageLog.View(func(tx *bolt.Tx) error {
b := tx.Bucket(bucketMessages).Bucket([]byte(server + ":" + channel))
if b == nil {
return nil
}
c := b.Cursor()
for k, v := c.Last(); count > 0 && k != nil; k, v = c.Prev() {
count--
json.Unmarshal(v, &messages[count])
}
return nil
})
if count < len(messages) {
return messages[count:], nil
} else {
return nil, nil
}
}
func (u *User) GetMessages(server, channel string, count int, fromID uint64) ([]Message, error) {
messages := make([]Message, count)
u.messageLog.View(func(tx *bolt.Tx) error {
b := tx.Bucket(bucketMessages).Bucket([]byte(server + ":" + channel))
if b == nil {
return nil
}
c := b.Cursor()
c.Seek([]byte(strconv.FormatUint(fromID, 10)))
for k, v := c.Prev(); count > 0 && k != nil; k, v = c.Prev() {
count--
json.Unmarshal(v, &messages[count])
}
return nil
})
if count < len(messages) {
return messages[count:], nil
} else {
return nil, nil
}
}
func (u *User) SearchMessages(server, channel, phrase string) ([]Message, error) {
serverQuery := bleve.NewMatchQuery(server)
serverQuery.SetField("server")
channelQuery := bleve.NewMatchQuery(channel)
channelQuery.SetField("to")
contentQuery := bleve.NewMatchQuery(phrase)
contentQuery.SetField("content")
query := bleve.NewBooleanQuery([]bleve.Query{serverQuery, channelQuery, contentQuery}, nil, nil)
search := bleve.NewSearchRequest(query)
searchResults, err := u.messageIndex.Search(search)
if err != nil {
return nil, err
}
messages := []Message{}
u.messageLog.View(func(tx *bolt.Tx) error {
b := tx.Bucket(bucketMessages)
for _, hit := range searchResults.Hits {
idx := strings.LastIndex(hit.ID, ":")
bc := b.Bucket([]byte(hit.ID[:idx]))
var message Message
json.Unmarshal(bc.Get([]byte(hit.ID[idx+1:])), &message)
messages = append(messages, message)
}
return nil
})
return messages, nil
}
func (u *User) Close() {
u.messageLog.Close()
u.messageIndex.Close()
}
func (u *User) openMessageLog() {
var err error
u.messageLog, err = bolt.Open(Path.Log(u.UUID), 0600, nil)
if err != nil {
log.Fatal(err)
}
u.messageLog.Update(func(tx *bolt.Tx) error {
tx.CreateBucketIfNotExists(bucketMessages)
return nil
})
indexPath := Path.Index(u.UUID)
u.messageIndex, err = bleve.Open(indexPath)
if err == bleve.ErrorIndexPathDoesNotExist {
mapping := bleve.NewIndexMapping()
u.messageIndex, err = bleve.New(indexPath, mapping)
if err != nil {
log.Fatal(err)
}
} else if err != nil {
log.Fatal(err)
}
func (u *User) serverID(address string) []byte {
id := make([]byte, 8+len(address))
copy(id, u.id)
copy(id[8:], address)
return id
}
func (u *User) channelID(server, channel string) []byte {
id := make([]byte, 8+len(server)+1+len(channel))
copy(id, u.id)
copy(id[8:], server)
copy(id[8+len(server)+1:], channel)
return id
}

View File

@ -4,7 +4,6 @@ import (
"crypto/tls"
"errors"
"io/ioutil"
"os"
)
var (
@ -29,17 +28,12 @@ func (u *User) SetCertificate(certPEM, keyPEM []byte) error {
u.certificate = &cert
u.lock.Unlock()
err = os.MkdirAll(Path.User(u.UUID), 0700)
err = ioutil.WriteFile(Path.Certificate(u.Username), certPEM, 0600)
if err != nil {
return ErrCouldNotSaveCert
}
err = ioutil.WriteFile(Path.Certificate(u.UUID), certPEM, 0600)
if err != nil {
return ErrCouldNotSaveCert
}
err = ioutil.WriteFile(Path.Key(u.UUID), keyPEM, 0600)
err = ioutil.WriteFile(Path.Key(u.Username), keyPEM, 0600)
if err != nil {
return ErrCouldNotSaveCert
}
@ -48,12 +42,12 @@ func (u *User) SetCertificate(certPEM, keyPEM []byte) error {
}
func (u *User) loadCertificate() error {
certPEM, err := ioutil.ReadFile(Path.Certificate(u.UUID))
certPEM, err := ioutil.ReadFile(Path.Certificate(u.Username))
if err != nil {
return err
}
keyPEM, err := ioutil.ReadFile(Path.Key(u.UUID))
keyPEM, err := ioutil.ReadFile(Path.Key(u.Username))
if err != nil {
return err
}

170
storage/user_messages.go Normal file
View File

@ -0,0 +1,170 @@
package storage
import (
"encoding/json"
"os"
"strconv"
"strings"
"time"
"github.com/khlieng/dispatch/Godeps/_workspace/src/github.com/blevesearch/bleve"
"github.com/khlieng/dispatch/Godeps/_workspace/src/github.com/boltdb/bolt"
)
type Message struct {
ID uint64 `json:"id"`
Server string `json:"server"`
From string `json:"from"`
To string `json:"to"`
Content string `json:"content"`
Time int64 `json:"time"`
}
func (u *User) LogMessage(server, from, to, content string) {
bucketKey := server + ":" + to
var id uint64
var idStr string
var message Message
u.messageLog.Update(func(tx *bolt.Tx) error {
b, _ := tx.Bucket(bucketMessages).CreateBucketIfNotExists([]byte(bucketKey))
id, _ = b.NextSequence()
idStr = strconv.FormatUint(id, 10)
message = Message{
ID: id,
Content: content,
Server: server,
From: from,
To: to,
Time: time.Now().Unix(),
}
data, _ := json.Marshal(message)
b.Put([]byte(idStr), data)
return nil
})
u.messageIndex.Index(bucketKey+":"+idStr, message)
}
func (u *User) GetLastMessages(server, channel string, count int) ([]Message, error) {
messages := make([]Message, count)
u.messageLog.View(func(tx *bolt.Tx) error {
b := tx.Bucket(bucketMessages).Bucket([]byte(server + ":" + channel))
if b == nil {
return nil
}
c := b.Cursor()
for k, v := c.Last(); count > 0 && k != nil; k, v = c.Prev() {
count--
json.Unmarshal(v, &messages[count])
}
return nil
})
if count < len(messages) {
return messages[count:], nil
} else {
return nil, nil
}
}
func (u *User) GetMessages(server, channel string, count int, fromID uint64) ([]Message, error) {
messages := make([]Message, count)
u.messageLog.View(func(tx *bolt.Tx) error {
b := tx.Bucket(bucketMessages).Bucket([]byte(server + ":" + channel))
if b == nil {
return nil
}
c := b.Cursor()
c.Seek([]byte(strconv.FormatUint(fromID, 10)))
for k, v := c.Prev(); count > 0 && k != nil; k, v = c.Prev() {
count--
json.Unmarshal(v, &messages[count])
}
return nil
})
if count < len(messages) {
return messages[count:], nil
}
return nil, nil
}
func (u *User) SearchMessages(server, channel, phrase string) ([]Message, error) {
serverQuery := bleve.NewMatchQuery(server)
serverQuery.SetField("server")
channelQuery := bleve.NewMatchQuery(channel)
channelQuery.SetField("to")
contentQuery := bleve.NewMatchQuery(phrase)
contentQuery.SetField("content")
query := bleve.NewBooleanQuery([]bleve.Query{serverQuery, channelQuery, contentQuery}, nil, nil)
search := bleve.NewSearchRequest(query)
searchResults, err := u.messageIndex.Search(search)
if err != nil {
return nil, err
}
messages := []Message{}
u.messageLog.View(func(tx *bolt.Tx) error {
b := tx.Bucket(bucketMessages)
for _, hit := range searchResults.Hits {
idx := strings.LastIndex(hit.ID, ":")
bc := b.Bucket([]byte(hit.ID[:idx]))
var message Message
json.Unmarshal(bc.Get([]byte(hit.ID[idx+1:])), &message)
messages = append(messages, message)
}
return nil
})
return messages, nil
}
func (u *User) openMessageLog() error {
err := os.MkdirAll(Path.User(u.Username), 0700)
if err != nil {
return err
}
u.messageLog, err = bolt.Open(Path.Log(u.Username), 0600, nil)
if err != nil {
return err
}
u.messageLog.Update(func(tx *bolt.Tx) error {
tx.CreateBucketIfNotExists(bucketMessages)
return nil
})
indexPath := Path.Index(u.Username)
u.messageIndex, err = bleve.Open(indexPath)
if err == bleve.ErrorIndexPathDoesNotExist {
mapping := bleve.NewIndexMapping()
u.messageIndex, err = bleve.New(indexPath, mapping)
if err != nil {
return err
}
} else if err != nil {
return err
}
return nil
}

View File

@ -13,6 +13,11 @@ func tempdir() string {
}
func TestUser(t *testing.T) {
defer func() {
r := recover()
assert.Nil(t, r)
}()
Initialize(tempdir())
Open()
@ -30,7 +35,7 @@ func TestUser(t *testing.T) {
Name: "#testing",
}
user := NewUser("unique")
user := NewUser()
user.AddServer(srv)
user.AddChannel(chan1)
user.AddChannel(chan2)
@ -40,7 +45,7 @@ func TestUser(t *testing.T) {
assert.Len(t, users, 1)
user = users[0]
assert.Equal(t, "unique", user.UUID)
assert.Equal(t, uint64(1), user.ID)
servers := user.GetServers()
assert.Len(t, servers, 1)