Persist, renew and delete sessions, refactor storage package, move reusable packages to pkg
This commit is contained in:
parent
121582f72a
commit
24f9553aa5
48 changed files with 1872 additions and 1171 deletions
|
@ -3,58 +3,95 @@ package server
|
|||
import (
|
||||
"log"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/khlieng/dispatch/pkg/session"
|
||||
"github.com/khlieng/dispatch/storage"
|
||||
)
|
||||
|
||||
const (
|
||||
cookieName = "dispatch"
|
||||
)
|
||||
func (d *Dispatch) handleAuth(w http.ResponseWriter, r *http.Request, createUser bool) *State {
|
||||
var state *State
|
||||
|
||||
func handleAuth(w http.ResponseWriter, r *http.Request, createUser bool) *Session {
|
||||
var session *Session
|
||||
|
||||
cookie, err := r.Cookie(cookieName)
|
||||
cookie, err := r.Cookie(session.CookieName)
|
||||
if err != nil {
|
||||
if createUser {
|
||||
session = newUser(w, r)
|
||||
state, err = d.newUser(w, r)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
session = sessions.get(cookie.Value)
|
||||
session := d.states.getSession(cookie.Value)
|
||||
if session != nil {
|
||||
log.Println(r.RemoteAddr, "[Auth] GET", r.URL.Path, "| Valid token | User ID:", session.user.ID)
|
||||
key := session.Key()
|
||||
newKey, expired, err := session.Refresh()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if !expired {
|
||||
state = d.states.get(session.UserID)
|
||||
if newKey != "" {
|
||||
d.states.setSession(session)
|
||||
d.states.deleteSession(key)
|
||||
session.SetCookie(w, r)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if state != nil {
|
||||
log.Println(r.RemoteAddr, "[Auth] GET", r.URL.Path, "| Valid token | User ID:", state.user.ID)
|
||||
} else if createUser {
|
||||
session = newUser(w, r)
|
||||
state, err = d.newUser(w, r)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return session
|
||||
return state
|
||||
}
|
||||
|
||||
func newUser(w http.ResponseWriter, r *http.Request) *Session {
|
||||
user, err := storage.NewUser()
|
||||
func (d *Dispatch) newUser(w http.ResponseWriter, r *http.Request) (*State, error) {
|
||||
user, err := storage.NewUser(d.Store)
|
||||
if err != nil {
|
||||
return nil
|
||||
return nil, err
|
||||
}
|
||||
|
||||
log.Println(r.RemoteAddr, "[Auth] Create session | User ID:", user.ID)
|
||||
|
||||
session, err := NewSession(user)
|
||||
messageStore, err := d.GetMessageStore(user)
|
||||
if err != nil {
|
||||
return nil
|
||||
return nil, err
|
||||
}
|
||||
sessions.set(session)
|
||||
go session.run()
|
||||
user.SetMessageStore(messageStore)
|
||||
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: cookieName,
|
||||
Value: session.id,
|
||||
Path: "/",
|
||||
Expires: time.Now().AddDate(0, 1, 0),
|
||||
HttpOnly: true,
|
||||
Secure: r.TLS != nil,
|
||||
})
|
||||
search, err := d.GetMessageSearchProvider(user)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
user.SetMessageSearchProvider(search)
|
||||
|
||||
return session
|
||||
log.Println(r.RemoteAddr, "[Auth] New anonymous user | ID:", user.ID)
|
||||
|
||||
session, err := session.New(user.ID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
d.states.setSession(session)
|
||||
go d.deleteSessionWhenExpired(session)
|
||||
|
||||
state := NewState(user, d)
|
||||
d.states.set(state)
|
||||
go state.run()
|
||||
|
||||
session.SetCookie(w, r)
|
||||
|
||||
return state, nil
|
||||
}
|
||||
|
||||
func (d *Dispatch) deleteSessionWhenExpired(session *session.Session) {
|
||||
deleteSessionWhenExpired(session, d.states)
|
||||
}
|
||||
|
||||
func deleteSessionWhenExpired(session *session.Session, stateStore *stateStore) {
|
||||
session.WaitUntilExpiration()
|
||||
stateStore.deleteSession(session.Key())
|
||||
}
|
||||
|
|
|
@ -11,55 +11,29 @@ import (
|
|||
)
|
||||
|
||||
type connectDefaults struct {
|
||||
Name string `json:"name,omitempty"`
|
||||
Host string `json:"host,omitempty"`
|
||||
Port int `json:"port,omitempty"`
|
||||
Channels []string `json:"channels,omitempty"`
|
||||
Password bool `json:"password,omitempty"`
|
||||
SSL bool `json:"ssl,omitempty"`
|
||||
ReadOnly bool `json:"readonly,omitempty"`
|
||||
ShowDetails bool `json:"showDetails,omitempty"`
|
||||
Name string
|
||||
Host string
|
||||
Port int
|
||||
Channels []string
|
||||
Password bool
|
||||
SSL bool
|
||||
ReadOnly bool
|
||||
ShowDetails bool
|
||||
}
|
||||
|
||||
type indexData struct {
|
||||
Defaults connectDefaults `json:"defaults"`
|
||||
Servers []Server `json:"servers,omitempty"`
|
||||
Channels []storage.Channel `json:"channels,omitempty"`
|
||||
Defaults connectDefaults
|
||||
Servers []Server
|
||||
Channels []storage.Channel
|
||||
|
||||
// Users in the selected channel
|
||||
Users *Userlist `json:"users,omitempty"`
|
||||
Users *Userlist
|
||||
|
||||
// Last messages in the selected channel
|
||||
Messages *Messages `json:"messages,omitempty"`
|
||||
Messages *Messages
|
||||
}
|
||||
|
||||
func (d *indexData) addUsersAndMessages(server, channel string, session *Session) {
|
||||
users := channelStore.GetUsers(server, channel)
|
||||
if len(users) > 0 {
|
||||
d.Users = &Userlist{
|
||||
Server: server,
|
||||
Channel: channel,
|
||||
Users: users,
|
||||
}
|
||||
}
|
||||
|
||||
messages, hasMore, err := session.user.GetLastMessages(server, channel, 50)
|
||||
if err == nil && len(messages) > 0 {
|
||||
m := Messages{
|
||||
Server: server,
|
||||
To: channel,
|
||||
Messages: messages,
|
||||
}
|
||||
|
||||
if hasMore {
|
||||
m.Next = messages[0].ID
|
||||
}
|
||||
|
||||
d.Messages = &m
|
||||
}
|
||||
}
|
||||
|
||||
func getIndexData(r *http.Request, session *Session) *indexData {
|
||||
func getIndexData(r *http.Request, state *State) *indexData {
|
||||
data := indexData{}
|
||||
|
||||
data.Defaults = connectDefaults{
|
||||
|
@ -73,12 +47,15 @@ func getIndexData(r *http.Request, session *Session) *indexData {
|
|||
ShowDetails: viper.GetBool("defaults.show_details"),
|
||||
}
|
||||
|
||||
if session == nil {
|
||||
if state == nil {
|
||||
return &data
|
||||
}
|
||||
|
||||
servers := session.user.GetServers()
|
||||
connections := session.getConnectionStates()
|
||||
servers, err := state.user.GetServers()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
connections := state.getConnectionStates()
|
||||
for _, server := range servers {
|
||||
server.Password = ""
|
||||
server.Username = ""
|
||||
|
@ -90,7 +67,10 @@ func getIndexData(r *http.Request, session *Session) *indexData {
|
|||
})
|
||||
}
|
||||
|
||||
channels := session.user.GetChannels()
|
||||
channels, err := state.user.GetChannels()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
for i, channel := range channels {
|
||||
channels[i].Topic = channelStore.GetTopic(channel.Server, channel.Name)
|
||||
}
|
||||
|
@ -98,18 +78,44 @@ func getIndexData(r *http.Request, session *Session) *indexData {
|
|||
|
||||
server, channel := getTabFromPath(r.URL.EscapedPath())
|
||||
if isInChannel(channels, server, channel) {
|
||||
data.addUsersAndMessages(server, channel, session)
|
||||
data.addUsersAndMessages(server, channel, state)
|
||||
return &data
|
||||
}
|
||||
|
||||
server, channel = parseTabCookie(r, r.URL.Path)
|
||||
if isInChannel(channels, server, channel) {
|
||||
data.addUsersAndMessages(server, channel, session)
|
||||
data.addUsersAndMessages(server, channel, state)
|
||||
}
|
||||
|
||||
return &data
|
||||
}
|
||||
|
||||
func (d *indexData) addUsersAndMessages(server, channel string, state *State) {
|
||||
users := channelStore.GetUsers(server, channel)
|
||||
if len(users) > 0 {
|
||||
d.Users = &Userlist{
|
||||
Server: server,
|
||||
Channel: channel,
|
||||
Users: users,
|
||||
}
|
||||
}
|
||||
|
||||
messages, hasMore, err := state.user.GetLastMessages(server, channel, 50)
|
||||
if err == nil && len(messages) > 0 {
|
||||
m := Messages{
|
||||
Server: server,
|
||||
To: channel,
|
||||
Messages: messages,
|
||||
}
|
||||
|
||||
if hasMore {
|
||||
m.Next = messages[0].ID
|
||||
}
|
||||
|
||||
d.Messages = &m
|
||||
}
|
||||
}
|
||||
|
||||
func isInChannel(channels []storage.Channel, server, channel string) bool {
|
||||
if channel != "" {
|
||||
for _, ch := range channels {
|
||||
|
|
|
@ -344,7 +344,7 @@ func easyjson7e607aefDecodeGithubComKhliengDispatchServer1(in *jlexer.Lexer, out
|
|||
out.Password = bool(in.Bool())
|
||||
case "ssl":
|
||||
out.SSL = bool(in.Bool())
|
||||
case "readonly":
|
||||
case "readOnly":
|
||||
out.ReadOnly = bool(in.Bool())
|
||||
case "showDetails":
|
||||
out.ShowDetails = bool(in.Bool())
|
||||
|
@ -432,7 +432,7 @@ func easyjson7e607aefEncodeGithubComKhliengDispatchServer1(out *jwriter.Writer,
|
|||
out.Bool(bool(in.SSL))
|
||||
}
|
||||
if in.ReadOnly {
|
||||
const prefix string = ",\"readonly\":"
|
||||
const prefix string = ",\"readOnly\":"
|
||||
if first {
|
||||
first = false
|
||||
out.RawString(prefix[1:])
|
||||
|
|
|
@ -2,61 +2,35 @@ package server
|
|||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"log"
|
||||
"net"
|
||||
|
||||
"github.com/khlieng/dispatch/irc"
|
||||
"github.com/khlieng/dispatch/storage"
|
||||
"github.com/spf13/viper"
|
||||
|
||||
"github.com/khlieng/dispatch/pkg/irc"
|
||||
"github.com/khlieng/dispatch/storage"
|
||||
)
|
||||
|
||||
func createNickInUseHandler(i *irc.Client, session *Session) func(string) string {
|
||||
func createNickInUseHandler(i *irc.Client, state *State) func(string) string {
|
||||
return func(nick string) string {
|
||||
newNick := nick + "_"
|
||||
|
||||
if newNick == i.GetNick() {
|
||||
session.sendJSON("nick_fail", NickFail{
|
||||
state.sendJSON("nick_fail", NickFail{
|
||||
Server: i.Host,
|
||||
})
|
||||
}
|
||||
|
||||
session.printError("Nickname", nick, "is already in use, using", newNick, "instead")
|
||||
state.printError("Nickname", nick, "is already in use, using", newNick, "instead")
|
||||
|
||||
return newNick
|
||||
}
|
||||
}
|
||||
|
||||
func reconnectIRC() {
|
||||
for _, user := range storage.LoadUsers() {
|
||||
session, err := NewSession(user)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
continue
|
||||
}
|
||||
sessions.set(session)
|
||||
go session.run()
|
||||
|
||||
channels := user.GetChannels()
|
||||
|
||||
for _, server := range user.GetServers() {
|
||||
i := connectIRC(server, session)
|
||||
|
||||
var joining []string
|
||||
for _, channel := range channels {
|
||||
if channel.Server == server.Host {
|
||||
joining = append(joining, channel.Name)
|
||||
}
|
||||
}
|
||||
i.Join(joining...)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func connectIRC(server storage.Server, session *Session) *irc.Client {
|
||||
func connectIRC(server *storage.Server, state *State) *irc.Client {
|
||||
i := irc.NewClient(server.Nick, server.Username)
|
||||
i.TLS = server.TLS
|
||||
i.Realname = server.Realname
|
||||
i.HandleNickInUse = createNickInUseHandler(i, session)
|
||||
i.HandleNickInUse = createNickInUseHandler(i, state)
|
||||
|
||||
address := server.Host
|
||||
if server.Port != "" {
|
||||
|
@ -83,14 +57,14 @@ func connectIRC(server storage.Server, session *Session) *irc.Client {
|
|||
InsecureSkipVerify: !viper.GetBool("verify_certificates"),
|
||||
}
|
||||
|
||||
if cert := session.user.GetCertificate(); cert != nil {
|
||||
if cert := state.user.GetCertificate(); cert != nil {
|
||||
i.TLSConfig.Certificates = []tls.Certificate{*cert}
|
||||
}
|
||||
}
|
||||
|
||||
session.setIRC(server.Host, i)
|
||||
state.setIRC(server.Host, i)
|
||||
i.Connect(address)
|
||||
go newIRCHandler(i, session).run()
|
||||
go newIRCHandler(i, state).run()
|
||||
|
||||
return i
|
||||
}
|
||||
|
|
|
@ -8,7 +8,7 @@ import (
|
|||
|
||||
"github.com/kjk/betterguid"
|
||||
|
||||
"github.com/khlieng/dispatch/irc"
|
||||
"github.com/khlieng/dispatch/pkg/irc"
|
||||
"github.com/khlieng/dispatch/storage"
|
||||
)
|
||||
|
||||
|
@ -17,8 +17,8 @@ var excludedErrors = []string{
|
|||
}
|
||||
|
||||
type ircHandler struct {
|
||||
client *irc.Client
|
||||
session *Session
|
||||
client *irc.Client
|
||||
state *State
|
||||
|
||||
whois WhoisReply
|
||||
userBuffers map[string][]string
|
||||
|
@ -27,10 +27,10 @@ type ircHandler struct {
|
|||
handlers map[string]func(*irc.Message)
|
||||
}
|
||||
|
||||
func newIRCHandler(client *irc.Client, session *Session) *ircHandler {
|
||||
func newIRCHandler(client *irc.Client, state *State) *ircHandler {
|
||||
i := &ircHandler{
|
||||
client: client,
|
||||
session: session,
|
||||
state: state,
|
||||
userBuffers: make(map[string][]string),
|
||||
}
|
||||
i.initHandlers()
|
||||
|
@ -43,15 +43,15 @@ func (i *ircHandler) run() {
|
|||
select {
|
||||
case msg, ok := <-i.client.Messages:
|
||||
if !ok {
|
||||
i.session.deleteIRC(i.client.Host)
|
||||
i.state.deleteIRC(i.client.Host)
|
||||
return
|
||||
}
|
||||
|
||||
i.dispatchMessage(msg)
|
||||
|
||||
case state := <-i.client.ConnectionChanged:
|
||||
i.session.sendJSON("connection_update", newConnectionUpdate(i.client.Host, state))
|
||||
i.session.setConnectionState(i.client.Host, state)
|
||||
i.state.sendJSON("connection_update", newConnectionUpdate(i.client.Host, state))
|
||||
i.state.setConnectionState(i.client.Host, state)
|
||||
|
||||
if state.Error != nil && (lastConnErr == nil ||
|
||||
state.Error.Error() != lastConnErr.Error()) {
|
||||
|
@ -66,7 +66,7 @@ func (i *ircHandler) run() {
|
|||
|
||||
func (i *ircHandler) dispatchMessage(msg *irc.Message) {
|
||||
if msg.Command[0] == '4' && !isExcludedError(msg.Command) {
|
||||
i.session.printError(formatIRCError(msg))
|
||||
i.state.printError(formatIRCError(msg))
|
||||
}
|
||||
|
||||
if handler, ok := i.handlers[msg.Command]; ok {
|
||||
|
@ -75,7 +75,7 @@ func (i *ircHandler) dispatchMessage(msg *irc.Message) {
|
|||
}
|
||||
|
||||
func (i *ircHandler) nick(msg *irc.Message) {
|
||||
i.session.sendJSON("nick", Nick{
|
||||
i.state.sendJSON("nick", Nick{
|
||||
Server: i.client.Host,
|
||||
Old: msg.Nick,
|
||||
New: msg.LastParam(),
|
||||
|
@ -84,12 +84,12 @@ func (i *ircHandler) nick(msg *irc.Message) {
|
|||
channelStore.RenameUser(msg.Nick, msg.LastParam(), i.client.Host)
|
||||
|
||||
if msg.LastParam() == i.client.GetNick() {
|
||||
go i.session.user.SetNick(msg.LastParam(), i.client.Host)
|
||||
go i.state.user.SetNick(msg.LastParam(), i.client.Host)
|
||||
}
|
||||
}
|
||||
|
||||
func (i *ircHandler) join(msg *irc.Message) {
|
||||
i.session.sendJSON("join", Join{
|
||||
i.state.sendJSON("join", Join{
|
||||
Server: i.client.Host,
|
||||
User: msg.Nick,
|
||||
Channels: msg.Params,
|
||||
|
@ -102,9 +102,9 @@ func (i *ircHandler) join(msg *irc.Message) {
|
|||
// Incase no topic is set and theres a cached one that needs to be cleared
|
||||
i.client.Topic(channel)
|
||||
|
||||
i.session.sendLastMessages(i.client.Host, channel, 50)
|
||||
i.state.sendLastMessages(i.client.Host, channel, 50)
|
||||
|
||||
go i.session.user.AddChannel(storage.Channel{
|
||||
go i.state.user.AddChannel(&storage.Channel{
|
||||
Server: i.client.Host,
|
||||
Name: channel,
|
||||
})
|
||||
|
@ -122,12 +122,12 @@ func (i *ircHandler) part(msg *irc.Message) {
|
|||
part.Reason = msg.Params[1]
|
||||
}
|
||||
|
||||
i.session.sendJSON("part", part)
|
||||
i.state.sendJSON("part", part)
|
||||
|
||||
channelStore.RemoveUser(msg.Nick, i.client.Host, part.Channel)
|
||||
|
||||
if msg.Nick == i.client.GetNick() {
|
||||
go i.session.user.RemoveChannel(i.client.Host, part.Channel)
|
||||
go i.state.user.RemoveChannel(i.client.Host, part.Channel)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -139,7 +139,7 @@ func (i *ircHandler) mode(msg *irc.Message) {
|
|||
mode.Channel = target
|
||||
mode.User = msg.Params[2]
|
||||
|
||||
i.session.sendJSON("mode", mode)
|
||||
i.state.sendJSON("mode", mode)
|
||||
|
||||
channelStore.SetMode(i.client.Host, target, msg.Params[2], mode.Add, mode.Remove)
|
||||
}
|
||||
|
@ -154,20 +154,20 @@ func (i *ircHandler) message(msg *irc.Message) {
|
|||
}
|
||||
|
||||
if msg.Params[0] == i.client.GetNick() {
|
||||
i.session.sendJSON("pm", message)
|
||||
i.state.sendJSON("pm", message)
|
||||
} else {
|
||||
message.To = msg.Params[0]
|
||||
i.session.sendJSON("message", message)
|
||||
i.state.sendJSON("message", message)
|
||||
}
|
||||
|
||||
if msg.Params[0] != "*" {
|
||||
go i.session.user.LogMessage(message.ID,
|
||||
go i.state.user.LogMessage(message.ID,
|
||||
i.client.Host, msg.Nick, msg.Params[0], msg.LastParam())
|
||||
}
|
||||
}
|
||||
|
||||
func (i *ircHandler) quit(msg *irc.Message) {
|
||||
i.session.sendJSON("quit", Quit{
|
||||
i.state.sendJSON("quit", Quit{
|
||||
Server: i.client.Host,
|
||||
User: msg.Nick,
|
||||
Reason: msg.LastParam(),
|
||||
|
@ -178,15 +178,15 @@ func (i *ircHandler) quit(msg *irc.Message) {
|
|||
|
||||
func (i *ircHandler) info(msg *irc.Message) {
|
||||
if msg.Command == irc.ReplyWelcome {
|
||||
i.session.sendJSON("nick", Nick{
|
||||
i.state.sendJSON("nick", Nick{
|
||||
Server: i.client.Host,
|
||||
New: msg.Params[0],
|
||||
})
|
||||
|
||||
go i.session.user.SetNick(msg.Params[0], i.client.Host)
|
||||
go i.state.user.SetNick(msg.Params[0], i.client.Host)
|
||||
}
|
||||
|
||||
i.session.sendJSON("pm", Message{
|
||||
i.state.sendJSON("pm", Message{
|
||||
Server: i.client.Host,
|
||||
From: msg.Nick,
|
||||
Content: strings.Join(msg.Params[1:], " "),
|
||||
|
@ -210,7 +210,7 @@ func (i *ircHandler) whoisChannels(msg *irc.Message) {
|
|||
|
||||
func (i *ircHandler) whoisEnd(msg *irc.Message) {
|
||||
if i.whois.Nick != "" {
|
||||
i.session.sendJSON("whois", i.whois)
|
||||
i.state.sendJSON("whois", i.whois)
|
||||
}
|
||||
i.whois = WhoisReply{}
|
||||
}
|
||||
|
@ -226,7 +226,7 @@ func (i *ircHandler) topic(msg *irc.Message) {
|
|||
channel = msg.Params[1]
|
||||
}
|
||||
|
||||
i.session.sendJSON("topic", Topic{
|
||||
i.state.sendJSON("topic", Topic{
|
||||
Server: i.client.Host,
|
||||
Channel: channel,
|
||||
Topic: msg.LastParam(),
|
||||
|
@ -239,7 +239,7 @@ func (i *ircHandler) topic(msg *irc.Message) {
|
|||
func (i *ircHandler) noTopic(msg *irc.Message) {
|
||||
channel := msg.Params[1]
|
||||
|
||||
i.session.sendJSON("topic", Topic{
|
||||
i.state.sendJSON("topic", Topic{
|
||||
Server: i.client.Host,
|
||||
Channel: channel,
|
||||
})
|
||||
|
@ -257,7 +257,7 @@ func (i *ircHandler) namesEnd(msg *irc.Message) {
|
|||
channel := msg.Params[1]
|
||||
users := i.userBuffers[channel]
|
||||
|
||||
i.session.sendJSON("users", Userlist{
|
||||
i.state.sendJSON("users", Userlist{
|
||||
Server: i.client.Host,
|
||||
Channel: channel,
|
||||
Users: users,
|
||||
|
@ -277,18 +277,18 @@ func (i *ircHandler) motd(msg *irc.Message) {
|
|||
}
|
||||
|
||||
func (i *ircHandler) motdEnd(msg *irc.Message) {
|
||||
i.session.sendJSON("motd", i.motdBuffer)
|
||||
i.state.sendJSON("motd", i.motdBuffer)
|
||||
i.motdBuffer = MOTD{}
|
||||
}
|
||||
|
||||
func (i *ircHandler) badNick(msg *irc.Message) {
|
||||
i.session.sendJSON("nick_fail", NickFail{
|
||||
i.state.sendJSON("nick_fail", NickFail{
|
||||
Server: i.client.Host,
|
||||
})
|
||||
}
|
||||
|
||||
func (i *ircHandler) error(msg *irc.Message) {
|
||||
i.session.printError(msg.LastParam())
|
||||
i.state.printError(msg.LastParam())
|
||||
}
|
||||
|
||||
func (i *ircHandler) initHandlers() {
|
||||
|
@ -327,7 +327,7 @@ func (i *ircHandler) initHandlers() {
|
|||
|
||||
func (i *ircHandler) log(v ...interface{}) {
|
||||
s := fmt.Sprintln(v...)
|
||||
log.Println("[IRC]", i.session.user.ID, i.client.Host, s[:len(s)-1])
|
||||
log.Println("[IRC]", i.state.user.ID, i.client.Host, s[:len(s)-1])
|
||||
}
|
||||
|
||||
func parseMode(mode string) *Mode {
|
||||
|
|
|
@ -8,8 +8,9 @@ import (
|
|||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/khlieng/dispatch/irc"
|
||||
"github.com/khlieng/dispatch/pkg/irc"
|
||||
"github.com/khlieng/dispatch/storage"
|
||||
"github.com/khlieng/dispatch/storage/boltdb"
|
||||
)
|
||||
|
||||
var user *storage.User
|
||||
|
@ -21,11 +22,18 @@ func TestMain(m *testing.M) {
|
|||
}
|
||||
|
||||
storage.Initialize(tempdir)
|
||||
storage.Open()
|
||||
user, err = storage.NewUser()
|
||||
|
||||
db, err := boltdb.New(storage.Path.Database())
|
||||
if err != nil {
|
||||
os.Exit(1)
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
user, err = storage.NewUser(db)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
user.SetMessageStore(db)
|
||||
|
||||
channelStore = storage.NewChannelStore()
|
||||
|
||||
code := m.Run()
|
||||
|
@ -41,7 +49,7 @@ func dispatchMessage(msg *irc.Message) WSResponse {
|
|||
func dispatchMessageMulti(msg *irc.Message) chan WSResponse {
|
||||
c := irc.NewClient("nick", "user")
|
||||
c.Host = "host.com"
|
||||
s, _ := NewSession(user)
|
||||
s := NewState(user, nil)
|
||||
|
||||
newIRCHandler(c, s).dispatchMessage(msg)
|
||||
|
||||
|
@ -187,7 +195,7 @@ func TestHandleIRCWelcome(t *testing.T) {
|
|||
func TestHandleIRCWhois(t *testing.T) {
|
||||
c := irc.NewClient("nick", "user")
|
||||
c.Host = "host.com"
|
||||
s, _ := NewSession(nil)
|
||||
s := NewState(nil, nil)
|
||||
i := newIRCHandler(c, s)
|
||||
|
||||
i.dispatchMessage(&irc.Message{
|
||||
|
@ -255,7 +263,7 @@ func TestHandleIRCNoTopic(t *testing.T) {
|
|||
func TestHandleIRCNames(t *testing.T) {
|
||||
c := irc.NewClient("nick", "user")
|
||||
c.Host = "host.com"
|
||||
s, _ := NewSession(nil)
|
||||
s := NewState(nil, nil)
|
||||
i := newIRCHandler(c, s)
|
||||
|
||||
i.dispatchMessage(&irc.Message{
|
||||
|
@ -281,7 +289,7 @@ func TestHandleIRCNames(t *testing.T) {
|
|||
func TestHandleIRCMotd(t *testing.T) {
|
||||
c := irc.NewClient("nick", "user")
|
||||
c.Host = "host.com"
|
||||
s, _ := NewSession(nil)
|
||||
s := NewState(nil, nil)
|
||||
i := newIRCHandler(c, s)
|
||||
|
||||
i.dispatchMessage(&irc.Message{
|
||||
|
@ -308,7 +316,7 @@ func TestHandleIRCMotd(t *testing.T) {
|
|||
func TestHandleIRCBadNick(t *testing.T) {
|
||||
c := irc.NewClient("nick", "user")
|
||||
c.Host = "host.com"
|
||||
s, _ := NewSession(nil)
|
||||
s := NewState(nil, nil)
|
||||
i := newIRCHandler(c, s)
|
||||
|
||||
i.dispatchMessage(&irc.Message{
|
||||
|
|
|
@ -5,7 +5,7 @@ import (
|
|||
|
||||
"github.com/mailru/easyjson"
|
||||
|
||||
"github.com/khlieng/dispatch/irc"
|
||||
"github.com/khlieng/dispatch/pkg/irc"
|
||||
"github.com/khlieng/dispatch/storage"
|
||||
)
|
||||
|
||||
|
|
|
@ -62,7 +62,7 @@ var (
|
|||
cspEnabled bool
|
||||
)
|
||||
|
||||
func initFileServer() {
|
||||
func (d *Dispatch) initFileServer() {
|
||||
if !viper.GetBool("dev") {
|
||||
data, err := assets.Asset(files[0].Asset)
|
||||
if err != nil {
|
||||
|
@ -154,24 +154,24 @@ func initFileServer() {
|
|||
}
|
||||
}
|
||||
|
||||
func serveFiles(w http.ResponseWriter, r *http.Request) {
|
||||
func (d *Dispatch) serveFiles(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/" {
|
||||
serveIndex(w, r)
|
||||
d.serveIndex(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
for _, file := range files {
|
||||
if strings.HasSuffix(r.URL.Path, file.Path) {
|
||||
serveFile(w, r, file)
|
||||
d.serveFile(w, r, file)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
serveIndex(w, r)
|
||||
d.serveIndex(w, r)
|
||||
}
|
||||
|
||||
func serveIndex(w http.ResponseWriter, r *http.Request) {
|
||||
session := handleAuth(w, r, false)
|
||||
func (d *Dispatch) serveIndex(w http.ResponseWriter, r *http.Request) {
|
||||
state := d.handleAuth(w, r, false)
|
||||
|
||||
if cspEnabled {
|
||||
var connectSrc string
|
||||
|
@ -228,10 +228,10 @@ func serveIndex(w http.ResponseWriter, r *http.Request) {
|
|||
w.Header().Set("Content-Encoding", "gzip")
|
||||
|
||||
gzw := gzip.NewWriter(w)
|
||||
IndexTemplate(gzw, getIndexData(r, session), files[1].Path, files[0].Path)
|
||||
IndexTemplate(gzw, getIndexData(r, state), files[1].Path, files[0].Path)
|
||||
gzw.Close()
|
||||
} else {
|
||||
IndexTemplate(w, getIndexData(r, session), files[1].Path, files[0].Path)
|
||||
IndexTemplate(w, getIndexData(r, state), files[1].Path, files[0].Path)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -246,7 +246,7 @@ func setPushCookie(w http.ResponseWriter, r *http.Request) {
|
|||
})
|
||||
}
|
||||
|
||||
func serveFile(w http.ResponseWriter, r *http.Request, file *File) {
|
||||
func (d *Dispatch) serveFile(w http.ResponseWriter, r *http.Request, file *File) {
|
||||
info, err := assets.AssetInfo(file.Asset)
|
||||
if err != nil {
|
||||
http.Error(w, "", http.StatusInternalServerError)
|
||||
|
|
120
server/server.go
120
server/server.go
|
@ -12,36 +12,99 @@ import (
|
|||
"github.com/gorilla/websocket"
|
||||
"github.com/spf13/viper"
|
||||
|
||||
"github.com/khlieng/dispatch/letsencrypt"
|
||||
"github.com/khlieng/dispatch/pkg/letsencrypt"
|
||||
"github.com/khlieng/dispatch/pkg/session"
|
||||
"github.com/khlieng/dispatch/storage"
|
||||
)
|
||||
|
||||
var (
|
||||
sessions *sessionStore
|
||||
channelStore *storage.ChannelStore
|
||||
var channelStore = storage.NewChannelStore()
|
||||
|
||||
upgrader = websocket.Upgrader{
|
||||
type Dispatch struct {
|
||||
Store storage.Store
|
||||
SessionStore storage.SessionStore
|
||||
|
||||
GetMessageStore func(*storage.User) (storage.MessageStore, error)
|
||||
GetMessageSearchProvider func(*storage.User) (storage.MessageSearchProvider, error)
|
||||
|
||||
upgrader websocket.Upgrader
|
||||
states *stateStore
|
||||
}
|
||||
|
||||
func (d *Dispatch) Run() {
|
||||
d.upgrader = websocket.Upgrader{
|
||||
ReadBufferSize: 1024,
|
||||
WriteBufferSize: 1024,
|
||||
}
|
||||
)
|
||||
|
||||
func Run() {
|
||||
sessions = newSessionStore()
|
||||
channelStore = storage.NewChannelStore()
|
||||
|
||||
if viper.GetBool("dev") {
|
||||
upgrader.CheckOrigin = func(r *http.Request) bool {
|
||||
d.upgrader.CheckOrigin = func(r *http.Request) bool {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
reconnectIRC()
|
||||
initFileServer()
|
||||
startHTTP()
|
||||
session.CookieName = "dispatch"
|
||||
|
||||
d.states = newStateStore(d.SessionStore)
|
||||
|
||||
d.loadUsers()
|
||||
d.initFileServer()
|
||||
d.startHTTP()
|
||||
}
|
||||
|
||||
func startHTTP() {
|
||||
func (d *Dispatch) loadUsers() {
|
||||
users, err := storage.LoadUsers(d.Store)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
log.Printf("Loading %d user(s)", len(users))
|
||||
|
||||
for i := range users {
|
||||
go d.loadUser(&users[i])
|
||||
}
|
||||
}
|
||||
|
||||
func (d *Dispatch) loadUser(user *storage.User) {
|
||||
messageStore, err := d.GetMessageStore(user)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
user.SetMessageStore(messageStore)
|
||||
|
||||
search, err := d.GetMessageSearchProvider(user)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
user.SetMessageSearchProvider(search)
|
||||
|
||||
state := NewState(user, d)
|
||||
d.states.set(state)
|
||||
go state.run()
|
||||
|
||||
channels, err := user.GetChannels()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
servers, err := user.GetServers()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
for _, server := range servers {
|
||||
i := connectIRC(&server, state)
|
||||
|
||||
var joining []string
|
||||
for _, channel := range channels {
|
||||
if channel.Server == server.Host {
|
||||
joining = append(joining, channel.Name)
|
||||
}
|
||||
}
|
||||
i.Join(joining...)
|
||||
}
|
||||
}
|
||||
|
||||
func (d *Dispatch) startHTTP() {
|
||||
port := viper.GetString("port")
|
||||
|
||||
if viper.GetBool("https.enabled") {
|
||||
|
@ -55,7 +118,7 @@ func startHTTP() {
|
|||
|
||||
server := &http.Server{
|
||||
Addr: ":" + portHTTPS,
|
||||
Handler: http.HandlerFunc(serve),
|
||||
Handler: http.HandlerFunc(d.serve),
|
||||
}
|
||||
|
||||
if certExists() {
|
||||
|
@ -71,13 +134,13 @@ func startHTTP() {
|
|||
go http.ListenAndServe(":80", http.HandlerFunc(letsEncryptProxy))
|
||||
}
|
||||
|
||||
letsEncrypt, err := letsencrypt.Run(dir, domain, email, ":"+lePort)
|
||||
le, err := letsencrypt.Run(dir, domain, email, ":"+lePort)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
server.TLSConfig = &tls.Config{
|
||||
GetCertificate: letsEncrypt.GetCertificate,
|
||||
GetCertificate: le.GetCertificate,
|
||||
}
|
||||
|
||||
log.Println("[HTTPS] Listening on port", portHTTPS)
|
||||
|
@ -92,11 +155,11 @@ func startHTTP() {
|
|||
port = "1337"
|
||||
}
|
||||
log.Println("[HTTP] Listening on port", port)
|
||||
log.Fatal(http.ListenAndServe(":"+port, http.HandlerFunc(serve)))
|
||||
log.Fatal(http.ListenAndServe(":"+port, http.HandlerFunc(d.serve)))
|
||||
}
|
||||
}
|
||||
|
||||
func serve(w http.ResponseWriter, r *http.Request) {
|
||||
func (d *Dispatch) serve(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != "GET" {
|
||||
fail(w, http.StatusNotFound)
|
||||
return
|
||||
|
@ -108,28 +171,27 @@ func serve(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
session := handleAuth(w, r, true)
|
||||
|
||||
if session == nil {
|
||||
log.Println("[Auth] No session")
|
||||
state := d.handleAuth(w, r, true)
|
||||
if state == nil {
|
||||
log.Println("[Auth] No state")
|
||||
fail(w, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
upgradeWS(w, r, session)
|
||||
d.upgradeWS(w, r, state)
|
||||
} else {
|
||||
serveFiles(w, r)
|
||||
d.serveFiles(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
func upgradeWS(w http.ResponseWriter, r *http.Request, session *Session) {
|
||||
conn, err := upgrader.Upgrade(w, r, w.Header())
|
||||
func (d *Dispatch) upgradeWS(w http.ResponseWriter, r *http.Request, state *State) {
|
||||
conn, err := d.upgrader.Upgrade(w, r, w.Header())
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
return
|
||||
}
|
||||
|
||||
newWSHandler(conn, session, r).run()
|
||||
newWSHandler(conn, state, r).run()
|
||||
}
|
||||
|
||||
func createHTTPSRedirect(portHTTPS string) http.HandlerFunc {
|
||||
|
|
|
@ -1,253 +0,0 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"fmt"
|
||||
|
||||
"github.com/khlieng/dispatch/irc"
|
||||
"github.com/khlieng/dispatch/storage"
|
||||
)
|
||||
|
||||
const (
|
||||
AnonymousSessionExpiration = 1 * time.Minute
|
||||
)
|
||||
|
||||
type Session struct {
|
||||
irc map[string]*irc.Client
|
||||
connectionState map[string]irc.ConnectionState
|
||||
ircLock sync.Mutex
|
||||
|
||||
ws map[string]*wsConn
|
||||
wsLock sync.Mutex
|
||||
broadcast chan WSResponse
|
||||
|
||||
id string
|
||||
user *storage.User
|
||||
expiration *time.Timer
|
||||
reset chan time.Duration
|
||||
}
|
||||
|
||||
func NewSession(user *storage.User) (*Session, error) {
|
||||
id, err := newSessionID()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Session{
|
||||
irc: make(map[string]*irc.Client),
|
||||
connectionState: make(map[string]irc.ConnectionState),
|
||||
ws: make(map[string]*wsConn),
|
||||
broadcast: make(chan WSResponse, 32),
|
||||
id: id,
|
||||
user: user,
|
||||
expiration: time.NewTimer(AnonymousSessionExpiration),
|
||||
reset: make(chan time.Duration, 1),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func newSessionID() (string, error) {
|
||||
key := make([]byte, 32)
|
||||
_, err := rand.Read(key)
|
||||
return base64.RawURLEncoding.EncodeToString(key), err
|
||||
}
|
||||
|
||||
func (s *Session) getIRC(server string) (*irc.Client, bool) {
|
||||
s.ircLock.Lock()
|
||||
i, ok := s.irc[server]
|
||||
s.ircLock.Unlock()
|
||||
|
||||
return i, ok
|
||||
}
|
||||
|
||||
func (s *Session) setIRC(server string, i *irc.Client) {
|
||||
s.ircLock.Lock()
|
||||
s.irc[server] = i
|
||||
s.connectionState[server] = irc.ConnectionState{
|
||||
Connected: false,
|
||||
}
|
||||
s.ircLock.Unlock()
|
||||
|
||||
s.reset <- 0
|
||||
}
|
||||
|
||||
func (s *Session) deleteIRC(server string) {
|
||||
s.ircLock.Lock()
|
||||
delete(s.irc, server)
|
||||
delete(s.connectionState, server)
|
||||
s.ircLock.Unlock()
|
||||
|
||||
s.resetExpirationIfEmpty()
|
||||
}
|
||||
|
||||
func (s *Session) numIRC() int {
|
||||
s.ircLock.Lock()
|
||||
n := len(s.irc)
|
||||
s.ircLock.Unlock()
|
||||
|
||||
return n
|
||||
}
|
||||
|
||||
func (s *Session) getConnectionStates() map[string]irc.ConnectionState {
|
||||
s.ircLock.Lock()
|
||||
state := make(map[string]irc.ConnectionState, len(s.connectionState))
|
||||
|
||||
for k, v := range s.connectionState {
|
||||
state[k] = v
|
||||
}
|
||||
s.ircLock.Unlock()
|
||||
|
||||
return state
|
||||
}
|
||||
|
||||
func (s *Session) setConnectionState(server string, state irc.ConnectionState) {
|
||||
s.ircLock.Lock()
|
||||
s.connectionState[server] = state
|
||||
s.ircLock.Unlock()
|
||||
}
|
||||
|
||||
func (s *Session) setWS(addr string, w *wsConn) {
|
||||
s.wsLock.Lock()
|
||||
s.ws[addr] = w
|
||||
s.wsLock.Unlock()
|
||||
|
||||
s.reset <- 0
|
||||
}
|
||||
|
||||
func (s *Session) deleteWS(addr string) {
|
||||
s.wsLock.Lock()
|
||||
delete(s.ws, addr)
|
||||
s.wsLock.Unlock()
|
||||
|
||||
s.resetExpirationIfEmpty()
|
||||
}
|
||||
|
||||
func (s *Session) numWS() int {
|
||||
s.ircLock.Lock()
|
||||
n := len(s.ws)
|
||||
s.ircLock.Unlock()
|
||||
|
||||
return n
|
||||
}
|
||||
|
||||
func (s *Session) sendJSON(t string, v interface{}) {
|
||||
s.broadcast <- WSResponse{t, v}
|
||||
}
|
||||
|
||||
func (s *Session) sendError(err error, server string) {
|
||||
s.sendJSON("error", Error{
|
||||
Server: server,
|
||||
Message: err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Session) sendLastMessages(server, channel string, count int) {
|
||||
messages, hasMore, err := s.user.GetLastMessages(server, channel, count)
|
||||
if err == nil && len(messages) > 0 {
|
||||
res := Messages{
|
||||
Server: server,
|
||||
To: channel,
|
||||
Messages: messages,
|
||||
}
|
||||
|
||||
if hasMore {
|
||||
res.Next = messages[0].ID
|
||||
}
|
||||
|
||||
s.sendJSON("messages", res)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Session) sendMessages(server, channel string, count int, fromID string) {
|
||||
messages, hasMore, err := s.user.GetMessages(server, channel, count, fromID)
|
||||
if err == nil && len(messages) > 0 {
|
||||
res := Messages{
|
||||
Server: server,
|
||||
To: channel,
|
||||
Messages: messages,
|
||||
Prepend: true,
|
||||
}
|
||||
|
||||
if hasMore {
|
||||
res.Next = messages[0].ID
|
||||
}
|
||||
|
||||
s.sendJSON("messages", res)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Session) print(a ...interface{}) {
|
||||
s.sendJSON("print", Message{
|
||||
Content: fmt.Sprintln(a...),
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Session) printError(a ...interface{}) {
|
||||
s.sendJSON("print", Message{
|
||||
Content: fmt.Sprintln(a...),
|
||||
Type: "error",
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Session) resetExpirationIfEmpty() {
|
||||
if s.numIRC() == 0 && s.numWS() == 0 {
|
||||
s.reset <- AnonymousSessionExpiration
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Session) run() {
|
||||
for {
|
||||
select {
|
||||
case res := <-s.broadcast:
|
||||
s.wsLock.Lock()
|
||||
for _, ws := range s.ws {
|
||||
ws.out <- res
|
||||
}
|
||||
s.wsLock.Unlock()
|
||||
|
||||
case <-s.expiration.C:
|
||||
sessions.delete(s.id)
|
||||
s.user.Remove()
|
||||
return
|
||||
|
||||
case duration := <-s.reset:
|
||||
if duration == 0 {
|
||||
s.expiration.Stop()
|
||||
} else {
|
||||
s.expiration.Reset(duration)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type sessionStore struct {
|
||||
sessions map[string]*Session
|
||||
lock sync.Mutex
|
||||
}
|
||||
|
||||
func newSessionStore() *sessionStore {
|
||||
return &sessionStore{
|
||||
sessions: make(map[string]*Session),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *sessionStore) get(id string) *Session {
|
||||
s.lock.Lock()
|
||||
session := s.sessions[id]
|
||||
s.lock.Unlock()
|
||||
return session
|
||||
}
|
||||
|
||||
func (s *sessionStore) set(session *Session) {
|
||||
s.lock.Lock()
|
||||
s.sessions[session.id] = session
|
||||
s.lock.Unlock()
|
||||
}
|
||||
|
||||
func (s *sessionStore) delete(id string) {
|
||||
s.lock.Lock()
|
||||
delete(s.sessions, id)
|
||||
s.lock.Unlock()
|
||||
}
|
323
server/state.go
Normal file
323
server/state.go
Normal file
|
@ -0,0 +1,323 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"log"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"fmt"
|
||||
|
||||
"github.com/khlieng/dispatch/pkg/irc"
|
||||
"github.com/khlieng/dispatch/pkg/session"
|
||||
"github.com/khlieng/dispatch/storage"
|
||||
)
|
||||
|
||||
const (
|
||||
AnonymousUserExpiration = 1 * time.Minute
|
||||
)
|
||||
|
||||
type State struct {
|
||||
irc map[string]*irc.Client
|
||||
connectionState map[string]irc.ConnectionState
|
||||
ircLock sync.Mutex
|
||||
|
||||
ws map[string]*wsConn
|
||||
wsLock sync.Mutex
|
||||
broadcast chan WSResponse
|
||||
|
||||
srv *Dispatch
|
||||
user *storage.User
|
||||
expiration *time.Timer
|
||||
reset chan time.Duration
|
||||
}
|
||||
|
||||
func NewState(user *storage.User, srv *Dispatch) *State {
|
||||
return &State{
|
||||
irc: make(map[string]*irc.Client),
|
||||
connectionState: make(map[string]irc.ConnectionState),
|
||||
ws: make(map[string]*wsConn),
|
||||
broadcast: make(chan WSResponse, 32),
|
||||
srv: srv,
|
||||
user: user,
|
||||
expiration: time.NewTimer(AnonymousUserExpiration),
|
||||
reset: make(chan time.Duration, 1),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *State) getIRC(server string) (*irc.Client, bool) {
|
||||
s.ircLock.Lock()
|
||||
i, ok := s.irc[server]
|
||||
s.ircLock.Unlock()
|
||||
|
||||
return i, ok
|
||||
}
|
||||
|
||||
func (s *State) setIRC(server string, i *irc.Client) {
|
||||
s.ircLock.Lock()
|
||||
s.irc[server] = i
|
||||
s.connectionState[server] = irc.ConnectionState{
|
||||
Connected: false,
|
||||
}
|
||||
s.ircLock.Unlock()
|
||||
|
||||
s.reset <- 0
|
||||
}
|
||||
|
||||
func (s *State) deleteIRC(server string) {
|
||||
s.ircLock.Lock()
|
||||
delete(s.irc, server)
|
||||
delete(s.connectionState, server)
|
||||
s.ircLock.Unlock()
|
||||
|
||||
s.resetExpirationIfEmpty()
|
||||
}
|
||||
|
||||
func (s *State) numIRC() int {
|
||||
s.ircLock.Lock()
|
||||
n := len(s.irc)
|
||||
s.ircLock.Unlock()
|
||||
|
||||
return n
|
||||
}
|
||||
|
||||
func (s *State) getConnectionStates() map[string]irc.ConnectionState {
|
||||
s.ircLock.Lock()
|
||||
state := make(map[string]irc.ConnectionState, len(s.connectionState))
|
||||
|
||||
for k, v := range s.connectionState {
|
||||
state[k] = v
|
||||
}
|
||||
s.ircLock.Unlock()
|
||||
|
||||
return state
|
||||
}
|
||||
|
||||
func (s *State) setConnectionState(server string, state irc.ConnectionState) {
|
||||
s.ircLock.Lock()
|
||||
s.connectionState[server] = state
|
||||
s.ircLock.Unlock()
|
||||
}
|
||||
|
||||
func (s *State) setWS(addr string, w *wsConn) {
|
||||
s.wsLock.Lock()
|
||||
s.ws[addr] = w
|
||||
s.wsLock.Unlock()
|
||||
|
||||
s.reset <- 0
|
||||
}
|
||||
|
||||
func (s *State) deleteWS(addr string) {
|
||||
s.wsLock.Lock()
|
||||
delete(s.ws, addr)
|
||||
s.wsLock.Unlock()
|
||||
|
||||
s.resetExpirationIfEmpty()
|
||||
}
|
||||
|
||||
func (s *State) numWS() int {
|
||||
s.ircLock.Lock()
|
||||
n := len(s.ws)
|
||||
s.ircLock.Unlock()
|
||||
|
||||
return n
|
||||
}
|
||||
|
||||
func (s *State) sendJSON(t string, v interface{}) {
|
||||
s.broadcast <- WSResponse{t, v}
|
||||
}
|
||||
|
||||
func (s *State) sendError(err error, server string) {
|
||||
s.sendJSON("error", Error{
|
||||
Server: server,
|
||||
Message: err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
func (s *State) sendLastMessages(server, channel string, count int) {
|
||||
messages, hasMore, err := s.user.GetLastMessages(server, channel, count)
|
||||
if err == nil && len(messages) > 0 {
|
||||
res := Messages{
|
||||
Server: server,
|
||||
To: channel,
|
||||
Messages: messages,
|
||||
}
|
||||
|
||||
if hasMore {
|
||||
res.Next = messages[0].ID
|
||||
}
|
||||
|
||||
s.sendJSON("messages", res)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *State) sendMessages(server, channel string, count int, fromID string) {
|
||||
messages, hasMore, err := s.user.GetMessages(server, channel, count, fromID)
|
||||
if err == nil && len(messages) > 0 {
|
||||
res := Messages{
|
||||
Server: server,
|
||||
To: channel,
|
||||
Messages: messages,
|
||||
Prepend: true,
|
||||
}
|
||||
|
||||
if hasMore {
|
||||
res.Next = messages[0].ID
|
||||
}
|
||||
|
||||
s.sendJSON("messages", res)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *State) print(a ...interface{}) {
|
||||
s.sendJSON("print", Message{
|
||||
Content: fmt.Sprintln(a...),
|
||||
})
|
||||
}
|
||||
|
||||
func (s *State) printError(a ...interface{}) {
|
||||
s.sendJSON("print", Message{
|
||||
Content: fmt.Sprintln(a...),
|
||||
Type: "error",
|
||||
})
|
||||
}
|
||||
|
||||
func (s *State) resetExpirationIfEmpty() {
|
||||
if s.numIRC() == 0 && s.numWS() == 0 {
|
||||
s.reset <- AnonymousUserExpiration
|
||||
}
|
||||
}
|
||||
|
||||
func (s *State) kill() {
|
||||
s.wsLock.Lock()
|
||||
for _, ws := range s.ws {
|
||||
ws.conn.Close()
|
||||
}
|
||||
s.wsLock.Unlock()
|
||||
s.ircLock.Lock()
|
||||
for _, i := range s.irc {
|
||||
i.Quit()
|
||||
}
|
||||
s.ircLock.Unlock()
|
||||
}
|
||||
|
||||
func (s *State) run() {
|
||||
for {
|
||||
select {
|
||||
case res := <-s.broadcast:
|
||||
s.wsLock.Lock()
|
||||
for _, ws := range s.ws {
|
||||
ws.out <- res
|
||||
}
|
||||
s.wsLock.Unlock()
|
||||
|
||||
case <-s.expiration.C:
|
||||
s.srv.states.delete(s.user.ID)
|
||||
s.user.Remove()
|
||||
return
|
||||
|
||||
case duration := <-s.reset:
|
||||
if duration == 0 {
|
||||
s.expiration.Stop()
|
||||
} else {
|
||||
s.expiration.Reset(duration)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type stateStore struct {
|
||||
states map[uint64]*State
|
||||
sessions map[string]*session.Session
|
||||
sessionStore storage.SessionStore
|
||||
lock sync.Mutex
|
||||
}
|
||||
|
||||
func newStateStore(sessionStore storage.SessionStore) *stateStore {
|
||||
store := &stateStore{
|
||||
states: make(map[uint64]*State),
|
||||
sessions: make(map[string]*session.Session),
|
||||
sessionStore: sessionStore,
|
||||
}
|
||||
|
||||
sessions, err := sessionStore.GetSessions()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
for _, session := range sessions {
|
||||
if !session.Expired() {
|
||||
session.Init()
|
||||
store.sessions[session.Key()] = &session
|
||||
go deleteSessionWhenExpired(&session, store)
|
||||
} else {
|
||||
go sessionStore.DeleteSession(session.Key())
|
||||
}
|
||||
}
|
||||
|
||||
return store
|
||||
}
|
||||
|
||||
func (s *stateStore) get(id uint64) *State {
|
||||
s.lock.Lock()
|
||||
state := s.states[id]
|
||||
s.lock.Unlock()
|
||||
return state
|
||||
}
|
||||
|
||||
func (s *stateStore) set(state *State) {
|
||||
s.lock.Lock()
|
||||
s.states[state.user.ID] = state
|
||||
s.lock.Unlock()
|
||||
}
|
||||
|
||||
func (s *stateStore) delete(id uint64) {
|
||||
s.lock.Lock()
|
||||
delete(s.states, id)
|
||||
for key, session := range s.sessions {
|
||||
if session.UserID == id {
|
||||
delete(s.sessions, key)
|
||||
go s.sessionStore.DeleteSession(key)
|
||||
}
|
||||
}
|
||||
s.lock.Unlock()
|
||||
}
|
||||
|
||||
func (s *stateStore) getSession(key string) *session.Session {
|
||||
s.lock.Lock()
|
||||
session := s.sessions[key]
|
||||
s.lock.Unlock()
|
||||
return session
|
||||
}
|
||||
|
||||
func (s *stateStore) setSession(session *session.Session) {
|
||||
s.lock.Lock()
|
||||
s.sessions[session.Key()] = session
|
||||
s.lock.Unlock()
|
||||
s.sessionStore.SaveSession(session)
|
||||
}
|
||||
|
||||
func (s *stateStore) deleteSession(key string) {
|
||||
s.lock.Lock()
|
||||
id := s.sessions[key].UserID
|
||||
delete(s.sessions, key)
|
||||
n := 0
|
||||
for _, session := range s.sessions {
|
||||
if session.UserID == id {
|
||||
n++
|
||||
}
|
||||
}
|
||||
state := s.states[id]
|
||||
if n == 0 {
|
||||
delete(s.states, id)
|
||||
}
|
||||
s.lock.Unlock()
|
||||
|
||||
if n == 0 {
|
||||
// This anonymous user is not reachable anymore since all sessions have
|
||||
// expired, so we clean it up
|
||||
state.kill()
|
||||
state.user.Remove()
|
||||
}
|
||||
|
||||
s.sessionStore.DeleteSession(key)
|
||||
}
|
|
@ -11,16 +11,16 @@ import (
|
|||
|
||||
type wsHandler struct {
|
||||
ws *wsConn
|
||||
session *Session
|
||||
state *State
|
||||
addr string
|
||||
handlers map[string]func([]byte)
|
||||
}
|
||||
|
||||
func newWSHandler(conn *websocket.Conn, session *Session, r *http.Request) *wsHandler {
|
||||
func newWSHandler(conn *websocket.Conn, state *State, r *http.Request) *wsHandler {
|
||||
h := &wsHandler{
|
||||
ws: newWSConn(conn),
|
||||
session: session,
|
||||
addr: conn.RemoteAddr().String(),
|
||||
ws: newWSConn(conn),
|
||||
state: state,
|
||||
addr: conn.RemoteAddr().String(),
|
||||
}
|
||||
h.init(r)
|
||||
h.initHandlers()
|
||||
|
@ -35,8 +35,8 @@ func (h *wsHandler) run() {
|
|||
for {
|
||||
req, ok := <-h.ws.in
|
||||
if !ok {
|
||||
if h.session != nil {
|
||||
h.session.deleteWS(h.addr)
|
||||
if h.state != nil {
|
||||
h.state.deleteWS(h.addr)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
@ -52,13 +52,16 @@ func (h *wsHandler) dispatchRequest(req WSRequest) {
|
|||
}
|
||||
|
||||
func (h *wsHandler) init(r *http.Request) {
|
||||
h.session.setWS(h.addr, h.ws)
|
||||
h.state.setWS(h.addr, h.ws)
|
||||
|
||||
log.Println(h.addr, "[Session] User ID:", h.session.user.ID, "|",
|
||||
h.session.numIRC(), "IRC connections |",
|
||||
h.session.numWS(), "WebSocket connections")
|
||||
log.Println(h.addr, "[State] User ID:", h.state.user.ID, "|",
|
||||
h.state.numIRC(), "IRC connections |",
|
||||
h.state.numWS(), "WebSocket connections")
|
||||
|
||||
channels := h.session.user.GetChannels()
|
||||
channels, err := h.state.user.GetChannels()
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
}
|
||||
path := r.URL.EscapedPath()[3:]
|
||||
pathServer, pathChannel := getTabFromPath(path)
|
||||
cookieServer, cookieChannel := parseTabCookie(r, path)
|
||||
|
@ -66,16 +69,17 @@ func (h *wsHandler) init(r *http.Request) {
|
|||
for _, channel := range channels {
|
||||
if (channel.Server == pathServer && channel.Name == pathChannel) ||
|
||||
(channel.Server == cookieServer && channel.Name == cookieChannel) {
|
||||
// Userlist and messages for this channel gets embedded in the index page
|
||||
continue
|
||||
}
|
||||
|
||||
h.session.sendJSON("users", Userlist{
|
||||
h.state.sendJSON("users", Userlist{
|
||||
Server: channel.Server,
|
||||
Channel: channel.Name,
|
||||
Users: channelStore.GetUsers(channel.Server, channel.Name),
|
||||
})
|
||||
|
||||
h.session.sendLastMessages(channel.Server, channel.Name, 50)
|
||||
h.state.sendLastMessages(channel.Server, channel.Name, 50)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -83,12 +87,12 @@ func (h *wsHandler) connect(b []byte) {
|
|||
var data Server
|
||||
data.UnmarshalJSON(b)
|
||||
|
||||
if _, ok := h.session.getIRC(data.Host); !ok {
|
||||
if _, ok := h.state.getIRC(data.Host); !ok {
|
||||
log.Println(h.addr, "[IRC] Add server", data.Host)
|
||||
|
||||
connectIRC(data.Server, h.session)
|
||||
connectIRC(&data.Server, h.state)
|
||||
|
||||
go h.session.user.AddServer(data.Server)
|
||||
go h.state.user.AddServer(&data.Server)
|
||||
} else {
|
||||
log.Println(h.addr, "[IRC]", data.Host, "already added")
|
||||
}
|
||||
|
@ -98,7 +102,7 @@ func (h *wsHandler) reconnect(b []byte) {
|
|||
var data ReconnectSettings
|
||||
data.UnmarshalJSON(b)
|
||||
|
||||
if i, ok := h.session.getIRC(data.Server); ok && !i.Connected() {
|
||||
if i, ok := h.state.getIRC(data.Server); ok && !i.Connected() {
|
||||
if i.TLS {
|
||||
i.TLSConfig.InsecureSkipVerify = data.SkipVerify
|
||||
}
|
||||
|
@ -110,7 +114,7 @@ func (h *wsHandler) join(b []byte) {
|
|||
var data Join
|
||||
data.UnmarshalJSON(b)
|
||||
|
||||
if i, ok := h.session.getIRC(data.Server); ok {
|
||||
if i, ok := h.state.getIRC(data.Server); ok {
|
||||
i.Join(data.Channels...)
|
||||
}
|
||||
}
|
||||
|
@ -119,7 +123,7 @@ func (h *wsHandler) part(b []byte) {
|
|||
var data Part
|
||||
data.UnmarshalJSON(b)
|
||||
|
||||
if i, ok := h.session.getIRC(data.Server); ok {
|
||||
if i, ok := h.state.getIRC(data.Server); ok {
|
||||
i.Part(data.Channels...)
|
||||
}
|
||||
}
|
||||
|
@ -129,22 +133,22 @@ func (h *wsHandler) quit(b []byte) {
|
|||
data.UnmarshalJSON(b)
|
||||
|
||||
log.Println(h.addr, "[IRC] Remove server", data.Server)
|
||||
if i, ok := h.session.getIRC(data.Server); ok {
|
||||
h.session.deleteIRC(data.Server)
|
||||
if i, ok := h.state.getIRC(data.Server); ok {
|
||||
h.state.deleteIRC(data.Server)
|
||||
i.Quit()
|
||||
}
|
||||
|
||||
go h.session.user.RemoveServer(data.Server)
|
||||
go h.state.user.RemoveServer(data.Server)
|
||||
}
|
||||
|
||||
func (h *wsHandler) message(b []byte) {
|
||||
var data Message
|
||||
data.UnmarshalJSON(b)
|
||||
|
||||
if i, ok := h.session.getIRC(data.Server); ok {
|
||||
if i, ok := h.state.getIRC(data.Server); ok {
|
||||
i.Privmsg(data.To, data.Content)
|
||||
|
||||
go h.session.user.LogMessage(betterguid.New(),
|
||||
go h.state.user.LogMessage(betterguid.New(),
|
||||
data.Server, i.GetNick(), data.To, data.Content)
|
||||
}
|
||||
}
|
||||
|
@ -153,7 +157,7 @@ func (h *wsHandler) nick(b []byte) {
|
|||
var data Nick
|
||||
data.UnmarshalJSON(b)
|
||||
|
||||
if i, ok := h.session.getIRC(data.Server); ok {
|
||||
if i, ok := h.state.getIRC(data.Server); ok {
|
||||
i.Nick(data.New)
|
||||
}
|
||||
}
|
||||
|
@ -162,7 +166,7 @@ func (h *wsHandler) topic(b []byte) {
|
|||
var data Topic
|
||||
data.UnmarshalJSON(b)
|
||||
|
||||
if i, ok := h.session.getIRC(data.Server); ok {
|
||||
if i, ok := h.state.getIRC(data.Server); ok {
|
||||
i.Topic(data.Channel, data.Topic)
|
||||
}
|
||||
}
|
||||
|
@ -171,7 +175,7 @@ func (h *wsHandler) invite(b []byte) {
|
|||
var data Invite
|
||||
data.UnmarshalJSON(b)
|
||||
|
||||
if i, ok := h.session.getIRC(data.Server); ok {
|
||||
if i, ok := h.state.getIRC(data.Server); ok {
|
||||
i.Invite(data.User, data.Channel)
|
||||
}
|
||||
}
|
||||
|
@ -180,7 +184,7 @@ func (h *wsHandler) kick(b []byte) {
|
|||
var data Invite
|
||||
data.UnmarshalJSON(b)
|
||||
|
||||
if i, ok := h.session.getIRC(data.Server); ok {
|
||||
if i, ok := h.state.getIRC(data.Server); ok {
|
||||
i.Kick(data.Channel, data.User)
|
||||
}
|
||||
}
|
||||
|
@ -189,7 +193,7 @@ func (h *wsHandler) whois(b []byte) {
|
|||
var data Whois
|
||||
data.UnmarshalJSON(b)
|
||||
|
||||
if i, ok := h.session.getIRC(data.Server); ok {
|
||||
if i, ok := h.state.getIRC(data.Server); ok {
|
||||
i.Whois(data.User)
|
||||
}
|
||||
}
|
||||
|
@ -198,7 +202,7 @@ func (h *wsHandler) away(b []byte) {
|
|||
var data Away
|
||||
data.UnmarshalJSON(b)
|
||||
|
||||
if i, ok := h.session.getIRC(data.Server); ok {
|
||||
if i, ok := h.state.getIRC(data.Server); ok {
|
||||
i.Away(data.Message)
|
||||
}
|
||||
}
|
||||
|
@ -207,7 +211,7 @@ func (h *wsHandler) raw(b []byte) {
|
|||
var data Raw
|
||||
data.UnmarshalJSON(b)
|
||||
|
||||
if i, ok := h.session.getIRC(data.Server); ok {
|
||||
if i, ok := h.state.getIRC(data.Server); ok {
|
||||
i.Write(data.Message)
|
||||
}
|
||||
}
|
||||
|
@ -217,13 +221,13 @@ func (h *wsHandler) search(b []byte) {
|
|||
var data SearchRequest
|
||||
data.UnmarshalJSON(b)
|
||||
|
||||
results, err := h.session.user.SearchMessages(data.Server, data.Channel, data.Phrase)
|
||||
results, err := h.state.user.SearchMessages(data.Server, data.Channel, data.Phrase)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
return
|
||||
}
|
||||
|
||||
h.session.sendJSON("search", SearchResult{
|
||||
h.state.sendJSON("search", SearchResult{
|
||||
Server: data.Server,
|
||||
Channel: data.Channel,
|
||||
Results: results,
|
||||
|
@ -235,20 +239,20 @@ func (h *wsHandler) cert(b []byte) {
|
|||
var data ClientCert
|
||||
data.UnmarshalJSON(b)
|
||||
|
||||
err := h.session.user.SetCertificate(data.Cert, data.Key)
|
||||
err := h.state.user.SetCertificate(data.Cert, data.Key)
|
||||
if err != nil {
|
||||
h.session.sendJSON("cert_fail", Error{Message: err.Error()})
|
||||
h.state.sendJSON("cert_fail", Error{Message: err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
h.session.sendJSON("cert_success", nil)
|
||||
h.state.sendJSON("cert_success", nil)
|
||||
}
|
||||
|
||||
func (h *wsHandler) fetchMessages(b []byte) {
|
||||
var data FetchMessages
|
||||
data.UnmarshalJSON(b)
|
||||
|
||||
h.session.sendMessages(data.Server, data.Channel, 200, data.Next)
|
||||
h.state.sendMessages(data.Server, data.Channel, 200, data.Next)
|
||||
}
|
||||
|
||||
func (h *wsHandler) setServerName(b []byte) {
|
||||
|
@ -256,7 +260,7 @@ func (h *wsHandler) setServerName(b []byte) {
|
|||
data.UnmarshalJSON(b)
|
||||
|
||||
if isValidServerName(data.Name) {
|
||||
h.session.user.SetServerName(data.Name, data.Server)
|
||||
h.state.user.SetServerName(data.Name, data.Server)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue