Store auth info in a JWT token in a cookie

This commit is contained in:
Ken-Håvard Lieng 2016-01-15 02:27:30 +01:00
parent 3e0a1be6bc
commit fb54d4966c
18 changed files with 499 additions and 331 deletions

102
server/auth.go Normal file
View file

@ -0,0 +1,102 @@
package server
import (
"crypto/rand"
"fmt"
"io/ioutil"
"log"
"net/http"
"github.com/dgrijalva/jwt-go"
"github.com/khlieng/dispatch/storage"
)
func handleAuth(w http.ResponseWriter, r *http.Request) *Session {
var session *Session
cookie, err := r.Cookie(cookieName)
if err != nil {
authLog(r, "No cookie set")
session = newUser(w, r)
} else {
token, err := jwt.Parse(cookie.Value, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
return hmacKey, nil
})
if err == nil && token.Valid {
userID := uint64(token.Claims["UserID"].(float64))
log.Println(r.RemoteAddr, "[Auth] GET", r.URL.Path, "| Valid token | User ID:", userID)
sessionLock.Lock()
session = sessions[userID]
sessionLock.Unlock()
} else {
if err != nil {
authLog(r, "Invalid token: "+err.Error())
} else {
authLog(r, "Invalid token")
}
session = newUser(w, r)
}
}
return session
}
func newUser(w http.ResponseWriter, r *http.Request) *Session {
user := storage.NewUser()
if user == nil {
return nil
}
log.Println(r.RemoteAddr, "[Auth] Create session | User ID:", user.ID)
session := NewSession(user)
sessionLock.Lock()
sessions[user.ID] = session
sessionLock.Unlock()
go session.write()
token := jwt.New(jwt.SigningMethodHS256)
token.Claims["UserID"] = user.ID
tokenString, err := token.SignedString(hmacKey)
if err != nil {
return nil
}
http.SetCookie(w, &http.Cookie{
Name: cookieName,
Value: tokenString,
Path: "/",
HttpOnly: true,
Secure: r.TLS != nil,
})
return session
}
func getHMACKey() ([]byte, error) {
key, err := ioutil.ReadFile(storage.Path.HMACKey())
if err != nil {
key = make([]byte, 32)
rand.Read(key)
err = ioutil.WriteFile(storage.Path.HMACKey(), key, 0600)
if err != nil {
return nil, err
}
}
return key, nil
}
func authLog(r *http.Request, s string) {
log.Println(r.RemoteAddr, "[Auth] GET", r.URL.Path, "|", s)
}

View file

