Pull https handling out into a new package

This commit is contained in:
Ken-Håvard Lieng 2018-12-31 03:33:05 +01:00
parent 67e32661f1
commit 63cf65100d
3 changed files with 187 additions and 146 deletions

170
pkg/https/https.go Normal file
View File

@ -0,0 +1,170 @@
package https
import (
"crypto/tls"
"net"
"net/http"
"net/url"
"time"
"github.com/khlieng/dispatch/pkg/netutil"
"github.com/klauspost/cpuid"
"github.com/mholt/certmagic"
)
type Config struct {
Addr string
PortHTTP string
PortHTTPS string
HTTPOnly bool
StoragePath string
Domain string
Email string
Cert string
Key string
}
func Serve(handler http.Handler, cfg Config) error {
errCh := make(chan error, 1)
httpSrv := &http.Server{
Addr: net.JoinHostPort(cfg.Addr, cfg.PortHTTP),
}
if !cfg.HTTPOnly {
httpSrv.ReadTimeout = 5 * time.Second
httpSrv.WriteTimeout = 5 * time.Second
httpsSrv := &http.Server{
Addr: net.JoinHostPort(cfg.Addr, cfg.PortHTTPS),
ReadTimeout: 5 * time.Second,
WriteTimeout: 10 * time.Second,
IdleTimeout: 120 * time.Second,
Handler: handler,
}
redirect := HTTPSRedirect(cfg.PortHTTPS, handler)
if cfg.Cert != "" || cfg.Key != "" {
httpSrv.Handler = redirect
httpsSrv.TLSConfig = TLSConfig(nil)
go func() {
errCh <- httpSrv.ListenAndServe()
}()
go func() {
errCh <- httpsSrv.ListenAndServeTLS(cfg.Cert, cfg.Key)
}()
} else {
var cache *certmagic.Cache
if cfg.StoragePath != "" {
cache = certmagic.NewCache(&certmagic.FileStorage{
Path: cfg.StoragePath,
})
}
magic := certmagic.NewWithCache(cache, certmagic.Config{
Agreed: true,
Email: cfg.Email,
MustStaple: true,
})
domains := []string{cfg.Domain}
if cfg.Domain == "" {
domains = []string{}
magic.OnDemand = &certmagic.OnDemandConfig{MaxObtain: 3}
}
err := magic.Manage(domains)
if err != nil {
return err
}
httpSrv.Handler = magic.HTTPChallengeHandler(redirect)
httpsSrv.TLSConfig = TLSConfig(magic.TLSConfig())
go func() {
errCh <- httpSrv.ListenAndServe()
}()
go func() {
errCh <- httpsSrv.ListenAndServeTLS("", "")
}()
}
} else {
httpSrv.ReadTimeout = 5 * time.Second
httpSrv.WriteTimeout = 10 * time.Second
httpSrv.IdleTimeout = 120 * time.Second
httpSrv.Handler = handler
return httpSrv.ListenAndServe()
}
return <-errCh
}
func HTTPSRedirect(portHTTPS string, fallback http.Handler) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
host, _, err := net.SplitHostPort(r.Host)
if err != nil {
host = r.Host
}
if fallback != nil && netutil.IsPrivate(host) {
fallback.ServeHTTP(w, r)
return
}
u := url.URL{
Scheme: "https",
Host: net.JoinHostPort(host, portHTTPS),
Path: r.RequestURI,
}
w.Header().Set("Connection", "close")
w.Header().Set("Location", u.String())
w.WriteHeader(http.StatusMovedPermanently)
}
}
func TLSConfig(tlsConfig *tls.Config) *tls.Config {
if tlsConfig == nil {
tlsConfig = &tls.Config{}
}
tlsConfig.MinVersion = tls.VersionTLS12
tlsConfig.CipherSuites = defaultCipherSuites()
tlsConfig.CurvePreferences = []tls.CurveID{
tls.X25519,
tls.CurveP256,
}
tlsConfig.PreferServerCipherSuites = true
return tlsConfig
}
func defaultCipherSuites() []uint16 {
if cpuid.CPU.AesNi() {
return []uint16{
tls.TLS_FALLBACK_SCSV,
tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305,
tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305,
}
}
return []uint16{
tls.TLS_FALLBACK_SCSV,
tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305,
tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305,
tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256}
}

View File

