dispatch/server/server.go

243 lines
5.5 KiB
Go
Raw Normal View History

package server
import (
2017-04-15 02:48:24 +00:00
"crypto/tls"
"log"
2016-01-04 18:26:32 +00:00
"net"
"net/http"
2016-01-04 18:26:32 +00:00
"net/http/httputil"
"net/url"
"strings"
2016-03-01 00:51:26 +00:00
"github.com/gorilla/websocket"
"github.com/khlieng/dispatch/pkg/letsencrypt"
"github.com/khlieng/dispatch/pkg/session"
2015-12-11 03:35:48 +00:00
"github.com/khlieng/dispatch/storage"
2018-11-06 10:13:32 +00:00
"github.com/spf13/viper"
)
var channelStore = storage.NewChannelStore()
type Dispatch struct {
Store storage.Store
SessionStore storage.SessionStore
GetMessageStore func(*storage.User) (storage.MessageStore, error)
GetMessageSearchProvider func(*storage.User) (storage.MessageSearchProvider, error)
upgrader websocket.Upgrader
states *stateStore
}
func (d *Dispatch) Run() {
d.upgrader = websocket.Upgrader{
2015-05-01 22:20:22 +00:00
ReadBufferSize: 1024,
WriteBufferSize: 1024,
}
if viper.GetBool("dev") {
d.upgrader.CheckOrigin = func(r *http.Request) bool {
return true
}
}
session.CookieName = "dispatch"
d.states = newStateStore(d.SessionStore)
go d.states.run()
d.loadUsers()
d.initFileServer()
d.startHTTP()
}
func (d *Dispatch) loadUsers() {
users, err := storage.LoadUsers(d.Store)
if err != nil {
log.Fatal(err)
}
log.Printf("[Init] %d users", len(users))
for _, user := range users {
go d.loadUser(user)
}
}
func (d *Dispatch) loadUser(user *storage.User) {
messageStore, err := d.GetMessageStore(user)
if err != nil {
log.Fatal(err)
}
user.SetMessageStore(messageStore)
search, err := d.GetMessageSearchProvider(user)
if err != nil {
log.Fatal(err)
}
user.SetMessageSearchProvider(search)
state := NewState(user, d)
d.states.set(state)
go state.run()
channels, err := user.GetChannels()
if err != nil {
log.Fatal(err)
}
servers, err := user.GetServers()
if err != nil {
log.Fatal(err)
}
for _, server := range servers {
i := connectIRC(server, state, user.GetLastIP())
var joining []string
for _, channel := range channels {
if channel.Server == server.Host {
joining = append(joining, channel.Name)
}
}
i.Join(joining...)
}
}
func (d *Dispatch) startHTTP() {
addr := viper.GetString("address")
2016-01-04 18:26:32 +00:00
port := viper.GetString("port")
2016-01-04 18:26:32 +00:00
if viper.GetBool("https.enabled") {
portHTTPS := viper.GetString("https.port")
redirect := viper.GetBool("https.redirect")
if redirect {
2016-01-04 18:26:32 +00:00
log.Println("[HTTP] Listening on port", port, "(HTTPS Redirect)")
go http.ListenAndServe(net.JoinHostPort(addr, port), createHTTPSRedirect(portHTTPS))
2016-01-04 18:26:32 +00:00
}
server := &http.Server{
Addr: net.JoinHostPort(addr, portHTTPS),
Handler: http.HandlerFunc(d.serve),
}
2016-01-04 18:26:32 +00:00
if certExists() {
log.Println("[HTTPS] Listening on port", portHTTPS)
server.ListenAndServeTLS(viper.GetString("https.cert"), viper.GetString("https.key"))
2016-01-04 18:26:32 +00:00
} else if domain := viper.GetString("letsencrypt.domain"); domain != "" {
dir := storage.Path.LetsEncrypt()
email := viper.GetString("letsencrypt.email")
lePort := viper.GetString("letsencrypt.port")
if viper.GetBool("letsencrypt.proxy") && lePort != "" && (port != "80" || !redirect) {
log.Println("[HTTP] Listening on port 80 (Let's Encrypt Proxy))")
go http.ListenAndServe(net.JoinHostPort(addr, "80"), http.HandlerFunc(letsEncryptProxy))
2016-01-04 18:26:32 +00:00
}
le, err := letsencrypt.Run(dir, domain, email, ":"+lePort)
2016-01-04 18:26:32 +00:00
if err != nil {
log.Fatal(err)
}
2017-04-15 02:48:24 +00:00
server.TLSConfig = &tls.Config{
GetCertificate: le.GetCertificate,
2017-04-15 02:48:24 +00:00
}
log.Println("[HTTPS] Listening on port", portHTTPS)
2017-04-15 02:48:24 +00:00
log.Fatal(server.ListenAndServeTLS("", ""))
2016-01-04 18:26:32 +00:00
} else {
log.Fatal("Could not locate SSL certificate or private key")
}
} else {
if viper.GetBool("dev") {
// The node dev server will proxy index page requests and
// websocket connections to this port
port = "1337"
}
2016-01-04 18:26:32 +00:00
log.Println("[HTTP] Listening on port", port)
log.Fatal(http.ListenAndServe(net.JoinHostPort(addr, port), http.HandlerFunc(d.serve)))
2016-01-04 18:26:32 +00:00
}
}
func (d *Dispatch) serve(w http.ResponseWriter, r *http.Request) {
if r.Method != "GET" {
fail(w, http.StatusNotFound)
return
}
if strings.HasPrefix(r.URL.Path, "/ws") {
if !websocket.IsWebSocketUpgrade(r) {
fail(w, http.StatusBadRequest)
return
}
state := d.handleAuth(w, r, true, true)
if state == nil {
log.Println("[Auth] No state")
fail(w, http.StatusInternalServerError)
return
}
d.upgradeWS(w, r, state)
2018-11-06 10:13:32 +00:00
} else if strings.HasPrefix(r.URL.Path, "/data") {
state := d.handleAuth(w, r, false, false)
2018-11-09 05:30:31 +00:00
data := getIndexData(r, "/", state)
2018-11-06 10:13:32 +00:00
2018-11-09 05:30:31 +00:00
writeJSON(w, r, data)
} else {
d.serveFiles(w, r)
}
}
func (d *Dispatch) upgradeWS(w http.ResponseWriter, r *http.Request, state *State) {
conn, err := d.upgrader.Upgrade(w, r, w.Header())
2015-05-01 22:20:22 +00:00
if err != nil {
log.Println(err)
return
}
newWSHandler(conn, state, r).run()
2015-05-01 22:20:22 +00:00
}
2016-01-04 18:26:32 +00:00
func createHTTPSRedirect(portHTTPS string) http.HandlerFunc {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if strings.HasPrefix(r.URL.Path, "/.well-known/acme-challenge") {
letsEncryptProxy(w, r)
return
}
host, _, err := net.SplitHostPort(r.Host)
if err != nil {
host = r.Host
}
2016-01-04 18:26:32 +00:00
u := url.URL{
Scheme: "https",
Host: net.JoinHostPort(host, portHTTPS),
Path: r.RequestURI,
}
w.Header().Set("Location", u.String())
w.WriteHeader(http.StatusMovedPermanently)
})
}
func letsEncryptProxy(w http.ResponseWriter, r *http.Request) {
host, _, err := net.SplitHostPort(r.Host)
if err != nil {
host = r.Host
}
2016-01-04 18:26:32 +00:00
upstream := &url.URL{
Scheme: "http",
Host: net.JoinHostPort(host, viper.GetString("letsencrypt.port")),
}
httputil.NewSingleHostReverseProxy(upstream).ServeHTTP(w, r)
}
func fail(w http.ResponseWriter, code int) {
http.Error(w, http.StatusText(code), code)
}