Skip to content

Commit

Permalink
Add refresh token test
Browse files Browse the repository at this point in the history
Signed-off-by: Jan-Otto Kröpke <[email protected]>
  • Loading branch information
jkroepke committed Nov 2, 2024
1 parent 5c7e81b commit 9a3ffd5
Show file tree
Hide file tree
Showing 8 changed files with 68 additions and 18 deletions.
2 changes: 1 addition & 1 deletion internal/internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import (
func BenchmarkFull(b *testing.B) {
b.StopTimer()

_, client, managementInterface, _, _, httpClient, _, shutdownFn := testutils.SetupMockEnvironment(context.Background(), b, config.Config{})
_, client, managementInterface, _, _, httpClient, _, shutdownFn := testutils.SetupMockEnvironment(context.Background(), b, config.Config{}, nil)
defer shutdownFn()

wg := sync.WaitGroup{}
Expand Down
2 changes: 1 addition & 1 deletion internal/oauth2/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ func TestHandler(t *testing.T) {

ctx, cancel := context.WithCancel(context.Background())

conf, client, managementInterface, _, httpClientListener, httpClient, logger, shutdownFn := testutils.SetupMockEnvironment(ctx, t, tt.conf)
conf, client, managementInterface, _, httpClientListener, httpClient, logger, shutdownFn := testutils.SetupMockEnvironment(ctx, t, tt.conf, nil)
defer shutdownFn()

wg := sync.WaitGroup{}
Expand Down
4 changes: 2 additions & 2 deletions internal/oauth2/providers/github/check_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ func TestValidateGroups(t *testing.T) {
}

httpClient := &http.Client{
Transport: testutils.NewRoundTripperFunc(func(_ *http.Request) (*http.Response, error) {
Transport: testutils.NewRoundTripperFunc(nil, func(_ http.RoundTripper, _ *http.Request) (*http.Response, error) {
resp := httptest.NewRecorder()
if strings.Contains(tt.userOrgs, "error") {
resp.WriteHeader(http.StatusInternalServerError)
Expand Down Expand Up @@ -223,7 +223,7 @@ func TestValidateRoles(t *testing.T) {
}

httpClient := &http.Client{
Transport: testutils.NewRoundTripperFunc(func(_ *http.Request) (*http.Response, error) {
Transport: testutils.NewRoundTripperFunc(nil, func(_ http.RoundTripper, _ *http.Request) (*http.Response, error) {
resp := httptest.NewRecorder()
if strings.Contains(tt.userTeams, "error") {
resp.WriteHeader(http.StatusInternalServerError)
Expand Down
2 changes: 1 addition & 1 deletion internal/oauth2/providers/github/user_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ func TestGetUser(t *testing.T) {
}

httpClient := &http.Client{
Transport: testutils.NewRoundTripperFunc(func(_ *http.Request) (*http.Response, error) {
Transport: testutils.NewRoundTripperFunc(nil, func(_ http.RoundTripper, _ *http.Request) (*http.Response, error) {
resp := httptest.NewRecorder()
if strings.Contains(tt.user, "error") {
resp.WriteHeader(http.StatusInternalServerError)
Expand Down
2 changes: 1 addition & 1 deletion internal/oauth2/providers/google/check_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ func TestValidateGroups(t *testing.T) {
}

httpClient := &http.Client{
Transport: testutils.NewRoundTripperFunc(func(req *http.Request) (*http.Response, error) {
Transport: testutils.NewRoundTripperFunc(nil, func(_ http.RoundTripper, req *http.Request) (*http.Response, error) {
resp := httptest.NewRecorder()
if strings.Contains(tt.tokenGroups, "error") {
resp.WriteHeader(http.StatusInternalServerError)
Expand Down
59 changes: 55 additions & 4 deletions internal/oauth2/refresh_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ package oauth2_test

import (
"bufio"
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
Expand All @@ -13,13 +15,15 @@ import (
"time"

"github.com/jkroepke/openvpn-auth-oauth2/internal/config"
"github.com/jkroepke/openvpn-auth-oauth2/internal/oauth2/providers/generic"
"github.com/jkroepke/openvpn-auth-oauth2/internal/oauth2/providers/github"
"github.com/jkroepke/openvpn-auth-oauth2/internal/oauth2/providers/google"
"github.com/jkroepke/openvpn-auth-oauth2/internal/oauth2/types"
"github.com/jkroepke/openvpn-auth-oauth2/internal/openvpn"
"github.com/jkroepke/openvpn-auth-oauth2/internal/utils/testutils"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zitadel/oidc/v3/pkg/oidc"
)

func TestRefreshReAuth(t *testing.T) {
Expand All @@ -28,6 +32,7 @@ func TestRefreshReAuth(t *testing.T) {
for _, tt := range []struct {
name string
conf config.Config
rt http.RoundTripper
}{
{
name: "Refresh",
Expand All @@ -36,6 +41,7 @@ func TestRefreshReAuth(t *testing.T) {
Refresh: config.OAuth2Refresh{Enabled: true, ValidateUser: true, UseSessionID: false},
},
},
rt: http.DefaultTransport,
},
{
name: "Refresh with ValidateUser=false",
Expand All @@ -44,6 +50,7 @@ func TestRefreshReAuth(t *testing.T) {
Refresh: config.OAuth2Refresh{Enabled: true, ValidateUser: false, UseSessionID: false},
},
},
rt: http.DefaultTransport,
},
{
name: "Refresh with SessionID=true + ValidateUser=false",
Expand All @@ -52,6 +59,7 @@ func TestRefreshReAuth(t *testing.T) {
Refresh: config.OAuth2Refresh{Enabled: true, ValidateUser: false, UseSessionID: true},
},
},
rt: http.DefaultTransport,
},
{
name: "Refresh with provider=google",
Expand All @@ -62,6 +70,7 @@ func TestRefreshReAuth(t *testing.T) {
Refresh: config.OAuth2Refresh{Enabled: true, ValidateUser: true, UseSessionID: false},
},
},
rt: http.DefaultTransport,
},
{
name: "Refresh with provider=github",
Expand All @@ -71,12 +80,47 @@ func TestRefreshReAuth(t *testing.T) {
Refresh: config.OAuth2Refresh{Enabled: true, ValidateUser: true, UseSessionID: false},
},
},
rt: http.DefaultTransport,
},
{
name: "Refresh without refresh token",
conf: config.Config{
OAuth2: config.OAuth2{
Provider: generic.Name,
Refresh: config.OAuth2Refresh{Enabled: true, ValidateUser: true, UseSessionID: false},
},
},
rt: testutils.NewRoundTripperFunc(http.DefaultTransport, func(rt http.RoundTripper, req *http.Request) (*http.Response, error) {

Check failure on line 93 in internal/oauth2/refresh_test.go

View workflow job for this annotation

GitHub Actions / lint

parameter name 'rt' is too short for the scope of its usage (varnamelen)
_ = req
if req.URL.Path == "/oauth/token" && req.Header.Get("Authorization") == "" {
res, err := rt.RoundTrip(req)

var tokenResponse oidc.AccessTokenResponse
if err := json.NewDecoder(res.Body).Decode(&tokenResponse); err != nil {
return nil, err
}

tokenResponse.RefreshToken = ""

var buf bytes.Buffer

if err := json.NewEncoder(&buf).Encode(tokenResponse); err != nil {
return nil, err
}

res.Body = io.NopCloser(&buf)

return res, err
}

return rt.RoundTrip(req)
}),
},
} {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()

conf, openVPNClient, managementInterface, _, _, httpClient, logger, shutdownFn := testutils.SetupMockEnvironment(context.Background(), t, tt.conf)
conf, openVPNClient, managementInterface, _, _, httpClient, logger, shutdownFn := testutils.SetupMockEnvironment(context.Background(), t, tt.conf, tt.rt)

t.Cleanup(func() {
if t.Failed() {
Expand Down Expand Up @@ -145,18 +189,25 @@ func TestRefreshReAuth(t *testing.T) {
)
testutils.SendMessage(t, managementInterfaceConn, "SUCCESS: client-auth command succeeded")

// Testing ReAuth
testutils.SendAndExpectMessage(t, managementInterfaceConn, reader,
">CLIENT:REAUTH,1,4\r\n>CLIENT:ENV,untrusted_ip=127.0.0.1\r\n>CLIENT:ENV,common_name=test\r\n>CLIENT:ENV,session_id=session_id\r\n>CLIENT:ENV,session_state=AuthenticatedEmptyUser\r\n>CLIENT:ENV,IV_SSO=webauth\r\n>CLIENT:ENV,END",
"client-auth-nt 1 4",
)
testutils.SendMessage(t, managementInterfaceConn, "SUCCESS: client-auth command succeeded")

// Test Disconnect
testutils.SendMessage(t, managementInterfaceConn, ">CLIENT:DISCONNECT,1\r\n>CLIENT:ENV,untrusted_ip=127.0.0.1\r\n>CLIENT:ENV,common_name=test\r\n>CLIENT:ENV,session_id=session_id\r\n>CLIENT:ENV,session_state=AuthenticatedEmptyUser\r\n>CLIENT:ENV,IV_SSO=webauth\r\n>CLIENT:ENV,END")

// Test ReAuth after DC
testutils.SendMessage(t, managementInterfaceConn, ">CLIENT:REAUTH,1,3\r\n>CLIENT:ENV,untrusted_ip=127.0.0.1\r\n>CLIENT:ENV,common_name=test\r\n>CLIENT:ENV,session_id=session_id\r\n>CLIENT:ENV,session_state=AuthenticatedEmptyUser\r\n>CLIENT:ENV,IV_SSO=webauth\r\n>CLIENT:ENV,END")
testutils.SendMessage(t, managementInterfaceConn, ">CLIENT:REAUTH,1,4\r\n>CLIENT:ENV,untrusted_ip=127.0.0.1\r\n>CLIENT:ENV,common_name=test\r\n>CLIENT:ENV,session_id=session_id\r\n>CLIENT:ENV,session_state=AuthenticatedEmptyUser\r\n>CLIENT:ENV,IV_SSO=webauth\r\n>CLIENT:ENV,END")

auth = testutils.ReadLine(t, managementInterfaceConn, reader)

if conf.OAuth2.Refresh.UseSessionID {
assert.Contains(t, auth, "client-auth-nt 1 3")
assert.Contains(t, auth, "client-auth-nt 1 4")
} else {
assert.Contains(t, auth, "client-pending-auth 1 3 \"WEB_AUTH::")
assert.Contains(t, auth, "client-pending-auth 1 4 \"WEB_AUTH::")
}

testutils.SendMessage(t, managementInterfaceConn, "SUCCESS: %s command succeeded", strings.SplitN(auth, " ", 2)[0])
Expand Down
9 changes: 5 additions & 4 deletions internal/utils/testutils/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,16 @@ import (
)

type RoundTripperFunc struct {
fn func(req *http.Request) (*http.Response, error)
fn func(rt http.RoundTripper, req *http.Request) (*http.Response, error)
rt http.RoundTripper
}

func NewRoundTripperFunc(fn func(req *http.Request) (*http.Response, error)) *RoundTripperFunc {
return &RoundTripperFunc{fn}
func NewRoundTripperFunc(rt http.RoundTripper, fn func(rt http.RoundTripper, req *http.Request) (*http.Response, error)) *RoundTripperFunc {
return &RoundTripperFunc{fn, rt}
}

func (f *RoundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) {
return f.fn(req)
return f.fn(f.rt, req)
}

type MockRoundTripper struct {
Expand Down
6 changes: 2 additions & 4 deletions internal/utils/testutils/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -226,9 +226,7 @@ func SetupResourceServer(tb testing.TB, clientListener net.Listener) (*httptest.
// SetupMockEnvironment setups an OpenVPN and IDP mock
//
//nolint:cyclop
func SetupMockEnvironment(ctx context.Context, tb testing.TB, conf config.Config) (config.Config, *openvpn.Client, net.Listener, *oauth2.Provider,
*httptest.Server, *http.Client, *Logger, func(),
) {
func SetupMockEnvironment(ctx context.Context, tb testing.TB, conf config.Config, rt http.RoundTripper) (config.Config, *openvpn.Client, net.Listener, *oauth2.Provider, *httptest.Server, *http.Client, *Logger, func()) {
tb.Helper()

logger := NewTestLogger()
Expand Down Expand Up @@ -282,7 +280,7 @@ func SetupMockEnvironment(ctx context.Context, tb testing.TB, conf config.Config
conf.OAuth2.Refresh.Expires = time.Hour
}

httpClient := &http.Client{Transport: NewMockRoundTripper(utils.NewUserAgentTransport(nil))}
httpClient := &http.Client{Transport: NewMockRoundTripper(utils.NewUserAgentTransport(rt))}
storageClient := storage.New(ctx, Secret, conf.OAuth2.Refresh.Expires)
provider := oauth2.New(logger.Logger, conf, storageClient, httpClient)
openvpnClient := openvpn.New(ctx, logger.Logger, conf, provider)
Expand Down

0 comments on commit 9a3ffd5

Please sign in to comment.