diff --git a/server/auth.go b/server/auth.go index fe258417..6ba40e63 100644 --- a/server/auth.go +++ b/server/auth.go @@ -36,6 +36,11 @@ func handleAuth(w http.ResponseWriter, r *http.Request) *Session { sessionLock.Lock() session = sessions[userID] sessionLock.Unlock() + + if session == nil { + // A previous anonymous session has been cleaned up, create a new one + session = newUser(w, r) + } } else { if err != nil { authLog(r, "Invalid token: "+err.Error()) @@ -63,7 +68,7 @@ func newUser(w http.ResponseWriter, r *http.Request) *Session { sessions[user.ID] = session sessionLock.Unlock() - go session.write() + go session.run() token := jwt.New(jwt.SigningMethodHS256) token.Claims["UserID"] = user.ID diff --git a/server/irc.go b/server/irc.go index 1d188e20..86b0b743 100644 --- a/server/irc.go +++ b/server/irc.go @@ -12,7 +12,7 @@ func reconnectIRC() { for _, user := range storage.LoadUsers() { session := NewSession(user) sessions[user.ID] = session - go session.write() + go session.run() channels := user.GetChannels() diff --git a/server/session.go b/server/session.go index bf9ec63c..0ff79d57 100644 --- a/server/session.go +++ b/server/session.go @@ -2,11 +2,16 @@ package server import ( "sync" + "time" "github.com/khlieng/dispatch/irc" "github.com/khlieng/dispatch/storage" ) +const ( + AnonymousSessionExpiration = 24 * time.Hour +) + type Session struct { irc map[string]*irc.Client connectionState map[string]bool @@ -16,7 +21,9 @@ type Session struct { wsLock sync.Mutex out chan WSResponse - user *storage.User + user *storage.User + expiration *time.Timer + reset chan time.Duration } func NewSession(user *storage.User) *Session { @@ -26,6 +33,8 @@ func NewSession(user *storage.User) *Session { ws: make(map[string]*wsConn), out: make(chan WSResponse, 32), user: user, + expiration: time.NewTimer(AnonymousSessionExpiration), + reset: make(chan time.Duration, 1), } } @@ -42,6 +51,8 @@ func (s *Session) setIRC(server string, i *irc.Client) { s.irc[server] = i s.connectionState[server] = false s.ircLock.Unlock() + + s.reset <- 0 } func (s *Session) deleteIRC(server string) { @@ -49,6 +60,8 @@ func (s *Session) deleteIRC(server string) { delete(s.irc, server) delete(s.connectionState, server) s.ircLock.Unlock() + + s.resetExpirationIfEmpty() } func (s *Session) numIRC() int { @@ -81,12 +94,16 @@ func (s *Session) setWS(addr string, w *wsConn) { s.wsLock.Lock() s.ws[addr] = w s.wsLock.Unlock() + + s.reset <- 0 } func (s *Session) deleteWS(addr string) { s.wsLock.Lock() delete(s.ws, addr) s.wsLock.Unlock() + + s.resetExpirationIfEmpty() } func (s *Session) numWS() int { @@ -108,12 +125,35 @@ func (s *Session) sendError(err error, server string) { }) } -func (s *Session) write() { - for res := range s.out { - s.wsLock.Lock() - for _, ws := range s.ws { - ws.out <- res - } - s.wsLock.Unlock() +func (s *Session) resetExpirationIfEmpty() { + if s.numIRC() == 0 && s.numWS() == 0 { + s.reset <- AnonymousSessionExpiration + } +} + +func (s *Session) run() { + for { + select { + case res := <-s.out: + s.wsLock.Lock() + for _, ws := range s.ws { + ws.out <- res + } + s.wsLock.Unlock() + + case <-s.expiration.C: + sessionLock.Lock() + delete(sessions, s.user.ID) + sessionLock.Unlock() + s.user.Remove() + return + + case duration := <-s.reset: + if duration == 0 { + s.expiration.Stop() + } else { + s.expiration.Reset(duration) + } + } } } diff --git a/storage/user.go b/storage/user.go index d24f87b8..8f94969c 100644 --- a/storage/user.go +++ b/storage/user.go @@ -2,6 +2,7 @@ package storage import ( "bytes" + "os" "strconv" "github.com/khlieng/dispatch/Godeps/_workspace/src/github.com/boltdb/bolt" @@ -169,9 +170,12 @@ func (u *User) RemoveChannel(server, channel string) { }) } -func (u *User) Close() { - u.messageLog.Close() - u.messageIndex.Close() +func (u *User) Remove() { + db.Batch(func(tx *bolt.Tx) error { + return tx.Bucket(bucketUsers).Delete(u.id) + }) + u.closeMessageLog() + os.RemoveAll(Path.User(u.Username)) } func (u *User) serverID(address string) []byte { diff --git a/storage/user_messages.go b/storage/user_messages.go index bafb5cfc..1c4a98eb 100644 --- a/storage/user_messages.go +++ b/storage/user_messages.go @@ -167,3 +167,8 @@ func (u *User) openMessageLog() error { return nil } + +func (u *User) closeMessageLog() { + u.messageLog.Close() + u.messageIndex.Close() +} diff --git a/storage/user_test.go b/storage/user_test.go index c83f8361..114fd937 100644 --- a/storage/user_test.go +++ b/storage/user_test.go @@ -2,6 +2,7 @@ package storage import ( "io/ioutil" + "os" "strconv" "testing" @@ -41,7 +42,7 @@ func TestUser(t *testing.T) { user.AddServer(srv) user.AddChannel(chan1) user.AddChannel(chan2) - user.Close() + user.closeMessageLog() users := LoadUsers() assert.Len(t, users, 1) @@ -69,6 +70,14 @@ func TestUser(t *testing.T) { user.RemoveServer(srv.Host) assert.Len(t, user.GetServers(), 0) assert.Len(t, user.GetChannels(), 0) + + user.Remove() + _, err = os.Stat(Path.User(user.Username)) + assert.True(t, os.IsNotExist(err)) + + for _, storedUser := range LoadUsers() { + assert.NotEqual(t, user.ID, storedUser.ID) + } } func TestMessages(t *testing.T) {