Skip to content

Commit

Permalink
net/http: add enforcement hook to Transport.RoundTrip, like our net ones
Browse files Browse the repository at this point in the history
Updates #55
Updates tailscale/corp#12702

Signed-off-by: Brad Fitzpatrick <[email protected]>
  • Loading branch information
bradfitz authored and awly committed Feb 7, 2024
1 parent dee2ceb commit be9ef3d
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 0 deletions.
1 change: 1 addition & 0 deletions api/go1.99999.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
pkg net, func SetDialEnforcer(func(context.Context, []Addr) error) #55
pkg net, func SetResolveEnforcer(func(context.Context, string, string, string, Addr) error) #55
pkg net/http, func SetRoundTripEnforcer(func(*Request) error) #55
pkg net, func WithSockTrace(context.Context, *SockTrace) context.Context #58
pkg net, func ContextSockTrace(context.Context) *SockTrace #58
pkg net, type SockTrace struct #58
Expand Down
21 changes: 21 additions & 0 deletions src/net/http/tailscale.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package http

var roundTripEnforcer func(*Request) error

// SetRoundTripEnforcer set a program-global resolver enforcer that can cause
// RoundTrip calls to fail based on the request and its context.
//
// f must be non-nil.
//
// SetRoundTripEnforcer can only be called once, and must not be called
// concurrent with any RoundTrip call; it's expected to be registered during
// init.
func SetRoundTripEnforcer(f func(*Request) error) {
if f == nil {
panic("nil func")
}
if roundTripEnforcer != nil {
panic("already called")
}
roundTripEnforcer = f
}
5 changes: 5 additions & 0 deletions src/net/http/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -515,6 +515,11 @@ func (t *Transport) alternateRoundTripper(req *Request) RoundTripper {

// roundTrip implements a RoundTripper over HTTP.
func (t *Transport) roundTrip(req *Request) (*Response, error) {
if roundTripEnforcer != nil {
if err := roundTripEnforcer(req); err != nil {
return nil, err
}
}
t.nextProtoOnce.Do(t.onceSetNextProtoDefaults)
ctx := req.Context()
trace := httptrace.ContextClientTrace(ctx)
Expand Down

0 comments on commit be9ef3d

Please sign in to comment.