From 637f0d956bf7e55d8f57c9bf15d2eed301111e12 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ken-H=C3=A5vard=20Lieng?= Date: Tue, 22 May 2018 03:56:48 +0200 Subject: [PATCH] Wait until a websocket connection comes in before creating new anonymous sessions --- server/auth.go | 13 +++++-------- server/index_data.go | 27 ++++++++++++++++----------- server/serve_files.go | 7 +------ server/server.go | 16 +++++++++++++--- 4 files changed, 35 insertions(+), 28 deletions(-) diff --git a/server/auth.go b/server/auth.go index 14a19944..c4809329 100644 --- a/server/auth.go +++ b/server/auth.go @@ -12,18 +12,19 @@ const ( cookieName = "dispatch" ) -func handleAuth(w http.ResponseWriter, r *http.Request) *Session { +func handleAuth(w http.ResponseWriter, r *http.Request, createUser bool) *Session { var session *Session cookie, err := r.Cookie(cookieName) if err != nil { - authLog(r, "No cookie set") - session = newUser(w, r) + if createUser { + session = newUser(w, r) + } } else { session = sessions.get(cookie.Value) if session != nil { log.Println(r.RemoteAddr, "[Auth] GET", r.URL.Path, "| Valid token | User ID:", session.user.ID) - } else { + } else if createUser { session = newUser(w, r) } } @@ -57,7 +58,3 @@ func newUser(w http.ResponseWriter, r *http.Request) *Session { return session } - -func authLog(r *http.Request, s string) { - log.Println(r.RemoteAddr, "[Auth] GET", r.URL.Path, "|", s) -} diff --git a/server/index_data.go b/server/index_data.go index 6cbf210a..74a06e1a 100644 --- a/server/index_data.go +++ b/server/index_data.go @@ -61,6 +61,22 @@ func (d *indexData) addUsersAndMessages(server, channel string, session *Session func getIndexData(r *http.Request, session *Session) *indexData { data := indexData{} + + data.Defaults = connectDefaults{ + Name: viper.GetString("defaults.name"), + Host: viper.GetString("defaults.host"), + Port: viper.GetInt("defaults.port"), + Channels: viper.GetStringSlice("defaults.channels"), + Password: viper.GetString("defaults.password") != "", + SSL: viper.GetBool("defaults.ssl"), + ReadOnly: viper.GetBool("defaults.readonly"), + ShowDetails: viper.GetBool("defaults.show_details"), + } + + if session == nil { + return &data + } + servers := session.user.GetServers() connections := session.getConnectionStates() for _, server := range servers { @@ -80,17 +96,6 @@ func getIndexData(r *http.Request, session *Session) *indexData { } data.Channels = channels - data.Defaults = connectDefaults{ - Name: viper.GetString("defaults.name"), - Host: viper.GetString("defaults.host"), - Port: viper.GetInt("defaults.port"), - Channels: viper.GetStringSlice("defaults.channels"), - Password: viper.GetString("defaults.password") != "", - SSL: viper.GetBool("defaults.ssl"), - ReadOnly: viper.GetBool("defaults.readonly"), - ShowDetails: viper.GetBool("defaults.show_details"), - } - server, channel := getTabFromPath(r.URL.EscapedPath()) if isInChannel(channels, server, channel) { data.addUsersAndMessages(server, channel, session) diff --git a/server/serve_files.go b/server/serve_files.go index 23ebc12a..4cbdfbfb 100644 --- a/server/serve_files.go +++ b/server/serve_files.go @@ -171,12 +171,7 @@ func serveFiles(w http.ResponseWriter, r *http.Request) { } func serveIndex(w http.ResponseWriter, r *http.Request) { - session := handleAuth(w, r) - if session == nil { - log.Println("[Auth] No session") - w.WriteHeader(500) - return - } + session := handleAuth(w, r, false) if cspEnabled { var connectSrc string diff --git a/server/server.go b/server/server.go index 7cc3bd68..eecf7ba2 100644 --- a/server/server.go +++ b/server/server.go @@ -98,15 +98,21 @@ func startHTTP() { func serve(w http.ResponseWriter, r *http.Request) { if r.Method != "GET" { - w.WriteHeader(404) + fail(w, http.StatusNotFound) return } if strings.HasPrefix(r.URL.Path, "/ws") { - session := handleAuth(w, r) + if !websocket.IsWebSocketUpgrade(r) { + fail(w, http.StatusBadRequest) + return + } + + session := handleAuth(w, r, true) + if session == nil { log.Println("[Auth] No session") - w.WriteHeader(500) + fail(w, http.StatusInternalServerError) return } @@ -162,3 +168,7 @@ func letsEncryptProxy(w http.ResponseWriter, r *http.Request) { httputil.NewSingleHostReverseProxy(upstream).ServeHTTP(w, r) } + +func fail(w http.ResponseWriter, code int) { + http.Error(w, http.StatusText(code), code) +}