Skip to content

Commit

Permalink
fix(storage): added context state interface
Browse files Browse the repository at this point in the history
  • Loading branch information
iyear committed Apr 30, 2023
1 parent b7e0426 commit 82c04d7
Showing 1 changed file with 31 additions and 30 deletions.
61 changes: 31 additions & 30 deletions pkg/storage/state.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package storage

import (
"context"
"encoding/json"
"errors"
"github.com/gotd/td/telegram/updates"
Expand All @@ -16,7 +17,7 @@ func NewState(kv kv.KV) updates.StateStorage {
return &State{kv: kv}
}

func (s *State) Get(key string, v interface{}) error {
func (s *State) Get(_ context.Context, key string, v interface{}) error {
data, err := s.kv.Get(key)
if err != nil {
return err
Expand All @@ -25,7 +26,7 @@ func (s *State) Get(key string, v interface{}) error {
return json.Unmarshal(data, v)
}

func (s *State) Set(key string, v interface{}) error {
func (s *State) Set(_ context.Context, key string, v interface{}) error {
data, err := json.Marshal(v)
if err != nil {
return err
Expand All @@ -34,10 +35,10 @@ func (s *State) Set(key string, v interface{}) error {
return s.kv.Set(key, data)
}

func (s *State) GetState(userID int64) (updates.State, bool, error) {
func (s *State) GetState(ctx context.Context, userID int64) (updates.State, bool, error) {
state := updates.State{}

if err := s.Get(key.State(userID), &state); err != nil {
if err := s.Get(ctx, key.State(userID), &state); err != nil {
if errors.Is(err, kv.ErrNotFound) {
return state, false, nil
}
Expand All @@ -47,69 +48,69 @@ func (s *State) GetState(userID int64) (updates.State, bool, error) {
return state, true, nil
}

func (s *State) SetState(userID int64, state updates.State) error {
if err := s.Set(key.State(userID), state); err != nil {
func (s *State) SetState(ctx context.Context, userID int64, state updates.State) error {
if err := s.Set(ctx, key.State(userID), state); err != nil {
return err
}

return s.Set(key.StateChannel(userID), struct{}{})
return s.Set(ctx, key.StateChannel(userID), struct{}{})
}

func (s *State) SetPts(userID int64, pts int) error {
func (s *State) SetPts(ctx context.Context, userID int64, pts int) error {
state, k := updates.State{}, key.State(userID)

if err := s.Get(k, &state); err != nil {
if err := s.Get(ctx, k, &state); err != nil {
return err
}
state.Pts = pts
return s.Set(k, state)
return s.Set(ctx, k, state)
}

func (s *State) SetQts(userID int64, qts int) error {
func (s *State) SetQts(ctx context.Context, userID int64, qts int) error {
state, k := updates.State{}, key.State(userID)

if err := s.Get(k, &state); err != nil {
if err := s.Get(ctx, k, &state); err != nil {
return err
}
state.Qts = qts
return s.Set(k, state)
return s.Set(ctx, k, state)
}

func (s *State) SetDate(userID int64, date int) error {
func (s *State) SetDate(ctx context.Context, userID int64, date int) error {
state, k := updates.State{}, key.State(userID)

if err := s.Get(k, &state); err != nil {
if err := s.Get(ctx, k, &state); err != nil {
return err
}
state.Date = date
return s.Set(k, state)
return s.Set(ctx, k, state)
}

func (s *State) SetSeq(userID int64, seq int) error {
func (s *State) SetSeq(ctx context.Context, userID int64, seq int) error {
state, k := updates.State{}, key.State(userID)

if err := s.Get(k, &state); err != nil {
if err := s.Get(ctx, k, &state); err != nil {
return err
}
state.Seq = seq
return s.Set(k, state)
return s.Set(ctx, k, state)
}

func (s *State) SetDateSeq(userID int64, date, seq int) error {
func (s *State) SetDateSeq(ctx context.Context, userID int64, date, seq int) error {
state, k := updates.State{}, key.State(userID)

if err := s.Get(k, &state); err != nil {
if err := s.Get(ctx, k, &state); err != nil {
return err
}
state.Date = date
state.Seq = seq
return s.Set(k, state)
return s.Set(ctx, k, state)
}

func (s *State) GetChannelPts(userID, channelID int64) (int, bool, error) {
func (s *State) GetChannelPts(ctx context.Context, userID, channelID int64) (int, bool, error) {
c := make(map[int64]int)

if err := s.Get(key.StateChannel(userID), &c); err != nil {
if err := s.Get(ctx, key.StateChannel(userID), &c); err != nil {
if errors.Is(err, kv.ErrNotFound) {
return 0, false, nil
}
Expand All @@ -124,25 +125,25 @@ func (s *State) GetChannelPts(userID, channelID int64) (int, bool, error) {
return pts, true, nil
}

func (s *State) SetChannelPts(userID, channelID int64, pts int) error {
func (s *State) SetChannelPts(ctx context.Context, userID, channelID int64, pts int) error {
c, k := make(map[int64]int), key.StateChannel(userID)

if err := s.Get(k, &c); err != nil {
if err := s.Get(ctx, k, &c); err != nil {
return err
}
c[channelID] = pts
return s.Set(k, c)
return s.Set(ctx, k, c)
}

func (s *State) ForEachChannels(userID int64, f func(channelID int64, pts int) error) error {
func (s *State) ForEachChannels(ctx context.Context, userID int64, f func(ctx context.Context, channelID int64, pts int) error) error {
c := make(map[int64]int)

if err := s.Get(key.StateChannel(userID), &c); err != nil {
if err := s.Get(ctx, key.StateChannel(userID), &c); err != nil {
return err
}

for channelID, pts := range c {
if err := f(channelID, pts); err != nil {
if err := f(ctx, channelID, pts); err != nil {
return err
}
}
Expand Down

0 comments on commit 82c04d7

Please sign in to comment.