diff --git a/cmd/dump/dump.go b/cmd/dump/dump.go index c0cefdf7..bf5777e3 100644 --- a/cmd/dump/dump.go +++ b/cmd/dump/dump.go @@ -4,7 +4,7 @@ import ( "bufio" "bytes" "compress/gzip" - "errors" + "github.com/clevyr/kubedb/internal/database/sqlformat" "github.com/clevyr/kubedb/internal/kubernetes" "github.com/clevyr/kubedb/internal/postgres" "github.com/spf13/cobra" @@ -26,12 +26,6 @@ var Command = &cobra.Command{ RunE: run, } -const ( - GzipFormat = iota - CustomFormat - PlainFormat -) - var ( dbname string username string @@ -60,17 +54,14 @@ func init() { Command.Flags().StringArrayVar(&excludeTableData, "exclude-table-data", []string{}, "do NOT dump data for the specified table(s)") } -func preRun(cmd *cobra.Command, args []string) error { - format, _ := cmd.Flags().GetString("format") - switch format { - case "gzip", "gz", "g": - outputFormat = GzipFormat - case "plain", "sql", "p": - outputFormat = PlainFormat - case "custom", "c": - outputFormat = CustomFormat - default: - return errors.New("invalid output format specified") +func preRun(cmd *cobra.Command, args []string) (err error) { + formatStr, err := cmd.Flags().GetString("format") + if err != nil { + return err + } + outputFormat, err = sqlformat.ParseFormat(formatStr) + if err != nil { + return err } return nil } @@ -127,9 +118,9 @@ func run(cmd *cobra.Command, args []string) (err error) { }() switch outputFormat { - case GzipFormat, CustomFormat: + case sqlformat.Gzip, sqlformat.Custom: _, err = io.Copy(fileWriter, pr) - case PlainFormat: + case sqlformat.Plain: var gzr *gzip.Reader gzr, err = gzip.NewReader(pr) if err != nil { @@ -165,11 +156,11 @@ func generateFilename(directory, namespace string) (string, error) { err = t.Execute(&tpl, data) switch outputFormat { - case GzipFormat: + case sqlformat.Gzip: tpl.WriteString(".sql.gz") - case PlainFormat: + case sqlformat.Plain: tpl.WriteString(".sql") - case CustomFormat: + case sqlformat.Custom: tpl.WriteString(".dmp") } @@ -193,7 +184,7 @@ func buildCommand() []string { for _, table := range excludeTableData { cmd = append(cmd, "--exclude-table-data=" + table) } - if outputFormat == CustomFormat { + if outputFormat == sqlformat.Custom { cmd = append(cmd, "--format=c") } else { cmd = append(cmd, "|", "gzip", "--force") diff --git a/cmd/restore/restore.go b/cmd/restore/restore.go index 7b6f7701..8fa6aba0 100644 --- a/cmd/restore/restore.go +++ b/cmd/restore/restore.go @@ -3,7 +3,7 @@ package restore import ( "bufio" "compress/gzip" - "errors" + "github.com/clevyr/kubedb/internal/database/sqlformat" "github.com/clevyr/kubedb/internal/kubernetes" "github.com/clevyr/kubedb/internal/postgres" "github.com/spf13/cobra" @@ -14,12 +14,6 @@ import ( "strings" ) -const ( - GzipContentType = iota - CustomContentType - PlainContentType -) - var Command = &cobra.Command{ Use: "restore", Aliases: []string{"r"}, @@ -52,25 +46,15 @@ func init() { } func preRun(cmd *cobra.Command, args []string) error { - format, _ := cmd.Flags().GetString("format") - switch format { - case "gzip", "gz", "g": - inputFormat = GzipContentType - case "plain", "sql", "p": - inputFormat = PlainContentType - case "custom", "c": - inputFormat = CustomContentType - default: - lower := strings.ToLower(args[0]) - switch { - case strings.HasSuffix(lower, ".sql.gz"): - inputFormat = GzipContentType - case strings.HasSuffix(lower, ".dmp"): - inputFormat = CustomContentType - case strings.HasSuffix(lower, ".sql"): - inputFormat = PlainContentType - default: - return errors.New("invalid input file type") + formatStr, err := cmd.Flags().GetString("format") + if err != nil { + return err + } + inputFormat, err = sqlformat.ParseFormat(formatStr) + if err != nil { + inputFormat, err = sqlformat.ParseFilename(args[0]) + if err != nil { + return err } } return nil @@ -111,7 +95,7 @@ func run(cmd *cobra.Command, args []string) (err error) { go func() { if clean { resetReader := strings.NewReader("drop schema public cascade; create schema public;") - err := kubernetes.Exec(client, postgresPod, buildCommand(PlainContentType, false), resetReader, os.Stdout, false) + err := kubernetes.Exec(client, postgresPod, buildCommand(sqlformat.Plain, false), resetReader, os.Stdout, false) if err != nil { pw.Close() ch <- err @@ -125,9 +109,9 @@ func run(cmd *cobra.Command, args []string) (err error) { }() switch inputFormat { - case GzipContentType, CustomContentType: + case sqlformat.Gzip, sqlformat.Custom: _, err = io.Copy(pw, fileReader) - case PlainContentType: + case sqlformat.Plain: gzw := gzip.NewWriter(pw) _, err = io.Copy(gzw, fileReader) gzw.Close() @@ -147,12 +131,12 @@ func run(cmd *cobra.Command, args []string) (err error) { func buildCommand(inputFormat uint8, gunzip bool) []string { cmd := []string{"PGPASSWORD=" + password} switch inputFormat { - case GzipContentType, PlainContentType: + case sqlformat.Gzip, sqlformat.Plain: if gunzip { cmd = append([]string{"gunzip", "--force", "|"}, cmd...) } cmd = append(cmd, "psql") - case CustomContentType: + case sqlformat.Custom: cmd = append(cmd, "pg_restore", "--format=custom", "--verbose") if noOwner { cmd = append(cmd, "--no-owner") diff --git a/internal/database/sqlformat/format.go b/internal/database/sqlformat/format.go new file mode 100644 index 00000000..cff9a558 --- /dev/null +++ b/internal/database/sqlformat/format.go @@ -0,0 +1,42 @@ +package sqlformat + +import ( + "errors" + "strings" +) + +const ( + Unknown = iota + Gzip + Plain + Custom +) + +var UnknownFormatError = errors.New("unknown format specified") + +func ParseFormat(format string) (uint8, error) { + format = strings.ToLower(format) + switch format { + case "gzip", "gz", "g": + return Gzip, nil + case "plain", "sql", "p": + return Plain, nil + case "custom", "c": + return Custom, nil + default: + return Unknown, UnknownFormatError + } +} + +func ParseFilename(filename string) (uint8, error) { + filename = strings.ToLower(filename) + switch { + case strings.HasSuffix(filename, ".sql.gz"): + return Gzip, nil + case strings.HasSuffix(filename, ".dmp"): + return Custom, nil + case strings.HasSuffix(filename, ".sql"): + return Plain, nil + } + return Unknown, UnknownFormatError +}