Skip to content

Commit

Permalink
fix: do not update refresh token on session refresh, if empty
Browse files Browse the repository at this point in the history
Signed-off-by: Jan-Otto Kröpke <[email protected]>
  • Loading branch information
jkroepke committed Nov 1, 2024
1 parent 31f7872 commit 49e962f
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 77 deletions.
119 changes: 43 additions & 76 deletions internal/openvpn/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,41 +13,6 @@ import (
)

func (c *Client) processClient(client connection.Client) error {
switch client.Reason {
case "CONNECT":
return c.clientConnect(client)
case "REAUTH":
return c.clientReAuthentication(client)
case "ESTABLISHED":
c.clientEstablished(client)
case "DISCONNECT":
c.clientDisconnect(client)
default:
return fmt.Errorf("unknown client reason: %s", client.Reason)
}

return nil
}

// clientConnect handles CONNECT events from the OpenVPN management interface.
func (c *Client) clientConnect(client connection.Client) error {
logger := c.logger.With(
slog.String("ip", fmt.Sprintf("%s:%s", client.IPAddr, client.IPPort)),
slog.Uint64("cid", client.CID),
slog.Uint64("kid", client.KID),
slog.String("common_name", client.CommonName),
slog.String("reason", client.Reason),
slog.String("session_id", client.SessionID),
slog.String("session_state", client.SessionState),
)

logger.Info("new client connection")

return c.handleClientAuthentication(logger, client)
}

