Skip to content

Commit

Permalink
Passing context storage (#3941)
Browse files Browse the repository at this point in the history
Signed-off-by: Bob Maertz <[email protected]>
  • Loading branch information
bobmaertz authored Feb 4, 2025
1 parent 8c587b2 commit ad31b5d
Show file tree
Hide file tree
Showing 35 changed files with 527 additions and 500 deletions.
28 changes: 14 additions & 14 deletions server/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ type dexAPI struct {
}

func (d dexAPI) GetClient(ctx context.Context, req *api.GetClientReq) (*api.GetClientResp, error) {
c, err := d.s.GetClient(req.Id)
c, err := d.s.GetClient(ctx, req.Id)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -108,7 +108,7 @@ func (d dexAPI) UpdateClient(ctx context.Context, req *api.UpdateClientReq) (*ap
return nil, errors.New("update client: no client ID supplied")
}

err := d.s.UpdateClient(req.Id, func(old storage.Client) (storage.Client, error) {
err := d.s.UpdateClient(ctx, req.Id, func(old storage.Client) (storage.Client, error) {
if req.RedirectUris != nil {
old.RedirectURIs = req.RedirectUris
}
Expand All @@ -134,7 +134,7 @@ func (d dexAPI) UpdateClient(ctx context.Context, req *api.UpdateClientReq) (*ap
}

func (d dexAPI) DeleteClient(ctx context.Context, req *api.DeleteClientReq) (*api.DeleteClientResp, error) {
err := d.s.DeleteClient(req.Id)
err := d.s.DeleteClient(ctx, req.Id)
if err != nil {
if err == storage.ErrNotFound {
return &api.DeleteClientResp{NotFound: true}, nil
Expand Down Expand Up @@ -219,7 +219,7 @@ func (d dexAPI) UpdatePassword(ctx context.Context, req *api.UpdatePasswordReq)
return old, nil
}

if err := d.s.UpdatePassword(req.Email, updater); err != nil {
if err := d.s.UpdatePassword(ctx, req.Email, updater); err != nil {
if err == storage.ErrNotFound {
return &api.UpdatePasswordResp{NotFound: true}, nil
}
Expand All @@ -235,7 +235,7 @@ func (d dexAPI) DeletePassword(ctx context.Context, req *api.DeletePasswordReq)
return nil, errors.New("no email supplied")
}

err := d.s.DeletePassword(req.Email)
err := d.s.DeletePassword(ctx, req.Email)
if err != nil {
if err == storage.ErrNotFound {
return &api.DeletePasswordResp{NotFound: true}, nil
Expand Down Expand Up @@ -268,7 +268,7 @@ func (d dexAPI) GetDiscovery(ctx context.Context, req *api.DiscoveryReq) (*api.D
}

func (d dexAPI) ListPasswords(ctx context.Context, req *api.ListPasswordReq) (*api.ListPasswordResp, error) {
passwordList, err := d.s.ListPasswords()
passwordList, err := d.s.ListPasswords(ctx)
if err != nil {
d.logger.Error("failed to list passwords", "err", err)
return nil, fmt.Errorf("list passwords: %v", err)
Expand Down Expand Up @@ -298,7 +298,7 @@ func (d dexAPI) VerifyPassword(ctx context.Context, req *api.VerifyPasswordReq)
return nil, errors.New("no password to verify supplied")
}

password, err := d.s.GetPassword(req.Email)
password, err := d.s.GetPassword(ctx, req.Email)
if err != nil {
if err == storage.ErrNotFound {
return &api.VerifyPasswordResp{
Expand Down Expand Up @@ -327,7 +327,7 @@ func (d dexAPI) ListRefresh(ctx context.Context, req *api.ListRefreshReq) (*api.
return nil, err
}

offlineSessions, err := d.s.GetOfflineSessions(id.UserId, id.ConnId)
offlineSessions, err := d.s.GetOfflineSessions(ctx, id.UserId, id.ConnId)
if err != nil {
if err == storage.ErrNotFound {
// This means that this user-client pair does not have a refresh token yet.
Expand Down Expand Up @@ -381,7 +381,7 @@ func (d dexAPI) RevokeRefresh(ctx context.Context, req *api.RevokeRefreshReq) (*
return old, nil
}

if err := d.s.UpdateOfflineSessions(id.UserId, id.ConnId, updater); err != nil {
if err := d.s.UpdateOfflineSessions(ctx, id.UserId, id.ConnId, updater); err != nil {
if err == storage.ErrNotFound {
return &api.RevokeRefreshResp{NotFound: true}, nil
}
Expand All @@ -397,7 +397,7 @@ func (d dexAPI) RevokeRefresh(ctx context.Context, req *api.RevokeRefreshReq) (*
//
// TODO(ericchiang): we don't have any good recourse if this call fails.
// Consider garbage collection of refresh tokens with no associated ref.
if err := d.s.DeleteRefresh(refreshID); err != nil {
if err := d.s.DeleteRefresh(ctx, refreshID); err != nil {
d.logger.Error("failed to delete refresh token", "err", err)
return nil, err
}
Expand Down Expand Up @@ -448,7 +448,7 @@ func (d dexAPI) CreateConnector(ctx context.Context, req *api.CreateConnectorReq
return &api.CreateConnectorResp{}, nil
}

func (d dexAPI) UpdateConnector(_ context.Context, req *api.UpdateConnectorReq) (*api.UpdateConnectorResp, error) {
func (d dexAPI) UpdateConnector(ctx context.Context, req *api.UpdateConnectorReq) (*api.UpdateConnectorResp, error) {
if !featureflags.APIConnectorsCRUD.Enabled() {
return nil, fmt.Errorf("%s feature flag is not enabled", featureflags.APIConnectorsCRUD.Name)
}
Expand Down Expand Up @@ -485,7 +485,7 @@ func (d dexAPI) UpdateConnector(_ context.Context, req *api.UpdateConnectorReq)
return old, nil
}

if err := d.s.UpdateConnector(req.Id, updater); err != nil {
if err := d.s.UpdateConnector(ctx, req.Id, updater); err != nil {
if err == storage.ErrNotFound {
return &api.UpdateConnectorResp{NotFound: true}, nil
}
Expand All @@ -505,7 +505,7 @@ func (d dexAPI) DeleteConnector(ctx context.Context, req *api.DeleteConnectorReq
return nil, errors.New("no id supplied")
}

err := d.s.DeleteConnector(req.Id)
err := d.s.DeleteConnector(ctx, req.Id)
if err != nil {
if err == storage.ErrNotFound {
return &api.DeleteConnectorResp{NotFound: true}, nil
Expand All @@ -521,7 +521,7 @@ func (d dexAPI) ListConnectors(ctx context.Context, req *api.ListConnectorReq) (
return nil, fmt.Errorf("%s feature flag is not enabled", featureflags.APIConnectorsCRUD.Name)
}

connectorList, err := d.s.ListConnectors()
connectorList, err := d.s.ListConnectors(ctx)
if err != nil {
d.logger.Error("api: failed to list connectors", "err", err)
return nil, fmt.Errorf("list connectors: %v", err)
Expand Down
4 changes: 2 additions & 2 deletions server/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ func TestPassword(t *testing.T) {
t.Fatalf("Unable to update password: %v", err)
}

pass, err := s.GetPassword(updateReq.Email)
pass, err := s.GetPassword(ctx, updateReq.Email)
if err != nil {
t.Fatalf("Unable to retrieve password: %v", err)
}
Expand Down Expand Up @@ -449,7 +449,7 @@ func TestUpdateClient(t *testing.T) {
t.Errorf("expected in response NotFound: %t", tc.want.NotFound)
}

client, err := s.GetClient(tc.req.Id)
client, err := s.GetClient(ctx, tc.req.Id)
if err != nil {
t.Errorf("no client found in the storage: %v", err)
}
Expand Down
18 changes: 10 additions & 8 deletions server/deviceflowhandlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ func (s *Server) handleDeviceTokenDeprecated(w http.ResponseWriter, r *http.Requ
}

func (s *Server) handleDeviceToken(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
deviceCode := r.Form.Get("device_code")
if deviceCode == "" {
s.tokenErrHelper(w, errInvalidRequest, "No device code received", http.StatusBadRequest)
Expand All @@ -208,7 +209,7 @@ func (s *Server) handleDeviceToken(w http.ResponseWriter, r *http.Request) {
now := s.now()

// Grab the device token, check validity
deviceToken, err := s.storage.GetDeviceToken(deviceCode)
deviceToken, err := s.storage.GetDeviceToken(ctx, deviceCode)
if err != nil {
if err != storage.ErrNotFound {
s.logger.ErrorContext(r.Context(), "failed to get device code", "err", err)
Expand Down Expand Up @@ -240,7 +241,7 @@ func (s *Server) handleDeviceToken(w http.ResponseWriter, r *http.Request) {
return old, nil
}
// Update device token last request time in storage
if err := s.storage.UpdateDeviceToken(deviceCode, updater); err != nil {
if err := s.storage.UpdateDeviceToken(ctx, deviceCode, updater); err != nil {
s.logger.ErrorContext(r.Context(), "failed to update device token", "err", err)
s.renderError(r, w, http.StatusInternalServerError, "")
return
Expand Down Expand Up @@ -299,7 +300,7 @@ func (s *Server) handleDeviceCallback(w http.ResponseWriter, r *http.Request) {
return
}

authCode, err := s.storage.GetAuthCode(code)
authCode, err := s.storage.GetAuthCode(ctx, code)
if err != nil || s.now().After(authCode.Expiry) {
errCode := http.StatusBadRequest
if err != nil && err != storage.ErrNotFound {
Expand All @@ -311,7 +312,7 @@ func (s *Server) handleDeviceCallback(w http.ResponseWriter, r *http.Request) {
}

// Grab the device request from storage
deviceReq, err := s.storage.GetDeviceRequest(userCode)
deviceReq, err := s.storage.GetDeviceRequest(ctx, userCode)
if err != nil || s.now().After(deviceReq.Expiry) {
errCode := http.StatusBadRequest
if err != nil && err != storage.ErrNotFound {
Expand All @@ -322,7 +323,7 @@ func (s *Server) handleDeviceCallback(w http.ResponseWriter, r *http.Request) {
return
}

client, err := s.storage.GetClient(deviceReq.ClientID)
client, err := s.storage.GetClient(ctx, deviceReq.ClientID)
if err != nil {
if err != storage.ErrNotFound {
s.logger.ErrorContext(r.Context(), "failed to get client", "err", err)
Expand All @@ -345,7 +346,7 @@ func (s *Server) handleDeviceCallback(w http.ResponseWriter, r *http.Request) {
}

// Grab the device token from storage
old, err := s.storage.GetDeviceToken(deviceReq.DeviceCode)
old, err := s.storage.GetDeviceToken(ctx, deviceReq.DeviceCode)
if err != nil || s.now().After(old.Expiry) {
errCode := http.StatusBadRequest
if err != nil && err != storage.ErrNotFound {
Expand Down Expand Up @@ -373,7 +374,7 @@ func (s *Server) handleDeviceCallback(w http.ResponseWriter, r *http.Request) {
}

// Update refresh token in the storage, store the token and mark as complete
if err := s.storage.UpdateDeviceToken(deviceReq.DeviceCode, updater); err != nil {
if err := s.storage.UpdateDeviceToken(ctx, deviceReq.DeviceCode, updater); err != nil {
s.logger.ErrorContext(r.Context(), "failed to update device token", "err", err)
s.renderError(r, w, http.StatusBadRequest, "")
return
Expand All @@ -391,6 +392,7 @@ func (s *Server) handleDeviceCallback(w http.ResponseWriter, r *http.Request) {
}

func (s *Server) verifyUserCode(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
switch r.Method {
case http.MethodPost:
err := r.ParseForm()
Expand All @@ -409,7 +411,7 @@ func (s *Server) verifyUserCode(w http.ResponseWriter, r *http.Request) {
userCode = strings.ToUpper(userCode)

// Find the user code in the available requests
deviceRequest, err := s.storage.GetDeviceRequest(userCode)
deviceRequest, err := s.storage.GetDeviceRequest(ctx, userCode)
if err != nil || s.now().After(deviceRequest.Expiry) {
if err != nil && err != storage.ErrNotFound {
s.logger.ErrorContext(r.Context(), "failed to get device request", "err", err)
Expand Down
Loading

0 comments on commit ad31b5d

Please sign in to comment.