diff --git a/internal/util/cmd_setup.go b/internal/util/cmd_setup.go index e626d835..3de3c926 100644 --- a/internal/util/cmd_setup.go +++ b/internal/util/cmd_setup.go @@ -18,6 +18,7 @@ import ( log "github.com/sirupsen/logrus" "github.com/spf13/cobra" "github.com/spf13/viper" + "golang.org/x/sync/errgroup" batchv1 "k8s.io/api/batch/v1" corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" @@ -33,6 +34,8 @@ type SetupOptions struct { func DefaultSetup(cmd *cobra.Command, conf *config.Global, opts SetupOptions) (err error) { cmd.SilenceUsage = true + ctx := cmd.Context() + conf.Kubeconfig = viper.GetString(consts.KubeconfigKey) conf.Context, err = cmd.Flags().GetString(consts.ContextFlag) if err != nil { @@ -50,12 +53,12 @@ func DefaultSetup(cmd *cobra.Command, conf *config.Global, opts SetupOptions) (e conf.Context = conf.Client.Context conf.Namespace = conf.Client.Namespace - access := namespace_filter.NewFromContext(cmd.Context()) + access := namespace_filter.NewFromContext(ctx) if !access.Match(conf.Client.Namespace) { return errors.New("The current action is disabled for namespace " + conf.Client.Namespace) } - if _, err := conf.Client.Namespaces().Get(cmd.Context(), conf.Namespace, metav1.GetOptions{}); err != nil { + if _, err := conf.Client.Namespaces().Get(ctx, conf.Namespace, metav1.GetOptions{}); err != nil { log.WithError(err).Warn("namespace may not exist") } @@ -69,7 +72,7 @@ func DefaultSetup(cmd *cobra.Command, conf *config.Global, opts SetupOptions) (e if slashIdx != 0 && slashIdx+1 < len(podFlag) { podFlag = podFlag[slashIdx+1:] } - pod, err := conf.Client.Pods().Get(cmd.Context(), podFlag, metav1.GetOptions{}) + pod, err := conf.Client.Pods().Get(ctx, podFlag, metav1.GetOptions{}) if err != nil { return err } @@ -83,7 +86,7 @@ func DefaultSetup(cmd *cobra.Command, conf *config.Global, opts SetupOptions) (e if dialectFlag == "" { // Configure via detection if len(pods) == 0 { - conf.Dialect, pods, err = database.DetectDialect(cmd.Context(), conf.Client) + conf.Dialect, pods, err = database.DetectDialect(ctx, conf.Client) if err != nil { return err } @@ -103,7 +106,7 @@ func DefaultSetup(cmd *cobra.Command, conf *config.Global, opts SetupOptions) (e log.WithField("dialect", conf.Dialect.Name()).Debug("configured database") if len(pods) == 0 { - pods, err = conf.Client.GetPodsFiltered(cmd.Context(), conf.Dialect.PodLabels()) + pods, err = conf.Client.GetPodsFiltered(ctx, conf.Dialect.PodLabels()) if err != nil { return err } @@ -111,7 +114,7 @@ func DefaultSetup(cmd *cobra.Command, conf *config.Global, opts SetupOptions) (e } if podFlag == "" { - pods, err = conf.Dialect.FilterPods(cmd.Context(), conf.Client, pods) + pods, err = conf.Dialect.FilterPods(ctx, conf.Client, pods) if err != nil { log.WithError(err).Warn("could not query primary instance") } @@ -136,70 +139,86 @@ func DefaultSetup(cmd *cobra.Command, conf *config.Global, opts SetupOptions) (e conf.DbPod = pods[idx] } - conf.Port, err = cmd.Flags().GetUint16(consts.PortFlag) - if err != nil { - panic(err) - } - if conf.Port == 0 { - port, err := conf.Client.GetValueFromEnv(cmd.Context(), conf.DbPod, conf.Dialect.PortEnvNames()) + group, ctx := errgroup.WithContext(ctx) + + group.Go(func() error { + conf.Port, err = cmd.Flags().GetUint16(consts.PortFlag) if err != nil { - log.Debug("could not detect port from pod env") - } else { - port, err := strconv.ParseUint(port, 10, 16) + panic(err) + } + if conf.Port == 0 { + port, err := conf.Client.GetValueFromEnv(ctx, conf.DbPod, conf.Dialect.PortEnvNames()) if err != nil { - log.WithField("port", port).Debug("failed to parse port from pod env") + log.Debug("could not detect port from pod env") } else { - conf.Port = uint16(port) - log.WithField("port", conf.Port).Debug("found port in pod env") + port, err := strconv.ParseUint(port, 10, 16) + if err != nil { + log.WithField("port", port).Debug("failed to parse port from pod env") + } else { + conf.Port = uint16(port) + log.WithField("port", conf.Port).Debug("found port in pod env") + } } } - } - if conf.Port == 0 { - conf.Port = conf.Dialect.DefaultPort() - } + if conf.Port == 0 { + conf.Port = conf.Dialect.DefaultPort() + } + return nil + }) - conf.Database, err = cmd.Flags().GetString(consts.DbnameFlag) - if err != nil && !opts.DisableAuthFlags { - panic(err) - } - if conf.Database == "" { - conf.Database, err = conf.Client.GetValueFromEnv(cmd.Context(), conf.DbPod, conf.Dialect.DatabaseEnvNames()) - if err != nil { - log.Debug("could not detect database from pod env") - } else { - log.WithField("database", conf.Database).Debug("found db name in pod env") + group.Go(func() error { + conf.Database, err = cmd.Flags().GetString(consts.DbnameFlag) + if err != nil && !opts.DisableAuthFlags { + panic(err) } - } + if conf.Database == "" { + conf.Database, err = conf.Client.GetValueFromEnv(ctx, conf.DbPod, conf.Dialect.DatabaseEnvNames()) + if err != nil { + log.Debug("could not detect database from pod env") + } else { + log.WithField("database", conf.Database).Debug("found db name in pod env") + } + } + return nil + }) - conf.Username, err = cmd.Flags().GetString(consts.UsernameFlag) - if err != nil && !opts.DisableAuthFlags { - panic(err) - } - if conf.Username == "" { - conf.Username, err = conf.Client.GetValueFromEnv(cmd.Context(), conf.DbPod, conf.Dialect.UserEnvNames()) - if err != nil { - conf.Username = conf.Dialect.DefaultUser() - log.WithField("user", conf.Username).Debug("could not detect user from pod env, using default") - } else { - log.WithField("user", conf.Username).Debug("found user in pod env") + group.Go(func() error { + conf.Username, err = cmd.Flags().GetString(consts.UsernameFlag) + if err != nil && !opts.DisableAuthFlags { + panic(err) } - } + if conf.Username == "" { + conf.Username, err = conf.Client.GetValueFromEnv(ctx, conf.DbPod, conf.Dialect.UserEnvNames()) + if err != nil { + conf.Username = conf.Dialect.DefaultUser() + log.WithField("user", conf.Username).Debug("could not detect user from pod env, using default") + } else { + log.WithField("user", conf.Username).Debug("found user in pod env") + } + } + return nil + }) - conf.Password, err = cmd.Flags().GetString(consts.PasswordFlag) - if err != nil && !opts.DisableAuthFlags { - panic(err) - } - if conf.Password == "" { - conf.Password, err = conf.Client.GetValueFromEnv(cmd.Context(), conf.DbPod, conf.Dialect.PasswordEnvNames(*conf)) - if err != nil { - return err + group.Go(func() error { + conf.Password, err = cmd.Flags().GetString(consts.PasswordFlag) + if err != nil && !opts.DisableAuthFlags { + panic(err) + } + if conf.Password == "" { + conf.Password, err = conf.Client.GetValueFromEnv(ctx, conf.DbPod, conf.Dialect.PasswordEnvNames(*conf)) + if err != nil { + return err + } } - } - if viper.GetBool(consts.LogRedactKey) { - log.AddHook(log_hooks.Redact(conf.Password)) - } - return nil + if viper.GetBool(consts.LogRedactKey) { + log.AddHook(log_hooks.Redact(conf.Password)) + } + + return nil + }) + + return group.Wait() } func CreateJob(cmd *cobra.Command, conf *config.Global, opts SetupOptions) error {