Wait until a websocket connection comes in before creating new anonymous sessions
This commit is contained in:
parent
27653982d7
commit
637f0d956b
@ -12,18 +12,19 @@ const (
|
|||||||
cookieName = "dispatch"
|
cookieName = "dispatch"
|
||||||
)
|
)
|
||||||
|
|
||||||
func handleAuth(w http.ResponseWriter, r *http.Request) *Session {
|
func handleAuth(w http.ResponseWriter, r *http.Request, createUser bool) *Session {
|
||||||
var session *Session
|
var session *Session
|
||||||
|
|
||||||
cookie, err := r.Cookie(cookieName)
|
cookie, err := r.Cookie(cookieName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
authLog(r, "No cookie set")
|
if createUser {
|
||||||
session = newUser(w, r)
|
session = newUser(w, r)
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
session = sessions.get(cookie.Value)
|
session = sessions.get(cookie.Value)
|
||||||
if session != nil {
|
if session != nil {
|
||||||
log.Println(r.RemoteAddr, "[Auth] GET", r.URL.Path, "| Valid token | User ID:", session.user.ID)
|
log.Println(r.RemoteAddr, "[Auth] GET", r.URL.Path, "| Valid token | User ID:", session.user.ID)
|
||||||
} else {
|
} else if createUser {
|
||||||
session = newUser(w, r)
|
session = newUser(w, r)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -57,7 +58,3 @@ func newUser(w http.ResponseWriter, r *http.Request) *Session {
|
|||||||
|
|
||||||
return session
|
return session
|
||||||
}
|
}
|
||||||
|
|
||||||
func authLog(r *http.Request, s string) {
|
|
||||||
log.Println(r.RemoteAddr, "[Auth] GET", r.URL.Path, "|", s)
|
|
||||||
}
|
|
||||||
|
@ -61,6 +61,22 @@ func (d *indexData) addUsersAndMessages(server, channel string, session *Session
|
|||||||
|
|
||||||
func getIndexData(r *http.Request, session *Session) *indexData {
|
func getIndexData(r *http.Request, session *Session) *indexData {
|
||||||
data := indexData{}
|
data := indexData{}
|
||||||
|
|
||||||
|
data.Defaults = connectDefaults{
|
||||||
|
Name: viper.GetString("defaults.name"),
|
||||||
|
Host: viper.GetString("defaults.host"),
|
||||||
|
Port: viper.GetInt("defaults.port"),
|
||||||
|
Channels: viper.GetStringSlice("defaults.channels"),
|
||||||
|
Password: viper.GetString("defaults.password") != "",
|
||||||
|
SSL: viper.GetBool("defaults.ssl"),
|
||||||
|
ReadOnly: viper.GetBool("defaults.readonly"),
|
||||||
|
ShowDetails: viper.GetBool("defaults.show_details"),
|
||||||
|
}
|
||||||
|
|
||||||
|
if session == nil {
|
||||||
|
return &data
|
||||||
|
}
|
||||||
|
|
||||||
servers := session.user.GetServers()
|
servers := session.user.GetServers()
|
||||||
connections := session.getConnectionStates()
|
connections := session.getConnectionStates()
|
||||||
for _, server := range servers {
|
for _, server := range servers {
|
||||||
@ -80,17 +96,6 @@ func getIndexData(r *http.Request, session *Session) *indexData {
|
|||||||
}
|
}
|
||||||
data.Channels = channels
|
data.Channels = channels
|
||||||
|
|
||||||
data.Defaults = connectDefaults{
|
|
||||||
Name: viper.GetString("defaults.name"),
|
|
||||||
Host: viper.GetString("defaults.host"),
|
|
||||||
Port: viper.GetInt("defaults.port"),
|
|
||||||
Channels: viper.GetStringSlice("defaults.channels"),
|
|
||||||
Password: viper.GetString("defaults.password") != "",
|
|
||||||
SSL: viper.GetBool("defaults.ssl"),
|
|
||||||
ReadOnly: viper.GetBool("defaults.readonly"),
|
|
||||||
ShowDetails: viper.GetBool("defaults.show_details"),
|
|
||||||
}
|
|
||||||
|
|
||||||
server, channel := getTabFromPath(r.URL.EscapedPath())
|
server, channel := getTabFromPath(r.URL.EscapedPath())
|
||||||
if isInChannel(channels, server, channel) {
|
if isInChannel(channels, server, channel) {
|
||||||
data.addUsersAndMessages(server, channel, session)
|
data.addUsersAndMessages(server, channel, session)
|
||||||
|
@ -171,12 +171,7 @@ func serveFiles(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func serveIndex(w http.ResponseWriter, r *http.Request) {
|
func serveIndex(w http.ResponseWriter, r *http.Request) {
|
||||||
session := handleAuth(w, r)
|
session := handleAuth(w, r, false)
|
||||||
if session == nil {
|
|
||||||
log.Println("[Auth] No session")
|
|
||||||
w.WriteHeader(500)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if cspEnabled {
|
if cspEnabled {
|
||||||
var connectSrc string
|
var connectSrc string
|
||||||
|
@ -98,15 +98,21 @@ func startHTTP() {
|
|||||||
|
|
||||||
func serve(w http.ResponseWriter, r *http.Request) {
|
func serve(w http.ResponseWriter, r *http.Request) {
|
||||||
if r.Method != "GET" {
|
if r.Method != "GET" {
|
||||||
w.WriteHeader(404)
|
fail(w, http.StatusNotFound)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if strings.HasPrefix(r.URL.Path, "/ws") {
|
if strings.HasPrefix(r.URL.Path, "/ws") {
|
||||||
session := handleAuth(w, r)
|
if !websocket.IsWebSocketUpgrade(r) {
|
||||||
|
fail(w, http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
session := handleAuth(w, r, true)
|
||||||
|
|
||||||
if session == nil {
|
if session == nil {
|
||||||
log.Println("[Auth] No session")
|
log.Println("[Auth] No session")
|
||||||
w.WriteHeader(500)
|
fail(w, http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -162,3 +168,7 @@ func letsEncryptProxy(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
httputil.NewSingleHostReverseProxy(upstream).ServeHTTP(w, r)
|
httputil.NewSingleHostReverseProxy(upstream).ServeHTTP(w, r)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func fail(w http.ResponseWriter, code int) {
|
||||||
|
http.Error(w, http.StatusText(code), code)
|
||||||
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user