Store auth info in a JWT token in a cookie
This commit is contained in:
parent
3e0a1be6bc
commit
fb54d4966c
18 changed files with 499 additions and 331 deletions
102
server/auth.go
Normal file
102
server/auth.go
Normal file
|
@ -0,0 +1,102 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"net/http"
|
||||
|
||||
"github.com/dgrijalva/jwt-go"
|
||||
|
||||
"github.com/khlieng/dispatch/storage"
|
||||
)
|
||||
|
||||
func handleAuth(w http.ResponseWriter, r *http.Request) *Session {
|
||||
var session *Session
|
||||
|
||||
cookie, err := r.Cookie(cookieName)
|
||||
if err != nil {
|
||||
authLog(r, "No cookie set")
|
||||
session = newUser(w, r)
|
||||
} else {
|
||||
token, err := jwt.Parse(cookie.Value, func(token *jwt.Token) (interface{}, error) {
|
||||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
||||
}
|
||||
return hmacKey, nil
|
||||
})
|
||||
|
||||
if err == nil && token.Valid {
|
||||
userID := uint64(token.Claims["UserID"].(float64))
|
||||
|
||||
log.Println(r.RemoteAddr, "[Auth] GET", r.URL.Path, "| Valid token | User ID:", userID)
|
||||
|
||||
sessionLock.Lock()
|
||||
session = sessions[userID]
|
||||
sessionLock.Unlock()
|
||||
} else {
|
||||
if err != nil {
|
||||
authLog(r, "Invalid token: "+err.Error())
|
||||
} else {
|
||||
authLog(r, "Invalid token")
|
||||
}
|
||||
session = newUser(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
return session
|
||||
}
|
||||
|
||||
func newUser(w http.ResponseWriter, r *http.Request) *Session {
|
||||
user := storage.NewUser()
|
||||
if user == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Println(r.RemoteAddr, "[Auth] Create session | User ID:", user.ID)
|
||||
|
||||
session := NewSession(user)
|
||||
|
||||
sessionLock.Lock()
|
||||
sessions[user.ID] = session
|
||||
sessionLock.Unlock()
|
||||
|
||||
go session.write()
|
||||
|
||||
token := jwt.New(jwt.SigningMethodHS256)
|
||||
token.Claims["UserID"] = user.ID
|
||||
tokenString, err := token.SignedString(hmacKey)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: cookieName,
|
||||
Value: tokenString,
|
||||
Path: "/",
|
||||
HttpOnly: true,
|
||||
Secure: r.TLS != nil,
|
||||
})
|
||||
|
||||
return session
|
||||
}
|
||||
|
||||
func getHMACKey() ([]byte, error) {
|
||||
key, err := ioutil.ReadFile(storage.Path.HMACKey())
|
||||
if err != nil {
|
||||
key = make([]byte, 32)
|
||||
rand.Read(key)
|
||||
|
||||
err = ioutil.WriteFile(storage.Path.HMACKey(), key, 0600)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return key, nil
|
||||
}
|
||||
|
||||
func authLog(r *http.Request, s string) {
|
||||
log.Println(r.RemoteAddr, "[Auth] GET", r.URL.Path, "|", s)
|
||||
}
|
|
@ -10,9 +10,8 @@ import (
|
|||
|
||||
func reconnectIRC() {
|
||||
for _, user := range storage.LoadUsers() {
|
||||
session := NewSession()
|
||||
session.user = user
|
||||
sessions[user.UUID] = session
|
||||
session := NewSession(user)
|
||||
sessions[user.ID] = session
|
||||
go session.write()
|
||||
|
||||
channels := user.GetChannels()
|
||||
|
@ -30,7 +29,13 @@ func reconnectIRC() {
|
|||
}
|
||||
|
||||
session.setIRC(server.Host, i)
|
||||
i.Connect(net.JoinHostPort(server.Host, server.Port))
|
||||
|
||||
if server.Port != "" {
|
||||
i.Connect(net.JoinHostPort(server.Host, server.Port))
|
||||
} else {
|
||||
i.Connect(server.Host)
|
||||
}
|
||||
|
||||
go newIRCHandler(i, session).run()
|
||||
|
||||
var joining []string
|
||||
|
|
|
@ -22,7 +22,7 @@ func TestMain(m *testing.M) {
|
|||
|
||||
storage.Initialize(tempdir)
|
||||
storage.Open()
|
||||
user = storage.NewUser("uuid")
|
||||
user = storage.NewUser()
|
||||
channelStore = storage.NewChannelStore()
|
||||
|
||||
code := m.Run()
|
||||
|
@ -34,8 +34,7 @@ func TestMain(m *testing.M) {
|
|||
func dispatchMessage(msg *irc.Message) WSResponse {
|
||||
c := irc.NewClient("nick", "user")
|
||||
c.Host = "host.com"
|
||||
s := NewSession()
|
||||
s.user = user
|
||||
s := NewSession(user)
|
||||
|
||||
newIRCHandler(c, s).dispatchMessage(msg)
|
||||
|
||||
|
@ -168,7 +167,7 @@ func TestHandleIRCWelcome(t *testing.T) {
|
|||
func TestHandleIRCWhois(t *testing.T) {
|
||||
c := irc.NewClient("nick", "user")
|
||||
c.Host = "host.com"
|
||||
s := NewSession()
|
||||
s := NewSession(nil)
|
||||
i := newIRCHandler(c, s)
|
||||
|
||||
i.dispatchMessage(&irc.Message{
|
||||
|
@ -212,7 +211,7 @@ func TestHandleIRCTopic(t *testing.T) {
|
|||
func TestHandleIRCNames(t *testing.T) {
|
||||
c := irc.NewClient("nick", "user")
|
||||
c.Host = "host.com"
|
||||
s := NewSession()
|
||||
s := NewSession(nil)
|
||||
i := newIRCHandler(c, s)
|
||||
|
||||
i.dispatchMessage(&irc.Message{
|
||||
|
@ -240,7 +239,7 @@ func TestHandleIRCNames(t *testing.T) {
|
|||
func TestHandleIRCMotd(t *testing.T) {
|
||||
c := irc.NewClient("nick", "user")
|
||||
c.Host = "host.com"
|
||||
s := NewSession()
|
||||
s := NewSession(nil)
|
||||
i := newIRCHandler(c, s)
|
||||
|
||||
i.dispatchMessage(&irc.Message{
|
||||
|
|
|
@ -29,10 +29,16 @@ type File struct {
|
|||
|
||||
func serveFiles(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/" {
|
||||
handleAuth(w, r)
|
||||
serveFile(w, r, "index.html.gz", "text/html")
|
||||
return
|
||||
}
|
||||
|
||||
if strings.HasSuffix(r.URL.Path, "favicon.ico") {
|
||||
w.WriteHeader(404)
|
||||
return
|
||||
}
|
||||
|
||||
for _, file := range files {
|
||||
if strings.HasSuffix(r.URL.Path, file.Path) {
|
||||
serveFile(w, r, file.Path+".gz", file.ContentType)
|
||||
|
@ -40,6 +46,7 @@ func serveFiles(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
}
|
||||
|
||||
handleAuth(w, r)
|
||||
serveFile(w, r, "index.html.gz", "text/html")
|
||||
}
|
||||
|
||||
|
|
|
@ -17,11 +17,17 @@ import (
|
|||
"github.com/khlieng/dispatch/storage"
|
||||
)
|
||||
|
||||
const (
|
||||
cookieName = "dispatch"
|
||||
)
|
||||
|
||||
var (
|
||||
channelStore *storage.ChannelStore
|
||||
sessions map[string]*Session
|
||||
sessions map[uint64]*Session
|
||||
sessionLock sync.Mutex
|
||||
|
||||
hmacKey []byte
|
||||
|
||||
upgrader = websocket.Upgrader{
|
||||
ReadBufferSize: 1024,
|
||||
WriteBufferSize: 1024,
|
||||
|
@ -35,7 +41,13 @@ func Run() {
|
|||
defer storage.Close()
|
||||
|
||||
channelStore = storage.NewChannelStore()
|
||||
sessions = make(map[string]*Session)
|
||||
sessions = make(map[uint64]*Session)
|
||||
|
||||
var err error
|
||||
hmacKey, err = getHMACKey()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
reconnectIRC()
|
||||
startHTTP()
|
||||
|
@ -95,27 +107,32 @@ func startHTTP() {
|
|||
|
||||
func serve(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != "GET" {
|
||||
w.WriteHeader(404)
|
||||
return
|
||||
}
|
||||
|
||||
if r.URL.Path == "/ws" {
|
||||
upgradeWS(w, r)
|
||||
session := handleAuth(w, r)
|
||||
if session == nil {
|
||||
log.Println("[Auth] No session")
|
||||
w.WriteHeader(500)
|
||||
return
|
||||
}
|
||||
|
||||
upgradeWS(w, r, session)
|
||||
} else {
|
||||
serveFiles(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
func upgradeWS(w http.ResponseWriter, r *http.Request) {
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
func upgradeWS(w http.ResponseWriter, r *http.Request, session *Session) {
|
||||
conn, err := upgrader.Upgrade(w, r, w.Header())
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
return
|
||||
}
|
||||
|
||||
uuid := r.URL.Query().Get("uuid")
|
||||
if uuid != "" {
|
||||
newWSHandler(conn, uuid).run()
|
||||
}
|
||||
newWSHandler(conn, session).run()
|
||||
}
|
||||
|
||||
func createHTTPSRedirect(portHTTPS string) http.HandlerFunc {
|
||||
|
|
|
@ -19,12 +19,13 @@ type Session struct {
|
|||
user *storage.User
|
||||
}
|
||||
|
||||
func NewSession() *Session {
|
||||
func NewSession(user *storage.User) *Session {
|
||||
return &Session{
|
||||
irc: make(map[string]*irc.Client),
|
||||
connectionState: make(map[string]bool),
|
||||
ws: make(map[string]*wsConn),
|
||||
out: make(chan WSResponse, 32),
|
||||
user: user,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -88,6 +89,14 @@ func (s *Session) deleteWS(addr string) {
|
|||
s.wsLock.Unlock()
|
||||
}
|
||||
|
||||
func (s *Session) numWS() int {
|
||||
s.ircLock.Lock()
|
||||
n := len(s.ws)
|
||||
s.ircLock.Unlock()
|
||||
|
||||
return n
|
||||
}
|
||||
|
||||
func (s *Session) sendJSON(t string, v interface{}) {
|
||||
s.out <- WSResponse{t, v}
|
||||
}
|
||||
|
|
|
@ -20,13 +20,14 @@ type wsHandler struct {
|
|||
handlers map[string]func([]byte)
|
||||
}
|
||||
|
||||
func newWSHandler(conn *websocket.Conn, uuid string) *wsHandler {
|
||||
func newWSHandler(conn *websocket.Conn, session *Session) *wsHandler {
|
||||
h := &wsHandler{
|
||||
ws: newWSConn(conn),
|
||||
addr: conn.RemoteAddr().String(),
|
||||
ws: newWSConn(conn),
|
||||
session: session,
|
||||
addr: conn.RemoteAddr().String(),
|
||||
}
|
||||
h.init(uuid)
|
||||
h.initHandlers()
|
||||
h.init()
|
||||
return h
|
||||
}
|
||||
|
||||
|
@ -54,44 +55,28 @@ func (h *wsHandler) dispatchRequest(req WSRequest) {
|
|||
}
|
||||
}
|
||||
|
||||
func (h *wsHandler) init(uuid string) {
|
||||
log.Println(h.addr, "set UUID", uuid)
|
||||
func (h *wsHandler) init() {
|
||||
h.session.setWS(h.addr, h.ws)
|
||||
|
||||
sessionLock.Lock()
|
||||
if storedSession, exists := sessions[uuid]; exists {
|
||||
sessionLock.Unlock()
|
||||
h.session = storedSession
|
||||
h.session.setWS(h.addr, h.ws)
|
||||
log.Println(h.addr, "[Session] User ID:", h.session.user.ID, "|",
|
||||
h.session.numIRC(), "IRC connections |",
|
||||
h.session.numWS(), "WebSocket connections")
|
||||
|
||||
log.Println(h.addr, "attached to", h.session.numIRC(), "existing IRC connections")
|
||||
channels := h.session.user.GetChannels()
|
||||
for i, channel := range channels {
|
||||
channels[i].Topic = channelStore.GetTopic(channel.Server, channel.Name)
|
||||
}
|
||||
|
||||
channels := h.session.user.GetChannels()
|
||||
for i, channel := range channels {
|
||||
channels[i].Topic = channelStore.GetTopic(channel.Server, channel.Name)
|
||||
}
|
||||
h.session.sendJSON("channels", channels)
|
||||
h.session.sendJSON("servers", h.session.user.GetServers())
|
||||
h.session.sendJSON("connection_update", h.session.getConnectionStates())
|
||||
|
||||
h.session.sendJSON("channels", channels)
|
||||
h.session.sendJSON("servers", h.session.user.GetServers())
|
||||
h.session.sendJSON("connection_update", h.session.getConnectionStates())
|
||||
|
||||
for _, channel := range channels {
|
||||
h.session.sendJSON("users", Userlist{
|
||||
Server: channel.Server,
|
||||
Channel: channel.Name,
|
||||
Users: channelStore.GetUsers(channel.Server, channel.Name),
|
||||
})
|
||||
}
|
||||
} else {
|
||||
h.session = NewSession()
|
||||
h.session.user = storage.NewUser(uuid)
|
||||
|
||||
sessions[uuid] = h.session
|
||||
sessionLock.Unlock()
|
||||
|
||||
h.session.setWS(h.addr, h.ws)
|
||||
h.session.sendJSON("servers", nil)
|
||||
|
||||
go h.session.write()
|
||||
for _, channel := range channels {
|
||||
h.session.sendJSON("users", Userlist{
|
||||
Server: channel.Server,
|
||||
Channel: channel.Name,
|
||||
Users: channelStore.GetUsers(channel.Server, channel.Name),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -105,7 +90,7 @@ func (h *wsHandler) connect(b []byte) {
|
|||
}
|
||||
|
||||
if _, ok := h.session.getIRC(host); !ok {
|
||||
log.Println(h.addr, "connecting to", data.Server)
|
||||
log.Println(h.addr, "[IRC] Add server", data.Server)
|
||||
|
||||
i := irc.NewClient(data.Nick, data.Username)
|
||||
i.TLS = data.TLS
|
||||
|
@ -134,7 +119,7 @@ func (h *wsHandler) connect(b []byte) {
|
|||
Realname: data.Realname,
|
||||
})
|
||||
} else {
|
||||
log.Println(h.addr, "already connected to", data.Server)
|
||||
log.Println(h.addr, "[IRC]", data.Server, "already added")
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -161,6 +146,8 @@ func (h *wsHandler) quit(b []byte) {
|
|||
json.Unmarshal(b, &data)
|
||||
|
||||
if i, ok := h.session.getIRC(data.Server); ok {
|
||||
log.Println(h.addr, "[IRC] Remove server", data.Server)
|
||||
|
||||
i.Quit()
|
||||
h.session.deleteIRC(data.Server)
|
||||
channelStore.RemoveUserAll(i.GetNick(), data.Server)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue