Handle channel names ending with a slash better

This commit is contained in:
Ken-Håvard Lieng 2017-06-13 04:25:59 +02:00
parent f03b30eff6
commit 0f5c3b57d2
8 changed files with 96 additions and 40 deletions

File diff suppressed because one or more lines are too long

View File

@ -21,7 +21,7 @@ export default function initialState({ store }) {
if (!store.getState().router.route) { if (!store.getState().router.route) {
const tab = Cookie.get('tab'); const tab = Cookie.get('tab');
if (tab) { if (tab) {
const [server, name = null] = tab.split(':'); const [server, name = null] = tab.split('-');
if (find(env.servers, srv => srv.host === server)) { if (find(env.servers, srv => srv.host === server)) {
store.dispatch(select(server, name, true)); store.dispatch(select(server, name, true));

View File

@ -16,7 +16,7 @@ class Tab extends TabRecord {
toString() { toString() {
let str = this.server; let str = this.server;
if (this.name) { if (this.name) {
str += `:${this.name}`; str += `-${this.name}`;
} }
return str; return str;
} }

View File

@ -3,7 +3,7 @@ import Backoff from 'backo';
export default class Socket { export default class Socket {
constructor(host) { constructor(host) {
const protocol = window.location.protocol === 'https:' ? 'wss' : 'ws'; const protocol = window.location.protocol === 'https:' ? 'wss' : 'ws';
this.url = `${protocol}://${host}/ws?path=${window.location.pathname}`; this.url = `${protocol}://${host}/ws${window.location.pathname}`;
this.connectTimeout = 20000; this.connectTimeout = 20000;
this.pingTimeout = 30000; this.pingTimeout = 30000;

View File

@ -2,6 +2,7 @@ package server
import ( import (
"net/http" "net/http"
"net/url"
"strings" "strings"
"github.com/spf13/viper" "github.com/spf13/viper"
@ -84,17 +85,18 @@ func getIndexData(r *http.Request, session *Session) *indexData {
Channels: channels, Channels: channels,
} }
params := strings.Split(strings.Trim(r.URL.Path, "/"), "/") server, channel := getTabFromPath(r.URL.EscapedPath())
if len(params) == 2 && isChannel(params[1]) { if channel != "" {
data.addUsersAndMessages(params[0], params[1], session) data.addUsersAndMessages(server, channel, session)
} else { return &data
server, channel := parseTabCookie(r, r.URL.Path) }
if channel != "" {
for _, ch := range channels { server, channel = parseTabCookie(r, r.URL.Path)
if server == ch.Server && channel == ch.Name { if channel != "" {
data.addUsersAndMessages(server, channel, session) for _, ch := range channels {
break if server == ch.Server && channel == ch.Name {
} data.addUsersAndMessages(server, channel, session)
break
} }
} }
} }
@ -102,11 +104,22 @@ func getIndexData(r *http.Request, session *Session) *indexData {
return &data return &data
} }
func getTabFromPath(rawPath string) (string, string) {
path := strings.Split(strings.Trim(rawPath, "/"), "/")
if len(path) == 2 {
name, err := url.PathUnescape(path[1])
if err == nil && isChannel(name) {
return path[0], name
}
}
return "", ""
}
func parseTabCookie(r *http.Request, path string) (string, string) { func parseTabCookie(r *http.Request, path string) (string, string) {
if path == "/" { if path == "/" {
cookie, err := r.Cookie("tab") cookie, err := r.Cookie("tab")
if err == nil { if err == nil {
tab := strings.Split(cookie.Value, ":") tab := strings.Split(cookie.Value, "-")
if len(tab) == 2 && isChannel(tab[1]) { if len(tab) == 2 && isChannel(tab[1]) {
return tab[0], tab[1] return tab[0], tab[1]

43
server/index_data_test.go Normal file
View File

@ -0,0 +1,43 @@
package server
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestGetTabFromPath(t *testing.T) {
cases := []struct {
input string
expectedServer string
expectedChannel string
}{
{
"/chat.freenode.net/%23r%2Fstuff%2F/",
"chat.freenode.net",
"#r/stuff/",
}, {
"/chat.freenode.net/%23r%2Fstuff%2F",
"chat.freenode.net",
"#r/stuff/",
}, {
"/chat.freenode.net/%23r%2Fstuff",
"chat.freenode.net",
"#r/stuff",
}, {
"/chat.freenode.net/%23stuff",
"chat.freenode.net",
"#stuff",
}, {
"/chat.freenode.net/%23stuff/cake",
"",
"",
},
}
for _, tc := range cases {
server, channel := getTabFromPath(tc.input)
assert.Equal(t, tc.expectedServer, server)
assert.Equal(t, tc.expectedChannel, channel)
}
}

View File

@ -103,7 +103,7 @@ func serve(w http.ResponseWriter, r *http.Request) {
return return
} }
if r.URL.Path == "/ws" { if strings.HasPrefix(r.URL.Path, "/ws") {
session := handleAuth(w, r) session := handleAuth(w, r)
if session == nil { if session == nil {
log.Println("[Auth] No session") log.Println("[Auth] No session")

View File

@ -66,13 +66,13 @@ func (h *wsHandler) init(r *http.Request) {
h.session.numWS(), "WebSocket connections") h.session.numWS(), "WebSocket connections")
channels := h.session.user.GetChannels() channels := h.session.user.GetChannels()
path := r.URL.Query().Get("path") path := r.URL.EscapedPath()[3:]
params := strings.Split(strings.Trim(path, "/"), "/") pathServer, pathChannel := getTabFromPath(path)
tabServer, tabChannel := parseTabCookie(r, path) cookieServer, cookieChannel := parseTabCookie(r, path)
for _, channel := range channels { for _, channel := range channels {
if (len(params) == 2 && channel.Server == params[0] && channel.Name == params[1]) || if (channel.Server == pathServer && channel.Name == pathChannel) ||
(channel.Server == tabServer && channel.Name == tabChannel) { (channel.Server == cookieServer && channel.Name == cookieChannel) {
continue continue
} }