diff --git a/storage/sql.go b/storage/crdb/sql.go similarity index 83% rename from storage/sql.go rename to storage/crdb/sql.go index 372063994c..5cc380154f 100644 --- a/storage/sql.go +++ b/storage/crdb/sql.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package storage +package crdb import ( "database/sql" @@ -25,29 +25,29 @@ import ( ) // ToMillisSinceEpoch converts a timestamp into milliseconds since epoch -func ToMillisSinceEpoch(t time.Time) int64 { +func toMillisSinceEpoch(t time.Time) int64 { return t.UnixNano() / 1000000 } // FromMillisSinceEpoch converts -func FromMillisSinceEpoch(ts int64) time.Time { +func fromMillisSinceEpoch(ts int64) time.Time { return time.Unix(0, ts*1000000) } // SetNullStringIfValid assigns src to dest if src is Valid. -func SetNullStringIfValid(src sql.NullString, dest *string) { +func setNullStringIfValid(src sql.NullString, dest *string) { if src.Valid { *dest = src.String } } -// Row defines a common interface between sql.Row and sql.Rows(!) -type Row interface { +// row defines a common interface between sql.Row and sql.Rows(!) +type row interface { Scan(dest ...interface{}) error } // ReadTree takes a sql row and returns a tree -func ReadTree(row Row) (*trillian.Tree, error) { +func readTree(r row) (*trillian.Tree, error) { tree := &trillian.Tree{} // Enums and Datetimes need an extra conversion step @@ -57,7 +57,7 @@ func ReadTree(row Row) (*trillian.Tree, error) { var privateKey, publicKey []byte var deleted sql.NullBool var deleteMillis sql.NullInt64 - err := row.Scan( + err := r.Scan( &tree.TreeId, &treeState, &treeType, @@ -78,8 +78,8 @@ func ReadTree(row Row) (*trillian.Tree, error) { return nil, err } - SetNullStringIfValid(displayName, &tree.DisplayName) - SetNullStringIfValid(description, &tree.Description) + setNullStringIfValid(displayName, &tree.DisplayName) + setNullStringIfValid(description, &tree.Description) // Convert all things! if ts, ok := trillian.TreeState_value[treeState]; ok { @@ -106,11 +106,11 @@ func ReadTree(row Row) (*trillian.Tree, error) { treeState, treeType, hashStrategy, hashAlgorithm, signatureAlgorithm) } - tree.CreateTime = timestamppb.New(FromMillisSinceEpoch(createMillis)) + tree.CreateTime = timestamppb.New(fromMillisSinceEpoch(createMillis)) if err := tree.CreateTime.CheckValid(); err != nil { return nil, fmt.Errorf("failed to parse create time: %w", err) } - tree.UpdateTime = timestamppb.New(FromMillisSinceEpoch(updateMillis)) + tree.UpdateTime = timestamppb.New(fromMillisSinceEpoch(updateMillis)) if err := tree.UpdateTime.CheckValid(); err != nil { return nil, fmt.Errorf("failed to parse update time: %w", err) } @@ -118,7 +118,7 @@ func ReadTree(row Row) (*trillian.Tree, error) { tree.Deleted = deleted.Valid && deleted.Bool if tree.Deleted && deleteMillis.Valid { - tree.DeleteTime = timestamppb.New(FromMillisSinceEpoch(deleteMillis.Int64)) + tree.DeleteTime = timestamppb.New(fromMillisSinceEpoch(deleteMillis.Int64)) if err := tree.DeleteTime.CheckValid(); err != nil { return nil, fmt.Errorf("failed to parse delete time: %w", err) } diff --git a/storage/crdb/sqladminstorage.go b/storage/crdb/sqladminstorage.go index ec805163d3..c6da7bcfb2 100644 --- a/storage/crdb/sqladminstorage.go +++ b/storage/crdb/sqladminstorage.go @@ -144,7 +144,7 @@ func (t *adminTX) GetTree(ctx context.Context, treeID int64) (*trillian.Tree, er }() // GetTree is an entry point for most RPCs, let's provide somewhat nicer error messages. - tree, err := storage.ReadTree(stmt.QueryRowContext(ctx, treeID)) + tree, err := readTree(stmt.QueryRowContext(ctx, treeID)) switch { case err == sql.ErrNoRows: // ErrNoRows doesn't provide useful information, so we don't forward it. @@ -183,7 +183,7 @@ func (t *adminTX) ListTrees(ctx context.Context, includeDeleted bool) ([]*trilli }() trees := []*trillian.Tree{} for rows.Next() { - tree, err := storage.ReadTree(rows) + tree, err := readTree(rows) if err != nil { return nil, err } @@ -206,8 +206,8 @@ func (t *adminTX) CreateTree(ctx context.Context, tree *trillian.Tree) (*trillia } // Use the time truncated-to-millis throughout, as that's what's stored. - nowMillis := storage.ToMillisSinceEpoch(time.Now()) - now := storage.FromMillisSinceEpoch(nowMillis) + nowMillis := toMillisSinceEpoch(time.Now()) + now := fromMillisSinceEpoch(nowMillis) newTree := proto.Clone(tree).(*trillian.Tree) newTree.TreeId = id @@ -328,8 +328,8 @@ func (t *adminTX) UpdateTree(ctx context.Context, treeID int64, updateFunc func( // ensure all entries in SequencedLeafData are integrated. // Use the time truncated-to-millis throughout, as that's what's stored. - nowMillis := storage.ToMillisSinceEpoch(time.Now()) - now := storage.FromMillisSinceEpoch(nowMillis) + nowMillis := toMillisSinceEpoch(time.Now()) + now := fromMillisSinceEpoch(nowMillis) tree.UpdateTime = timestamppb.New(now) if err != nil { return nil, fmt.Errorf("failed to build update time: %v", err) @@ -366,7 +366,7 @@ func (t *adminTX) UpdateTree(ctx context.Context, treeID int64, updateFunc func( } func (t *adminTX) SoftDeleteTree(ctx context.Context, treeID int64) (*trillian.Tree, error) { - return t.updateDeleted(ctx, treeID, true /* deleted */, storage.ToMillisSinceEpoch(time.Now()) /* deleteTimeMillis */) + return t.updateDeleted(ctx, treeID, true /* deleted */, toMillisSinceEpoch(time.Now()) /* deleteTimeMillis */) } func (t *adminTX) UndeleteTree(ctx context.Context, treeID int64) (*trillian.Tree, error) { diff --git a/storage/mysql/admin_storage.go b/storage/mysql/admin_storage.go index 16c62404b6..16ed7c911b 100644 --- a/storage/mysql/admin_storage.go +++ b/storage/mysql/admin_storage.go @@ -143,7 +143,7 @@ func (t *adminTX) GetTree(ctx context.Context, treeID int64) (*trillian.Tree, er }() // GetTree is an entry point for most RPCs, let's provide somewhat nicer error messages. - tree, err := storage.ReadTree(stmt.QueryRowContext(ctx, treeID)) + tree, err := readTree(stmt.QueryRowContext(ctx, treeID)) switch { case err == sql.ErrNoRows: // ErrNoRows doesn't provide useful information, so we don't forward it. @@ -182,7 +182,7 @@ func (t *adminTX) ListTrees(ctx context.Context, includeDeleted bool) ([]*trilli }() trees := []*trillian.Tree{} for rows.Next() { - tree, err := storage.ReadTree(rows) + tree, err := readTree(rows) if err != nil { return nil, err } @@ -205,8 +205,8 @@ func (t *adminTX) CreateTree(ctx context.Context, tree *trillian.Tree) (*trillia } // Use the time truncated-to-millis throughout, as that's what's stored. - nowMillis := storage.ToMillisSinceEpoch(time.Now()) - now := storage.FromMillisSinceEpoch(nowMillis) + nowMillis := toMillisSinceEpoch(time.Now()) + now := fromMillisSinceEpoch(nowMillis) newTree := proto.Clone(tree).(*trillian.Tree) newTree.TreeId = id @@ -327,8 +327,8 @@ func (t *adminTX) UpdateTree(ctx context.Context, treeID int64, updateFunc func( // ensure all entries in SequencedLeafData are integrated. // Use the time truncated-to-millis throughout, as that's what's stored. - nowMillis := storage.ToMillisSinceEpoch(time.Now()) - now := storage.FromMillisSinceEpoch(nowMillis) + nowMillis := toMillisSinceEpoch(time.Now()) + now := fromMillisSinceEpoch(nowMillis) tree.UpdateTime = timestamppb.New(now) if err != nil { return nil, fmt.Errorf("failed to build update time: %v", err) @@ -365,7 +365,7 @@ func (t *adminTX) UpdateTree(ctx context.Context, treeID int64, updateFunc func( } func (t *adminTX) SoftDeleteTree(ctx context.Context, treeID int64) (*trillian.Tree, error) { - return t.updateDeleted(ctx, treeID, true /* deleted */, storage.ToMillisSinceEpoch(time.Now()) /* deleteTimeMillis */) + return t.updateDeleted(ctx, treeID, true /* deleted */, toMillisSinceEpoch(time.Now()) /* deleteTimeMillis */) } func (t *adminTX) UndeleteTree(ctx context.Context, treeID int64) (*trillian.Tree, error) { diff --git a/storage/mysql/sql.go b/storage/mysql/sql.go new file mode 100644 index 0000000000..a929952727 --- /dev/null +++ b/storage/mysql/sql.go @@ -0,0 +1,128 @@ +// Copyright 2018 Google LLC. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mysql + +import ( + "database/sql" + "fmt" + "time" + + "github.com/google/trillian" + "google.golang.org/protobuf/types/known/durationpb" + "google.golang.org/protobuf/types/known/timestamppb" +) + +// toMillisSinceEpoch converts a timestamp into milliseconds since epoch +func toMillisSinceEpoch(t time.Time) int64 { + return t.UnixNano() / 1000000 +} + +// fromMillisSinceEpoch converts +func fromMillisSinceEpoch(ts int64) time.Time { + return time.Unix(0, ts*1000000) +} + +// setNullStringIfValid assigns src to dest if src is Valid. +func setNullStringIfValid(src sql.NullString, dest *string) { + if src.Valid { + *dest = src.String + } +} + +// row defines a common interface between sql.Row and sql.Rows(!) +type row interface { + Scan(dest ...interface{}) error +} + +// readTree takes a sql row and returns a tree +func readTree(r row) (*trillian.Tree, error) { + tree := &trillian.Tree{} + + // Enums and Datetimes need an extra conversion step + var treeState, treeType, hashStrategy, hashAlgorithm, signatureAlgorithm string + var createMillis, updateMillis, maxRootDurationMillis int64 + var displayName, description sql.NullString + var privateKey, publicKey []byte + var deleted sql.NullBool + var deleteMillis sql.NullInt64 + err := r.Scan( + &tree.TreeId, + &treeState, + &treeType, + &hashStrategy, + &hashAlgorithm, + &signatureAlgorithm, + &displayName, + &description, + &createMillis, + &updateMillis, + &privateKey, + &publicKey, + &maxRootDurationMillis, + &deleted, + &deleteMillis, + ) + if err != nil { + return nil, err + } + + setNullStringIfValid(displayName, &tree.DisplayName) + setNullStringIfValid(description, &tree.Description) + + // Convert all things! + if ts, ok := trillian.TreeState_value[treeState]; ok { + tree.TreeState = trillian.TreeState(ts) + } else { + return nil, fmt.Errorf("unknown TreeState: %v", treeState) + } + if tt, ok := trillian.TreeType_value[treeType]; ok { + tree.TreeType = trillian.TreeType(tt) + } else { + return nil, fmt.Errorf("unknown TreeType: %v", treeType) + } + if hashStrategy != "RFC6962_SHA256" { + return nil, fmt.Errorf("unknown HashStrategy: %v", hashStrategy) + } + + // Let's make sure we didn't mismatch any of the casts above + ok := tree.TreeState.String() == treeState && + tree.TreeType.String() == treeType + if !ok { + return nil, fmt.Errorf( + "mismatched enum: tree = %v, enums = [%v, %v, %v, %v, %v]", + tree, + treeState, treeType, hashStrategy, hashAlgorithm, signatureAlgorithm) + } + + tree.CreateTime = timestamppb.New(fromMillisSinceEpoch(createMillis)) + if err := tree.CreateTime.CheckValid(); err != nil { + return nil, fmt.Errorf("failed to parse create time: %w", err) + } + tree.UpdateTime = timestamppb.New(fromMillisSinceEpoch(updateMillis)) + if err := tree.UpdateTime.CheckValid(); err != nil { + return nil, fmt.Errorf("failed to parse update time: %w", err) + } + tree.MaxRootDuration = durationpb.New(time.Duration(maxRootDurationMillis * int64(time.Millisecond))) + + tree.Deleted = deleted.Valid && deleted.Bool + if tree.Deleted && deleteMillis.Valid { + tree.DeleteTime = timestamppb.New(fromMillisSinceEpoch(deleteMillis.Int64)) + if err := tree.DeleteTime.CheckValid(); err != nil { + return nil, fmt.Errorf("failed to parse delete time: %w", err) + } + } + + return tree, nil +}