Skip to content

Commit 234b6f2

Browse files
authored
feat: load session only once when middleware is used (#4187)
1 parent 5665f20 commit 234b6f2

File tree

5 files changed

+43
-6
lines changed

5 files changed

+43
-6
lines changed

selfservice/flow/settings/handler.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,7 @@ type createNativeSettingsFlow struct {
214214
// default: errorGeneric
215215
func (h *Handler) createNativeSettingsFlow(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
216216
ctx := r.Context()
217-
s, err := h.d.SessionManager().FetchFromRequest(ctx, r)
217+
s, err := h.d.SessionManager().FetchFromRequestContext(ctx, r)
218218
if err != nil {
219219
h.d.Writer().WriteError(w, r, err)
220220
return
@@ -298,7 +298,7 @@ type createBrowserSettingsFlow struct {
298298
// default: errorGeneric
299299
func (h *Handler) createBrowserSettingsFlow(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
300300
ctx := r.Context()
301-
s, err := h.d.SessionManager().FetchFromRequest(ctx, r)
301+
s, err := h.d.SessionManager().FetchFromRequestContext(ctx, r)
302302
if err != nil {
303303
h.d.SelfServiceErrorManager().Forward(ctx, w, r, err)
304304
return
@@ -404,7 +404,7 @@ func (h *Handler) getSettingsFlow(w http.ResponseWriter, r *http.Request, _ http
404404
return
405405
}
406406

407-
sess, err := h.d.SessionManager().FetchFromRequest(ctx, r)
407+
sess, err := h.d.SessionManager().FetchFromRequestContext(ctx, r)
408408
if err != nil {
409409
h.d.Writer().WriteError(w, r, err)
410410
return
@@ -574,7 +574,7 @@ func (h *Handler) updateSettingsFlow(w http.ResponseWriter, r *http.Request, ps
574574
return
575575
}
576576

577-
ss, err := h.d.SessionManager().FetchFromRequest(ctx, r)
577+
ss, err := h.d.SessionManager().FetchFromRequestContext(ctx, r)
578578
if err != nil {
579579
h.d.SettingsFlowErrorHandler().WriteFlowError(w, r, node.DefaultGroup, f, nil, err)
580580
return

session/handler.go

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
package session
55

66
import (
7+
"context"
78
"fmt"
89
"net/http"
910
"strconv"
@@ -837,9 +838,17 @@ func (h *Handler) listMySessions(w http.ResponseWriter, r *http.Request, _ httpr
837838
h.r.Writer().Write(w, r, sess)
838839
}
839840

841+
type sessionInContext int
842+
843+
const (
844+
sessionInContextKey sessionInContext = iota
845+
)
846+
840847
func (h *Handler) IsAuthenticated(wrap httprouter.Handle, onUnauthenticated httprouter.Handle) httprouter.Handle {
841848
return func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
842-
if _, err := h.r.SessionManager().FetchFromRequest(r.Context(), r); err != nil {
849+
ctx := r.Context()
850+
sess, err := h.r.SessionManager().FetchFromRequest(ctx, r)
851+
if err != nil {
843852
if onUnauthenticated != nil {
844853
onUnauthenticated(w, r, ps)
845854
return
@@ -849,7 +858,7 @@ func (h *Handler) IsAuthenticated(wrap httprouter.Handle, onUnauthenticated http
849858
return
850859
}
851860

852-
wrap(w, r, ps)
861+
wrap(w, r.WithContext(context.WithValue(ctx, sessionInContextKey, sess)), ps)
853862
}
854863
}
855864

session/manager.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,9 @@ type Manager interface {
133133
// FetchFromRequest creates an HTTP session using cookies.
134134
FetchFromRequest(context.Context, *http.Request) (*Session, error)
135135

136+
// FetchFromRequestContext returns the session from the context or if that is unset, falls back to FetchFromRequest.
137+
FetchFromRequestContext(context.Context, *http.Request) (*Session, error)
138+
136139
// PurgeFromRequest removes an HTTP session.
137140
PurgeFromRequest(context.Context, http.ResponseWriter, *http.Request) error
138141

session/manager_http.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,17 @@ func (s *ManagerHTTP) extractToken(r *http.Request) string {
227227
return token
228228
}
229229

230+
func (s *ManagerHTTP) FetchFromRequestContext(ctx context.Context, r *http.Request) (_ *Session, err error) {
231+
ctx, span := s.r.Tracer(ctx).Tracer().Start(ctx, "sessions.ManagerHTTP.FetchFromRequestContext")
232+
otelx.End(span, &err)
233+
234+
if sess, ok := ctx.Value(sessionInContextKey).(*Session); ok {
235+
return sess, nil
236+
}
237+
238+
return s.FetchFromRequest(ctx, r)
239+
}
240+
230241
func (s *ManagerHTTP) FetchFromRequest(ctx context.Context, r *http.Request) (_ *Session, err error) {
231242
ctx, span := s.r.Tracer(ctx).Tracer().Start(ctx, "sessions.ManagerHTTP.FetchFromRequest")
232243
defer func() {

session/manager_http_test.go

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,16 @@ func TestManagerHTTP(t *testing.T) {
244244
reg.Writer().Write(w, r, sess)
245245
})
246246

247+
rp.GET("/session/get-middleware", reg.SessionHandler().IsAuthenticated(func(w http.ResponseWriter, r *http.Request, p httprouter.Params) {
248+
sess, err := reg.SessionManager().FetchFromRequestContext(r.Context(), r)
249+
if err != nil {
250+
t.Logf("Got error on lookup: %s %T", err, errors.Unwrap(err))
251+
reg.Writer().WriteError(w, r, err)
252+
return
253+
}
254+
reg.Writer().Write(w, r, sess)
255+
}, session.RedirectOnUnauthenticated("https://failed.com")))
256+
247257
pts := httptest.NewServer(x.NewTestCSRFHandler(rp, reg))
248258
t.Cleanup(pts.Close)
249259
conf.MustSet(ctx, config.ViperKeyPublicBaseURL, pts.URL)
@@ -263,6 +273,10 @@ func TestManagerHTTP(t *testing.T) {
263273
res, err := c.Get(pts.URL + "/session/get")
264274
require.NoError(t, err)
265275
assert.EqualValues(t, http.StatusOK, res.StatusCode)
276+
277+
res, err = c.Get(pts.URL + "/session/get-middleware")
278+
require.NoError(t, err)
279+
assert.EqualValues(t, http.StatusOK, res.StatusCode)
266280
})
267281

268282
t.Run("case=key rotation", func(t *testing.T) {

0 commit comments

Comments
 (0)