Skip to content

Commit

Permalink
retry fetching agent ssh key
Browse files Browse the repository at this point in the history
  • Loading branch information
johnstcn committed May 3, 2024
1 parent 8b31faa commit 9e13995
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 21 deletions.
2 changes: 1 addition & 1 deletion envbuilder.go
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ func Run(ctx context.Context, options Options) error {
CABundle: caBundle,
}

cloneOpts.RepoAuth = SetupRepoAuth(&options)
cloneOpts.RepoAuth = SetupRepoAuth(ctx, &options)
if options.GitHTTPProxyURL != "" {
cloneOpts.ProxyOptions = transport.ProxyOptions{
URL: options.GitHTTPProxyURL,
Expand Down
43 changes: 39 additions & 4 deletions git.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,13 @@ import (
"fmt"
"io"
"net"
"net/http"
"net/url"
"os"
"strings"
"time"

"github.com/cenkalti/backoff/v4"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/agentsdk"
"github.com/go-git/go-billy/v5"
Expand Down Expand Up @@ -186,7 +188,7 @@ func LogHostKeyCallback(log LoggerFunc) gossh.HostKeyCallback {
// If SSH_KNOWN_HOSTS is not set, the SSH auth method will be configured
// to accept and log all host keys. Otherwise, host key checking will be
// performed as usual.
func SetupRepoAuth(options *Options) transport.AuthMethod {
func SetupRepoAuth(ctx context.Context, options *Options) transport.AuthMethod {
if options.GitURL == "" {
options.Logger(codersdk.LogLevelInfo, "#1: ❔ No Git URL supplied!")
return nil
Expand Down Expand Up @@ -231,14 +233,14 @@ func SetupRepoAuth(options *Options) transport.AuthMethod {
// an SSH key from Coder!
if signer == nil && options.CoderAgentURL != "" && options.CoderAgentToken != "" {
options.Logger(codersdk.LogLevelInfo, "#1: 🔑 Fetching key from %s!", options.CoderAgentURL)
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
fetchCtx, cancel := context.WithCancel(ctx)
defer cancel()
s, err := FetchCoderSSHKey(ctx, options.CoderAgentURL, options.CoderAgentToken)
s, err := FetchCoderSSHKeyRetry(fetchCtx, options.Logger, options.CoderAgentURL, options.CoderAgentToken)
if err == nil {
signer = s
options.Logger(codersdk.LogLevelInfo, "#1: 🔑 Fetched %s key %s !", signer.PublicKey().Type(), keyFingerprint(signer)[:8])
} else {
options.Logger(codersdk.LogLevelInfo, "#1: ❌ Failed to fetch SSH key: %w", options.CoderAgentURL, err)
options.Logger(codersdk.LogLevelInfo, "#1: ❌ Failed to fetch SSH key from %s: %w", options.CoderAgentURL, err)
}
}

Expand Down Expand Up @@ -276,6 +278,39 @@ func SetupRepoAuth(options *Options) transport.AuthMethod {
return auth
}

// FetchCoderSSHKeyRetry wraps FetchCoderSSHKey in backoff.Retry.
// Retries are attempted if Coder responds with a 401 Unauthorized.
// This indicates that the workspace build has not yet completed.
// It will retry for up to 1 minute with exponential backoff.
// Any other error is considered a permanent failure.
func FetchCoderSSHKeyRetry(ctx context.Context, log LoggerFunc, coderURL, agentToken string) (gossh.Signer, error) {
ctx, cancel := context.WithCancel(ctx)
defer cancel()

signerChan := make(chan gossh.Signer, 1)
eb := backoff.NewExponentialBackOff()
eb.MaxElapsedTime = 0
eb.MaxInterval = time.Minute
bkoff := backoff.WithContext(eb, ctx)
err := backoff.Retry(func() error {
s, err := FetchCoderSSHKey(ctx, coderURL, agentToken)
if err != nil {
var sdkErr *codersdk.Error
if errors.As(err, &sdkErr) && sdkErr.StatusCode() == http.StatusUnauthorized {
// Retry, as this may just mean that the workspace build has not yet
// completed.
log(codersdk.LogLevelInfo, "#1: 🕐 Backing off as the workspace build has not yet completed...")
return err
}
close(signerChan)
return backoff.Permanent(err)
}
signerChan <- s
return nil
}, bkoff)
return <-signerChan, err
}

// FetchCoderSSHKey fetches the user's Git SSH key from Coder using the supplied
// Coder URL and agent token.
func FetchCoderSSHKey(ctx context.Context, coderURL string, agentToken string) (gossh.Signer, error) {
Expand Down
66 changes: 50 additions & 16 deletions git_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"os"
"path/filepath"
"regexp"
"sync/atomic"
"testing"

"github.com/coder/coder/v2/codersdk"
Expand Down Expand Up @@ -268,11 +269,12 @@ func TestCloneRepoSSH(t *testing.T) {
// nolint:paralleltest // t.Setenv for SSH_AUTH_SOCK
func TestSetupRepoAuth(t *testing.T) {
t.Setenv("SSH_AUTH_SOCK", "")
ctx := context.Background()
t.Run("Empty", func(t *testing.T) {
opts := &envbuilder.Options{
Logger: testLog(t),
}
auth := envbuilder.SetupRepoAuth(opts)
auth := envbuilder.SetupRepoAuth(ctx, opts)
require.Nil(t, auth)
})

Expand All @@ -281,7 +283,7 @@ func TestSetupRepoAuth(t *testing.T) {
GitURL: "http://host.tld/repo",
Logger: testLog(t),
}
auth := envbuilder.SetupRepoAuth(opts)
auth := envbuilder.SetupRepoAuth(ctx, opts)
require.Nil(t, auth)
})

Expand All @@ -292,7 +294,7 @@ func TestSetupRepoAuth(t *testing.T) {
GitPassword: "pass",
Logger: testLog(t),
}
auth := envbuilder.SetupRepoAuth(opts)
auth := envbuilder.SetupRepoAuth(ctx, opts)
ba, ok := auth.(*githttp.BasicAuth)
require.True(t, ok)
require.Equal(t, opts.GitUsername, ba.Username)
Expand All @@ -306,7 +308,7 @@ func TestSetupRepoAuth(t *testing.T) {
GitPassword: "pass",
Logger: testLog(t),
}
auth := envbuilder.SetupRepoAuth(opts)
auth := envbuilder.SetupRepoAuth(ctx, opts)
ba, ok := auth.(*githttp.BasicAuth)
require.True(t, ok)
require.Equal(t, opts.GitUsername, ba.Username)
Expand All @@ -320,7 +322,7 @@ func TestSetupRepoAuth(t *testing.T) {
GitSSHPrivateKeyPath: kPath,
Logger: testLog(t),
}
auth := envbuilder.SetupRepoAuth(opts)
auth := envbuilder.SetupRepoAuth(ctx, opts)
_, ok := auth.(*gitssh.PublicKeys)
require.True(t, ok)
})
Expand All @@ -332,7 +334,7 @@ func TestSetupRepoAuth(t *testing.T) {
GitSSHPrivateKeyPath: kPath,
Logger: testLog(t),
}
auth := envbuilder.SetupRepoAuth(opts)
auth := envbuilder.SetupRepoAuth(ctx, opts)
_, ok := auth.(*gitssh.PublicKeys)
require.True(t, ok)
})
Expand All @@ -345,7 +347,7 @@ func TestSetupRepoAuth(t *testing.T) {
GitSSHPrivateKeyPath: kPath,
Logger: testLog(t),
}
auth := envbuilder.SetupRepoAuth(opts)
auth := envbuilder.SetupRepoAuth(ctx, opts)
_, ok := auth.(*gitssh.PublicKeys)
require.True(t, ok)
})
Expand All @@ -358,7 +360,7 @@ func TestSetupRepoAuth(t *testing.T) {
GitUsername: "user",
Logger: testLog(t),
}
auth := envbuilder.SetupRepoAuth(opts)
auth := envbuilder.SetupRepoAuth(ctx, opts)
_, ok := auth.(*gitssh.PublicKeys)
require.True(t, ok)
})
Expand All @@ -370,7 +372,7 @@ func TestSetupRepoAuth(t *testing.T) {
GitSSHPrivateKeyPath: kPath,
Logger: testLog(t),
}
auth := envbuilder.SetupRepoAuth(opts)
auth := envbuilder.SetupRepoAuth(ctx, opts)
pk, ok := auth.(*gitssh.PublicKeys)
require.True(t, ok)
require.NotNil(t, pk.Signer)
Expand All @@ -384,7 +386,7 @@ func TestSetupRepoAuth(t *testing.T) {
GitURL: "ssh://[email protected]:repo/path",
Logger: testLog(t),
}
auth := envbuilder.SetupRepoAuth(opts)
auth := envbuilder.SetupRepoAuth(ctx, opts)
require.Nil(t, auth) // TODO: actually test SSH_AUTH_SOCK
})