@ -1,21 +1,17 @@
package server
import (
"crypto/tls"
"log"
"net"
"net/http"
"net/url"
"strings"
"sync"
"time"
"github.com/gorilla/websocket"
"github.com/khlieng/dispatch/config"
"github.com/khlieng/dispatch/pkg/netutil"
"github.com/khlieng/dispatch/pkg/https"
"github.com/khlieng/dispatch/pkg/session"
"github.com/khlieng/dispatch/storage"
"github.com/mholt/certmagic"
)
var channelStore = storage.NewChannelStore()
@ -137,79 +133,26 @@ func (d *Dispatch) startHTTP() {
port = "1337"
}
httpSrv := &http.Server{
Addr: net.JoinHostPort(cfg.Address, port),
}
if cfg.HTTPS.Enabled {
httpSrv.ReadTimeout = 5 * time.Second
httpSrv.WriteTimeout = 5 * time.Second
httpsSrv := &http.Server{
Addr: net.JoinHostPort(cfg.Address, cfg.HTTPS.Port),
ReadHeaderTimeout: 5 * time.Second,
WriteTimeout: 10 * time.Second,
IdleTimeout: 120 * time.Second,
Handler: d,
}
redirect := createHTTPSRedirect(cfg.HTTPS.Port, d)
if d.certExists() {
httpSrv.Handler = redirect
log.Println("[HTTP] Listening on port", port, "(HTTPS Redirect)")
go httpSrv.ListenAndServe()
log.Println("[HTTPS] Listening on port", cfg.HTTPS.Port)
log.Fatal(httpsSrv.ListenAndServeTLS(cfg.HTTPS.Cert, cfg.HTTPS.Key))
} else {
cache := certmagic.NewCache(&certmagic.FileStorage{
Path: storage.Path.LetsEncrypt(),
})
magic := certmagic.NewWithCache(cache, certmagic.Config{
Agreed: true,
Email: cfg.LetsEncrypt.Email,
MustStaple: true,
})
domains := []string{cfg.LetsEncrypt.Domain}
if cfg.LetsEncrypt.Domain == "" {
domains = []string{}
magic.OnDemand = &certmagic.OnDemandConfig{MaxObtain: 3}
}
err := magic.Manage(domains)
if err != nil {
log.Fatal(err)
}
tlsConfig := magic.TLSConfig()
tlsConfig.MinVersion = tls.VersionTLS12
tlsConfig.CipherSuites = getCipherSuites()
tlsConfig.CurvePreferences = []tls.CurveID{
tls.X25519,
tls.CurveP256,
}
tlsConfig.PreferServerCipherSuites = true
httpsSrv.TLSConfig = tlsConfig
httpSrv.Handler = magic.HTTPChallengeHandler(redirect)
log.Println("[HTTP] Listening on port", port, "(HTTPS Redirect)")
go httpSrv.ListenAndServe()
log.Println("[HTTPS] Listening on port", cfg.HTTPS.Port)
log.Fatal(httpsSrv.ListenAndServeTLS("", ""))
}
log.Println("[HTTP] Listening on port", port, "(HTTPS Redirect)")
log.Println("[HTTPS] Listening on port", cfg.HTTPS.Port)
} else {
httpSrv.ReadHeaderTimeout = 5 * time.Second
httpSrv.WriteTimeout = 10 * time.Second
httpSrv.IdleTimeout = 120 * time.Second
httpSrv.Handler = d
log.Println("[HTTP] Listening on port", port)
log.Fatal(httpSrv.ListenAndServe())
}
log.Fatal(https.Serve(d, https.Config{
Addr: cfg.Address,
PortHTTP: port,
PortHTTPS: cfg.HTTPS.Port,
HTTPOnly: !cfg.HTTPS.Enabled,
StoragePath: storage.Path.LetsEncrypt(),
Domain: cfg.LetsEncrypt.Domain,
Email: cfg.LetsEncrypt.Email,
Cert: cfg.HTTPS.Cert,
Key: cfg.HTTPS.Key,
}))
}
func (d *Dispatch) ServeHTTP(w http.ResponseWriter, r *http.Request) {
@ -258,30 +201,6 @@ func (d *Dispatch) upgradeWS(w http.ResponseWriter, r *http.Request, state *Stat
newWSHandler(conn, state, r).run()
}
func createHTTPSRedirect(portHTTPS string, fallback http.Handler) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
host, _, err := net.SplitHostPort(r.Host)
if err != nil {
host = r.Host
}
if netutil.IsPrivate(host) {
fallback.ServeHTTP(w, r)
return
}
u := url.URL{
Scheme: "https",
Host: net.JoinHostPort(host, portHTTPS),
Path: r.RequestURI,
}
w.Header().Set("Connection", "close")
w.Header().Set("Location", u.String())
w.WriteHeader(http.StatusMovedPermanently)
}
}
func fail(w http.ResponseWriter, code int) {
http.Error(w, http.StatusText(code), code)
}

View File

@ -1,48 +0,0 @@
package server
import (
"crypto/tls"
"os"
"github.com/klauspost/cpuid"
)
func getCipherSuites() []uint16 {
if cpuid.CPU.AesNi() {
return []uint16{
tls.TLS_FALLBACK_SCSV,
tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305,
tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305,
}
}
return []uint16{
tls.TLS_FALLBACK_SCSV,
tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305,
tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305,
tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256}
}
func (d *Dispatch) certExists() bool {
cfg := d.Config().HTTPS
if cfg.Cert == "" || cfg.Key == "" {
return false
}
if _, err := os.Stat(cfg.Cert); err != nil {
return false
}
if _, err := os.Stat(cfg.Key); err != nil {
return false
}
return true
}