diff --git a/cmd/restore/restore.go b/cmd/restore/restore.go index 4ef1681b..7b6f7701 100644 --- a/cmd/restore/restore.go +++ b/cmd/restore/restore.go @@ -3,21 +3,29 @@ package restore import ( "bufio" "compress/gzip" + "errors" "github.com/clevyr/kubedb/internal/kubernetes" "github.com/clevyr/kubedb/internal/postgres" "github.com/spf13/cobra" "io" _ "k8s.io/client-go/plugin/pkg/client/auth" "log" - "net/http" "os" "strings" ) +const ( + GzipContentType = iota + CustomContentType + PlainContentType +) + var Command = &cobra.Command{ Use: "restore", Aliases: []string{"r"}, Short: "Restore a database", + Args: cobra.ExactArgs(1), + PreRunE: preRun, RunE: run, } @@ -25,7 +33,10 @@ var ( dbname string username string password string + inputFormat uint8 singleTransaction bool + clean bool + noOwner bool ) func init() { @@ -33,7 +44,36 @@ func init() { Command.Flags().StringVarP(&username, "username", "U", "postgres", "database username") Command.Flags().StringVarP(&password, "password", "p", "", "database password") - Command.Flags().BoolVar(&singleTransaction, "single-transaction", true, "execute as a single transaction") + Command.Flags().StringP("format", "F", "", "input format. inferred by default ([g]zip, [c]ustom, [p]lain text)") + + Command.Flags().BoolVarP(&singleTransaction, "single-transaction", "1", true, "restore as a single transaction") + Command.Flags().BoolVarP(&clean, "clean", "c", true, "clean (drop) database objects before recreating") + Command.Flags().BoolVarP(&noOwner, "no-owner", "O", true, "skip restoration of object ownership in plain-text format") +} + +func preRun(cmd *cobra.Command, args []string) error { + format, _ := cmd.Flags().GetString("format") + switch format { + case "gzip", "gz", "g": + inputFormat = GzipContentType + case "plain", "sql", "p": + inputFormat = PlainContentType + case "custom", "c": + inputFormat = CustomContentType + default: + lower := strings.ToLower(args[0]) + switch { + case strings.HasSuffix(lower, ".sql.gz"): + inputFormat = GzipContentType + case strings.HasSuffix(lower, ".dmp"): + inputFormat = CustomContentType + case strings.HasSuffix(lower, ".sql"): + inputFormat = PlainContentType + default: + return errors.New("invalid input file type") + } + } + return nil } func run(cmd *cobra.Command, args []string) (err error) { @@ -69,26 +109,25 @@ func run(cmd *cobra.Command, args []string) (err error) { ch := make(chan error) go func() { - resetReader := strings.NewReader("drop schema public cascade; create schema public;") - err := kubernetes.Exec(client, postgresPod, buildCommand(false), resetReader, os.Stdout, false) - if err != nil { - ch <- err - return + if clean { + resetReader := strings.NewReader("drop schema public cascade; create schema public;") + err := kubernetes.Exec(client, postgresPod, buildCommand(PlainContentType, false), resetReader, os.Stdout, false) + if err != nil { + pw.Close() + ch <- err + return + } } - err = kubernetes.Exec(client, postgresPod, buildCommand(true), pr, os.Stdout, false) + err := kubernetes.Exec(client, postgresPod, buildCommand(inputFormat, true), pr, os.Stdout, false) + pw.Close() ch <- err }() - contentType, err := getFileContentType(infile) - if err != nil { - return err - } - - switch contentType { - case "application/x-gzip": + switch inputFormat { + case GzipContentType, CustomContentType: _, err = io.Copy(pw, fileReader) - default: + case PlainContentType: gzw := gzip.NewWriter(pw) _, err = io.Copy(gzw, fileReader) gzw.Close() @@ -105,31 +144,23 @@ func run(cmd *cobra.Command, args []string) (err error) { return err } -func buildCommand(gunzip bool) []string { - var cmd []string - if gunzip { - cmd = []string{"gunzip", "|"} +func buildCommand(inputFormat uint8, gunzip bool) []string { + cmd := []string{"PGPASSWORD=" + password} + switch inputFormat { + case GzipContentType, PlainContentType: + if gunzip { + cmd = append([]string{"gunzip", "--force", "|"}, cmd...) + } + cmd = append(cmd, "psql") + case CustomContentType: + cmd = append(cmd, "pg_restore", "--format=custom", "--verbose") + if noOwner { + cmd = append(cmd, "--no-owner") + } } - cmd = append(cmd, "PGPASSWORD="+password, "psql", "--username="+username, "--dbname="+dbname) + cmd = append(cmd, "--username="+username, "--dbname="+dbname) if singleTransaction { cmd = append(cmd, "--single-transaction") } return []string{"sh", "-c", strings.Join(cmd, " ")} } - -func getFileContentType(infile *os.File) (string, error) { - // Only the first 512 bytes are used to sniff the content type. - buffer := make([]byte, 512) - - _, err := infile.Read(buffer) - if err != nil { - return "", err - } - - // Use the net/http package's handy DectectContentType function. Always returns a valid - // content-type by returning "application/octet-stream" if no others seemed to match. - contentType := http.DetectContentType(buffer) - - _, err = infile.Seek(0, 0) - return contentType, err -}