Skip to content

Commit

Permalink
perf(dump,restore): Optimize tracking of written bytes
Browse files Browse the repository at this point in the history
  • Loading branch information
gabe565 committed Oct 22, 2024
1 parent 7a5e985 commit 4f4b671
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 54 deletions.
24 changes: 14 additions & 10 deletions internal/actions/dump/dump.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"log/slog"
"os"
"path/filepath"
"sync/atomic"
"time"

"github.com/charmbracelet/lipgloss"
Expand All @@ -23,6 +24,7 @@ import (
"github.com/clevyr/kubedb/internal/storage"
"github.com/clevyr/kubedb/internal/tui"
"github.com/clevyr/kubedb/internal/util"
"github.com/dustin/go-humanize"
"github.com/muesli/termenv"
"github.com/spf13/viper"
"golang.org/x/sync/errgroup"
Expand Down Expand Up @@ -134,7 +136,7 @@ func (action Dump) Run(ctx context.Context) error {
pr = gzPipeReader
}

sizeW := &util.SizeWriter{}
var written atomic.Int64
errGroup.Go(func() error {
// Begin copying export to local file
defer func(pr io.ReadCloser) {
Expand All @@ -151,14 +153,16 @@ func (action Dump) Run(ctx context.Context) error {
}
}

if _, err := io.Copy(io.MultiWriter(f, bar, sizeW), r); err != nil {
n, err := io.Copy(io.MultiWriter(f, bar), r)
written.Add(n)
if err != nil {
return err
}
return f.Close()
})

util.OnFinalize(func(err error) {
action.printSummary(err, time.Since(startTime).Truncate(10*time.Millisecond), sizeW)
action.printSummary(err, time.Since(startTime).Truncate(10*time.Millisecond), written.Load())
})

if err := errGroup.Wait(); err != nil {
Expand All @@ -169,12 +173,12 @@ func (action Dump) Run(ctx context.Context) error {

actionLog.Info("Dump complete",
"took", time.Since(startTime).Truncate(10*time.Millisecond),
"size", sizeW,
"size", written.Load(),
)

if handler, ok := notifier.FromContext(ctx); ok {
if logger, ok := handler.(notifier.Logs); ok {
logger.SetLog(action.summary(nil, time.Since(startTime).Truncate(10*time.Millisecond), sizeW, true))
logger.SetLog(action.summary(nil, time.Since(startTime).Truncate(10*time.Millisecond), written.Load(), true))
}
}
return nil
Expand All @@ -200,7 +204,7 @@ func (action Dump) buildCommand() (*command.Builder, error) {
return cmd, nil
}

func (action Dump) summary(err error, took time.Duration, size *util.SizeWriter, plain bool) string {
func (action Dump) summary(err error, took time.Duration, written int64, plain bool) string {
var r *lipgloss.Renderer
if plain {
r = lipgloss.NewRenderer(os.Stdout, termenv.WithTTY(false))
Expand All @@ -218,8 +222,8 @@ func (action Dump) summary(err error, took time.Duration, size *util.SizeWriter,
Row("Took", took.String())
if err != nil {
t.Row("Error", tui.ErrStyle(r).Render(err.Error()))
} else if size != nil {
t.Row("Size", size.String())
} else {
t.Row("Size", humanize.IBytes(uint64(written))) //nolint:gosec
}

if plain {
Expand All @@ -232,10 +236,10 @@ func (action Dump) summary(err error, took time.Duration, size *util.SizeWriter,
)
}

func (action Dump) printSummary(err error, took time.Duration, size *util.SizeWriter) {
func (action Dump) printSummary(err error, took time.Duration, written int64) {
out := os.Stdout
if action.Filename == "-" {
out = os.Stderr
}
_, _ = io.WriteString(out, "\n"+action.summary(err, took, size, false)+"\n")
_, _ = io.WriteString(out, "\n"+action.summary(err, took, written, false)+"\n")
}
55 changes: 34 additions & 21 deletions internal/actions/restore/restore.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"log/slog"
"os"
"strings"
"sync/atomic"
"time"

"github.com/charmbracelet/huh"
Expand All @@ -23,6 +24,7 @@ import (
"github.com/clevyr/kubedb/internal/storage"
"github.com/clevyr/kubedb/internal/tui"
"github.com/clevyr/kubedb/internal/util"
"github.com/dustin/go-humanize"
"github.com/muesli/termenv"
"github.com/spf13/viper"
"golang.org/x/sync/errgroup"
Expand Down Expand Up @@ -90,20 +92,22 @@ func (action Restore) Run(ctx context.Context) error {
return action.runInDatabasePod(ctx, pr, bar.Logger(), bar.Logger(), action.Format)
})

sizeW := &util.SizeWriter{}
var written atomic.Int64
errGroup.Go(func() error {
defer func(pw io.WriteCloser) {
_ = pw.Close()
}(pw)

w := io.MultiWriter(pw, sizeW, bar)
w := io.MultiWriter(pw, bar)

// Clean database
if action.Clean && action.Format != sqlformat.Custom {
if db, ok := action.Dialect.(config.DBDatabaseDropper); ok {
dropQuery := db.DatabaseDropQuery(action.Database)
actionLog.Info("Cleaning existing data")
if err := action.copy(w, strings.NewReader(dropQuery)); err != nil {
n, err := action.copy(w, strings.NewReader(dropQuery))
written.Add(n)
if err != nil {
return err
}
}
Expand All @@ -123,11 +127,15 @@ func (action Restore) Run(ctx context.Context) error {
}(f)
}

if _, err := io.Copy(w, f); err != nil {
n, err := io.Copy(w, f)
written.Add(n)
if err != nil {
return err
}
case sqlformat.Plain, sqlformat.Custom:
if err := action.copy(w, f); err != nil {
n, err := action.copy(w, f)
written.Add(n)
if err != nil {
return err
}
}
Expand All @@ -151,11 +159,15 @@ func (action Restore) Run(ctx context.Context) error {
defer func() {
_ = pw.Close()
}()
return action.copy(pw, strings.NewReader(analyzeQuery))
n, err := action.copy(pw, strings.NewReader(analyzeQuery))
written.Add(n)
return err
})
}()
} else {
if err := action.copy(w, strings.NewReader(analyzeQuery)); err != nil {
n, err := action.copy(w, strings.NewReader(analyzeQuery))
written.Add(n)
if err != nil {
return err
}
}
Expand All @@ -171,7 +183,7 @@ func (action Restore) Run(ctx context.Context) error {
})

util.OnFinalize(func(err error) {
action.printSummary(err, time.Since(startTime).Truncate(10*time.Millisecond), sizeW)
action.printSummary(err, time.Since(startTime).Truncate(10*time.Millisecond), written.Load())
})

if err := errGroup.Wait(); err != nil {
Expand All @@ -182,12 +194,12 @@ func (action Restore) Run(ctx context.Context) error {

actionLog.Info("Restore complete",
"took", time.Since(startTime).Truncate(10*time.Millisecond),
"size", sizeW,
"size", written.Load(),
)

if handler, ok := notifier.FromContext(ctx); ok {
if logger, ok := handler.(notifier.Logs); ok {
logger.SetLog(action.summary(nil, time.Since(startTime).Truncate(10*time.Millisecond), sizeW, true))
logger.SetLog(action.summary(nil, time.Since(startTime).Truncate(10*time.Millisecond), written.Load(), true))
}
}
return nil
Expand All @@ -213,17 +225,18 @@ func (action Restore) buildCommand(inputFormat sqlformat.Format) (*command.Build
return cmd, nil
}

func (action Restore) copy(w io.Writer, r io.Reader) error {
func (action Restore) copy(w io.Writer, r io.Reader) (int64, error) {
if action.RemoteGzip {
gzw := gzip.NewWriter(w)
if _, err := io.Copy(gzw, r); err != nil {
return err
n, err := io.Copy(gzw, r)
if err != nil {
return n, err
}
return gzw.Close()
return n, gzw.Close()
}

_, err := io.Copy(w, r)
return err
n, err := io.Copy(w, r)
return n, err
}

func (action Restore) runInDatabasePod(ctx context.Context, r *io.PipeReader, stdout, stderr io.Writer, inputFormat sqlformat.Format) error {
Expand Down Expand Up @@ -289,7 +302,7 @@ func (action Restore) Confirm() (bool, error) {
return response, err
}

func (action Restore) summary(err error, took time.Duration, size *util.SizeWriter, plain bool) string {
func (action Restore) summary(err error, took time.Duration, written int64, plain bool) string {
var r *lipgloss.Renderer
if plain {
r = lipgloss.NewRenderer(os.Stdout, termenv.WithTTY(false))
Expand All @@ -302,8 +315,8 @@ func (action Restore) summary(err error, took time.Duration, size *util.SizeWrit
Row("Took", took.String())
if err != nil {
t.Row("Error", tui.ErrStyle(r).Render(err.Error()))
} else if size != nil {
t.Row("Size", size.String())
} else {
t.Row("Size", humanize.IBytes(uint64(written))) //nolint:gosec
}

if plain {
Expand All @@ -316,10 +329,10 @@ func (action Restore) summary(err error, took time.Duration, size *util.SizeWrit
)
}

func (action Restore) printSummary(err error, took time.Duration, size *util.SizeWriter) {
func (action Restore) printSummary(err error, took time.Duration, written int64) {
out := os.Stdout
if action.Filename == "-" {
out = os.Stderr
}
_, _ = io.WriteString(out, "\n"+action.summary(err, took, size, false)+"\n")
_, _ = io.WriteString(out, "\n"+action.summary(err, took, written, false)+"\n")
}
6 changes: 3 additions & 3 deletions internal/actions/restore/restore_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ func TestRestore_copy(t *testing.T) {
name string
fields fields
args args
wantW string
want string
wantErr require.ErrorAssertionFunc
}{
{"gzip", fields{Restore: config.Restore{Global: config.Global{RemoteGzip: true}}}, args{strings.NewReader(input)}, gzipped.String(), require.NoError},
Expand All @@ -106,9 +106,9 @@ func TestRestore_copy(t *testing.T) {
Analyze: tt.fields.Analyze,
}
w := &bytes.Buffer{}
err := action.copy(w, tt.args.r)
_, err := action.copy(w, tt.args.r)
tt.wantErr(t, err)
assert.Equal(t, tt.wantW, w.String())
assert.Equal(t, tt.want, w.String())
})
}
}
20 changes: 0 additions & 20 deletions internal/util/sizewriter.go

This file was deleted.

0 comments on commit 4f4b671

Please sign in to comment.