Skip to content

Commit ecb30d6

Browse files
committed
feat: add support for reloading certs when renewed
1 parent 15dc534 commit ecb30d6

File tree

2 files changed

+117
-1
lines changed

2 files changed

+117
-1
lines changed

cmd/root.go

+26-1
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ package cmd
2323

2424
import (
2525
"context"
26+
"crypto/tls"
2627
"fmt"
2728
"net/http"
2829
"os"
@@ -34,6 +35,7 @@ import (
3435
"github.com/estahn/k8s-image-swapper/pkg/registry"
3536
"github.com/estahn/k8s-image-swapper/pkg/secrets"
3637
"github.com/estahn/k8s-image-swapper/pkg/types"
38+
"github.com/estahn/k8s-image-swapper/pkg/utils"
3739
"github.com/estahn/k8s-image-swapper/pkg/webhook"
3840
homedir "github.com/mitchellh/go-homedir"
3941
"github.com/prometheus/client_golang/prometheus/promhttp"
@@ -151,7 +153,20 @@ A mutating webhook for Kubernetes, pointing the images to a new location.`,
151153
log.Info().Msgf("Listening on %v", cfg.ListenAddress)
152154
//err = http.ListenAndServeTLS(":8080", cfg.certFile, cfg.keyFile, whHandler)
153155
if cfg.TLSCertFile != "" && cfg.TLSKeyFile != "" {
154-
if err := srv.ListenAndServeTLS(cfg.TLSCertFile, cfg.TLSKeyFile); err != nil {
156+
kpr, err := utils.NewKeypairReloader(cfg.TLSCertFile, cfg.TLSKeyFile)
157+
if err != nil {
158+
log.Err(err).Msg("Failed to load key pair")
159+
os.Exit(1)
160+
}
161+
162+
// this will check if there are new certs before every tls handshake
163+
t := &tls.Config{GetCertificate: kpr.GetCertificateFunc()}
164+
srv.TLSConfig = t
165+
166+
srv.TLSConfig = &tls.Config{
167+
GetCertificate: getCertificate,
168+
}
169+
if err := srv.ListenAndServeTLS("", ""); err != nil {
155170
log.Err(err).Msg("error serving webhook")
156171
os.Exit(1)
157172
}
@@ -278,6 +293,16 @@ func initLogger() {
278293
}
279294
}
280295

296+
func getCertificate(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
297+
//log.Info().Msg("Loading TLS")
298+
caFiles, err := tls.LoadX509KeyPair(cfg.TLSCertFile, cfg.TLSKeyFile)
299+
if err != nil {
300+
return nil, err
301+
}
302+
303+
return &caFiles, nil
304+
}
305+
281306
// setupImagePullSecretsProvider configures the provider handling secrets
282307
func setupImagePullSecretsProvider() secrets.ImagePullSecretsProvider {
283308
config, err := rest.InClusterConfig()

pkg/utils/tlsutil.go

+91
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
package utils
2+
3+
import (
4+
"crypto/tls"
5+
"path"
6+
"sync"
7+
8+
"github.com/fsnotify/fsnotify"
9+
"github.com/rs/zerolog/log"
10+
)
11+
12+
// KeypairReloader structs holds cert path and certs
13+
type KeypairReloader struct {
14+
certMu sync.RWMutex
15+
cert *tls.Certificate
16+
tlsCertFile string
17+
tlsKeyFile string
18+
}
19+
20+
// NewKeypairReloader will load certs on first run and trigger a goroutine for fsnotify watcher
21+
func NewKeypairReloader(tlsCertFile, tlsKeyFile string) (*KeypairReloader, error) {
22+
result := &KeypairReloader{
23+
tlsCertFile: tlsCertFile,
24+
tlsKeyFile: tlsKeyFile,
25+
}
26+
cert, err := tls.LoadX509KeyPair(tlsCertFile, tlsKeyFile)
27+
if err != nil {
28+
return nil, err
29+
}
30+
result.cert = &cert
31+
32+
// creates a new file watcher
33+
watcher, err := fsnotify.NewWatcher()
34+
if err != nil {
35+
return nil, err
36+
}
37+
38+
defer func() {
39+
if err != nil {
40+
watcher.Close()
41+
}
42+
}()
43+
44+
// Notify on changes to the cert directory
45+
if err := watcher.Add(path.Dir(tlsCertFile)); err != nil {
46+
return nil, err
47+
}
48+
49+
go func() {
50+
for {
51+
select {
52+
// watch for events
53+
case event := <-watcher.Events:
54+
// Watch for changes to the tlsCertFile
55+
if event.Name == tlsCertFile {
56+
log.Info().Msg("Reloading certs")
57+
if err := result.reload(); err != nil {
58+
log.Err(err).Msg("Could not load new certs")
59+
}
60+
}
61+
62+
// watch for errors
63+
case err := <-watcher.Errors:
64+
log.Err(err).Msg("Watcher error")
65+
}
66+
}
67+
}()
68+
69+
return result, nil
70+
}
71+
72+
// reload loads updated cert and key whenever they are updated
73+
func (kpr *KeypairReloader) reload() error {
74+
newCert, err := tls.LoadX509KeyPair(kpr.tlsCertFile, kpr.tlsKeyFile)
75+
if err != nil {
76+
return err
77+
}
78+
kpr.certMu.Lock()
79+
defer kpr.certMu.Unlock()
80+
kpr.cert = &newCert
81+
return nil
82+
}
83+
84+
// GetCertificateFunc will return function which will be used as tls.Config.GetCertificate
85+
func (kpr *KeypairReloader) GetCertificateFunc() func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
86+
return func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
87+
kpr.certMu.RLock()
88+
defer kpr.certMu.RUnlock()
89+
return kpr.cert, nil
90+
}
91+
}

0 commit comments

Comments
 (0)