Skip to content

Commit

Permalink
refactor: proof store indices
Browse files Browse the repository at this point in the history
  • Loading branch information
bryanchriswhite committed Dec 21, 2023
1 parent 95138be commit fabfbef
Show file tree
Hide file tree
Showing 5 changed files with 217 additions and 81 deletions.
91 changes: 55 additions & 36 deletions x/supplier/keeper/proof.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,52 +9,58 @@ import (

// UpsertProof inserts or updates a specific proof in the store by index.
func (k Keeper) UpsertProof(ctx sdk.Context, proof types.Proof) {
store := prefix.NewStore(ctx.KVStore(k.storeKey), types.KeyPrefix(types.ProofKeyPrefix))
b := k.cdc.MustMarshal(&proof)
// TODO_NEXT(@bryanchriswhite #141): Refactor keys to support multiple indices.
store.Set(types.ProofKey(
proof.GetSessionHeader().GetSessionId(),
), b)
}
logger := k.Logger(ctx).With("method", "UpsertProof")

// GetProof returns a proof from its index
func (k Keeper) GetProof(
ctx sdk.Context,
sessionId string,

) (val types.Proof, found bool) {
store := prefix.NewStore(ctx.KVStore(k.storeKey), types.KeyPrefix(types.ProofKeyPrefix))

// TODO_NEXT(@bryanchriswhite #141): Refactor proof keys to support multiple indices.
b := store.Get(types.ProofKey(
sessionId,
))
if b == nil {
return val, false
proofBz := k.cdc.MustMarshal(&proof)
parentStore := ctx.KVStore(k.storeKey)

primaryStore := prefix.NewStore(parentStore, types.KeyPrefix(types.ProofPrimaryKeyPrefix))
sessionId := proof.GetSessionHeader().GetSessionId()
primaryKey := types.ProofPrimaryKey(sessionId, proof.GetSupplierAddress())
primaryStore.Set(primaryKey, proofBz)

logger.Info("upserted proof for supplier %s with primaryKey %s", proof.GetSupplierAddress(), primaryKey)

// Update the address index: supplierAddress -> [ProofPrimaryKey]
addressStoreIndex := prefix.NewStore(parentStore, types.KeyPrefix(types.ProofSupplierAddressPrefix))
addressKey := types.ProofSupplierAddressKey(proof.GetSupplierAddress(), primaryKey)
addressStoreIndex.Set(addressKey, primaryKey)

logger.Info("indexed Proof for supplier %s with primaryKey %s", proof.GetSupplierAddress(), primaryKey)

claim, found := k.GetClaim(ctx, sessionId, proof.GetSupplierAddress())
if !found {
// TOOD_IN_THIS_COMMIT: error: claim not found...
//panic("claim not found")
return
}

k.cdc.MustUnmarshal(b, &val)
return val, true
// Update the session end height index: sessionEndHeight -> [ProofPrimaryKey]
sessionHeightStoreIndex := prefix.NewStore(parentStore, types.KeyPrefix(types.ProofSessionEndHeightPrefix))
sessionEndHeight := claim.GetSessionHeader().GetSessionEndBlockHeight()
heightKey := types.ProofSupplierEndSessionHeightKey(sessionEndHeight, primaryKey)
sessionHeightStoreIndex.Set(heightKey, primaryKey)
}

// GetProof returns a proof from its index
func (k Keeper) GetProof(ctx sdk.Context, sessionId, supplierAdd string) (val types.Proof, found bool) {
primaryKey := types.ProofPrimaryKey(sessionId, supplierAdd)
return k.getProofByPrimaryKey(ctx, primaryKey)
}

// RemoveProof removes a proof from the store
func (k Keeper) RemoveProof(
ctx sdk.Context,
// TODO_NEXT(@bryanchriswhite #141): Refactor proof keys to support multiple indices.
index string,

) {
store := prefix.NewStore(ctx.KVStore(k.storeKey), types.KeyPrefix(types.ProofKeyPrefix))
// TODO_NEXT(@bryanchriswhite #141): Refactor proof keys to support multiple indices.
store.Delete(types.ProofKey(
index,
))
func (k Keeper) RemoveProof(ctx sdk.Context, sessionId, supplierAddr string) {
parentStore := ctx.KVStore(k.storeKey)
proofPrimaryStore := prefix.NewStore(parentStore, types.KeyPrefix(types.ProofPrimaryKeyPrefix))
proofPrimaryKey := types.ProofPrimaryKey(sessionId, supplierAddr)
proofPrimaryStore.Delete(proofPrimaryKey)
}

// GetAllProofs returns all proof
func (k Keeper) GetAllProofs(ctx sdk.Context) (list []types.Proof) {
store := prefix.NewStore(ctx.KVStore(k.storeKey), types.KeyPrefix(types.ProofKeyPrefix))
iterator := sdk.KVStorePrefixIterator(store, []byte{})
parentStore := ctx.KVStore(k.storeKey)
primaryStore := prefix.NewStore(parentStore, types.KeyPrefix(types.ProofPrimaryKeyPrefix))
iterator := sdk.KVStorePrefixIterator(primaryStore, []byte{})

defer iterator.Close()

Expand All @@ -66,3 +72,16 @@ func (k Keeper) GetAllProofs(ctx sdk.Context) (list []types.Proof) {

return
}

// getProofByPrimaryKey is a helper that retrieves, if exists, the Proof associated with the key provided
func (k Keeper) getProofByPrimaryKey(ctx sdk.Context, primaryKey []byte) (val types.Proof, found bool) {
store := prefix.NewStore(ctx.KVStore(k.storeKey), types.KeyPrefix(types.ProofPrimaryKeyPrefix))

proofBz := store.Get(primaryKey)
if proofBz == nil {
return val, false
}

k.cdc.MustUnmarshal(proofBz, &val)
return val, true
}
18 changes: 9 additions & 9 deletions x/supplier/keeper/proof_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
// Prevent strconv unused error
var _ = strconv.IntSize

// TODO_IN_THIS_COMMIT: consider if this should be in terms of sessionIds and supplierAddrs instead of n.
func createNProofs(keeper *keeper.Keeper, ctx sdk.Context, n int) []types.Proof {
proofs := make([]types.Proof, n)
for i := range proofs {
Expand All @@ -45,8 +46,10 @@ func TestProofGet(t *testing.T) {
keeper, ctx := keepertest.SupplierKeeper(t, nil)
proofs := createNProofs(keeper, ctx, 10)
for _, proof := range proofs {
rst, found := keeper.GetProof(ctx,
rst, found := keeper.GetProof(
ctx,
proof.GetSessionHeader().GetSessionId(),
proof.GetSupplierAddress(),
)
require.True(t, found)
require.Equal(t,
Expand All @@ -57,14 +60,11 @@ func TestProofGet(t *testing.T) {
}
func TestProofRemove(t *testing.T) {
keeper, ctx := keepertest.SupplierKeeper(t, nil)
items := createNProofs(keeper, ctx, 10)
for _, item := range items {
keeper.RemoveProof(ctx,
item.GetSessionHeader().GetSessionId(),
)
_, found := keeper.GetProof(ctx,
item.GetSessionHeader().GetSessionId(),
)
proofs := createNProofs(keeper, ctx, 10)
for _, proof := range proofs {
sessionId := proof.GetSessionHeader().GetSessionId()
keeper.RemoveProof(ctx, sessionId, proof.GetSupplierAddress())
_, found := keeper.GetProof(ctx, sessionId, proof.GetSupplierAddress())
require.False(t, found)
}
}
Expand Down
58 changes: 46 additions & 12 deletions x/supplier/keeper/query_proof.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package keeper

import (
"context"
"fmt"

"github.com/cosmos/cosmos-sdk/store/prefix"
sdk "github.com/cosmos/cosmos-sdk/types"
Expand All @@ -17,19 +18,53 @@ func (k Keeper) AllProofs(goCtx context.Context, req *types.QueryAllProofsReques
return nil, status.Error(codes.InvalidArgument, "invalid request")
}

var proofs []types.Proof
ctx := sdk.UnwrapSDKContext(goCtx)

store := ctx.KVStore(k.storeKey)
proofStore := prefix.NewStore(store, types.KeyPrefix(types.ProofKeyPrefix))

var (
// isCustomIndex is used to determined if we'll be using the store that points
// to the actual Claim values, or a secondary index that points to the primary keys.
isCustomIndex bool
keyPrefix []byte
)

switch filter := req.Filter.(type) {
case *types.QueryAllProofsRequest_SupplierAddress:
isCustomIndex = true
keyPrefix = types.KeyPrefix(types.ProofSupplierAddressPrefix)
keyPrefix = append(keyPrefix, []byte(filter.SupplierAddress)...)
case *types.QueryAllProofsRequest_SessionEndHeight:
isCustomIndex = true
keyPrefix = types.KeyPrefix(types.ProofSessionEndHeightPrefix)
keyPrefix = append(keyPrefix, []byte(fmt.Sprintf("%d", filter.SessionEndHeight))...)
case *types.QueryAllProofsRequest_SessionId:
isCustomIndex = false
keyPrefix = types.KeyPrefix(types.ProofPrimaryKeyPrefix)
keyPrefix = append(keyPrefix, []byte(filter.SessionId)...)
default:
isCustomIndex = false
keyPrefix = types.KeyPrefix(types.ProofPrimaryKeyPrefix)
}
proofStore := prefix.NewStore(store, keyPrefix)

var proofs []types.Proof
pageRes, err := query.Paginate(proofStore, req.Pagination, func(key []byte, value []byte) error {
var proof types.Proof
if err := k.cdc.Unmarshal(value, &proof); err != nil {
return err
if isCustomIndex {
// We retrieve the primaryKey, and need to query the actual proof before decoding it.
proof, proofFound := k.getProofByPrimaryKey(ctx, value)
if proofFound {
proofs = append(proofs, proof)
}
} else {
// The value is an encoded proof.
var proof types.Proof
if err := k.cdc.Unmarshal(value, &proof); err != nil {
return err
}

proofs = append(proofs, proof)
}

proofs = append(proofs, proof)
return nil
})

Expand All @@ -42,14 +77,13 @@ func (k Keeper) AllProofs(goCtx context.Context, req *types.QueryAllProofsReques

func (k Keeper) Proof(goCtx context.Context, req *types.QueryGetProofRequest) (*types.QueryGetProofResponse, error) {
if req == nil {
return nil, status.Error(codes.InvalidArgument, "invalid request")
err := types.ErrSupplierInvalidSessionId.Wrap("request cannot be nil")
return nil, status.Error(codes.InvalidArgument, err.Error())
}

ctx := sdk.UnwrapSDKContext(goCtx)

val, found := k.GetProof(
ctx,
req.Index,
)
val, found := k.GetProof(ctx, req.GetSessionId(), req.GetSupplierAddress())
if !found {
return nil, status.Error(codes.NotFound, "not found")
}
Expand Down
95 changes: 83 additions & 12 deletions x/supplier/keeper/query_proof_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (

keepertest "github.com/pokt-network/poktroll/testutil/keeper"
"github.com/pokt-network/poktroll/testutil/nullify"
"github.com/pokt-network/poktroll/testutil/sample"
"github.com/pokt-network/poktroll/x/supplier/types"
)

Expand All @@ -21,37 +22,107 @@ var _ = strconv.IntSize
func TestProofQuerySingle(t *testing.T) {
keeper, ctx := keepertest.SupplierKeeper(t, nil)
wctx := sdk.WrapSDKContext(ctx)
msgs := createNProofs(keeper, ctx, 2)
proofs := createNProofs(keeper, ctx, 2)

var randSupplierAddr = sample.AccAddress()
tests := []struct {
desc string
request *types.QueryGetProofRequest
desc string

request *types.QueryGetProofRequest

response *types.QueryGetProofResponse
err error
}{
{
desc: "First",
request: &types.QueryGetProofRequest{
Index: msgs[0].GetSessionHeader().GetSessionId(),
SessionId: proofs[0].GetSessionHeader().GetSessionId(),
SupplierAddress: proofs[0].SupplierAddress,
},
response: &types.QueryGetProofResponse{Proof: msgs[0]},
response: &types.QueryGetProofResponse{Proof: proofs[0]},
},
{
desc: "Second",
request: &types.QueryGetProofRequest{
Index: msgs[1].GetSessionHeader().GetSessionId(),
SessionId: proofs[1].GetSessionHeader().GetSessionId(),
SupplierAddress: proofs[1].SupplierAddress,
},
response: &types.QueryGetProofResponse{Proof: proofs[1]},
},
{
desc: "Proof Not Found - Random SessionId",

request: &types.QueryGetProofRequest{
SessionId: "not a real session id",
SupplierAddress: proofs[0].GetSupplierAddress(),
},

err: status.Error(
codes.NotFound,
types.ErrSupplierProofNotFound.Wrapf(
// TODO_CONSIDERATION: factor out error message format strings to constants.
"session ID %q and supplier %q",
"not a real session id",
proofs[0].GetSupplierAddress(),
).Error(),
),
},
{
desc: "Proof Not Found - Random Supplier Address",

request: &types.QueryGetProofRequest{
SessionId: proofs[0].GetSessionHeader().GetSessionId(),
SupplierAddress: randSupplierAddr,
},
response: &types.QueryGetProofResponse{Proof: msgs[1]},

err: status.Error(
codes.NotFound,
types.ErrSupplierProofNotFound.Wrapf(
"session ID %q and supplier %q",
proofs[0].GetSessionHeader().GetSessionId(),
randSupplierAddr,
).Error(),
),
},
{
desc: "KeyNotFound",
desc: "InvalidRequest - Missing SessionId",
request: &types.QueryGetProofRequest{
Index: strconv.Itoa(100000),
// SessionId: Intentionally Omitted
SupplierAddress: proofs[0].GetSupplierAddress(),
},
err: status.Error(codes.NotFound, "not found"),

err: status.Error(
codes.InvalidArgument,
types.ErrSupplierInvalidSessionId.Wrapf(
"invalid session ID for proof being retrieved %s",
"",
).Error(),
),
},
{
desc: "InvalidRequest",
err: status.Error(codes.InvalidArgument, "invalid request"),
desc: "InvalidRequest - Missing SupplierAddress",
request: &types.QueryGetProofRequest{
SessionId: proofs[0].GetSessionHeader().GetSessionId(),
// SupplierAddress: Intentionally Omitted,
},

err: status.Error(
codes.InvalidArgument,
types.ErrSupplierInvalidAddress.Wrap(
"invalid supplier address for proof being retrieved ; (empty address string is not allowed)",
).Error(),
),
},
{
desc: "InvalidRequest - nil QueryGetProofRequest",
request: nil,

err: status.Error(
codes.InvalidArgument,
types.ErrSupplierInvalidQueryRequest.Wrap(
"request cannot be nil",
).Error(),
),
},
}
for _, tc := range tests {
Expand Down
Loading

0 comments on commit fabfbef

Please sign in to comment.