Persist, renew and delete sessions, refactor storage package, move reusable packages to pkg

This commit is contained in:
Ken-Håvard Lieng 2018-05-31 23:24:59 +02:00
parent 121582f72a
commit 24f9553aa5
48 changed files with 1872 additions and 1171 deletions

View file

@ -3,58 +3,95 @@ package server
import (
"log"
"net/http"
"time"
"github.com/khlieng/dispatch/pkg/session"
"github.com/khlieng/dispatch/storage"
)
const (
cookieName = "dispatch"
)
func (d *Dispatch) handleAuth(w http.ResponseWriter, r *http.Request, createUser bool) *State {
var state *State
func handleAuth(w http.ResponseWriter, r *http.Request, createUser bool) *Session {
var session *Session
cookie, err := r.Cookie(cookieName)
cookie, err := r.Cookie(session.CookieName)
if err != nil {
if createUser {
session = newUser(w, r)
state, err = d.newUser(w, r)
if err != nil {
log.Println(err)
}
}
} else {
session = sessions.get(cookie.Value)
session := d.states.getSession(cookie.Value)
if session != nil {
log.Println(r.RemoteAddr, "[Auth] GET", r.URL.Path, "| Valid token | User ID:", session.user.ID)
key := session.Key()
newKey, expired, err := session.Refresh()
if err != nil {
return nil
}
if !expired {
state = d.states.get(session.UserID)
if newKey != "" {
d.states.setSession(session)
d.states.deleteSession(key)
session.SetCookie(w, r)
}
}
}
if state != nil {
log.Println(r.RemoteAddr, "[Auth] GET", r.URL.Path, "| Valid token | User ID:", state.user.ID)
} else if createUser {
session = newUser(w, r)
state, err = d.newUser(w, r)
if err != nil {
log.Println(err)
}
}
}
return session
return state
}
func newUser(w http.ResponseWriter, r *http.Request) *Session {
user, err := storage.NewUser()
func (d *Dispatch) newUser(w http.ResponseWriter, r *http.Request) (*State, error) {
user, err := storage.NewUser(d.Store)
if err != nil {
return nil
return nil, err
}
log.Println(r.RemoteAddr, "[Auth] Create session | User ID:", user.ID)
session, err := NewSession(user)
messageStore, err := d.GetMessageStore(user)
if err != nil {
return nil
return nil, err
}
sessions.set(session)
go session.run()
user.SetMessageStore(messageStore)
http.SetCookie(w, &http.Cookie{
Name: cookieName,
Value: session.id,
Path: "/",
Expires: time.Now().AddDate(0, 1, 0),
HttpOnly: true,
Secure: r.TLS != nil,
})
search, err := d.GetMessageSearchProvider(user)
if err != nil {
return nil, err
}
user.SetMessageSearchProvider(search)
return session
log.Println(r.RemoteAddr, "[Auth] New anonymous user | ID:", user.ID)
session, err := session.New(user.ID)
if err != nil {
return nil, err
}
d.states.setSession(session)
go d.deleteSessionWhenExpired(session)
state := NewState(user, d)
d.states.set(state)
go state.run()
session.SetCookie(w, r)
return state, nil
}
func (d *Dispatch) deleteSessionWhenExpired(session *session.Session) {
deleteSessionWhenExpired(session, d.states)
}
func deleteSessionWhenExpired(session *session.Session, stateStore *stateStore) {
session.WaitUntilExpiration()
stateStore.deleteSession(session.Key())
}

View file

@ -11,55 +11,29 @@ import (
)
type connectDefaults struct {
Name string `json:"name,omitempty"`
Host string `json:"host,omitempty"`
Port int `json:"port,omitempty"`
Channels []string `json:"channels,omitempty"`
Password bool `json:"password,omitempty"`
SSL bool `json:"ssl,omitempty"`
ReadOnly bool `json:"readonly,omitempty"`
ShowDetails bool `json:"showDetails,omitempty"`
Name string
Host string
Port int
Channels []string
Password bool
SSL bool
ReadOnly bool
ShowDetails bool
}
type indexData struct {
Defaults connectDefaults `json:"defaults"`
Servers []Server `json:"servers,omitempty"`
Channels []storage.Channel `json:"channels,omitempty"`
Defaults connectDefaults
Servers []Server
Channels []storage.Channel
// Users in the selected channel
Users *Userlist `json:"users,omitempty"`
Users *Userlist
// Last messages in the selected channel
Messages *Messages `json:"messages,omitempty"`
Messages *Messages
}
func (d *indexData) addUsersAndMessages(server, channel string, session *Session) {
users := channelStore.GetUsers(server, channel)
if len(users) > 0 {
d.Users = &Userlist{
Server: server,
Channel: channel,
Users: users,
}
}
messages, hasMore, err := session.user.GetLastMessages(server, channel, 50)
if err == nil && len(messages) > 0 {
m := Messages{
Server: server,
To: channel,
Messages: messages,
}
if hasMore {
m.Next = messages[0].ID
}
d.Messages = &m
}
}
func getIndexData(r *http.Request, session *Session) *indexData {
func getIndexData(r *http.Request, state *State) *indexData {
data := indexData{}
data.Defaults = connectDefaults{
@ -73,12 +47,15 @@ func getIndexData(r *http.Request, session *Session) *indexData {
ShowDetails: viper.GetBool("defaults.show_details"),
}
if session == nil {
if state == nil {
return &data
}
servers := session.user.GetServers()
connections := session.getConnectionStates()
servers, err := state.user.GetServers()
if err != nil {
return nil
}
connections := state.getConnectionStates()
for _, server := range servers {
server.Password = ""
server.Username = ""
@ -90,7 +67,10 @@ func getIndexData(r *http.Request, session *Session) *indexData {
})
}
channels := session.user.GetChannels()
channels, err := state.user.GetChannels()
if err != nil {
return nil
}
for i, channel := range channels {
channels[i].Topic = channelStore.GetTopic(channel.Server, channel.Name)
}
@ -98,18 +78,44 @@ func getIndexData(r *http.Request, session *Session) *indexData {
server, channel := getTabFromPath(r.URL.EscapedPath())
if isInChannel(channels, server, channel) {
data.addUsersAndMessages(server, channel, session)
data.addUsersAndMessages(server, channel, state)
return &data
}
server, channel = parseTabCookie(r, r.URL.Path)
if isInChannel(channels, server, channel) {
data.addUsersAndMessages(server, channel, session)
data.addUsersAndMessages(server, channel, state)
}
return &data
}
func (d *indexData) addUsersAndMessages(server, channel string, state *State) {
users := channelStore.GetUsers(server, channel)
if len(users) > 0 {
d.Users = &Userlist{
Server: server,
Channel: channel,
Users: users,
}
}
messages, hasMore, err := state.user.GetLastMessages(server, channel, 50)
if err == nil && len(messages) > 0 {
m := Messages{
Server: server,
To: channel,
Messages: messages,
}
if hasMore {
m.Next = messages[0].ID
}
d.Messages = &m
}
}
func isInChannel(channels []storage.Channel, server, channel string) bool {
if channel != "" {
for _, ch := range channels {

View file

@ -344,7 +344,7 @@ func easyjson7e607aefDecodeGithubComKhliengDispatchServer1(in *jlexer.Lexer, out
out.Password = bool(in.Bool())
case "ssl":
out.SSL = bool(in.Bool())
case "readonly":
case "readOnly":
out.ReadOnly = bool(in.Bool())
case "showDetails":
out.ShowDetails = bool(in.Bool())
@ -432,7 +432,7 @@ func easyjson7e607aefEncodeGithubComKhliengDispatchServer1(out *jwriter.Writer,
out.Bool(bool(in.SSL))
}
if in.ReadOnly {
const prefix string = ",\"readonly\":"
const prefix string = ",\"readOnly\":"
if first {
first = false
out.RawString(prefix[1:])

View file

@ -2,61 +2,35 @@ package server
import (
"crypto/tls"
"log"
"net"
"github.com/khlieng/dispatch/irc"
"github.com/khlieng/dispatch/storage"
"github.com/spf13/viper"
"github.com/khlieng/dispatch/pkg/irc"
"github.com/khlieng/dispatch/storage"
)
func createNickInUseHandler(i *irc.Client, session *Session) func(string) string {
func createNickInUseHandler(i *irc.Client, state *State) func(string) string {
return func(nick string) string {
newNick := nick + "_"
if newNick == i.GetNick() {
session.sendJSON("nick_fail", NickFail{
state.sendJSON("nick_fail", NickFail{
Server: i.Host,
})
}
session.printError("Nickname", nick, "is already in use, using", newNick, "instead")
state.printError("Nickname", nick, "is already in use, using", newNick, "instead")
return newNick
}
}
func reconnectIRC() {
for _, user := range storage.LoadUsers() {
session, err := NewSession(user)
if err != nil {
log.Println(err)
continue
}
sessions.set(session)
go session.run()
channels := user.GetChannels()
for _, server := range user.GetServers() {
i := connectIRC(server, session)
var joining []string
for _, channel := range channels {
if channel.Server == server.Host {
joining = append(joining, channel.Name)
}
}
i.Join(joining...)
}
}
}
func connectIRC(server storage.Server, session *Session) *irc.Client {
func connectIRC(server *storage.Server, state *State) *irc.Client {
i := irc.NewClient(server.Nick, server.Username)
i.TLS = server.TLS
i.Realname = server.Realname
i.HandleNickInUse = createNickInUseHandler(i, session)
i.HandleNickInUse = createNickInUseHandler(i, state)
address := server.Host
if server.Port != "" {
@ -83,14 +57,14 @@ func connectIRC(server storage.Server, session *Session) *irc.Client {
InsecureSkipVerify: !viper.GetBool("verify_certificates"),
}
if cert := session.user.GetCertificate(); cert != nil {
if cert := state.user.GetCertificate(); cert != nil {
i.TLSConfig.Certificates = []tls.Certificate{*cert}
}
}
session.setIRC(server.Host, i)
state.setIRC(server.Host, i)
i.Connect(address)
go newIRCHandler(i, session).run()
go newIRCHandler(i, state).run()
return i
}

View file

@ -8,7 +8,7 @@ import (
"github.com/kjk/betterguid"
"github.com/khlieng/dispatch/irc"
"github.com/khlieng/dispatch/pkg/irc"
"github.com/khlieng/dispatch/storage"
)
@ -17,8 +17,8 @@ var excludedErrors = []string{
}
type ircHandler struct {
client *irc.Client
session *Session
client *irc.Client
state *State
whois WhoisReply
userBuffers map[string][]string
@ -27,10 +27,10 @@ type ircHandler struct {
handlers map[string]func(*irc.Message)
}
func newIRCHandler(client *irc.Client, session *Session) *ircHandler {
func newIRCHandler(client *irc.Client, state *State) *ircHandler {
i := &ircHandler{
client: client,
session: session,
state: state,
userBuffers: make(map[string][]string),
}
i.initHandlers()
@ -43,15 +43,15 @@ func (i *ircHandler) run() {
select {
case msg, ok := <-i.client.Messages:
if !ok {
i.session.deleteIRC(i.client.Host)
i.state.deleteIRC(i.client.Host)
return
}
i.dispatchMessage(msg)
case state := <-i.client.ConnectionChanged:
i.session.sendJSON("connection_update", newConnectionUpdate(i.client.Host, state))
i.session.setConnectionState(i.client.Host, state)
i.state.sendJSON("connection_update", newConnectionUpdate(i.client.Host, state))
i.state.setConnectionState(i.client.Host, state)
if state.Error != nil && (lastConnErr == nil ||
state.Error.Error() != lastConnErr.Error()) {
@ -66,7 +66,7 @@ func (i *ircHandler) run() {
func (i *ircHandler) dispatchMessage(msg *irc.Message) {
if msg.Command[0] == '4' && !isExcludedError(msg.Command) {
i.session.printError(formatIRCError(msg))
i.state.printError(formatIRCError(msg))
}
if handler, ok := i.handlers[msg.Command]; ok {
@ -75,7 +75,7 @@ func (i *ircHandler) dispatchMessage(msg *irc.Message) {
}
func (i *ircHandler) nick(msg *irc.Message) {
i.session.sendJSON("nick", Nick{
i.state.sendJSON("nick", Nick{
Server: i.client.Host,
Old: msg.Nick,
New: msg.LastParam(),
@ -84,12 +84,12 @@ func (i *ircHandler) nick(msg *irc.Message) {
channelStore.RenameUser(msg.Nick, msg.LastParam(), i.client.Host)
if msg.LastParam() == i.client.GetNick() {
go i.session.user.SetNick(msg.LastParam(), i.client.Host)
go i.state.user.SetNick(msg.LastParam(), i.client.Host)
}
}
func (i *ircHandler) join(msg *irc.Message) {
i.session.sendJSON("join", Join{
i.state.sendJSON("join", Join{
Server: i.client.Host,
User: msg.Nick,
Channels: msg.Params,
@ -102,9 +102,9 @@ func (i *ircHandler) join(msg *irc.Message) {
// Incase no topic is set and theres a cached one that needs to be cleared
i.client.Topic(channel)
i.session.sendLastMessages(i.client.Host, channel, 50)
i.state.sendLastMessages(i.client.Host, channel, 50)
go i.session.user.AddChannel(storage.Channel{
go i.state.user.AddChannel(&storage.Channel{
Server: i.client.Host,
Name: channel,
})
@ -122,12 +122,12 @@ func (i *ircHandler) part(msg *irc.Message) {
part.Reason = msg.Params[1]
}
i.session.sendJSON("part", part)
i.state.sendJSON("part", part)
channelStore.RemoveUser(msg.Nick, i.client.Host, part.Channel)
if msg.Nick == i.client.GetNick() {
go i.session.user.RemoveChannel(i.client.Host, part.Channel)
go i.state.user.RemoveChannel(i.client.Host, part.Channel)
}
}
@ -139,7 +139,7 @@ func (i *ircHandler) mode(msg *irc.Message) {
mode.Channel = target
mode.User = msg.Params[2]
i.session.sendJSON("mode", mode)
i.state.sendJSON("mode", mode)
channelStore.SetMode(i.client.Host, target, msg.Params[2], mode.Add, mode.Remove)
}
@ -154,20 +154,20 @@ func (i *ircHandler) message(msg *irc.Message) {
}
if msg.Params[0] == i.client.GetNick() {
i.session.sendJSON("pm", message)
i.state.sendJSON("pm", message)
} else {
message.To = msg.Params[0]
i.session.sendJSON("message", message)
i.state.sendJSON("message", message)
}
if msg.Params[0] != "*" {
go i.session.user.LogMessage(message.ID,
go i.state.user.LogMessage(message.ID,
i.client.Host, msg.Nick, msg.Params[0], msg.LastParam())
}
}
func (i *ircHandler) quit(msg *irc.Message) {
i.session.sendJSON("quit", Quit{
i.state.sendJSON("quit", Quit{
Server: i.client.Host,
User: msg.Nick,
Reason: msg.LastParam(),
@ -178,15 +178,15 @@ func (i *ircHandler) quit(msg *irc.Message) {
func (i *ircHandler) info(msg *irc.Message) {
if msg.Command == irc.ReplyWelcome {
i.session.sendJSON("nick", Nick{
i.state.sendJSON("nick", Nick{
Server: i.client.Host,
New: msg.Params[0],
})
go i.session.user.SetNick(msg.Params[0], i.client.Host)
go i.state.user.SetNick(msg.Params[0], i.client.Host)
}
i.session.sendJSON("pm", Message{
i.state.sendJSON("pm", Message{
Server: i.client.Host,
From: msg.Nick,
Content: strings.Join(msg.Params[1:], " "),
@ -210,7 +210,7 @@ func (i *ircHandler) whoisChannels(msg *irc.Message) {
func (i *ircHandler) whoisEnd(msg *irc.Message) {
if i.whois.Nick != "" {
i.session.sendJSON("whois", i.whois)
i.state.sendJSON("whois", i.whois)
}
i.whois = WhoisReply{}
}
@ -226,7 +226,7 @@ func (i *ircHandler) topic(msg *irc.Message) {
channel = msg.Params[1]
}
i.session.sendJSON("topic", Topic{
i.state.sendJSON("topic", Topic{
Server: i.client.Host,
Channel: channel,
Topic: msg.LastParam(),
@ -239,7 +239,7 @@ func (i *ircHandler) topic(msg *irc.Message) {
func (i *ircHandler) noTopic(msg *irc.Message) {
channel := msg.Params[1]
i.session.sendJSON("topic", Topic{
i.state.sendJSON("topic", Topic{
Server: i.client.Host,
Channel: channel,
})
@ -257,7 +257,7 @@ func (i *ircHandler) namesEnd(msg *irc.Message) {
channel := msg.Params[1]
users := i.userBuffers[channel]
i.session.sendJSON("users", Userlist{
i.state.sendJSON("users", Userlist{
Server: i.client.Host,
Channel: channel,
Users: users,
@ -277,18 +277,18 @@ func (i *ircHandler) motd(msg *irc.Message) {
}
func (i *ircHandler) motdEnd(msg *irc.Message) {
i.session.sendJSON("motd", i.motdBuffer)
i.state.sendJSON("motd", i.motdBuffer)
i.motdBuffer = MOTD{}
}
func (i *ircHandler) badNick(msg *irc.Message) {
i.session.sendJSON("nick_fail", NickFail{
i.state.sendJSON("nick_fail", NickFail{
Server: i.client.Host,
})
}
func (i *ircHandler) error(msg *irc.Message) {
i.session.printError(msg.LastParam())
i.state.printError(msg.LastParam())
}
func (i *ircHandler) initHandlers() {
@ -327,7 +327,7 @@ func (i *ircHandler) initHandlers() {
func (i *ircHandler) log(v ...interface{}) {
s := fmt.Sprintln(v...)
log.Println("[IRC]", i.session.user.ID, i.client.Host, s[:len(s)-1])
log.Println("[IRC]", i.state.user.ID, i.client.Host, s[:len(s)-1])
}
func parseMode(mode string) *Mode {

View file

@ -8,8 +8,9 @@ import (
"github.com/stretchr/testify/assert"
"github.com/khlieng/dispatch/irc"
"github.com/khlieng/dispatch/pkg/irc"
"github.com/khlieng/dispatch/storage"
"github.com/khlieng/dispatch/storage/boltdb"
)
var user *storage.User
@ -21,11 +22,18 @@ func TestMain(m *testing.M) {
}
storage.Initialize(tempdir)
storage.Open()
user, err = storage.NewUser()
db, err := boltdb.New(storage.Path.Database())
if err != nil {
os.Exit(1)
log.Fatal(err)
}
user, err = storage.NewUser(db)
if err != nil {
log.Fatal(err)
}
user.SetMessageStore(db)
channelStore = storage.NewChannelStore()
code := m.Run()
@ -41,7 +49,7 @@ func dispatchMessage(msg *irc.Message) WSResponse {
func dispatchMessageMulti(msg *irc.Message) chan WSResponse {
c := irc.NewClient("nick", "user")
c.Host = "host.com"
s, _ := NewSession(user)
s := NewState(user, nil)
newIRCHandler(c, s).dispatchMessage(msg)
@ -187,7 +195,7 @@ func TestHandleIRCWelcome(t *testing.T) {
func TestHandleIRCWhois(t *testing.T) {
c := irc.NewClient("nick", "user")
c.Host = "host.com"
s, _ := NewSession(nil)
s := NewState(nil, nil)
i := newIRCHandler(c, s)
i.dispatchMessage(&irc.Message{
@ -255,7 +263,7 @@ func TestHandleIRCNoTopic(t *testing.T) {
func TestHandleIRCNames(t *testing.T) {
c := irc.NewClient("nick", "user")
c.Host = "host.com"
s, _ := NewSession(nil)
s := NewState(nil, nil)
i := newIRCHandler(c, s)
i.dispatchMessage(&irc.Message{
@ -281,7 +289,7 @@ func TestHandleIRCNames(t *testing.T) {
func TestHandleIRCMotd(t *testing.T) {
c := irc.NewClient("nick", "user")
c.Host = "host.com"
s, _ := NewSession(nil)
s := NewState(nil, nil)
i := newIRCHandler(c, s)
i.dispatchMessage(&irc.Message{
@ -308,7 +316,7 @@ func TestHandleIRCMotd(t *testing.T) {
func TestHandleIRCBadNick(t *testing.T) {
c := irc.NewClient("nick", "user")
c.Host = "host.com"
s, _ := NewSession(nil)
s := NewState(nil, nil)
i := newIRCHandler(c, s)
i.dispatchMessage(&irc.Message{

View file

@ -5,7 +5,7 @@ import (
"github.com/mailru/easyjson"
"github.com/khlieng/dispatch/irc"
"github.com/khlieng/dispatch/pkg/irc"
"github.com/khlieng/dispatch/storage"
)

View file

@ -62,7 +62,7 @@ var (
cspEnabled bool
)
func initFileServer() {
func (d *Dispatch) initFileServer() {
if !viper.GetBool("dev") {
data, err := assets.Asset(files[0].Asset)
if err != nil {
@ -154,24 +154,24 @@ func initFileServer() {
}
}
func serveFiles(w http.ResponseWriter, r *http.Request) {
func (d *Dispatch) serveFiles(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/" {
serveIndex(w, r)
d.serveIndex(w, r)
return
}
for _, file := range files {
if strings.HasSuffix(r.URL.Path, file.Path) {
serveFile(w, r, file)
d.serveFile(w, r, file)
return
}
}
serveIndex(w, r)
d.serveIndex(w, r)
}
func serveIndex(w http.ResponseWriter, r *http.Request) {
session := handleAuth(w, r, false)
func (d *Dispatch) serveIndex(w http.ResponseWriter, r *http.Request) {
state := d.handleAuth(w, r, false)
if cspEnabled {
var connectSrc string
@ -228,10 +228,10 @@ func serveIndex(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Encoding", "gzip")
gzw := gzip.NewWriter(w)
IndexTemplate(gzw, getIndexData(r, session), files[1].Path, files[0].Path)
IndexTemplate(gzw, getIndexData(r, state), files[1].Path, files[0].Path)
gzw.Close()
} else {
IndexTemplate(w, getIndexData(r, session), files[1].Path, files[0].Path)
IndexTemplate(w, getIndexData(r, state), files[1].Path, files[0].Path)
}
}
@ -246,7 +246,7 @@ func setPushCookie(w http.ResponseWriter, r *http.Request) {
})
}
func serveFile(w http.ResponseWriter, r *http.Request, file *File) {
func (d *Dispatch) serveFile(w http.ResponseWriter, r *http.Request, file *File) {
info, err := assets.AssetInfo(file.Asset)
if err != nil {
http.Error(w, "", http.StatusInternalServerError)

View file

@ -12,36 +12,99 @@ import (
"github.com/gorilla/websocket"
"github.com/spf13/viper"
"github.com/khlieng/dispatch/letsencrypt"
"github.com/khlieng/dispatch/pkg/letsencrypt"
"github.com/khlieng/dispatch/pkg/session"
"github.com/khlieng/dispatch/storage"
)
var (
sessions *sessionStore
channelStore *storage.ChannelStore
var channelStore = storage.NewChannelStore()
upgrader = websocket.Upgrader{
type Dispatch struct {
Store storage.Store
SessionStore storage.SessionStore
GetMessageStore func(*storage.User) (storage.MessageStore, error)
GetMessageSearchProvider func(*storage.User) (storage.MessageSearchProvider, error)
upgrader websocket.Upgrader
states *stateStore
}
func (d *Dispatch) Run() {
d.upgrader = websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
}
)
func Run() {
sessions = newSessionStore()
channelStore = storage.NewChannelStore()
if viper.GetBool("dev") {
upgrader.CheckOrigin = func(r *http.Request) bool {
d.upgrader.CheckOrigin = func(r *http.Request) bool {
return true
}
}
reconnectIRC()
initFileServer()
startHTTP()
session.CookieName = "dispatch"
d.states = newStateStore(d.SessionStore)
d.loadUsers()
d.initFileServer()
d.startHTTP()
}
func startHTTP() {
func (d *Dispatch) loadUsers() {
users, err := storage.LoadUsers(d.Store)
if err != nil {
log.Fatal(err)
}
log.Printf("Loading %d user(s)", len(users))
for i := range users {
go d.loadUser(&users[i])
}
}
func (d *Dispatch) loadUser(user *storage.User) {
messageStore, err := d.GetMessageStore(user)
if err != nil {
log.Fatal(err)
}
user.SetMessageStore(messageStore)
search, err := d.GetMessageSearchProvider(user)
if err != nil {
log.Fatal(err)
}
user.SetMessageSearchProvider(search)
state := NewState(user, d)
d.states.set(state)
go state.run()
channels, err := user.GetChannels()
if err != nil {
log.Fatal(err)
}
servers, err := user.GetServers()
if err != nil {
log.Fatal(err)
}
for _, server := range servers {
i := connectIRC(&server, state)
var joining []string
for _, channel := range channels {
if channel.Server == server.Host {
joining = append(joining, channel.Name)
}
}
i.Join(joining...)
}
}
func (d *Dispatch) startHTTP() {
port := viper.GetString("port")
if viper.GetBool("https.enabled") {
@ -55,7 +118,7 @@ func startHTTP() {
server := &http.Server{
Addr: ":" + portHTTPS,
Handler: http.HandlerFunc(serve),
Handler: http.HandlerFunc(d.serve),
}
if certExists() {
@ -71,13 +134,13 @@ func startHTTP() {
go http.ListenAndServe(":80", http.HandlerFunc(letsEncryptProxy))
}
letsEncrypt, err := letsencrypt.Run(dir, domain, email, ":"+lePort)
le, err := letsencrypt.Run(dir, domain, email, ":"+lePort)
if err != nil {
log.Fatal(err)
}
server.TLSConfig = &tls.Config{
GetCertificate: letsEncrypt.GetCertificate,
GetCertificate: le.GetCertificate,
}
log.Println("[HTTPS] Listening on port", portHTTPS)
@ -92,11 +155,11 @@ func startHTTP() {
port = "1337"
}
log.Println("[HTTP] Listening on port", port)
log.Fatal(http.ListenAndServe(":"+port, http.HandlerFunc(serve)))
log.Fatal(http.ListenAndServe(":"+port, http.HandlerFunc(d.serve)))
}
}
func serve(w http.ResponseWriter, r *http.Request) {
func (d *Dispatch) serve(w http.ResponseWriter, r *http.Request) {
if r.Method != "GET" {
fail(w, http.StatusNotFound)
return
@ -108,28 +171,27 @@ func serve(w http.ResponseWriter, r *http.Request) {
return
}
session := handleAuth(w, r, true)
if session == nil {
log.Println("[Auth] No session")
state := d.handleAuth(w, r, true)
if state == nil {
log.Println("[Auth] No state")
fail(w, http.StatusInternalServerError)
return
}
upgradeWS(w, r, session)
d.upgradeWS(w, r, state)
} else {
serveFiles(w, r)
d.serveFiles(w, r)
}
}
func upgradeWS(w http.ResponseWriter, r *http.Request, session *Session) {
conn, err := upgrader.Upgrade(w, r, w.Header())
func (d *Dispatch) upgradeWS(w http.ResponseWriter, r *http.Request, state *State) {
conn, err := d.upgrader.Upgrade(w, r, w.Header())
if err != nil {
log.Println(err)
return
}
newWSHandler(conn, session, r).run()
newWSHandler(conn, state, r).run()
}
func createHTTPSRedirect(portHTTPS string) http.HandlerFunc {

View file

@ -1,253 +0,0 @@
package server
import (
"crypto/rand"
"encoding/base64"
"sync"
"time"
"fmt"
"github.com/khlieng/dispatch/irc"
"github.com/khlieng/dispatch/storage"
)
const (
AnonymousSessionExpiration = 1 * time.Minute
)
type Session struct {
irc map[string]*irc.Client
connectionState map[string]irc.ConnectionState
ircLock sync.Mutex
ws map[string]*wsConn
wsLock sync.Mutex
broadcast chan WSResponse
id string
user *storage.User
expiration *time.Timer
reset chan time.Duration
}
func NewSession(user *storage.User) (*Session, error) {
id, err := newSessionID()
if err != nil {
return nil, err
}
return &Session{
irc: make(map[string]*irc.Client),
connectionState: make(map[string]irc.ConnectionState),
ws: make(map[string]*wsConn),
broadcast: make(chan WSResponse, 32),
id: id,
user: user,
expiration: time.NewTimer(AnonymousSessionExpiration),
reset: make(chan time.Duration, 1),
}, nil
}
func newSessionID() (string, error) {
key := make([]byte, 32)
_, err := rand.Read(key)
return base64.RawURLEncoding.EncodeToString(key), err
}
func (s *Session) getIRC(server string) (*irc.Client, bool) {
s.ircLock.Lock()
i, ok := s.irc[server]
s.ircLock.Unlock()
return i, ok
}
func (s *Session) setIRC(server string, i *irc.Client) {
s.ircLock.Lock()
s.irc[server] = i
s.connectionState[server] = irc.ConnectionState{
Connected: false,
}
s.ircLock.Unlock()
s.reset <- 0
}
func (s *Session) deleteIRC(server string) {
s.ircLock.Lock()
delete(s.irc, server)
delete(s.connectionState, server)
s.ircLock.Unlock()
s.resetExpirationIfEmpty()
}
func (s *Session) numIRC() int {
s.ircLock.Lock()
n := len(s.irc)
s.ircLock.Unlock()
return n
}
func (s *Session) getConnectionStates() map[string]irc.ConnectionState {
s.ircLock.Lock()
state := make(map[string]irc.ConnectionState, len(s.connectionState))
for k, v := range s.connectionState {
state[k] = v
}
s.ircLock.Unlock()
return state
}
func (s *Session) setConnectionState(server string, state irc.ConnectionState) {
s.ircLock.Lock()
s.connectionState[server] = state
s.ircLock.Unlock()
}
func (s *Session) setWS(addr string, w *wsConn) {
s.wsLock.Lock()
s.ws[addr] = w
s.wsLock.Unlock()
s.reset <- 0
}
func (s *Session) deleteWS(addr string) {
s.wsLock.Lock()
delete(s.ws, addr)
s.wsLock.Unlock()
s.resetExpirationIfEmpty()
}
func (s *Session) numWS() int {
s.ircLock.Lock()
n := len(s.ws)
s.ircLock.Unlock()
return n
}
func (s *Session) sendJSON(t string, v interface{}) {
s.broadcast <- WSResponse{t, v}
}
func (s *Session) sendError(err error, server string) {
s.sendJSON("error", Error{
Server: server,
Message: err.Error(),
})
}
func (s *Session) sendLastMessages(server, channel string, count int) {
messages, hasMore, err := s.user.GetLastMessages(server, channel, count)
if err == nil && len(messages) > 0 {
res := Messages{
Server: server,
To: channel,
Messages: messages,
}
if hasMore {
res.Next = messages[0].ID
}
s.sendJSON("messages", res)
}
}
func (s *Session) sendMessages(server, channel string, count int, fromID string) {
messages, hasMore, err := s.user.GetMessages(server, channel, count, fromID)
if err == nil && len(messages) > 0 {
res := Messages{
Server: server,
To: channel,
Messages: messages,
Prepend: true,
}
if hasMore {
res.Next = messages[0].ID
}
s.sendJSON("messages", res)
}
}
func (s *Session) print(a ...interface{}) {
s.sendJSON("print", Message{
Content: fmt.Sprintln(a...),
})
}
func (s *Session) printError(a ...interface{}) {
s.sendJSON("print", Message{
Content: fmt.Sprintln(a...),
Type: "error",
})
}
func (s *Session) resetExpirationIfEmpty() {
if s.numIRC() == 0 && s.numWS() == 0 {
s.reset <- AnonymousSessionExpiration
}
}
func (s *Session) run() {
for {
select {
case res := <-s.broadcast:
s.wsLock.Lock()
for _, ws := range s.ws {
ws.out <- res
}
s.wsLock.Unlock()
case <-s.expiration.C:
sessions.delete(s.id)
s.user.Remove()
return
case duration := <-s.reset:
if duration == 0 {
s.expiration.Stop()
} else {
s.expiration.Reset(duration)
}
}
}
}
type sessionStore struct {
sessions map[string]*Session
lock sync.Mutex
}
func newSessionStore() *sessionStore {
return &sessionStore{
sessions: make(map[string]*Session),
}
}
func (s *sessionStore) get(id string) *Session {
s.lock.Lock()
session := s.sessions[id]
s.lock.Unlock()
return session
}
func (s *sessionStore) set(session *Session) {
s.lock.Lock()
s.sessions[session.id] = session
s.lock.Unlock()
}
func (s *sessionStore) delete(id string) {
s.lock.Lock()
delete(s.sessions, id)
s.lock.Unlock()
}

323
server/state.go Normal file
View file

@ -0,0 +1,323 @@
package server
import (
"log"
"sync"
"time"
"fmt"
"github.com/khlieng/dispatch/pkg/irc"
"github.com/khlieng/dispatch/pkg/session"
"github.com/khlieng/dispatch/storage"
)
const (
AnonymousUserExpiration = 1 * time.Minute
)
type State struct {
irc map[string]*irc.Client
connectionState map[string]irc.ConnectionState
ircLock sync.Mutex
ws map[string]*wsConn
wsLock sync.Mutex
broadcast chan WSResponse
srv *Dispatch
user *storage.User
expiration *time.Timer
reset chan time.Duration
}
func NewState(user *storage.User, srv *Dispatch) *State {
return &State{
irc: make(map[string]*irc.Client),
connectionState: make(map[string]irc.ConnectionState),
ws: make(map[string]*wsConn),
broadcast: make(chan WSResponse, 32),
srv: srv,
user: user,
expiration: time.NewTimer(AnonymousUserExpiration),
reset: make(chan time.Duration, 1),
}
}
func (s *State) getIRC(server string) (*irc.Client, bool) {
s.ircLock.Lock()
i, ok := s.irc[server]
s.ircLock.Unlock()
return i, ok
}
func (s *State) setIRC(server string, i *irc.Client) {
s.ircLock.Lock()
s.irc[server] = i
s.connectionState[server] = irc.ConnectionState{
Connected: false,
}
s.ircLock.Unlock()
s.reset <- 0
}
func (s *State) deleteIRC(server string) {
s.ircLock.Lock()
delete(s.irc, server)
delete(s.connectionState, server)
s.ircLock.Unlock()
s.resetExpirationIfEmpty()
}
func (s *State) numIRC() int {
s.ircLock.Lock()
n := len(s.irc)
s.ircLock.Unlock()
return n
}
func (s *State) getConnectionStates() map[string]irc.ConnectionState {
s.ircLock.Lock()
state := make(map[string]irc.ConnectionState, len(s.connectionState))
for k, v := range s.connectionState {
state[k] = v
}
s.ircLock.Unlock()
return state
}
func (s *State) setConnectionState(server string, state irc.ConnectionState) {
s.ircLock.Lock()
s.connectionState[server] = state
s.ircLock.Unlock()
}
func (s *State) setWS(addr string, w *wsConn) {
s.wsLock.Lock()
s.ws[addr] = w
s.wsLock.Unlock()
s.reset <- 0
}
func (s *State) deleteWS(addr string) {
s.wsLock.Lock()
delete(s.ws, addr)
s.wsLock.Unlock()
s.resetExpirationIfEmpty()
}
func (s *State) numWS() int {
s.ircLock.Lock()
n := len(s.ws)
s.ircLock.Unlock()
return n
}
func (s *State) sendJSON(t string, v interface{}) {
s.broadcast <- WSResponse{t, v}
}
func (s *State) sendError(err error, server string) {
s.sendJSON("error", Error{
Server: server,
Message: err.Error(),
})
}
func (s *State) sendLastMessages(server, channel string, count int) {
messages, hasMore, err := s.user.GetLastMessages(server, channel, count)
if err == nil && len(messages) > 0 {
res := Messages{
Server: server,
To: channel,
Messages: messages,
}
if hasMore {
res.Next = messages[0].ID
}
s.sendJSON("messages", res)
}
}
func (s *State) sendMessages(server, channel string, count int, fromID string) {
messages, hasMore, err := s.user.GetMessages(server, channel, count, fromID)
if err == nil && len(messages) > 0 {
res := Messages{
Server: server,
To: channel,
Messages: messages,
Prepend: true,
}
if hasMore {
res.Next = messages[0].ID
}
s.sendJSON("messages", res)
}
}
func (s *State) print(a ...interface{}) {
s.sendJSON("print", Message{
Content: fmt.Sprintln(a...),
})
}
func (s *State) printError(a ...interface{}) {
s.sendJSON("print", Message{
Content: fmt.Sprintln(a...),
Type: "error",
})
}
func (s *State) resetExpirationIfEmpty() {
if s.numIRC() == 0 && s.numWS() == 0 {
s.reset <- AnonymousUserExpiration
}
}
func (s *State) kill() {
s.wsLock.Lock()
for _, ws := range s.ws {
ws.conn.Close()
}
s.wsLock.Unlock()
s.ircLock.Lock()
for _, i := range s.irc {
i.Quit()
}
s.ircLock.Unlock()
}
func (s *State) run() {
for {
select {
case res := <-s.broadcast:
s.wsLock.Lock()
for _, ws := range s.ws {
ws.out <- res
}
s.wsLock.Unlock()
case <-s.expiration.C:
s.srv.states.delete(s.user.ID)
s.user.Remove()
return
case duration := <-s.reset:
if duration == 0 {
s.expiration.Stop()
} else {
s.expiration.Reset(duration)
}
}
}
}
type stateStore struct {
states map[uint64]*State
sessions map[string]*session.Session
sessionStore storage.SessionStore
lock sync.Mutex
}
func newStateStore(sessionStore storage.SessionStore) *stateStore {
store := &stateStore{
states: make(map[uint64]*State),
sessions: make(map[string]*session.Session),
sessionStore: sessionStore,
}
sessions, err := sessionStore.GetSessions()
if err != nil {
log.Fatal(err)
}
for _, session := range sessions {
if !session.Expired() {
session.Init()
store.sessions[session.Key()] = &session
go deleteSessionWhenExpired(&session, store)
} else {
go sessionStore.DeleteSession(session.Key())
}
}
return store
}
func (s *stateStore) get(id uint64) *State {
s.lock.Lock()
state := s.states[id]
s.lock.Unlock()
return state
}
func (s *stateStore) set(state *State) {
s.lock.Lock()
s.states[state.user.ID] = state
s.lock.Unlock()
}
func (s *stateStore) delete(id uint64) {
s.lock.Lock()
delete(s.states, id)
for key, session := range s.sessions {
if session.UserID == id {
delete(s.sessions, key)
go s.sessionStore.DeleteSession(key)
}
}
s.lock.Unlock()
}
func (s *stateStore) getSession(key string) *session.Session {
s.lock.Lock()
session := s.sessions[key]
s.lock.Unlock()
return session
}
func (s *stateStore) setSession(session *session.Session) {
s.lock.Lock()
s.sessions[session.Key()] = session
s.lock.Unlock()
s.sessionStore.SaveSession(session)
}
func (s *stateStore) deleteSession(key string) {
s.lock.Lock()
id := s.sessions[key].UserID
delete(s.sessions, key)
n := 0
for _, session := range s.sessions {
if session.UserID == id {
n++
}
}
state := s.states[id]
if n == 0 {
delete(s.states, id)
}
s.lock.Unlock()
if n == 0 {
// This anonymous user is not reachable anymore since all sessions have
// expired, so we clean it up
state.kill()
state.user.Remove()
}
s.sessionStore.DeleteSession(key)
}

View file

@ -11,16 +11,16 @@ import (
type wsHandler struct {
ws *wsConn
session *Session
state *State
addr string
handlers map[string]func([]byte)
}
func newWSHandler(conn *websocket.Conn, session *Session, r *http.Request) *wsHandler {
func newWSHandler(conn *websocket.Conn, state *State, r *http.Request) *wsHandler {
h := &wsHandler{
ws: newWSConn(conn),
session: session,
addr: conn.RemoteAddr().String(),
ws: newWSConn(conn),
state: state,
addr: conn.RemoteAddr().String(),
}
h.init(r)
h.initHandlers()
@ -35,8 +35,8 @@ func (h *wsHandler) run() {
for {
req, ok := <-h.ws.in
if !ok {
if h.session != nil {
h.session.deleteWS(h.addr)
if h.state != nil {
h.state.deleteWS(h.addr)
}
return
}
@ -52,13 +52,16 @@ func (h *wsHandler) dispatchRequest(req WSRequest) {
}
func (h *wsHandler) init(r *http.Request) {
h.session.setWS(h.addr, h.ws)
h.state.setWS(h.addr, h.ws)
log.Println(h.addr, "[Session] User ID:", h.session.user.ID, "|",
h.session.numIRC(), "IRC connections |",
h.session.numWS(), "WebSocket connections")
log.Println(h.addr, "[State] User ID:", h.state.user.ID, "|",
h.state.numIRC(), "IRC connections |",
h.state.numWS(), "WebSocket connections")
channels := h.session.user.GetChannels()
channels, err := h.state.user.GetChannels()
if err != nil {
log.Println(err)
}
path := r.URL.EscapedPath()[3:]
pathServer, pathChannel := getTabFromPath(path)
cookieServer, cookieChannel := parseTabCookie(r, path)
@ -66,16 +69,17 @@ func (h *wsHandler) init(r *http.Request) {
for _, channel := range channels {
if (channel.Server == pathServer && channel.Name == pathChannel) ||
(channel.Server == cookieServer && channel.Name == cookieChannel) {
// Userlist and messages for this channel gets embedded in the index page
continue
}
h.session.sendJSON("users", Userlist{
h.state.sendJSON("users", Userlist{
Server: channel.Server,
Channel: channel.Name,
Users: channelStore.GetUsers(channel.Server, channel.Name),
})
h.session.sendLastMessages(channel.Server, channel.Name, 50)
h.state.sendLastMessages(channel.Server, channel.Name, 50)
}
}
@ -83,12 +87,12 @@ func (h *wsHandler) connect(b []byte) {
var data Server
data.UnmarshalJSON(b)
if _, ok := h.session.getIRC(data.Host); !ok {
if _, ok := h.state.getIRC(data.Host); !ok {
log.Println(h.addr, "[IRC] Add server", data.Host)
connectIRC(data.Server, h.session)
connectIRC(&data.Server, h.state)
go h.session.user.AddServer(data.Server)
go h.state.user.AddServer(&data.Server)
} else {
log.Println(h.addr, "[IRC]", data.Host, "already added")
}
@ -98,7 +102,7 @@ func (h *wsHandler) reconnect(b []byte) {
var data ReconnectSettings
data.UnmarshalJSON(b)
if i, ok := h.session.getIRC(data.Server); ok && !i.Connected() {
if i, ok := h.state.getIRC(data.Server); ok && !i.Connected() {
if i.TLS {
i.TLSConfig.InsecureSkipVerify = data.SkipVerify
}
@ -110,7 +114,7 @@ func (h *wsHandler) join(b []byte) {
var data Join
data.UnmarshalJSON(b)
if i, ok := h.session.getIRC(data.Server); ok {
if i, ok := h.state.getIRC(data.Server); ok {
i.Join(data.Channels...)
}
}
@ -119,7 +123,7 @@ func (h *wsHandler) part(b []byte) {
var data Part
data.UnmarshalJSON(b)
if i, ok := h.session.getIRC(data.Server); ok {
if i, ok := h.state.getIRC(data.Server); ok {
i.Part(data.Channels...)
}
}
@ -129,22 +133,22 @@ func (h *wsHandler) quit(b []byte) {
data.UnmarshalJSON(b)
log.Println(h.addr, "[IRC] Remove server", data.Server)
if i, ok := h.session.getIRC(data.Server); ok {
h.session.deleteIRC(data.Server)
if i, ok := h.state.getIRC(data.Server); ok {
h.state.deleteIRC(data.Server)
i.Quit()
}
go h.session.user.RemoveServer(data.Server)
go h.state.user.RemoveServer(data.Server)
}
func (h *wsHandler) message(b []byte) {
var data Message
data.UnmarshalJSON(b)
if i, ok := h.session.getIRC(data.Server); ok {
if i, ok := h.state.getIRC(data.Server); ok {
i.Privmsg(data.To, data.Content)
go h.session.user.LogMessage(betterguid.New(),
go h.state.user.LogMessage(betterguid.New(),
data.Server, i.GetNick(), data.To, data.Content)
}
}
@ -153,7 +157,7 @@ func (h *wsHandler) nick(b []byte) {
var data Nick
data.UnmarshalJSON(b)
if i, ok := h.session.getIRC(data.Server); ok {
if i, ok := h.state.getIRC(data.Server); ok {
i.Nick(data.New)
}
}
@ -162,7 +166,7 @@ func (h *wsHandler) topic(b []byte) {
var data Topic
data.UnmarshalJSON(b)
if i, ok := h.session.getIRC(data.Server); ok {
if i, ok := h.state.getIRC(data.Server); ok {
i.Topic(data.Channel, data.Topic)
}
}
@ -171,7 +175,7 @@ func (h *wsHandler) invite(b []byte) {
var data Invite
data.UnmarshalJSON(b)
if i, ok := h.session.getIRC(data.Server); ok {
if i, ok := h.state.getIRC(data.Server); ok {
i.Invite(data.User, data.Channel)
}
}
@ -180,7 +184,7 @@ func (h *wsHandler) kick(b []byte) {
var data Invite
data.UnmarshalJSON(b)
if i, ok := h.session.getIRC(data.Server); ok {
if i, ok := h.state.getIRC(data.Server); ok {
i.Kick(data.Channel, data.User)
}
}
@ -189,7 +193,7 @@ func (h *wsHandler) whois(b []byte) {
var data Whois
data.UnmarshalJSON(b)
if i, ok := h.session.getIRC(data.Server); ok {
if i, ok := h.state.getIRC(data.Server); ok {
i.Whois(data.User)
}
}
@ -198,7 +202,7 @@ func (h *wsHandler) away(b []byte) {
var data Away
data.UnmarshalJSON(b)
if i, ok := h.session.getIRC(data.Server); ok {
if i, ok := h.state.getIRC(data.Server); ok {
i.Away(data.Message)
}
}
@ -207,7 +211,7 @@ func (h *wsHandler) raw(b []byte) {
var data Raw
data.UnmarshalJSON(b)
if i, ok := h.session.getIRC(data.Server); ok {
if i, ok := h.state.getIRC(data.Server); ok {
i.Write(data.Message)
}
}
@ -217,13 +221,13 @@ func (h *wsHandler) search(b []byte) {
var data SearchRequest
data.UnmarshalJSON(b)
results, err := h.session.user.SearchMessages(data.Server, data.Channel, data.Phrase)
results, err := h.state.user.SearchMessages(data.Server, data.Channel, data.Phrase)
if err != nil {
log.Println(err)
return
}
h.session.sendJSON("search", SearchResult{
h.state.sendJSON("search", SearchResult{
Server: data.Server,
Channel: data.Channel,
Results: results,
@ -235,20 +239,20 @@ func (h *wsHandler) cert(b []byte) {
var data ClientCert
data.UnmarshalJSON(b)
err := h.session.user.SetCertificate(data.Cert, data.Key)
err := h.state.user.SetCertificate(data.Cert, data.Key)
if err != nil {
h.session.sendJSON("cert_fail", Error{Message: err.Error()})
h.state.sendJSON("cert_fail", Error{Message: err.Error()})
return
}
h.session.sendJSON("cert_success", nil)
h.state.sendJSON("cert_success", nil)
}
func (h *wsHandler) fetchMessages(b []byte) {
var data FetchMessages
data.UnmarshalJSON(b)
h.session.sendMessages(data.Server, data.Channel, 200, data.Next)
h.state.sendMessages(data.Server, data.Channel, 200, data.Next)
}
func (h *wsHandler) setServerName(b []byte) {
@ -256,7 +260,7 @@ func (h *wsHandler) setServerName(b []byte) {
data.UnmarshalJSON(b)
if isValidServerName(data.Name) {
h.session.user.SetServerName(data.Name, data.Server)
h.state.user.SetServerName(data.Name, data.Server)
}
}