@@ -28,10 +28,14 @@ var safeMethods = []string{"GET", "HEAD", "OPTIONS", "TRACE"}
2828// reasons for CSRF check failures
2929var (
3030 ErrNoReferer = errors .New ("A secure request contained no Referer or its value was malformed" )
31- ErrBadReferer = errors .New ("A secure request's Referer comes from a different Origin " +
31+ ErrBadReferer = errors .New ("A secure request's Referer comes from a different origin " +
3232 " from the request's URL" )
33- ErrBadToken = errors .New ("The CSRF token in the cookie doesn't match the one" +
33+ ErrBadOrigin = errors .New ("Request was made with a disallowed origin specified in the Origin header" )
34+ ErrBadToken = errors .New ("The CSRF token in the cookie doesn't match the one" +
3435 " received in a form/header." )
36+
37+ // Internal error. When this is raised, and the request is secure, we additionally check for Referer.
38+ errNoOrigin = errors .New ("Origin header was not present" )
3539)
3640
3741type CSRFHandler struct {
@@ -45,7 +49,9 @@ type CSRFHandler struct {
4549 baseCookie http.Cookie
4650
4751 // Slices of paths that are exempt from CSRF checks.
48- // They can be specified by...
52+ // All of those will be matched against Request.URL.Path,
53+ // So they should take the leading slash into account
54+ // Paths can be specified by...
4955 // ...an exact path,
5056 exemptPaths []string
5157 // ...a regexp,
@@ -55,8 +61,8 @@ type CSRFHandler struct {
5561 // ...or a custom matcher function
5662 exemptFunc func (r * http.Request ) bool
5763
58- // All of those will be matched against Request.URL.Path,
59- // So they should take the leading slash into account
64+ isTLS func ( r * http. Request ) bool
65+ isAllowedOrigin func ( r * url. URL ) bool
6066}
6167
6268func defaultFailureHandler (w http.ResponseWriter , r * http.Request ) {
@@ -95,6 +101,7 @@ func New(handler http.Handler) *CSRFHandler {
95101 csrf := & CSRFHandler {successHandler : handler ,
96102 failureHandler : http .HandlerFunc (defaultFailureHandler ),
97103 baseCookie : baseCookie ,
104+ isTLS : func (r * http.Request ) bool { return true },
98105 }
99106
100107 return csrf
@@ -145,26 +152,10 @@ func (h *CSRFHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
145152 return
146153 }
147154
148- // if the request is secure, we enforce origin check
149- // for referer to prevent MITM of http->https requests
150- if r .URL .Scheme == "https" {
151- referer , err := url .Parse (r .Header .Get ("Referer" ))
152-
153- // if we can't parse the referer or it's empty,
154- // we assume it's not specified
155- if err != nil || referer .String () == "" {
156- ctxSetReason (r , ErrNoReferer )
157- h .handleFailure (w , r )
158- return
159- }
160-
161- // if the referer doesn't share origin with the request URL,
162- // we have another error for that
163- if ! sameOrigin (referer , r .URL ) {
164- ctxSetReason (r , ErrBadReferer )
165- h .handleFailure (w , r )
166- return
167- }
155+ if err := h .ensureSameOrigin (r ); err != nil {
156+ ctxSetReason (r , err )
157+ h .handleFailure (w , r )
158+ return
168159 }
169160
170161 // Finally, we check the token itself.
@@ -193,6 +184,75 @@ func (h *CSRFHandler) handleFailure(w http.ResponseWriter, r *http.Request) {
193184 h .failureHandler .ServeHTTP (w , r )
194185}
195186
187+ func (h * CSRFHandler ) ensureSameOrigin (r * http.Request ) error {
188+ selfOrigin := & url.URL {
189+ Scheme : "http" ,
190+ Host : r .Host ,
191+ }
192+ isTLS := h .isTLS (r )
193+ if isTLS {
194+ selfOrigin .Scheme = "https"
195+ }
196+
197+ secFetchSite := r .Header .Get ("Sec-Fetch-Site" )
198+ if secFetchSite == "same-origin" {
199+ return nil
200+ }
201+
202+ // If no `Sec-Fetch-Site: same-origin` is present, fallback to Origin or Referer,
203+ // including considering custom allowed origins.
204+ err := h .checkOrigin (selfOrigin , r )
205+ if err == nil {
206+ return nil
207+ } else if ! errors .Is (err , errNoOrigin ) {
208+ return err
209+ }
210+
211+ // If Origin header was not present, fall back on Referer check for both secure and insecure requests.
212+ // This is opposite of Django's behavior, but should be fine, as neither of the three headers existing is an edge case.
213+ // https://github.com/django/django/blob/8be0c0d6901669661fca578f474cd51cd284d35a/django/middleware/csrf.py#L460
214+ return h .checkReferer (selfOrigin , r )
215+ }
216+
217+ func (h * CSRFHandler ) checkReferer (selfOrigin * url.URL , r * http.Request ) error {
218+ referer , err := url .Parse (r .Referer ())
219+ if err != nil || referer .String () == "" {
220+ return ErrNoReferer
221+ }
222+
223+ if sameOrigin (selfOrigin , referer ) {
224+ return nil
225+ }
226+
227+ if h .isAllowedOrigin != nil && h .isAllowedOrigin (referer ) {
228+ return nil
229+ }
230+
231+ return ErrBadReferer
232+ }
233+
234+ func (h * CSRFHandler ) checkOrigin (selfOrigin * url.URL , r * http.Request ) error {
235+ originStr := r .Header .Get ("Origin" )
236+ if originStr == "" || originStr == "null" {
237+ return errNoOrigin
238+ }
239+
240+ origin , err := url .Parse (originStr )
241+ if err != nil {
242+ return err
243+ }
244+
245+ if sameOrigin (selfOrigin , origin ) {
246+ return nil
247+ }
248+
249+ if h .isAllowedOrigin != nil && h .isAllowedOrigin (origin ) {
250+ return nil
251+ }
252+
253+ return ErrBadOrigin
254+ }
255+
196256// Generates a new token, sets it on the given request and returns it
197257func (h * CSRFHandler ) RegenerateToken (w http.ResponseWriter , r * http.Request ) string {
198258 token := generateToken ()
@@ -224,3 +284,73 @@ func (h *CSRFHandler) SetFailureHandler(handler http.Handler) {
224284func (h * CSRFHandler ) SetBaseCookie (cookie http.Cookie ) {
225285 h .baseCookie = cookie
226286}
287+
288+ // SetIsTLSFunc sets a delegate function which determines, on a per-request basis, whether the request is made over a secure connection.
289+ // This should return `true` iff the URL that the user uses to access the application begins with https://.
290+ // For example, if the Go web application is served via plain-text HTTP,
291+ // but the user is accessing it through HTTPS via a TLS-terminating reverse-proxy, this should return `true`.
292+ //
293+ // Examples:
294+ //
295+ // 1. If you're using the Go TLS stack (no TLS-terminating proxies in between the user and the app), you may use:
296+ //
297+ // h.SetIsTLSFunc(func(r *http.Request) bool { return r.TLS != nil })
298+ //
299+ // 2. If your application is behind a reverse proxy that terminates TLS, you should configure the reverse proxy
300+ // to report the protocol that the request was made over via an HTTP header,
301+ // e.g. `X-Forwarded-Proto`.
302+ // You should also validate that the request is coming in from an IP of a trusted reverse proxy
303+ // to ensure that this header has not been spoofed by an attacker. For example:
304+ //
305+ // var trustedProxies = []string{"198.51.100.1", "198.51.100.2"}
306+ // h.SetIsTLSFunc(func(r *http.Request) bool {
307+ // ip, _, _ := strings.Cut(r.RemoteAddr, ":")
308+ // proto := r.Header.Get("X-Forwarded-Proto")
309+ // return slices.Contains(trustedProxies, ip) && proto == "https"
310+ // })
311+ func (h * CSRFHandler ) SetIsTLSFunc (f func (* http.Request ) bool ) {
312+ h .isTLS = f
313+ }
314+
315+ // SetAllowedOrigins defines a function that checks whether the request comes from an allowed origin.
316+ // This function will be invoked when the request is not considered a same-origin request.
317+ // If this function returns `false`, request will be disallowed.
318+ //
319+ // In most cases, this will be used with [StaticOrigins].
320+ func (h * CSRFHandler ) SetIsAllowedOriginFunc (f func (* url.URL ) bool ) {
321+ h .isAllowedOrigin = f
322+ }
323+
324+ // StaticOrigins returns a delegate, suitable for passing to [CSRFHandler.SetIsAllowedOriginFunc],
325+ // that validates the request origin against a static list of allowed origins.
326+ // This function expects each element to be of form `scheme://host`, e.g.: `https://example.com`, `http://example.org`.
327+ // If any element of the slice is an invalid URL, this function will return an error.
328+ // If an element includes additional URL parts (e.g. a path), these parts will be ignored,
329+ // as origin checks only take the scheme and host into account.
330+ //
331+ // Example:
332+ //
333+ // h := nosurf.New()
334+ // origins, err := nosurf.StaticOrigins("https://api.example.com", "http://insecure.example.com")
335+ // if err != nil {
336+ // panic(err)
337+ // }
338+ // h.SetIsAllowedOriginFunc(origins)
339+ func StaticOrigins (origins ... string ) (func (r * url.URL ) bool , error ) {
340+ var allowedOrigins []* url.URL
341+ for _ , o := range origins {
342+ url , err := url .Parse (o )
343+ if err != nil {
344+ return nil , err
345+ }
346+ allowedOrigins = append (allowedOrigins , url )
347+ }
348+ return func (u * url.URL ) bool {
349+ for _ , candidate := range allowedOrigins {
350+ if sameOrigin (candidate , u ) {
351+ return true
352+ }
353+ }
354+ return false
355+ }, nil
356+ }
0 commit comments