// clientReAuthentication handles REAUTH events from the OpenVPN management interface.
func (c *Client) clientReAuthentication(client connection.Client) error {
logger := c.logger.With(
slog.String("ip", fmt.Sprintf("%s:%s", client.IPAddr, client.IPPort)),
slog.Uint64("cid", client.CID),
Expand All @@ -58,16 +23,30 @@ func (c *Client) clientReAuthentication(client connection.Client) error {
slog.String("session_state", client.SessionState),
)

logger.Info("new client reauth")
switch client.Reason {
case "CONNECT", "REAUTH":
c.handleClientAuthentication(logger, client)
case "ESTABLISHED":
c.clientEstablished(logger, client)
case "DISCONNECT":
c.clientDisconnect(logger, client)
default:
return fmt.Errorf("unknown client reason: %s", client.Reason)
}

return c.handleClientAuthentication(logger, client)
return nil
}

// handleClientAuthentication holds the shared authentication logic for CONNECT and REAUTH events.
func (c *Client) handleClientAuthentication(logger *slog.Logger, client connection.Client) error {
func (c *Client) handleClientAuthentication(logger *slog.Logger, client connection.Client) {
logger.Info("new client authentication")

// Check if the client is allowed to bypass authentication. If so, accept the client.
if c.checkAuthBypass(logger, client) {
return nil
if c.checkAuthBypass(client) {
logger.Info("client bypass authentication")
c.AcceptClient(logger, state.ClientIdentifier{CID: client.CID, KID: client.KID}, client.CommonName)

return
}

// Check if the client supports SSO authentication via webauth.
Expand All @@ -76,27 +55,39 @@ func (c *Client) handleClientAuthentication(logger *slog.Logger, client connecti
logger.Warn(errorSsoNotSupported)
c.DenyClient(logger, state.ClientIdentifier{CID: client.CID, KID: client.KID}, errorSsoNotSupported)

return nil
return
}

// Check if the client is already authenticated and refresh the client's authentication if enabled.
// If the client is successfully re-authenticated, accept the client.
if c.conf.OAuth2.Refresh.Enabled {
ok, err := c.silentReAuthentication(logger, client)
if err != nil {
logger.Warn(fmt.Errorf("%w. denying client", err).Error())
c.DenyClient(logger, state.ClientIdentifier{CID: client.CID, KID: client.KID}, ReasonStateExpiredOrInvalid)

return nil
logger.Error("error refreshing client auth",
slog.Any("err", err),
)

return
} else if ok {
c.AcceptClient(logger, state.ClientIdentifier{CID: client.CID, KID: client.KID}, client.CommonName)

return nil
return
}
}

// Start the authentication process for the client.
return c.startClientAuth(logger, client)
if err := c.startClientAuth(logger, client); err != nil {
// Deny the client if an error occurred during the authentication process.
logger.Error("error starting client auth",
slog.Any("err", err),
)

c.DenyClient(logger, state.ClientIdentifier{CID: client.CID, KID: client.KID}, "internal error")
}

return

Check failure on line 90 in internal/openvpn/client.go

View workflow job for this annotation

GitHub Actions / lint

S1023: redundant `return` statement (gosimple)
}

// startClientAuth initiates the authentication process for the client.
Expand Down Expand Up @@ -130,32 +121,23 @@ func (c *Client) startClientAuth(logger *slog.Logger, client connection.Client)
startURL := utils.StringConcat(strings.TrimSuffix(c.conf.HTTP.BaseURL.String(), "/"), "/oauth2/start?state=", encodedSession)

if len(startURL) >= 245 {
c.DenyClient(logger, clientIdentifier, "internal error")

return fmt.Errorf("url %s (%d chars) too long! OpenVPN support up to 245 chars. "+
"Try --openvpn.common-name.mode=omit or --log.vpn-client-ip=false to avoid this error",
startURL, len(startURL))
}

logger.Info("start pending auth")
logger.Info("sent client-pending-auth command")

_, err = c.SendCommandf(`client-pending-auth %d %d "WEB_AUTH::%s" %.0f`, client.CID, client.KID, startURL, c.conf.OpenVpn.AuthPendingTimeout.Seconds())
if err != nil {
logger.Warn("error from sending client-pending-auth command", slog.Any("err", err))
return fmt.Errorf("error sending client-pending-auth command: %w", err)
}

return nil
}

func (c *Client) checkAuthBypass(logger *slog.Logger, client connection.Client) bool {
if !slices.Contains(c.conf.OpenVpn.Bypass.CommonNames, client.CommonName) {
return false
}

logger.Info("client bypass authentication")
c.AcceptClient(logger, state.ClientIdentifier{CID: client.CID, KID: client.KID}, client.CommonName)

return true
func (c *Client) checkAuthBypass(client connection.Client) bool {
return slices.Contains(c.conf.OpenVpn.Bypass.CommonNames, client.CommonName)
}

func (c *Client) silentReAuthentication(logger *slog.Logger, client connection.Client) (bool, error) {
Expand All @@ -175,29 +157,14 @@ func (c *Client) silentReAuthentication(logger *slog.Logger, client connection.C
return ok, nil
}

func (c *Client) clientEstablished(client connection.Client) {
c.logger.LogAttrs(context.Background(),
func (c *Client) clientEstablished(logger *slog.Logger, client connection.Client) {
logger.LogAttrs(context.Background(),
slog.LevelInfo, "client established",
slog.String("ip", fmt.Sprintf("%s:%s", client.IPAddr, client.IPPort)),
slog.String("vpn_ip", client.VPNAddress),
slog.Uint64("cid", client.CID),
slog.String("common_name", client.CommonName),
slog.String("reason", client.Reason),
slog.String("session_id", client.SessionID),
slog.String("session_state", client.SessionState),
)
}

func (c *Client) clientDisconnect(client connection.Client) {
logger := c.logger.With(
slog.String("ip", fmt.Sprintf("%s:%s", client.IPAddr, client.IPPort)),
slog.Uint64("cid", client.CID),
slog.String("common_name", client.CommonName),
slog.String("reason", client.Reason),
slog.String("session_id", client.SessionID),
slog.String("session_state", client.SessionState),
)

func (c *Client) clientDisconnect(logger *slog.Logger, client connection.Client) {
logger.Info("client disconnected")

c.oauth2.ClientDisconnect(c.ctx, logger, client)
Expand Down
2 changes: 1 addition & 1 deletion internal/openvpn/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ func TestClientFull(t *testing.T) {
if strings.Contains(tt.expect, "WEB_AUTH") {
assert.Contains(t, auth, tt.expect)
} else {
assert.Equal(t, tt.expect, auth)
assert.Equal(t, tt.expect, auth, logger.String())
}

testutils.SendMessage(t, conn, "SUCCESS: %s command succeeded\r\n", strings.SplitN(auth, " ", 2)[0])
Expand Down

0 comments on commit 49e962f

Please sign in to comment.