Skip to content

Commit

Permalink
✨ Ask for user confirmation before restoring
Browse files Browse the repository at this point in the history
  • Loading branch information
gabe565 committed Jun 11, 2021
1 parent 6c9c5eb commit b027664
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 0 deletions.
5 changes: 5 additions & 0 deletions cmd/restore/restore.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package restore
import (
"bufio"
"compress/gzip"
"github.com/clevyr/kubedb/internal/cli"
"github.com/clevyr/kubedb/internal/database/sqlformat"
"github.com/clevyr/kubedb/internal/kubernetes"
"github.com/clevyr/kubedb/internal/postgres"
Expand Down Expand Up @@ -91,6 +92,10 @@ func run(cmd *cobra.Command, args []string) (err error) {

log.Println("Restoring \"" + args[0] + "\" to \"" + postgresPod.Name + "\"")

if err = cli.Confirm(os.Stdin, false); err != nil {
return err
}

ch := make(chan error)
go func() {
if clean {
Expand Down
40 changes: 40 additions & 0 deletions internal/cli/input.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package cli

import (
"bufio"
"errors"
"fmt"
"io"
"strings"
)

var UserDeclinedErr = errors.New("user declined")

func Confirm(r io.Reader, defaultVal bool) (err error) {
fmt.Print("Continue? ")
if defaultVal {
fmt.Print("[Y/n]: ")
} else {
fmt.Print("[y/N]: ")
}

buf := bufio.NewReader(r)
var response string
response, err = buf.ReadString('\n')
if err != nil {
return err
}

response = strings.ToLower(strings.TrimSpace(response))
switch response {
case "yes", "y":
return nil
case "no", "n":
return UserDeclinedErr
}
if defaultVal {
return nil
} else {
return UserDeclinedErr
}
}
39 changes: 39 additions & 0 deletions internal/cli/input_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package cli

import (
"os"
"strings"
"testing"
)

func testConfirm(response string, defaultVal bool) error {
temp := os.Stdout
os.Stdout = nil
result := Confirm(strings.NewReader(response + "\n"), defaultVal)
os.Stdout = temp
return result
}

func TestConfirm(t *testing.T) {
type confirmTestCase struct {
response string
defaultVal bool
error error
}

testCases := []confirmTestCase{
{response: "yes", defaultVal: true, error: nil},
{response: "yes", defaultVal: false, error: nil},
{response: "no", defaultVal: true, error: UserDeclinedErr},
{response: "no", defaultVal: false, error: UserDeclinedErr},
{response: "", defaultVal: true, error: nil},
{response: "", defaultVal: false, error: UserDeclinedErr},
}

for key, testCase := range testCases {
err := testConfirm(testCase.response, testCase.defaultVal)
if err != testCase.error {
t.Errorf("case %d: got %v; expected %v", key, err, testCase.error)
}
}
}

0 comments on commit b027664

Please sign in to comment.