@ -10,9 +10,8 @@ import (
func reconnectIRC() {
for _, user := range storage.LoadUsers() {
session := NewSession()
session.user = user
sessions[user.UUID] = session
session := NewSession(user)
sessions[user.ID] = session
go session.write()
channels := user.GetChannels()
@ -30,7 +29,13 @@ func reconnectIRC() {
}
session.setIRC(server.Host, i)
i.Connect(net.JoinHostPort(server.Host, server.Port))
if server.Port != "" {
i.Connect(net.JoinHostPort(server.Host, server.Port))
} else {
i.Connect(server.Host)
}
go newIRCHandler(i, session).run()
var joining []string

View file

@ -22,7 +22,7 @@ func TestMain(m *testing.M) {
storage.Initialize(tempdir)
storage.Open()
user = storage.NewUser("uuid")
user = storage.NewUser()
channelStore = storage.NewChannelStore()
code := m.Run()
@ -34,8 +34,7 @@ func TestMain(m *testing.M) {
func dispatchMessage(msg *irc.Message) WSResponse {
c := irc.NewClient("nick", "user")
c.Host = "host.com"
s := NewSession()
s.user = user
s := NewSession(user)
newIRCHandler(c, s).dispatchMessage(msg)
@ -168,7 +167,7 @@ func TestHandleIRCWelcome(t *testing.T) {
func TestHandleIRCWhois(t *testing.T) {
c := irc.NewClient("nick", "user")
c.Host = "host.com"
s := NewSession()
s := NewSession(nil)
i := newIRCHandler(c, s)
i.dispatchMessage(&irc.Message{
@ -212,7 +211,7 @@ func TestHandleIRCTopic(t *testing.T) {
func TestHandleIRCNames(t *testing.T) {
c := irc.NewClient("nick", "user")
c.Host = "host.com"
s := NewSession()
s := NewSession(nil)
i := newIRCHandler(c, s)
i.dispatchMessage(&irc.Message{
@ -240,7 +239,7 @@ func TestHandleIRCNames(t *testing.T) {
func TestHandleIRCMotd(t *testing.T) {
c := irc.NewClient("nick", "user")
c.Host = "host.com"
s := NewSession()
s := NewSession(nil)
i := newIRCHandler(c, s)
i.dispatchMessage(&irc.Message{

View file

@ -29,10 +29,16 @@ type File struct {
func serveFiles(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/" {
handleAuth(w, r)
serveFile(w, r, "index.html.gz", "text/html")
return
}
if strings.HasSuffix(r.URL.Path, "favicon.ico") {
w.WriteHeader(404)
return
}
for _, file := range files {
if strings.HasSuffix(r.URL.Path, file.Path) {
serveFile(w, r, file.Path+".gz", file.ContentType)
@ -40,6 +46,7 @@ func serveFiles(w http.ResponseWriter, r *http.Request) {
}
}
handleAuth(w, r)
serveFile(w, r, "index.html.gz", "text/html")
}

View file

@ -17,11 +17,17 @@ import (
"github.com/khlieng/dispatch/storage"
)
const (
cookieName = "dispatch"
)
var (
channelStore *storage.ChannelStore
sessions map[string]*Session
sessions map[uint64]*Session
sessionLock sync.Mutex
hmacKey []byte
upgrader = websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
@ -35,7 +41,13 @@ func Run() {
defer storage.Close()
channelStore = storage.NewChannelStore()
sessions = make(map[string]*Session)
sessions = make(map[uint64]*Session)
var err error
hmacKey, err = getHMACKey()
if err != nil {
log.Fatal(err)
}
reconnectIRC()
startHTTP()
@ -95,27 +107,32 @@ func startHTTP() {
func serve(w http.ResponseWriter, r *http.Request) {
if r.Method != "GET" {
w.WriteHeader(404)
return
}
if r.URL.Path == "/ws" {
upgradeWS(w, r)
session := handleAuth(w, r)
if session == nil {
log.Println("[Auth] No session")
w.WriteHeader(500)
return
}
upgradeWS(w, r, session)
} else {
serveFiles(w, r)
}
}
func upgradeWS(w http.ResponseWriter, r *http.Request) {
conn, err := upgrader.Upgrade(w, r, nil)
func upgradeWS(w http.ResponseWriter, r *http.Request, session *Session) {
conn, err := upgrader.Upgrade(w, r, w.Header())
if err != nil {
log.Println(err)
return
}
uuid := r.URL.Query().Get("uuid")
if uuid != "" {
newWSHandler(conn, uuid).run()
}
newWSHandler(conn, session).run()
}
func createHTTPSRedirect(portHTTPS string) http.HandlerFunc {

View file

@ -19,12 +19,13 @@ type Session struct {
user *storage.User
}
func NewSession() *Session {
func NewSession(user *storage.User) *Session {
return &Session{
irc: make(map[string]*irc.Client),
connectionState: make(map[string]bool),
ws: make(map[string]*wsConn),
out: make(chan WSResponse, 32),
user: user,
}
}
@ -88,6 +89,14 @@ func (s *Session) deleteWS(addr string) {
s.wsLock.Unlock()
}
func (s *Session) numWS() int {
s.ircLock.Lock()
n := len(s.ws)
s.ircLock.Unlock()
return n
}
func (s *Session) sendJSON(t string, v interface{}) {
s.out <- WSResponse{t, v}
}

View file

@ -20,13 +20,14 @@ type wsHandler struct {
handlers map[string]func([]byte)
}
func newWSHandler(conn *websocket.Conn, uuid string) *wsHandler {
func newWSHandler(conn *websocket.Conn, session *Session) *wsHandler {
h := &wsHandler{
ws: newWSConn(conn),
addr: conn.RemoteAddr().String(),
ws: newWSConn(conn),
session: session,
addr: conn.RemoteAddr().String(),
}
h.init(uuid)
h.initHandlers()
h.init()
return h
}
@ -54,44 +55,28 @@ func (h *wsHandler) dispatchRequest(req WSRequest) {
}
}
func (h *wsHandler) init(uuid string) {
log.Println(h.addr, "set UUID", uuid)
func (h *wsHandler) init() {
h.session.setWS(h.addr, h.ws)
sessionLock.Lock()
if storedSession, exists := sessions[uuid]; exists {
sessionLock.Unlock()
h.session = storedSession
h.session.setWS(h.addr, h.ws)
log.Println(h.addr, "[Session] User ID:", h.session.user.ID, "|",
h.session.numIRC(), "IRC connections |",
h.session.numWS(), "WebSocket connections")
log.Println(h.addr, "attached to", h.session.numIRC(), "existing IRC connections")
channels := h.session.user.GetChannels()
for i, channel := range channels {
channels[i].Topic = channelStore.GetTopic(channel.Server, channel.Name)
}
channels := h.session.user.GetChannels()
for i, channel := range channels {
channels[i].Topic = channelStore.GetTopic(channel.Server, channel.Name)
}
h.session.sendJSON("channels", channels)
h.session.sendJSON("servers", h.session.user.GetServers())
h.session.sendJSON("connection_update", h.session.getConnectionStates())
h.session.sendJSON("channels", channels)
h.session.sendJSON("servers", h.session.user.GetServers())
h.session.sendJSON("connection_update", h.session.getConnectionStates())
for _, channel := range channels {
h.session.sendJSON("users", Userlist{
Server: channel.Server,
Channel: channel.Name,
Users: channelStore.GetUsers(channel.Server, channel.Name),
})
}
} else {
h.session = NewSession()
h.session.user = storage.NewUser(uuid)
sessions[uuid] = h.session
sessionLock.Unlock()
h.session.setWS(h.addr, h.ws)
h.session.sendJSON("servers", nil)
go h.session.write()
for _, channel := range channels {
h.session.sendJSON("users", Userlist{
Server: channel.Server,
Channel: channel.Name,
Users: channelStore.GetUsers(channel.Server, channel.Name),
})
}
}
@ -105,7 +90,7 @@ func (h *wsHandler) connect(b []byte) {
}
if _, ok := h.session.getIRC(host); !ok {
log.Println(h.addr, "connecting to", data.Server)
log.Println(h.addr, "[IRC] Add server", data.Server)
i := irc.NewClient(data.Nick, data.Username)
i.TLS = data.TLS
@ -134,7 +119,7 @@ func (h *wsHandler) connect(b []byte) {
Realname: data.Realname,
})
} else {
log.Println(h.addr, "already connected to", data.Server)
log.Println(h.addr, "[IRC]", data.Server, "already added")
}
}
@ -161,6 +146,8 @@ func (h *wsHandler) quit(b []byte) {
json.Unmarshal(b, &data)
if i, ok := h.session.getIRC(data.Server); ok {
log.Println(h.addr, "[IRC] Remove server", data.Server)
i.Quit()
h.session.deleteIRC(data.Server)
channelStore.RemoveUserAll(i.GetNick(), data.Server)