From 47dd4f51cbea70f1d548bb26c1bcb5045fe5a52b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ken-H=C3=A5vard=20Lieng?= Date: Wed, 6 Jan 2016 22:19:06 +0100 Subject: [PATCH] Refactor lets encrypt integration to support cert changes and ocsp stapling without restarting --- letsencrypt/letsencrypt.go | 151 ++++++++++++++++++++++++++++++------ server/server.go | 28 +++---- server/{https.go => tls.go} | 37 ++------- 3 files changed, 146 insertions(+), 70 deletions(-) rename server/{https.go => tls.go} (51%) diff --git a/letsencrypt/letsencrypt.go b/letsencrypt/letsencrypt.go index 8404cf1b..fb1152b7 100644 --- a/letsencrypt/letsencrypt.go +++ b/letsencrypt/letsencrypt.go @@ -1,9 +1,11 @@ package letsencrypt import ( + "crypto/tls" "encoding/json" "io/ioutil" "os" + "sync" "time" "github.com/khlieng/dispatch/Godeps/_workspace/src/github.com/xenolf/lego/acme" @@ -14,12 +16,12 @@ const KeySize = 2048 var directory Directory -func Run(dir, domain, email, port string, onChange func()) (string, string, error) { +func Run(dir, domain, email, port string) (*state, error) { directory = Directory(dir) user, err := getUser(email) if err != nil { - return "", "", nil + return nil, err } client, err := acme.NewClient(URL, &user, KeySize) @@ -29,48 +31,114 @@ func Run(dir, domain, email, port string, onChange func()) (string, string, erro if user.Registration == nil { user.Registration, err = client.Register() if err != nil { - return "", "", err + return nil, err } err = client.AgreeToTOS() if err != nil { - return "", "", err + return nil, err } err = saveUser(user) if err != nil { - return "", "", err + return nil, err } } + s := &state{ + client: client, + domain: domain, + } + if certExists(domain) { - renew(client, domain) + if !s.renew() { + err = s.loadCert() + if err != nil { + return nil, err + } + } + s.refreshOCSP() } else { - err = obtain(client, domain) + err = s.obtain() if err != nil { - return "", "", err + return nil, err } } - go keepRenewed(client, domain, onChange) + go s.maintain() - return directory.Cert(domain), directory.Key(domain), nil + return s, nil } -func obtain(client *acme.Client, domain string) error { - cert, errors := client.ObtainCertificate([]string{domain}, true) - if err := errors[domain]; err != nil { +type state struct { + client *acme.Client + domain string + cert *tls.Certificate + certPEM []byte + lock sync.Mutex +} + +func (s *state) GetCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate, error) { + s.lock.Lock() + cert := s.cert + s.lock.Unlock() + + return cert, nil +} + +func (s *state) getCertPEM() []byte { + s.lock.Lock() + certPEM := s.certPEM + s.lock.Unlock() + + return certPEM +} + +func (s *state) setCert(meta acme.CertificateResource) { + cert, err := tls.X509KeyPair(meta.Certificate, meta.PrivateKey) + if err == nil { + s.lock.Lock() + if s.cert != nil { + cert.OCSPStaple = s.cert.OCSPStaple + } + + s.cert = &cert + s.certPEM = meta.Certificate + s.lock.Unlock() + } +} + +func (s *state) setOCSP(ocsp []byte) { + cert := tls.Certificate{ + OCSPStaple: ocsp, + } + + s.lock.Lock() + if s.cert != nil { + cert.Certificate = s.cert.Certificate + cert.PrivateKey = s.cert.PrivateKey + } + s.cert = &cert + s.lock.Unlock() +} + +func (s *state) obtain() error { + cert, errors := s.client.ObtainCertificate([]string{s.domain}, true) + if err := errors[s.domain]; err != nil { if _, ok := err.(acme.TOSError); ok { - err := client.AgreeToTOS() + err := s.client.AgreeToTOS() if err != nil { return err } - return obtain(client, domain) + return s.obtain() } return err } + s.setCert(cert) + s.refreshOCSP() + err := saveCert(cert) if err != nil { return err @@ -79,8 +147,8 @@ func obtain(client *acme.Client, domain string) error { return nil } -func renew(client *acme.Client, domain string) bool { - cert, err := ioutil.ReadFile(directory.Cert(domain)) +func (s *state) renew() bool { + cert, err := ioutil.ReadFile(directory.Cert(s.domain)) if err != nil { return false } @@ -93,12 +161,12 @@ func renew(client *acme.Client, domain string) bool { daysLeft := int(exp.Sub(time.Now().UTC()).Hours() / 24) if daysLeft <= 30 { - metaBytes, err := ioutil.ReadFile(directory.Meta(domain)) + metaBytes, err := ioutil.ReadFile(directory.Meta(s.domain)) if err != nil { return false } - key, err := ioutil.ReadFile(directory.Key(domain)) + key, err := ioutil.ReadFile(directory.Key(s.domain)) if err != nil { return false } @@ -112,10 +180,10 @@ func renew(client *acme.Client, domain string) bool { meta.PrivateKey = key Renew: - newMeta, err := client.RenewCertificate(meta, true) + newMeta, err := s.client.RenewCertificate(meta, true) if err != nil { if _, ok := err.(acme.TOSError); ok { - err := client.AgreeToTOS() + err := s.client.AgreeToTOS() if err != nil { return false } @@ -124,6 +192,8 @@ func renew(client *acme.Client, domain string) bool { return false } + s.setCert(newMeta) + err = saveCert(newMeta) if err != nil { return false @@ -135,15 +205,46 @@ func renew(client *acme.Client, domain string) bool { return false } -func keepRenewed(client *acme.Client, domain string, onChange func()) { +func (s *state) refreshOCSP() { + ocsp, resp, err := acme.GetOCSPForCert(s.getCertPEM()) + if err == nil && resp.Status == acme.OCSPGood { + s.setOCSP(ocsp) + } +} + +func (s *state) maintain() { + renew := time.Tick(24 * time.Hour) + ocsp := time.Tick(1 * time.Hour) for { - time.Sleep(24 * time.Hour) - if renew(client, domain) { - onChange() + select { + case <-renew: + s.renew() + + case <-ocsp: + s.refreshOCSP() } } } +func (s *state) loadCert() error { + cert, err := ioutil.ReadFile(directory.Cert(s.domain)) + if err != nil { + return err + } + + key, err := ioutil.ReadFile(directory.Key(s.domain)) + if err != nil { + return err + } + + s.setCert(acme.CertificateResource{ + Certificate: cert, + PrivateKey: key, + }) + + return nil +} + func certExists(domain string) bool { if _, err := os.Stat(directory.Cert(domain)); err != nil { return false diff --git a/server/server.go b/server/server.go index 268d7c19..ecba716c 100644 --- a/server/server.go +++ b/server/server.go @@ -1,6 +1,7 @@ package server import ( + "crypto/tls" "log" "net" "net/http" @@ -46,23 +47,22 @@ func startHTTP() { port := viper.GetString("port") if viper.GetBool("https.enabled") { - var err error portHTTPS := viper.GetString("https.port") redirect := viper.GetBool("https.redirect") - https := restartableHTTPS{ - addr: ":" + portHTTPS, - handler: http.HandlerFunc(serve), - } - - if viper.GetBool("https.redirect") { + if redirect { log.Println("[HTTP] Listening on port", port, "(HTTPS Redirect)") go http.ListenAndServe(":"+port, createHTTPSRedirect(portHTTPS)) } + server := &http.Server{ + Addr: ":" + portHTTPS, + Handler: http.HandlerFunc(serve), + } + if certExists() { - https.cert = viper.GetString("https.cert") - https.key = viper.GetString("https.key") + log.Println("[HTTPS] Listening on port", portHTTPS) + server.ListenAndServeTLS(viper.GetString("https.cert"), viper.GetString("https.key")) } else if domain := viper.GetString("letsencrypt.domain"); domain != "" { dir := storage.Path.LetsEncrypt() email := viper.GetString("letsencrypt.email") @@ -73,16 +73,18 @@ func startHTTP() { go http.ListenAndServe(":80", http.HandlerFunc(letsEncryptProxy)) } - https.cert, https.key, err = letsencrypt.Run(dir, domain, email, lePort, https.restart) + letsEncrypt, err := letsencrypt.Run(dir, domain, email, lePort) if err != nil { log.Fatal(err) } + + server.TLSConfig = &tls.Config{GetCertificate: letsEncrypt.GetCertificate} + + log.Println("[HTTPS] Listening on port", portHTTPS) + log.Fatal(listenAndServeTLS(server)) } else { log.Fatal("Could not locate SSL certificate or private key") } - - log.Println("[HTTPS] Listening on port", portHTTPS) - https.start() } else { log.Println("[HTTP] Listening on port", port) log.Fatal(http.ListenAndServe(":"+port, http.HandlerFunc(serve))) diff --git a/server/https.go b/server/tls.go similarity index 51% rename from server/https.go rename to server/tls.go index 8be88a5c..a3419c90 100644 --- a/server/https.go +++ b/server/tls.go @@ -10,43 +10,16 @@ import ( "github.com/khlieng/dispatch/Godeps/_workspace/src/github.com/spf13/viper" ) -type restartableHTTPS struct { - listener net.Listener - handler http.Handler - addr string - cert string - key string -} +func listenAndServeTLS(srv *http.Server) error { + srv.TLSConfig.NextProtos = []string{"http/1.1"} -func (r *restartableHTTPS) start() error { - var err error - - config := &tls.Config{ - NextProtos: []string{"http/1.1"}, - Certificates: make([]tls.Certificate, 1), - } - - config.Certificates[0], err = tls.LoadX509KeyPair(r.cert, r.key) + ln, err := net.Listen("tcp", srv.Addr) if err != nil { return err } - ln, err := net.Listen("tcp", r.addr) - if err != nil { - return err - } - - r.listener = tls.NewListener(tcpKeepAliveListener{ln.(*net.TCPListener)}, config) - return http.Serve(r.listener, r.handler) -} - -func (r *restartableHTTPS) stop() { - r.listener.Close() -} - -func (r *restartableHTTPS) restart() { - r.stop() - go r.start() + tlsListener := tls.NewListener(tcpKeepAliveListener{ln.(*net.TCPListener)}, srv.TLSConfig) + return srv.Serve(tlsListener) } type tcpKeepAliveListener struct {