Skip to content

Commit 91327fe

Browse files
github-actions[bot]maoanran
authored andcommitted
Squashed commit of the following:
commit e4779e1 Author: splaunov <[email protected]> Date: Tue Sep 17 13:49:38 2024 +0300 fix: transient_payload lost in API flow with session token exchange code (PS-482)
1 parent 36e624c commit 91327fe

File tree

5 files changed

+94
-10
lines changed

5 files changed

+94
-10
lines changed

selfservice/flow/transient_payload.go

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
package flow
2+
3+
import (
4+
"github.com/ory/x/sqlxx"
5+
"github.com/tidwall/gjson"
6+
"github.com/tidwall/sjson"
7+
)
8+
9+
const internalContextTransientPayloadPath = "transient_payload"
10+
11+
func SetTransientPayloadIntoInternalContext(flow InternalContexter, transientPayload sqlxx.JSONRawMessage) error {
12+
if flow.GetInternalContext() == nil {
13+
flow.EnsureInternalContext()
14+
}
15+
bytes, err := sjson.SetBytes(
16+
flow.GetInternalContext(),
17+
internalContextTransientPayloadPath,
18+
transientPayload,
19+
)
20+
if err != nil {
21+
return err
22+
}
23+
flow.SetInternalContext(bytes)
24+
25+
return nil
26+
}
27+
28+
func GetTransientPayloadFromInternalContext(flow InternalContexter) (sqlxx.JSONRawMessage, error) {
29+
if flow.GetInternalContext() == nil {
30+
flow.EnsureInternalContext()
31+
}
32+
raw := gjson.GetBytes(flow.GetInternalContext(), internalContextTransientPayloadPath)
33+
if !raw.IsObject() {
34+
return nil, nil
35+
}
36+
37+
return sqlxx.JSONRawMessage(raw.Raw), nil
38+
}

selfservice/strategy/oidc/strategy.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,14 @@ func (s *Strategy) ValidateCallback(w http.ResponseWriter, r *http.Request, ps h
316316
}
317317
cntnr.State = stateParam
318318
cntnr.FlowID = uuid.FromBytesOrNil(state.FlowId).String()
319+
internalContexter, ok := f.(flow.InternalContexter)
320+
if ok {
321+
transientPayload, err := flow.GetTransientPayloadFromInternalContext(internalContexter)
322+
if err != nil {
323+
return nil, state, &cntnr, err
324+
}
325+
cntnr.TransientPayload = json.RawMessage(transientPayload)
326+
}
319327
}
320328

