Pass in config struct

This commit is contained in:
Ken-Håvard Lieng 2018-12-11 10:51:20 +01:00
parent 8f1105bc59
commit 71f79fd84e
8 changed files with 360 additions and 103 deletions

View File

@ -5,10 +5,9 @@ import (
"io/ioutil"
"log"
"os"
"time"
"github.com/fsnotify/fsnotify"
"github.com/khlieng/dispatch/assets"
"github.com/khlieng/dispatch/config"
"github.com/khlieng/dispatch/server"
"github.com/khlieng/dispatch/storage"
"github.com/khlieng/dispatch/storage/bleve"
@ -36,34 +35,18 @@ var rootCmd = &cobra.Command{
Use: "dispatch",
Short: "Web-based IRC client in Go.",
PersistentPreRun: func(cmd *cobra.Command, args []string) {
if v, _ := cmd.Flags().GetBool("version"); v {
if viper.GetBool("version") {
printVersion()
os.Exit(0)
}
if cmd.Use == "dispatch" {
if cmd == cmd.Root() {
fmt.Printf(logo, version.Tag, version.Commit, version.Date)
}
storage.Initialize(viper.GetString("dir"))
initConfig(storage.Path.Config(), viper.GetBool("reset_config"))
viper.SetConfigName("config")
viper.AddConfigPath(storage.Path.Root())
viper.ReadInConfig()
viper.WatchConfig()
prev := time.Now()
viper.OnConfigChange(func(e fsnotify.Event) {
now := time.Now()
// fsnotify sometimes fires twice
if now.Sub(prev) > time.Second {
log.Println("New config loaded")
prev = now
}
})
initConfig(storage.Path.Config(), viper.GetBool("reset-config"))
},
Run: func(cmd *cobra.Command, args []string) {
@ -78,19 +61,28 @@ var rootCmd = &cobra.Command{
}
defer db.Close()
srv := server.Dispatch{
Store: db,
SessionStore: db,
cfg, cfgUpdated := config.LoadConfig()
dispatch := server.New(cfg)
GetMessageStore: func(user *storage.User) (storage.MessageStore, error) {
go func() {
for {
dispatch.SetConfig(<-cfgUpdated)
log.Println("New config loaded")
}
}()
dispatch.Store = db
dispatch.SessionStore = db
dispatch.GetMessageStore = func(user *storage.User) (storage.MessageStore, error) {
return boltdb.New(storage.Path.Log(user.Username))
},
GetMessageSearchProvider: func(user *storage.User) (storage.MessageSearchProvider, error) {
return bleve.New(storage.Path.Index(user.Username))
},
}
srv.Run()
dispatch.GetMessageSearchProvider = func(user *storage.User) (storage.MessageSearchProvider, error) {
return bleve.New(storage.Path.Index(user.Username))
}
dispatch.Run()
},
}
@ -110,14 +102,11 @@ func init() {
rootCmd.Flags().Bool("dev", false, "development mode")
rootCmd.Flags().BoolP("version", "v", false, "show version")
viper.BindPFlag("dir", rootCmd.PersistentFlags().Lookup("dir"))
viper.BindPFlag("reset_config", rootCmd.PersistentFlags().Lookup("reset-config"))
viper.BindPFlag("address", rootCmd.Flags().Lookup("address"))
viper.BindPFlag("port", rootCmd.Flags().Lookup("port"))
viper.BindPFlag("dev", rootCmd.Flags().Lookup("dev"))
viper.BindPFlags(rootCmd.PersistentFlags())
viper.BindPFlags(rootCmd.Flags())
viper.SetDefault("hexIP", false)
viper.SetDefault("verify_client_certificates", true)
viper.SetDefault("verify_certificates", true)
}
func initConfig(configPath string, overwrite bool) {

84
config/config.go Normal file
View File

@ -0,0 +1,84 @@
package config
import (
"time"
"github.com/fsnotify/fsnotify"
"github.com/khlieng/dispatch/storage"
"github.com/spf13/viper"
)
type Config struct {
Address string
Port string
Dev bool
HexIP bool
VerifyCertificates bool `mapstructure:"verify_certificates"`
Defaults *Defaults
HTTPS *HTTPS
LetsEncrypt *LetsEncrypt
}
type Defaults struct {
Name string
Host string
Port int
Channels []string
Password string
SSL bool
ReadOnly bool
ShowDetails bool `mapstructure:"show_details"`
}
type HTTPS struct {
Enabled bool
Port string
Redirect bool
Cert string
Key string
HSTS *HSTS
}
type HSTS struct {
Enabled bool
MaxAge string `mapstructure:"max_age"`
IncludeSubdomains bool `mapstructure:"include_subdomains"`
Preload bool
}
type LetsEncrypt struct {
Domain string
Email string
Port string
Proxy bool
}
func LoadConfig() (*Config, chan *Config) {
viper.SetConfigName("config")
viper.AddConfigPath(storage.Path.Root())
viper.ReadInConfig()
config := &Config{}
viper.Unmarshal(config)
viper.WatchConfig()
configCh := make(chan *Config, 1)
prev := time.Now()
viper.OnConfigChange(func(e fsnotify.Event) {
now := time.Now()
// fsnotify sometimes fires twice
if now.Sub(prev) > time.Second {
config := &Config{}
err := viper.Unmarshal(config)
if err == nil {
configCh <- config
}
prev = now
}
})
return config, configCh
}

View File

@ -5,9 +5,9 @@ import (
"net/url"
"strings"
"github.com/khlieng/dispatch/config"
"github.com/khlieng/dispatch/storage"
"github.com/khlieng/dispatch/version"
"github.com/spf13/viper"
)
type connectDefaults struct {
@ -28,7 +28,7 @@ type dispatchVersion struct {
}
type indexData struct {
Defaults connectDefaults
Defaults *config.Defaults
Servers []Server
Channels []*storage.Channel
HexIP bool
@ -43,9 +43,12 @@ type indexData struct {
Messages *Messages
}
func getIndexData(r *http.Request, path string, state *State) *indexData {
func (d *Dispatch) getIndexData(r *http.Request, path string, state *State) *indexData {
cfg := d.Config()
data := indexData{
HexIP: viper.GetBool("hexIP"),
Defaults: cfg.Defaults,
HexIP: cfg.HexIP,
Version: dispatchVersion{
Tag: version.Tag,
Commit: version.Commit,
@ -53,15 +56,8 @@ func getIndexData(r *http.Request, path string, state *State) *indexData {
},
}
data.Defaults = connectDefaults{
Name: viper.GetString("defaults.name"),
Host: viper.GetString("defaults.host"),
Port: viper.GetInt("defaults.port"),
Channels: viper.GetStringSlice("defaults.channels"),
Password: viper.GetString("defaults.password") != "",
SSL: viper.GetBool("defaults.ssl"),
ReadOnly: viper.GetBool("defaults.readonly"),
ShowDetails: viper.GetBool("defaults.show_details"),
if data.Defaults.Password != "" {
data.Defaults.Password = "******"
}
if state == nil {

View File

@ -4,6 +4,7 @@ package server
import (
json "encoding/json"
config "github.com/khlieng/dispatch/config"
storage "github.com/khlieng/dispatch/storage"
easyjson "github.com/mailru/easyjson"
jlexer "github.com/mailru/easyjson/jlexer"
@ -38,8 +39,14 @@ func easyjson7e607aefDecodeGithubComKhliengDispatchServer(in *jlexer.Lexer, out
}
switch key {
case "defaults":
if data := in.Raw(); in.Ok() {
in.AddError((out.Defaults).UnmarshalJSON(data))
if in.IsNull() {
in.Skip()
out.Defaults = nil
} else {
if out.Defaults == nil {
out.Defaults = new(config.Defaults)
}
easyjson7e607aefDecodeGithubComKhliengDispatchConfig(in, &*out.Defaults)
}
case "servers":
if in.IsNull() {
@ -153,7 +160,7 @@ func easyjson7e607aefEncodeGithubComKhliengDispatchServer(out *jwriter.Writer, i
out.RawByte('{')
first := true
_ = first
if true {
if in.Defaults != nil {
const prefix string = ",\"defaults\":"
if first {
first = false
@ -161,7 +168,7 @@ func easyjson7e607aefEncodeGithubComKhliengDispatchServer(out *jwriter.Writer, i
} else {
out.RawString(prefix)
}
out.Raw((in.Defaults).MarshalJSON())
easyjson7e607aefEncodeGithubComKhliengDispatchConfig(out, *in.Defaults)
}
if len(in.Servers) != 0 {
const prefix string = ",\"servers\":"
@ -352,6 +359,167 @@ func easyjson7e607aefEncodeGithubComKhliengDispatchStorage(out *jwriter.Writer,
}
out.RawByte('}')
}
func easyjson7e607aefDecodeGithubComKhliengDispatchConfig(in *jlexer.Lexer, out *config.Defaults) {
isTopLevel := in.IsStart()
if in.IsNull() {
if isTopLevel {
in.Consumed()
}
in.Skip()
return
}
in.Delim('{')
for !in.IsDelim('}') {
key := in.UnsafeString()
in.WantColon()
if in.IsNull() {
in.Skip()
in.WantComma()
continue
}
switch key {
case "name":
out.Name = string(in.String())
case "host":
out.Host = string(in.String())
case "port":
out.Port = int(in.Int())
case "channels":
if in.IsNull() {
in.Skip()
out.Channels = nil
} else {
in.Delim('[')
if out.Channels == nil {
if !in.IsDelim(']') {
out.Channels = make([]string, 0, 4)
} else {
out.Channels = []string{}
}
} else {
out.Channels = (out.Channels)[:0]
}
for !in.IsDelim(']') {
var v7 string
v7 = string(in.String())
out.Channels = append(out.Channels, v7)
in.WantComma()
}
in.Delim(']')
}
case "password":
out.Password = string(in.String())
case "ssl":
out.SSL = bool(in.Bool())
case "readOnly":
out.ReadOnly = bool(in.Bool())
case "showDetails":
out.ShowDetails = bool(in.Bool())
default:
in.SkipRecursive()
}
in.WantComma()
}
in.Delim('}')
if isTopLevel {
in.Consumed()
}
}
func easyjson7e607aefEncodeGithubComKhliengDispatchConfig(out *jwriter.Writer, in config.Defaults) {
out.RawByte('{')
first := true
_ = first
if in.Name != "" {
const prefix string = ",\"name\":"
if first {
first = false
out.RawString(prefix[1:])
} else {
out.RawString(prefix)
}
out.String(string(in.Name))
}
if in.Host != "" {
const prefix string = ",\"host\":"
if first {
first = false
out.RawString(prefix[1:])
} else {
out.RawString(prefix)
}
out.String(string(in.Host))
}
if in.Port != 0 {
const prefix string = ",\"port\":"
if first {
first = false
out.RawString(prefix[1:])
} else {
out.RawString(prefix)
}
out.Int(int(in.Port))
}
if len(in.Channels) != 0 {
const prefix string = ",\"channels\":"
if first {
first = false
out.RawString(prefix[1:])
} else {
out.RawString(prefix)
}
{
out.RawByte('[')
for v8, v9 := range in.Channels {
if v8 > 0 {
out.RawByte(',')
}
out.String(string(v9))
}
out.RawByte(']')
}
}
if in.Password != "" {
const prefix string = ",\"password\":"
if first {
first = false
out.RawString(prefix[1:])
} else {
out.RawString(prefix)
}
out.String(string(in.Password))
}
if in.SSL {
const prefix string = ",\"ssl\":"
if first {
first = false
out.RawString(prefix[1:])
} else {
out.RawString(prefix)
}
out.Bool(bool(in.SSL))
}
if in.ReadOnly {
const prefix string = ",\"readOnly\":"
if first {
first = false
out.RawString(prefix[1:])
} else {
out.RawString(prefix)
}
out.Bool(bool(in.ReadOnly))
}
if in.ShowDetails {
const prefix string = ",\"showDetails\":"
if first {
first = false
out.RawString(prefix[1:])
} else {
out.RawString(prefix)
}
out.Bool(bool(in.ShowDetails))
}
out.RawByte('}')
}
func easyjson7e607aefDecodeGithubComKhliengDispatchServer1(in *jlexer.Lexer, out *dispatchVersion) {
isTopLevel := in.IsStart()
if in.IsNull() {
@ -488,9 +656,9 @@ func easyjson7e607aefDecodeGithubComKhliengDispatchServer2(in *jlexer.Lexer, out
out.Channels = (out.Channels)[:0]
}
for !in.IsDelim(']') {
var v7 string
v7 = string(in.String())
out.Channels = append(out.Channels, v7)
var v10 string
v10 = string(in.String())
out.Channels = append(out.Channels, v10)
in.WantComma()
}
in.Delim(']')
@ -557,11 +725,11 @@ func easyjson7e607aefEncodeGithubComKhliengDispatchServer2(out *jwriter.Writer,
}
{
out.RawByte('[')
for v8, v9 := range in.Channels {
if v8 > 0 {
for v11, v12 := range in.Channels {
if v11 > 0 {
out.RawByte(',')
}
out.String(string(v9))
out.String(string(v12))
}
out.RawByte(']')
}

View File

@ -5,8 +5,6 @@ import (
"encoding/hex"
"net"
"github.com/spf13/viper"
"github.com/khlieng/dispatch/pkg/irc"
"github.com/khlieng/dispatch/storage"
)
@ -38,7 +36,9 @@ func connectIRC(server *storage.Server, state *State, srcIP []byte) *irc.Client
address = net.JoinHostPort(server.Host, server.Port)
}
if viper.GetBool("hexIP") {
cfg := state.srv.Config()
if cfg.HexIP {
i.Username = hex.EncodeToString(srcIP)
} else if i.Username == "" {
i.Username = server.Nick
@ -49,16 +49,16 @@ func connectIRC(server *storage.Server, state *State, srcIP []byte) *irc.Client
}
if server.Password == "" &&
viper.GetString("defaults.password") != "" &&
address == viper.GetString("defaults.host") {
i.Password = viper.GetString("defaults.password")
cfg.Defaults.Password != "" &&
address == cfg.Defaults.Host {
i.Password = cfg.Defaults.Password
} else {
i.Password = server.Password
}
if i.TLS {
i.TLSConfig = &tls.Config{
InsecureSkipVerify: !viper.GetBool("verify_certificates"),
InsecureSkipVerify: !cfg.VerifyCertificates,
}
if cert := state.user.GetCertificate(); cert != nil {

View File

@ -16,7 +16,6 @@ import (
"github.com/dsnet/compress/brotli"
"github.com/khlieng/dispatch/assets"
"github.com/spf13/viper"
)
const longCacheControl = "public, max-age=31536000, immutable"
@ -73,7 +72,9 @@ var (
)
func (d *Dispatch) initFileServer() {
if viper.GetBool("dev") {
cfg := d.Config()
if cfg.Dev {
indexScripts = []string{"boot.js", "main.js"}
} else {
bootloader := decompressedAsset(findAssetName("boot*.js"))
@ -149,13 +150,13 @@ workbox.precaching.precacheAndRoute([{
}]);
workbox.routing.registerNavigationRoute('/');`)...)
if viper.GetBool("https.hsts.enabled") && viper.GetBool("https.enabled") {
hstsHeader = "max-age=" + viper.GetString("https.hsts.max_age")
if cfg.HTTPS.HSTS.Enabled && cfg.HTTPS.Enabled {
hstsHeader = "max-age=" + cfg.HTTPS.HSTS.MaxAge
if viper.GetBool("https.hsts.include_subdomains") {
if cfg.HTTPS.HSTS.IncludeSubdomains {
hstsHeader += "; includeSubDomains"
}
if viper.GetBool("https.hsts.preload") {
if cfg.HTTPS.HSTS.Preload {
hstsHeader += "; preload"
}
}

View File

@ -8,12 +8,13 @@ import (
"net/http/httputil"
"net/url"
"strings"
"sync"
"github.com/gorilla/websocket"
"github.com/khlieng/dispatch/config"
"github.com/khlieng/dispatch/pkg/letsencrypt"
"github.com/khlieng/dispatch/pkg/session"
"github.com/khlieng/dispatch/storage"
"github.com/spf13/viper"
)
var channelStore = storage.NewChannelStore()
@ -25,8 +26,29 @@ type Dispatch struct {
GetMessageStore func(*storage.User) (storage.MessageStore, error)
GetMessageSearchProvider func(*storage.User) (storage.MessageSearchProvider, error)
cfg *config.Config
upgrader websocket.Upgrader
states *stateStore
lock sync.Mutex
}
func New(cfg *config.Config) *Dispatch {
return &Dispatch{
cfg: cfg,
}
}
func (d *Dispatch) Config() *config.Config {
d.lock.Lock()
cfg := d.cfg
d.lock.Unlock()
return cfg
}
func (d *Dispatch) SetConfig(cfg *config.Config) {
d.lock.Lock()
d.cfg = cfg
d.lock.Unlock()
}
func (d *Dispatch) Run() {
@ -35,7 +57,7 @@ func (d *Dispatch) Run() {
WriteBufferSize: 1024,
}
if viper.GetBool("dev") {
if d.Config().Dev {
d.upgrader.CheckOrigin = func(r *http.Request) bool {
return true
}
@ -105,16 +127,17 @@ func (d *Dispatch) loadUser(user *storage.User) {
}
func (d *Dispatch) startHTTP() {
addr := viper.GetString("address")
port := viper.GetString("port")
cfg := d.Config()
addr := cfg.Address
port := cfg.Port
if viper.GetBool("https.enabled") {
portHTTPS := viper.GetString("https.port")
redirect := viper.GetBool("https.redirect")
if cfg.HTTPS.Enabled {
portHTTPS := cfg.HTTPS.Port
redirect := cfg.HTTPS.Redirect
if redirect {
log.Println("[HTTP] Listening on port", port, "(HTTPS Redirect)")
go http.ListenAndServe(net.JoinHostPort(addr, port), createHTTPSRedirect(portHTTPS))
go http.ListenAndServe(net.JoinHostPort(addr, port), d.createHTTPSRedirect(portHTTPS))
}
server := &http.Server{
@ -122,17 +145,17 @@ func (d *Dispatch) startHTTP() {
Handler: d,
}
if certExists() {
if d.certExists() {
log.Println("[HTTPS] Listening on port", portHTTPS)
server.ListenAndServeTLS(viper.GetString("https.cert"), viper.GetString("https.key"))
} else if domain := viper.GetString("letsencrypt.domain"); domain != "" {
server.ListenAndServeTLS(cfg.HTTPS.Cert, cfg.HTTPS.Key)
} else if domain := cfg.LetsEncrypt.Domain; domain != "" {
dir := storage.Path.LetsEncrypt()
email := viper.GetString("letsencrypt.email")
lePort := viper.GetString("letsencrypt.port")
email := cfg.LetsEncrypt.Email
lePort := cfg.LetsEncrypt.Port
if viper.GetBool("letsencrypt.proxy") && lePort != "" && (port != "80" || !redirect) {
if cfg.LetsEncrypt.Proxy && lePort != "" && (port != "80" || !redirect) {
log.Println("[HTTP] Listening on port 80 (Let's Encrypt Proxy))")
go http.ListenAndServe(net.JoinHostPort(addr, "80"), http.HandlerFunc(letsEncryptProxy))
go http.ListenAndServe(net.JoinHostPort(addr, "80"), http.HandlerFunc(d.letsEncryptProxy))
}
le, err := letsencrypt.Run(dir, domain, email, ":"+lePort)
@ -150,7 +173,7 @@ func (d *Dispatch) startHTTP() {
log.Fatal("Could not locate SSL certificate or private key")
}
} else {
if viper.GetBool("dev") {
if cfg.Dev {
// The node dev server will proxy index page requests and
// websocket connections to this port
port = "1337"
@ -174,10 +197,9 @@ func (d *Dispatch) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
state := d.handleAuth(w, r, true, true)
data := getIndexData(r, referer.EscapedPath(), state)
data := d.getIndexData(r, referer.EscapedPath(), state)
writeJSON(w, r, data)
} else if strings.HasPrefix(r.URL.Path, "/ws") {
if !websocket.IsWebSocketUpgrade(r) {
fail(w, http.StatusBadRequest)
@ -207,10 +229,10 @@ func (d *Dispatch) upgradeWS(w http.ResponseWriter, r *http.Request, state *Stat
newWSHandler(conn, state, r).run()
}
func createHTTPSRedirect(portHTTPS string) http.HandlerFunc {
func (d *Dispatch) createHTTPSRedirect(portHTTPS string) http.HandlerFunc {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if strings.HasPrefix(r.URL.Path, "/.well-known/acme-challenge") {
letsEncryptProxy(w, r)
d.letsEncryptProxy(w, r)
return
}
@ -230,7 +252,7 @@ func createHTTPSRedirect(portHTTPS string) http.HandlerFunc {
})
}
func letsEncryptProxy(w http.ResponseWriter, r *http.Request) {
func (d *Dispatch) letsEncryptProxy(w http.ResponseWriter, r *http.Request) {
host, _, err := net.SplitHostPort(r.Host)
if err != nil {
host = r.Host
@ -238,7 +260,7 @@ func letsEncryptProxy(w http.ResponseWriter, r *http.Request) {
upstream := &url.URL{
Scheme: "http",
Host: net.JoinHostPort(host, viper.GetString("letsencrypt.port")),
Host: net.JoinHostPort(host, d.Config().LetsEncrypt.Port),
}
httputil.NewSingleHostReverseProxy(upstream).ServeHTTP(w, r)

View File

@ -2,22 +2,19 @@ package server
import (
"os"
"github.com/spf13/viper"
)
func certExists() bool {
cert := viper.GetString("https.cert")
key := viper.GetString("https.key")
func (d *Dispatch) certExists() bool {
cfg := d.Config().HTTPS
if cert == "" || key == "" {
if cfg.Cert == "" || cfg.Key == "" {
return false
}
if _, err := os.Stat(cert); err != nil {
if _, err := os.Stat(cfg.Cert); err != nil {
return false
}
if _, err := os.Stat(key); err != nil {
if _, err := os.Stat(cfg.Key); err != nil {
return false
}