Remove empty anonymous sessions after a certain time period
This commit is contained in:
parent
3bcea0ec98
commit
e856b66f97
@ -36,6 +36,11 @@ func handleAuth(w http.ResponseWriter, r *http.Request) *Session {
|
|||||||
sessionLock.Lock()
|
sessionLock.Lock()
|
||||||
session = sessions[userID]
|
session = sessions[userID]
|
||||||
sessionLock.Unlock()
|
sessionLock.Unlock()
|
||||||
|
|
||||||
|
if session == nil {
|
||||||
|
// A previous anonymous session has been cleaned up, create a new one
|
||||||
|
session = newUser(w, r)
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
authLog(r, "Invalid token: "+err.Error())
|
authLog(r, "Invalid token: "+err.Error())
|
||||||
@ -63,7 +68,7 @@ func newUser(w http.ResponseWriter, r *http.Request) *Session {
|
|||||||
sessions[user.ID] = session
|
sessions[user.ID] = session
|
||||||
sessionLock.Unlock()
|
sessionLock.Unlock()
|
||||||
|
|
||||||
go session.write()
|
go session.run()
|
||||||
|
|
||||||
token := jwt.New(jwt.SigningMethodHS256)
|
token := jwt.New(jwt.SigningMethodHS256)
|
||||||
token.Claims["UserID"] = user.ID
|
token.Claims["UserID"] = user.ID
|
||||||
|
@ -12,7 +12,7 @@ func reconnectIRC() {
|
|||||||
for _, user := range storage.LoadUsers() {
|
for _, user := range storage.LoadUsers() {
|
||||||
session := NewSession(user)
|
session := NewSession(user)
|
||||||
sessions[user.ID] = session
|
sessions[user.ID] = session
|
||||||
go session.write()
|
go session.run()
|
||||||
|
|
||||||
channels := user.GetChannels()
|
channels := user.GetChannels()
|
||||||
|
|
||||||
|
@ -2,11 +2,16 @@ package server
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"sync"
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/khlieng/dispatch/irc"
|
"github.com/khlieng/dispatch/irc"
|
||||||
"github.com/khlieng/dispatch/storage"
|
"github.com/khlieng/dispatch/storage"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
AnonymousSessionExpiration = 24 * time.Hour
|
||||||
|
)
|
||||||
|
|
||||||
type Session struct {
|
type Session struct {
|
||||||
irc map[string]*irc.Client
|
irc map[string]*irc.Client
|
||||||
connectionState map[string]bool
|
connectionState map[string]bool
|
||||||
@ -17,6 +22,8 @@ type Session struct {
|
|||||||
out chan WSResponse
|
out chan WSResponse
|
||||||
|
|
||||||
user *storage.User
|
user *storage.User
|
||||||
|
expiration *time.Timer
|
||||||
|
reset chan time.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSession(user *storage.User) *Session {
|
func NewSession(user *storage.User) *Session {
|
||||||
@ -26,6 +33,8 @@ func NewSession(user *storage.User) *Session {
|
|||||||
ws: make(map[string]*wsConn),
|
ws: make(map[string]*wsConn),
|
||||||
out: make(chan WSResponse, 32),
|
out: make(chan WSResponse, 32),
|
||||||
user: user,
|
user: user,
|
||||||
|
expiration: time.NewTimer(AnonymousSessionExpiration),
|
||||||
|
reset: make(chan time.Duration, 1),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -42,6 +51,8 @@ func (s *Session) setIRC(server string, i *irc.Client) {
|
|||||||
s.irc[server] = i
|
s.irc[server] = i
|
||||||
s.connectionState[server] = false
|
s.connectionState[server] = false
|
||||||
s.ircLock.Unlock()
|
s.ircLock.Unlock()
|
||||||
|
|
||||||
|
s.reset <- 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Session) deleteIRC(server string) {
|
func (s *Session) deleteIRC(server string) {
|
||||||
@ -49,6 +60,8 @@ func (s *Session) deleteIRC(server string) {
|
|||||||
delete(s.irc, server)
|
delete(s.irc, server)
|
||||||
delete(s.connectionState, server)
|
delete(s.connectionState, server)
|
||||||
s.ircLock.Unlock()
|
s.ircLock.Unlock()
|
||||||
|
|
||||||
|
s.resetExpirationIfEmpty()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Session) numIRC() int {
|
func (s *Session) numIRC() int {
|
||||||
@ -81,12 +94,16 @@ func (s *Session) setWS(addr string, w *wsConn) {
|
|||||||
s.wsLock.Lock()
|
s.wsLock.Lock()
|
||||||
s.ws[addr] = w
|
s.ws[addr] = w
|
||||||
s.wsLock.Unlock()
|
s.wsLock.Unlock()
|
||||||
|
|
||||||
|
s.reset <- 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Session) deleteWS(addr string) {
|
func (s *Session) deleteWS(addr string) {
|
||||||
s.wsLock.Lock()
|
s.wsLock.Lock()
|
||||||
delete(s.ws, addr)
|
delete(s.ws, addr)
|
||||||
s.wsLock.Unlock()
|
s.wsLock.Unlock()
|
||||||
|
|
||||||
|
s.resetExpirationIfEmpty()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Session) numWS() int {
|
func (s *Session) numWS() int {
|
||||||
@ -108,12 +125,35 @@ func (s *Session) sendError(err error, server string) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Session) write() {
|
func (s *Session) resetExpirationIfEmpty() {
|
||||||
for res := range s.out {
|
if s.numIRC() == 0 && s.numWS() == 0 {
|
||||||
|
s.reset <- AnonymousSessionExpiration
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Session) run() {
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case res := <-s.out:
|
||||||
s.wsLock.Lock()
|
s.wsLock.Lock()
|
||||||
for _, ws := range s.ws {
|
for _, ws := range s.ws {
|
||||||
ws.out <- res
|
ws.out <- res
|
||||||
}
|
}
|
||||||
s.wsLock.Unlock()
|
s.wsLock.Unlock()
|
||||||
|
|
||||||
|
case <-s.expiration.C:
|
||||||
|
sessionLock.Lock()
|
||||||
|
delete(sessions, s.user.ID)
|
||||||
|
sessionLock.Unlock()
|
||||||
|
s.user.Remove()
|
||||||
|
return
|
||||||
|
|
||||||
|
case duration := <-s.reset:
|
||||||
|
if duration == 0 {
|
||||||
|
s.expiration.Stop()
|
||||||
|
} else {
|
||||||
|
s.expiration.Reset(duration)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -2,6 +2,7 @@ package storage
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"os"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
"github.com/khlieng/dispatch/Godeps/_workspace/src/github.com/boltdb/bolt"
|
"github.com/khlieng/dispatch/Godeps/_workspace/src/github.com/boltdb/bolt"
|
||||||
@ -169,9 +170,12 @@ func (u *User) RemoveChannel(server, channel string) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *User) Close() {
|
func (u *User) Remove() {
|
||||||
u.messageLog.Close()
|
db.Batch(func(tx *bolt.Tx) error {
|
||||||
u.messageIndex.Close()
|
return tx.Bucket(bucketUsers).Delete(u.id)
|
||||||
|
})
|
||||||
|
u.closeMessageLog()
|
||||||
|
os.RemoveAll(Path.User(u.Username))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *User) serverID(address string) []byte {
|
func (u *User) serverID(address string) []byte {
|
||||||
|
@ -167,3 +167,8 @@ func (u *User) openMessageLog() error {
|
|||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (u *User) closeMessageLog() {
|
||||||
|
u.messageLog.Close()
|
||||||
|
u.messageIndex.Close()
|
||||||
|
}
|
||||||
|
@ -2,6 +2,7 @@ package storage
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
|
"os"
|
||||||
"strconv"
|
"strconv"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
@ -41,7 +42,7 @@ func TestUser(t *testing.T) {
|
|||||||
user.AddServer(srv)
|
user.AddServer(srv)
|
||||||
user.AddChannel(chan1)
|
user.AddChannel(chan1)
|
||||||
user.AddChannel(chan2)
|
user.AddChannel(chan2)
|
||||||
user.Close()
|
user.closeMessageLog()
|
||||||
|
|
||||||
users := LoadUsers()
|
users := LoadUsers()
|
||||||
assert.Len(t, users, 1)
|
assert.Len(t, users, 1)
|
||||||
@ -69,6 +70,14 @@ func TestUser(t *testing.T) {
|
|||||||
user.RemoveServer(srv.Host)
|
user.RemoveServer(srv.Host)
|
||||||
assert.Len(t, user.GetServers(), 0)
|
assert.Len(t, user.GetServers(), 0)
|
||||||
assert.Len(t, user.GetChannels(), 0)
|
assert.Len(t, user.GetChannels(), 0)
|
||||||
|
|
||||||
|
user.Remove()
|
||||||
|
_, err = os.Stat(Path.User(user.Username))
|
||||||
|
assert.True(t, os.IsNotExist(err))
|
||||||
|
|
||||||
|
for _, storedUser := range LoadUsers() {
|
||||||
|
assert.NotEqual(t, user.ID, storedUser.ID)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestMessages(t *testing.T) {
|
func TestMessages(t *testing.T) {
|
||||||
|
Loading…
Reference in New Issue
Block a user