diff --git a/cmd/clair/main.go b/cmd/clair/main.go index 9c864aaa6d..03f19fd44a 100644 --- a/cmd/clair/main.go +++ b/cmd/clair/main.go @@ -3,6 +3,7 @@ package main import ( "context" "crypto/tls" + "errors" "flag" "fmt" golog "log" @@ -98,47 +99,53 @@ func main() { } auto.PrintLogs(ctx) - // Some machinery for starting and stopping server goroutines: - down := &Shutdown{} - srvs, srvctx := errgroup.WithContext(ctx) + // Signal handler, for orderly shutdown. + sig, stop := signal.NotifyContext(ctx, append(platformShutdown, os.Interrupt)...) + defer stop() + zlog.Info(ctx).Msg("registered signal handler") + go func() { + <-sig.Done() + stop() + zlog.Info(ctx).Msg("unregistered signal handler") + }() - // Introspection server goroutine. - srvs.Go(func() (_ error) { - zlog.Info(srvctx).Msg("launching introspection server") - i, err := introspection.New(srvctx, conf, nil) - if err != nil { - zlog.Warn(srvctx). - Err(err).Msg("introspection server configuration failed. continuing anyway") - return - } - down.Add(i.Server) - if err := i.ListenAndServe(); err != http.ErrServerClosed { - zlog.Warn(srvctx). - Err(err).Msg("introspection server failed to launch. continuing anyway") - } - return - }) + srvs, srvctx := errgroup.WithContext(sig) + srvs.Go(serveIntrospection(srvctx, &conf)) + srvs.Go(serveAPI(srvctx, &conf)) + + zlog.Info(ctx). + Str("version", cmd.Version). + Msg("ready") + if err := srvs.Wait(); err != nil { + zlog.Error(ctx). + Err(err). + Msg("fatal error") + fail = true + } +} - // HTTP API server goroutine. - srvs.Go(func() error { - zlog.Info(srvctx).Msg("launching http transport") - srvs, err := initialize.Services(srvctx, &conf) +func serveAPI(ctx context.Context, cfg *config.Config) func() error { + return func() error { + zlog.Info(ctx).Msg("launching http transport") + srvs, err := initialize.Services(ctx, cfg) if err != nil { return fmt.Errorf("service initialization failed: %w", err) } srv := http.Server{ - BaseContext: func(_ net.Listener) context.Context { return srvctx }, + BaseContext: func(_ net.Listener) context.Context { + return context.WithoutCancel(ctx) + }, } - srv.Handler, err = httptransport.New(srvctx, &conf, srvs.Indexer, srvs.Matcher, srvs.Notifier) + srv.Handler, err = httptransport.New(ctx, cfg, srvs.Indexer, srvs.Matcher, srvs.Notifier) if err != nil { return fmt.Errorf("http transport configuration failed: %w", err) } - l, err := net.Listen("tcp", conf.HTTPListenAddr) + l, err := net.Listen("tcp", cfg.HTTPListenAddr) if err != nil { return fmt.Errorf("http transport configuration failed: %w", err) } - if conf.TLS != nil { - cfg, err := conf.TLS.Config() + if cfg.TLS != nil { + cfg, err := cfg.TLS.Config() if err != nil { return fmt.Errorf("tls configuration failed: %w", err) } @@ -146,40 +153,51 @@ func main() { srv.TLSConfig = cfg l = tls.NewListener(l, cfg) } - down.Add(&srv) health.Ready() - if err := srv.Serve(l); err != http.ErrServerClosed { - return fmt.Errorf("http transport failed to launch: %w", err) - } - return nil - }) - // Signal handler goroutine. - go func() { - ctx, stop := signal.NotifyContext(ctx, os.Interrupt) - defer func() { - // Note that we're using a background context here, so that we get a - // full timeout if the signal handler has fired. - tctx, done := context.WithTimeout(context.Background(), 10*time.Second) - err := down.Shutdown(tctx) - if err != nil { - zlog.Error(ctx).Err(err).Msg("error shutting down server") + var eg errgroup.Group + eg.Go(func() error { + if err := srv.Serve(l); !errors.Is(err, http.ErrServerClosed) { + return fmt.Errorf("http transport failed to launch: %w", err) } - done() - stop() - zlog.Info(ctx).Msg("unregistered signal handler") - }() - zlog.Info(ctx).Msg("registered signal handler") - select { - case <-ctx.Done(): - zlog.Info(ctx).Stringer("signal", os.Interrupt).Msg("gracefully shutting down") - case <-srvctx.Done(): + return nil + }) + eg.Go(func() error { + <-ctx.Done() + ctx, done := context.WithTimeoutCause(context.Background(), 10*time.Second, context.Cause(ctx)) + defer done() + return srv.Shutdown(ctx) + }) + return eg.Wait() + } +} + +func serveIntrospection(ctx context.Context, cfg *config.Config) func() error { + return func() error { + zlog.Info(ctx).Msg("launching introspection server") + srv, err := introspection.New(ctx, cfg, nil) + if err != nil { + zlog.Warn(ctx). + Err(err). + Msg("introspection server configuration failed; continuing anyway") + return nil } - }() - zlog.Info(ctx).Str("version", cmd.Version).Msg("ready") - if err := srvs.Wait(); err != nil { - zlog.Error(ctx).Err(err).Msg("fatal error") - fail = true + var eg errgroup.Group + eg.Go(func() error { + if err := srv.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) { + zlog.Warn(ctx). + Err(err). + Msg("introspection server failed to launch; continuing anyway") + } + return nil + }) + eg.Go(func() error { + <-ctx.Done() + ctx, done := context.WithTimeoutCause(context.Background(), 10*time.Second, context.Cause(ctx)) + defer done() + return srv.Shutdown(ctx) + }) + return eg.Wait() } } diff --git a/cmd/clair/os_other.go b/cmd/clair/os_other.go new file mode 100644 index 0000000000..2a0aea4982 --- /dev/null +++ b/cmd/clair/os_other.go @@ -0,0 +1,7 @@ +//go:build !unix + +package main + +import "os" + +var platformShutdown = []os.Signal{} diff --git a/cmd/clair/os_unix.go b/cmd/clair/os_unix.go new file mode 100644 index 0000000000..c59bcfc051 --- /dev/null +++ b/cmd/clair/os_unix.go @@ -0,0 +1,10 @@ +//go:build unix + +package main + +import ( + "os" + "syscall" +) + +var platformShutdown = []os.Signal{syscall.SIGTERM} diff --git a/cmd/clair/shutdown.go b/cmd/clair/shutdown.go deleted file mode 100644 index 71fb8aa8d8..0000000000 --- a/cmd/clair/shutdown.go +++ /dev/null @@ -1,44 +0,0 @@ -package main - -import ( - "context" - "fmt" - "net/http" - "sync" - - "golang.org/x/sync/errgroup" -) - -// Shutdown aggregates http.Sever Shutdown methods. -type Shutdown struct { - mu sync.Mutex - m map[*http.Server]struct{} -} - -// Add registers a server. -func (s *Shutdown) Add(srv *http.Server) { - s.mu.Lock() - defer s.mu.Unlock() - if s.m == nil { - s.m = make(map[*http.Server]struct{}) - } - s.m[srv] = struct{}{} -} - -// Shutdown calls Shutdown on all added Servers. If a timeout is needed, it -// should be done via the passed Context. -func (s *Shutdown) Shutdown(ctx context.Context) error { - s.mu.Lock() // Leave locked forever - eg := &errgroup.Group{} - for srv := range s.m { - srv := srv - eg.Go(func() error { - if err := srv.Shutdown(ctx); err != nil { - return fmt.Errorf("unable to shutdown %q: %w", srv.Addr, err) - } - return nil - }) - delete(s.m, srv) - } - return eg.Wait() -} diff --git a/introspection/server.go b/introspection/server.go index bd7ba654f4..31ce6e5929 100644 --- a/introspection/server.go +++ b/introspection/server.go @@ -36,7 +36,7 @@ const ( // exposing Clair metrics and traces type Server struct { // configuration provided when starting Clair - conf config.Config + conf *config.Config // Server embeds a http.Server and http.ServeMux. // The http.Server will be configured with the ServeMux on successful // initialization. @@ -46,7 +46,7 @@ type Server struct { health func() bool } -func New(ctx context.Context, conf config.Config, health func() bool) (*Server, error) { +func New(ctx context.Context, conf *config.Config, health func() bool) (*Server, error) { ctx = zlog.ContextWithValues(ctx, "component", "introspection/New") var addr string @@ -176,7 +176,8 @@ func New(ctx context.Context, conf config.Config, health func() bool) (*Server, ) otel.SetTracerProvider(tp) i.Server.RegisterOnShutdown(func() { - ctx, cancel := context.WithTimeout(ctx, 5*time.Second) + zlog.Info(ctx).Msg("shutting down trace provider") + ctx, cancel := context.WithTimeout(context.WithoutCancel(ctx), 5*time.Second) defer cancel() if err := tp.Shutdown(ctx); err != nil { zlog.Error(ctx).Err(err).Msg("error shutting down trace provider")