321329
if errorParam != "" {

selfservice/strategy/oidc/strategy_login.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"bytes"
88
"context"
99
"encoding/json"
10+
"github.com/ory/x/sqlxx"
1011
"net/http"
1112
"strings"
1213
"time"
@@ -198,6 +199,12 @@ func (s *Strategy) Login(w http.ResponseWriter, r *http.Request, f *login.Flow,
198199
f.IDToken = p.IDToken
199200
f.RawIDTokenNonce = p.IDTokenNonce
200201
f.TransientPayload = p.TransientPayload
202+
if err := flow.SetTransientPayloadIntoInternalContext(f, sqlxx.JSONRawMessage(p.TransientPayload)); err != nil {
203+
return nil, err
204+
}
205+
if err := s.d.LoginFlowPersister().UpdateLoginFlow(ctx, f); err != nil {
206+
return nil, err
207+
}
201208

202209
pid := p.Provider // this can come from both url query and post body
203210
if pid == "" {

selfservice/strategy/oidc/strategy_registration.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,12 @@ func (s *Strategy) Register(w http.ResponseWriter, r *http.Request, f *registrat
168168
f.TransientPayload = p.TransientPayload
169169
f.IDToken = p.IDToken
170170
f.RawIDTokenNonce = p.IDTokenNonce
171+
if err := flow.SetTransientPayloadIntoInternalContext(f, sqlxx.JSONRawMessage(p.TransientPayload)); err != nil {
172+
return s.handleError(ctx, w, r, f, pid, nil, err)
173+
}
174+
if err := s.d.RegistrationFlowPersister().UpdateRegistrationFlow(ctx, f); err != nil {
175+
return s.handleError(ctx, w, r, f, pid, nil, err)
176+
}
171177

172178
if !strings.EqualFold(strings.ToLower(p.Method), s.SettingsStrategyID()) && p.Method != "" {
173179
// the user is sending a method that is not oidc, but the payload includes a provider

selfservice/strategy/oidc/strategy_test.go

Lines changed: 35 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -199,11 +199,13 @@ func TestStrategy(t *testing.T) {
199199
return res, body
200200
}
201201

202-
makeAPICodeFlowRequest := func(t *testing.T, provider, action string) (returnToURL *url.URL) {
203-
res, err := testhelpers.NewDebugClient(t).Post(action, "application/json", strings.NewReader(fmt.Sprintf(`{
204-
"method": "oidc",
205-
"provider": %q
206-
}`, provider)))
202+
makeAPICodeFlowRequest := func(t *testing.T, provider, action string, transientPayload string) (returnToURL *url.URL) {
203+
res, err := testhelpers.NewDebugClient(t).Post(action, "application/json",
204+
strings.NewReader(fmt.Sprintf(`{
205+
"method": "oidc",
206+
"provider": %q,
207+
"transient_payload": %q
208+
}`, provider, transientPayload)))
207209
require.NoError(t, err)
208210
require.Equal(t, http.StatusUnprocessableEntity, res.StatusCode)
209211
var changeLocation flow.BrowserLocationChangeRequiredError
@@ -834,14 +836,25 @@ func TestStrategy(t *testing.T) {
834836
})
835837

836838
t.Run("suite=API with session token exchange code", func(t *testing.T) {
839+
postRegistrationWebhook := hooktest.NewServer()
840+
t.Cleanup(postRegistrationWebhook.Close)
841+
postRegistrationWebhook.SetConfig(t, conf.GetProvider(ctx),
842+
config.HookStrategyKey(config.ViperKeySelfServiceRegistrationAfter, identity.CredentialsTypeOIDC.String()))
843+
844+
postLoginWebhook := hooktest.NewServer()
845+
t.Cleanup(postLoginWebhook.Close)
846+
postLoginWebhook.SetConfig(t, conf.GetProvider(ctx),
847+
config.HookStrategyKey(config.ViperKeySelfServiceLoginAfter, config.HookGlobal))
848+
837849
scope = []string{"openid"}
850+
transientPayload := `{"data": "registration"}`
838851

839852
loginOrRegister := func(t *testing.T, flowID uuid.UUID, code string) {
840853
_, err := exchangeCodeForToken(t, sessiontokenexchange.Codes{InitCode: code})
841854
require.Error(t, err)
842855

843856
action := assertFormValues(t, flowID, "valid")
844-
returnToURL := makeAPICodeFlowRequest(t, "valid", action)
857+
returnToURL := makeAPICodeFlowRequest(t, "valid", action, transientPayload)
845858
returnToCode := returnToURL.Query().Get("code")
846859
assert.NotEmpty(t, code, "code query param was empty in the return_to URL")
847860

@@ -857,27 +870,39 @@ func TestStrategy(t *testing.T) {
857870
performRegistration := func(t *testing.T) {
858871
f := newAPIRegistrationFlow(t, returnTS.URL+"?return_session_token_exchange_code=true&return_to=/app_code", 1*time.Minute)
859872
loginOrRegister(t, f.ID, f.SessionTokenExchangeCode)
873+
postRegistrationWebhook.AssertTransientPayload(t, transientPayload)
874+
}
875+
startRegistrationButLogin := func(t *testing.T) {
876+
f := newAPIRegistrationFlow(t, returnTS.URL+"?return_session_token_exchange_code=true&return_to=/app_code", 1*time.Minute)
877+
loginOrRegister(t, f.ID, f.SessionTokenExchangeCode)
878+
postLoginWebhook.AssertTransientPayload(t, transientPayload)
860879
}
861880
performLogin := func(t *testing.T) {
862881
f := newAPILoginFlow(t, returnTS.URL+"?return_session_token_exchange_code=true&return_to=/app_code", 1*time.Minute)
863882
loginOrRegister(t, f.ID, f.SessionTokenExchangeCode)
883+
postLoginWebhook.AssertTransientPayload(t, transientPayload)
884+
}
885+
startLoginButRegister := func(t *testing.T) {
886+
f := newAPILoginFlow(t, returnTS.URL+"?return_session_token_exchange_code=true&return_to=/app_code", 1*time.Minute)
887+
loginOrRegister(t, f.ID, f.SessionTokenExchangeCode)
888+
postRegistrationWebhook.AssertTransientPayload(t, transientPayload)
864889
}
865890

866891
for _, tc := range []struct {
867892
name string
868893
first, then func(*testing.T)
869894
}{{
870895
name: "login-twice",
871-
first: performLogin, then: performLogin,
896+
first: startLoginButRegister, then: performLogin,
872897
}, {
873898
name: "login-then-register",
874-
first: performLogin, then: performRegistration,
899+
first: startLoginButRegister, then: startRegistrationButLogin,
875900
}, {
876901
name: "register-then-login",
877902
first: performRegistration, then: performLogin,
878903
}, {
879904
name: "register-twice",
880-
first: performRegistration, then: performRegistration,
905+
first: performRegistration, then: startRegistrationButLogin,
881906
}} {
882907
t.Run("case="+tc.name, func(t *testing.T) {
883908
subject = tc.name + "[email protected]"
@@ -902,7 +927,7 @@ func TestStrategy(t *testing.T) {
902927
require.Error(t, err)
903928

904929
action := assertFormValues(t, f.ID, "valid")
905-
returnToURL := makeAPICodeFlowRequest(t, "valid", action)
930+
returnToURL := makeAPICodeFlowRequest(t, "valid", action, "{}")
906931
returnedFlow := returnToURL.Query().Get("flow")
907932

908933
require.NotEmpty(t, returnedFlow, "flow query param was empty in the return_to URL")

0 commit comments

Comments
 (0)