Expand Down Expand Up @@ -415,19 +417,51 @@ func TestSetupRepoAuth(t *testing.T) {
GitURL: "ssh://[email protected]:repo/path",
Logger: testLog(t),
}
auth := envbuilder.SetupRepoAuth(opts)
auth := envbuilder.SetupRepoAuth(ctx, opts)
pk, ok := auth.(*gitssh.PublicKeys)
require.True(t, ok)
require.NotNil(t, pk.Signer)
require.Equal(t, actualSigner, pk.Signer)
})

t.Run("SSH/CoderForbidden", func(t *testing.T) {
t.Run("SSH/CoderRetry", func(t *testing.T) {
token := uuid.NewString()
actualSigner, err := gossh.ParsePrivateKey([]byte(testKey))
require.NoError(t, err)
var count atomic.Int64
// Return 401 initially, but eventually 200.
handler := func(w http.ResponseWriter, r *http.Request) {
hdr := r.Header.Get(codersdk.SessionTokenHeader)
assert.Equal(t, hdr, token)
w.WriteHeader(http.StatusForbidden)
c := count.Add(1)
if c < 3 {
hdr := r.Header.Get(codersdk.SessionTokenHeader)
assert.Equal(t, hdr, token)
w.WriteHeader(http.StatusUnauthorized)
return
}
_ = json.NewEncoder(w).Encode(&agentsdk.GitSSHKey{
PublicKey: string(actualSigner.PublicKey().Marshal()),
PrivateKey: string(testKey),
})
}
srv := httptest.NewServer(http.HandlerFunc(handler))
opts := &envbuilder.Options{
CoderAgentURL: srv.URL,
CoderAgentToken: token,
GitURL: "ssh://[email protected]:repo/path",
Logger: testLog(t),
}
auth := envbuilder.SetupRepoAuth(ctx, opts)
pk, ok := auth.(*gitssh.PublicKeys)
require.True(t, ok)
require.NotNil(t, pk.Signer)
require.Equal(t, actualSigner, pk.Signer)
})

t.Run("SSH/NotCoder", func(t *testing.T) {
token := uuid.NewString()
handler := func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusTeapot)
_, _ = w.Write([]byte("I'm a teapot!"))
}
srv := httptest.NewServer(http.HandlerFunc(handler))
opts := &envbuilder.Options{
Expand All @@ -436,7 +470,7 @@ func TestSetupRepoAuth(t *testing.T) {
GitURL: "ssh://[email protected]:repo/path",
Logger: testLog(t),
}
auth := envbuilder.SetupRepoAuth(opts)
auth := envbuilder.SetupRepoAuth(ctx, opts)
require.Nil(t, auth)
})
}
Expand Down

0 comments on commit 9e13995

Please sign in to comment.