Remove empty anonymous sessions after a certain time period

This commit is contained in:
Ken-Håvard Lieng 2016-01-19 22:02:12 +01:00
parent 3bcea0ec98
commit e856b66f97
6 changed files with 77 additions and 14 deletions

View File

@ -36,6 +36,11 @@ func handleAuth(w http.ResponseWriter, r *http.Request) *Session {
sessionLock.Lock() sessionLock.Lock()
session = sessions[userID] session = sessions[userID]
sessionLock.Unlock() sessionLock.Unlock()
if session == nil {
// A previous anonymous session has been cleaned up, create a new one
session = newUser(w, r)
}
} else { } else {
if err != nil { if err != nil {
authLog(r, "Invalid token: "+err.Error()) authLog(r, "Invalid token: "+err.Error())
@ -63,7 +68,7 @@ func newUser(w http.ResponseWriter, r *http.Request) *Session {
sessions[user.ID] = session sessions[user.ID] = session
sessionLock.Unlock() sessionLock.Unlock()
go session.write() go session.run()
token := jwt.New(jwt.SigningMethodHS256) token := jwt.New(jwt.SigningMethodHS256)
token.Claims["UserID"] = user.ID token.Claims["UserID"] = user.ID

View File

@ -12,7 +12,7 @@ func reconnectIRC() {
for _, user := range storage.LoadUsers() { for _, user := range storage.LoadUsers() {
session := NewSession(user) session := NewSession(user)
sessions[user.ID] = session sessions[user.ID] = session
go session.write() go session.run()
channels := user.GetChannels() channels := user.GetChannels()

View File

@ -2,11 +2,16 @@ package server
import ( import (
"sync" "sync"
"time"
"github.com/khlieng/dispatch/irc" "github.com/khlieng/dispatch/irc"
"github.com/khlieng/dispatch/storage" "github.com/khlieng/dispatch/storage"
) )
const (
AnonymousSessionExpiration = 24 * time.Hour
)
type Session struct { type Session struct {
irc map[string]*irc.Client irc map[string]*irc.Client
connectionState map[string]bool connectionState map[string]bool
@ -16,7 +21,9 @@ type Session struct {
wsLock sync.Mutex wsLock sync.Mutex
out chan WSResponse out chan WSResponse
user *storage.User user *storage.User
expiration *time.Timer
reset chan time.Duration
} }
func NewSession(user *storage.User) *Session { func NewSession(user *storage.User) *Session {
@ -26,6 +33,8 @@ func NewSession(user *storage.User) *Session {
ws: make(map[string]*wsConn), ws: make(map[string]*wsConn),
out: make(chan WSResponse, 32), out: make(chan WSResponse, 32),
user: user, user: user,
expiration: time.NewTimer(AnonymousSessionExpiration),
reset: make(chan time.Duration, 1),
} }
} }
@ -42,6 +51,8 @@ func (s *Session) setIRC(server string, i *irc.Client) {
s.irc[server] = i s.irc[server] = i
s.connectionState[server] = false s.connectionState[server] = false
s.ircLock.Unlock() s.ircLock.Unlock()
s.reset <- 0
} }
func (s *Session) deleteIRC(server string) { func (s *Session) deleteIRC(server string) {
@ -49,6 +60,8 @@ func (s *Session) deleteIRC(server string) {
delete(s.irc, server) delete(s.irc, server)
delete(s.connectionState, server) delete(s.connectionState, server)
s.ircLock.Unlock() s.ircLock.Unlock()
s.resetExpirationIfEmpty()
} }
func (s *Session) numIRC() int { func (s *Session) numIRC() int {
@ -81,12 +94,16 @@ func (s *Session) setWS(addr string, w *wsConn) {
s.wsLock.Lock() s.wsLock.Lock()
s.ws[addr] = w s.ws[addr] = w
s.wsLock.Unlock() s.wsLock.Unlock()
s.reset <- 0
} }
func (s *Session) deleteWS(addr string) { func (s *Session) deleteWS(addr string) {
s.wsLock.Lock() s.wsLock.Lock()
delete(s.ws, addr) delete(s.ws, addr)
s.wsLock.Unlock() s.wsLock.Unlock()
s.resetExpirationIfEmpty()
} }
func (s *Session) numWS() int { func (s *Session) numWS() int {
@ -108,12 +125,35 @@ func (s *Session) sendError(err error, server string) {
}) })
} }
func (s *Session) write() { func (s *Session) resetExpirationIfEmpty() {
for res := range s.out { if s.numIRC() == 0 && s.numWS() == 0 {
s.wsLock.Lock() s.reset <- AnonymousSessionExpiration
for _, ws := range s.ws { }
ws.out <- res }
}
s.wsLock.Unlock() func (s *Session) run() {
for {
select {
case res := <-s.out:
s.wsLock.Lock()
for _, ws := range s.ws {
ws.out <- res
}
s.wsLock.Unlock()
case <-s.expiration.C:
sessionLock.Lock()
delete(sessions, s.user.ID)
sessionLock.Unlock()
s.user.Remove()
return
case duration := <-s.reset:
if duration == 0 {
s.expiration.Stop()
} else {
s.expiration.Reset(duration)
}
}
} }
} }

View File

@ -2,6 +2,7 @@ package storage
import ( import (
"bytes" "bytes"
"os"
"strconv" "strconv"
"github.com/khlieng/dispatch/Godeps/_workspace/src/github.com/boltdb/bolt" "github.com/khlieng/dispatch/Godeps/_workspace/src/github.com/boltdb/bolt"
@ -169,9 +170,12 @@ func (u *User) RemoveChannel(server, channel string) {
}) })
} }
func (u *User) Close() { func (u *User) Remove() {
u.messageLog.Close() db.Batch(func(tx *bolt.Tx) error {
u.messageIndex.Close() return tx.Bucket(bucketUsers).Delete(u.id)
})
u.closeMessageLog()
os.RemoveAll(Path.User(u.Username))
} }
func (u *User) serverID(address string) []byte { func (u *User) serverID(address string) []byte {

View File

@ -167,3 +167,8 @@ func (u *User) openMessageLog() error {
return nil return nil
} }
func (u *User) closeMessageLog() {
u.messageLog.Close()
u.messageIndex.Close()
}

View File

@ -2,6 +2,7 @@ package storage
import ( import (
"io/ioutil" "io/ioutil"
"os"
"strconv" "strconv"
"testing" "testing"
@ -41,7 +42,7 @@ func TestUser(t *testing.T) {
user.AddServer(srv) user.AddServer(srv)
user.AddChannel(chan1) user.AddChannel(chan1)
user.AddChannel(chan2) user.AddChannel(chan2)
user.Close() user.closeMessageLog()
users := LoadUsers() users := LoadUsers()
assert.Len(t, users, 1) assert.Len(t, users, 1)
@ -69,6 +70,14 @@ func TestUser(t *testing.T) {
user.RemoveServer(srv.Host) user.RemoveServer(srv.Host)
assert.Len(t, user.GetServers(), 0) assert.Len(t, user.GetServers(), 0)
assert.Len(t, user.GetChannels(), 0) assert.Len(t, user.GetChannels(), 0)
user.Remove()
_, err = os.Stat(Path.User(user.Username))
assert.True(t, os.IsNotExist(err))
for _, storedUser := range LoadUsers() {
assert.NotEqual(t, user.ID, storedUser.ID)
}
} }
func TestMessages(t *testing.T) { func TestMessages(t *testing.T) {