diff --git a/cmd/dex/serve.go b/cmd/dex/serve.go index 8a69c7ee3e..edec41a0cf 100644 --- a/cmd/dex/serve.go +++ b/cmd/dex/serve.go @@ -32,6 +32,8 @@ import ( "github.com/spf13/cobra" "google.golang.org/grpc" "google.golang.org/grpc/credentials" + "google.golang.org/grpc/health" + healthgrpc "google.golang.org/grpc/health/grpc_health_v1" "google.golang.org/grpc/reflection" "github.com/dexidp/dex/api/v2" @@ -59,6 +61,8 @@ var buildInfo = prometheus.NewGaugeVec( []string{"version", "go_version", "platform"}, ) +var healthCheckPeriod = 15 * time.Second + func commandServe() *cobra.Command { options := serveOptions{} @@ -379,7 +383,7 @@ func runServe(options serveOptions) error { CheckName: "storage", CheckFunc: storage.NewCustomHealthCheckFunc(serverConfig.Storage, serverConfig.Now), }, - gosundheit.ExecutionPeriod(15*time.Second), + gosundheit.ExecutionPeriod(healthCheckPeriod), gosundheit.InitiallyPassing(true), ) @@ -508,8 +512,24 @@ func runServe(options serveOptions) error { } grpcSrv := grpc.NewServer(grpcOptions...) + healthcheck := health.NewServer() + healthgrpc.RegisterHealthServer(grpcSrv, healthcheck) api.RegisterDexServer(grpcSrv, server.NewAPI(serverConfig.Storage, logger, version, serv)) + go func() { + var status healthgrpc.HealthCheckResponse_ServingStatus + for { + switch healthChecker.IsHealthy() { + case true: + status = healthgrpc.HealthCheckResponse_SERVING + default: + status = healthgrpc.HealthCheckResponse_NOT_SERVING + } + healthcheck.SetServingStatus("", status) + time.Sleep(healthCheckPeriod) + } + }() + grpcMetrics.InitializeMetrics(grpcSrv) if c.GRPC.Reflection { logger.Info("enabling reflection in grpc service") @@ -520,6 +540,24 @@ func runServe(options serveOptions) error { return grpcSrv.Serve(grpcListener) }, func(err error) { logger.Debug("starting graceful shutdown", "server", "grpc") + done := make(chan struct{}) + + go func() { + healthcheck.Shutdown() + grpcSrv.GracefulStop() + close(done) + }() + + select { + case <-done: + // Graceful shutdown completed within the timeout + logger.Debug("Graceful shutdown completed", "server", "grpc") + case <-time.After(time.Minute): + // Timeout reached, force stop the server + logger.Debug("Graceful shutdown timed out. forcing shutdown", "server", "grpc") + grpcSrv.Stop() + } + grpcSrv.GracefulStop() }) }