Use random session IDs instead of jwt
This commit is contained in:
parent
6f0ea05f4b
commit
0648b67cb8
26 changed files with 45 additions and 1622 deletions
|
@ -1,15 +1,10 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/dgrijalva/jwt-go"
|
||||
|
||||
"github.com/khlieng/dispatch/storage"
|
||||
)
|
||||
|
||||
|
@ -17,33 +12,6 @@ const (
|
|||
cookieName = "dispatch"
|
||||
)
|
||||
|
||||
var (
|
||||
hmacKey []byte
|
||||
)
|
||||
|
||||
func initAuth() {
|
||||
var err error
|
||||
hmacKey, err = getHMACKey()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func getHMACKey() ([]byte, error) {
|
||||
key, err := ioutil.ReadFile(storage.Path.HMACKey())
|
||||
if err != nil {
|
||||
key = make([]byte, 32)
|
||||
rand.Read(key)
|
||||
|
||||
err = ioutil.WriteFile(storage.Path.HMACKey(), key, 0600)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return key, nil
|
||||
}
|
||||
|
||||
func handleAuth(w http.ResponseWriter, r *http.Request) *Session {
|
||||
var session *Session
|
||||
|
||||
|
@ -52,26 +20,10 @@ func handleAuth(w http.ResponseWriter, r *http.Request) *Session {
|
|||
authLog(r, "No cookie set")
|
||||
session = newUser(w, r)
|
||||
} else {
|
||||
token, err := parseToken(cookie.Value)
|
||||
|
||||
if err == nil && token.Valid {
|
||||
claims := token.Claims.(jwt.MapClaims)
|
||||
userID := uint64(claims["UserID"].(float64))
|
||||
|
||||
log.Println(r.RemoteAddr, "[Auth] GET", r.URL.Path, "| Valid token | User ID:", userID)
|
||||
|
||||
session = sessions.get(userID)
|
||||
if session == nil {
|
||||
// A previous anonymous session has been cleaned up, create a new one
|
||||
session = newUser(w, r)
|
||||
}
|
||||
session = sessions.get(cookie.Value)
|
||||
if session != nil {
|
||||
log.Println(r.RemoteAddr, "[Auth] GET", r.URL.Path, "| Valid token | User ID:", session.user.ID)
|
||||
} else {
|
||||
if err != nil {
|
||||
authLog(r, "Invalid token: "+err.Error())
|
||||
} else {
|
||||
authLog(r, "Invalid token")
|
||||
}
|
||||
|
||||
session = newUser(w, r)
|
||||
}
|
||||
}
|
||||
|
@ -87,21 +39,16 @@ func newUser(w http.ResponseWriter, r *http.Request) *Session {
|
|||
|
||||
log.Println(r.RemoteAddr, "[Auth] Create session | User ID:", user.ID)
|
||||
|
||||
session := NewSession(user)
|
||||
sessions.set(user.ID, session)
|
||||
go session.run()
|
||||
|
||||
token := jwt.New(jwt.SigningMethodHS256)
|
||||
claims := token.Claims.(jwt.MapClaims)
|
||||
claims["UserID"] = user.ID
|
||||
tokenString, err := token.SignedString(hmacKey)
|
||||
session, err := NewSession(user)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
sessions.set(session)
|
||||
go session.run()
|
||||
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: cookieName,
|
||||
Value: tokenString,
|
||||
Value: session.id,
|
||||
Path: "/",
|
||||
Expires: time.Now().AddDate(0, 1, 0),
|
||||
HttpOnly: true,
|
||||
|
@ -111,16 +58,6 @@ func newUser(w http.ResponseWriter, r *http.Request) *Session {
|
|||
return session
|
||||
}
|
||||
|
||||
func parseToken(cookie string) (*jwt.Token, error) {
|
||||
return jwt.Parse(cookie, func(token *jwt.Token) (interface{}, error) {
|
||||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
||||
}
|
||||
return hmacKey, nil
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
func authLog(r *http.Request, s string) {
|
||||
log.Println(r.RemoteAddr, "[Auth] GET", r.URL.Path, "|", s)
|
||||
}
|
||||
|
|
|
@ -2,6 +2,7 @@ package server
|
|||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"log"
|
||||
"net"
|
||||
|
||||
"github.com/khlieng/dispatch/irc"
|
||||
|
@ -27,8 +28,12 @@ func createNickInUseHandler(i *irc.Client, session *Session) func(string) string
|
|||
|
||||
func reconnectIRC() {
|
||||
for _, user := range storage.LoadUsers() {
|
||||
session := NewSession(user)
|
||||
sessions.set(user.ID, session)
|
||||
session, err := NewSession(user)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
continue
|
||||
}
|
||||
sessions.set(session)
|
||||
go session.run()
|
||||
|
||||
channels := user.GetChannels()
|
||||
|
|
|
@ -41,7 +41,7 @@ func dispatchMessage(msg *irc.Message) WSResponse {
|
|||
func dispatchMessageMulti(msg *irc.Message) chan WSResponse {
|
||||
c := irc.NewClient("nick", "user")
|
||||
c.Host = "host.com"
|
||||
s := NewSession(user)
|
||||
s, _ := NewSession(user)
|
||||
|
||||
newIRCHandler(c, s).dispatchMessage(msg)
|
||||
|
||||
|
@ -187,7 +187,7 @@ func TestHandleIRCWelcome(t *testing.T) {
|
|||
func TestHandleIRCWhois(t *testing.T) {
|
||||
c := irc.NewClient("nick", "user")
|
||||
c.Host = "host.com"
|
||||
s := NewSession(nil)
|
||||
s, _ := NewSession(nil)
|
||||
i := newIRCHandler(c, s)
|
||||
|
||||
i.dispatchMessage(&irc.Message{
|
||||
|
@ -255,7 +255,7 @@ func TestHandleIRCNoTopic(t *testing.T) {
|
|||
func TestHandleIRCNames(t *testing.T) {
|
||||
c := irc.NewClient("nick", "user")
|
||||
c.Host = "host.com"
|
||||
s := NewSession(nil)
|
||||
s, _ := NewSession(nil)
|
||||
i := newIRCHandler(c, s)
|
||||
|
||||
i.dispatchMessage(&irc.Message{
|
||||
|
@ -281,7 +281,7 @@ func TestHandleIRCNames(t *testing.T) {
|
|||
func TestHandleIRCMotd(t *testing.T) {
|
||||
c := irc.NewClient("nick", "user")
|
||||
c.Host = "host.com"
|
||||
s := NewSession(nil)
|
||||
s, _ := NewSession(nil)
|
||||
i := newIRCHandler(c, s)
|
||||
|
||||
i.dispatchMessage(&irc.Message{
|
||||
|
@ -308,7 +308,7 @@ func TestHandleIRCMotd(t *testing.T) {
|
|||
func TestHandleIRCBadNick(t *testing.T) {
|
||||
c := irc.NewClient("nick", "user")
|
||||
c.Host = "host.com"
|
||||
s := NewSession(nil)
|
||||
s, _ := NewSession(nil)
|
||||
i := newIRCHandler(c, s)
|
||||
|
||||
i.dispatchMessage(&irc.Message{
|
||||
|
|
|
@ -37,7 +37,6 @@ func Run() {
|
|||
}
|
||||
|
||||
reconnectIRC()
|
||||
initAuth()
|
||||
initFileServer()
|
||||
startHTTP()
|
||||
}
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
|
@ -11,7 +13,7 @@ import (
|
|||
)
|
||||
|
||||
const (
|
||||
AnonymousSessionExpiration = 24 * time.Hour
|
||||
AnonymousSessionExpiration = 1 * time.Minute
|
||||
)
|
||||
|
||||
type Session struct {
|
||||
|
@ -23,21 +25,33 @@ type Session struct {
|
|||
wsLock sync.Mutex
|
||||
broadcast chan WSResponse
|
||||
|
||||
id string
|
||||
user *storage.User
|
||||
expiration *time.Timer
|
||||
reset chan time.Duration
|
||||
}
|
||||
|
||||
func NewSession(user *storage.User) *Session {
|
||||
func NewSession(user *storage.User) (*Session, error) {
|
||||
id, err := newSessionID()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Session{
|
||||
irc: make(map[string]*irc.Client),
|
||||
connectionState: make(map[string]irc.ConnectionState),
|
||||
ws: make(map[string]*wsConn),
|
||||
broadcast: make(chan WSResponse, 32),
|
||||
id: id,
|
||||
user: user,
|
||||
expiration: time.NewTimer(AnonymousSessionExpiration),
|
||||
reset: make(chan time.Duration, 1),
|
||||
}
|
||||
}, nil
|
||||
}
|
||||
|
||||
func newSessionID() (string, error) {
|
||||
key := make([]byte, 32)
|
||||
_, err := rand.Read(key)
|
||||
return base64.RawURLEncoding.EncodeToString(key), err
|
||||
}
|
||||
|
||||
func (s *Session) getIRC(server string) (*irc.Client, bool) {
|
||||
|
@ -194,7 +208,7 @@ func (s *Session) run() {
|
|||
s.wsLock.Unlock()
|
||||
|
||||
case <-s.expiration.C:
|
||||
sessions.delete(s.user.ID)
|
||||
sessions.delete(s.id)
|
||||
s.user.Remove()
|
||||
return
|
||||
|
||||
|
@ -209,31 +223,31 @@ func (s *Session) run() {
|
|||
}
|
||||
|
||||
type sessionStore struct {
|
||||
sessions map[uint64]*Session
|
||||
sessions map[string]*Session
|
||||
lock sync.Mutex
|
||||
}
|
||||
|
||||
func newSessionStore() *sessionStore {
|
||||
return &sessionStore{
|
||||
sessions: make(map[uint64]*Session),
|
||||
sessions: make(map[string]*Session),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *sessionStore) get(userid uint64) *Session {
|
||||
func (s *sessionStore) get(id string) *Session {
|
||||
s.lock.Lock()
|
||||
session := s.sessions[userid]
|
||||
session := s.sessions[id]
|
||||
s.lock.Unlock()
|
||||
return session
|
||||
}
|
||||
|
||||
func (s *sessionStore) set(userid uint64, session *Session) {
|
||||
func (s *sessionStore) set(session *Session) {
|
||||
s.lock.Lock()
|
||||
s.sessions[userid] = session
|
||||
s.sessions[session.id] = session
|
||||
s.lock.Unlock()
|
||||
}
|
||||
|
||||
func (s *sessionStore) delete(userid uint64) {
|
||||
func (s *sessionStore) delete(id string) {
|
||||
s.lock.Lock()
|
||||
delete(s.sessions, userid)
|
||||
delete(s.sessions, id)
|
||||
s.lock.Unlock()
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue