Skip to content

Commit

Permalink
chore: make credential overrides cred context aware
Browse files Browse the repository at this point in the history
  • Loading branch information
ibuildthecloud committed Jan 22, 2025
1 parent 3f876b2 commit c02c4cb
Show file tree
Hide file tree
Showing 6 changed files with 199 additions and 60 deletions.
60 changes: 46 additions & 14 deletions pkg/credentials/factory.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package credentials

import (
"context"
"strings"

"github.com/docker/docker-credential-helpers/client"
"github.com/gptscript-ai/gptscript/pkg/config"
Expand All @@ -13,12 +14,32 @@ type ProgramLoaderRunner interface {
Run(ctx context.Context, prg types.Program, input string) (output string, err error)
}

func NewFactory(ctx context.Context, cfg *config.CLIConfig, plr ProgramLoaderRunner) (StoreFactory, error) {
func NewFactory(ctx context.Context, cfg *config.CLIConfig, overrides []string, plr ProgramLoaderRunner) (StoreFactory, error) {
creds, err := ParseCredentialOverrides(overrides)
if err != nil {
return StoreFactory{}, err
}

overrideMap := make(map[string]map[string]map[string]string)
for k, v := range creds {
contextName, toolName, ok := strings.Cut(k, ctxSeparator)
if !ok {
continue
}
toolMap, ok := overrideMap[contextName]
if !ok {
toolMap = make(map[string]map[string]string)
}
toolMap[toolName] = v
overrideMap[contextName] = toolMap
}

toolName := translateToolName(cfg.CredentialsStore)
if toolName == config.FileCredHelper {
return StoreFactory{
file: true,
cfg: cfg,
file: true,
cfg: cfg,
overrides: overrideMap,
}, nil
}

Expand All @@ -28,10 +49,11 @@ func NewFactory(ctx context.Context, cfg *config.CLIConfig, plr ProgramLoaderRun
}

return StoreFactory{
ctx: ctx,
prg: prg,
runner: plr,
cfg: cfg,
ctx: ctx,
prg: prg,
runner: plr,
cfg: cfg,
overrides: overrideMap,
}, nil
}

Expand All @@ -41,22 +63,32 @@ type StoreFactory struct {
file bool
runner ProgramLoaderRunner
cfg *config.CLIConfig
// That's a lot of maps: context -> toolName -> key -> value
overrides map[string]map[string]map[string]string
}

func (s *StoreFactory) NewStore(credCtxs []string) (CredentialStore, error) {
if err := validateCredentialCtx(credCtxs); err != nil {
return nil, err
}
if s.file {
return Store{
credCtxs: credCtxs,
cfg: s.cfg,
return withOverride{
target: Store{
credCtxs: credCtxs,
cfg: s.cfg,
},
overrides: s.overrides,
credContext: credCtxs,
}, nil
}
return Store{
credCtxs: credCtxs,
cfg: s.cfg,
program: s.program,
return withOverride{
target: Store{
credCtxs: credCtxs,
cfg: s.cfg,
program: s.program,
},
overrides: s.overrides,
credContext: credCtxs,
}, nil
}

Expand Down
149 changes: 149 additions & 0 deletions pkg/credentials/overrides.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
package credentials

import (
"context"
"fmt"
"maps"
"os"
"strings"
)

// ParseCredentialOverrides parses a string of credential overrides that the user provided as a command line arg.
// The format of credential overrides can be one of two things:
// cred1:ENV1,ENV2 (direct mapping of environment variables)
// cred1:ENV1=VALUE1,ENV2=VALUE2 (key-value pairs)
//
// This function turns it into a map[string]map[string]string like this:
//
// {
// "cred1": {
// "ENV1": "VALUE1",
// "ENV2": "VALUE2",
// }
// }
func ParseCredentialOverrides(overrides []string) (map[string]map[string]string, error) {
credentialOverrides := make(map[string]map[string]string)

for _, o := range overrides {
credName, envs, found := strings.Cut(o, ":")
if !found {
return nil, fmt.Errorf("invalid credential override: %s", o)
}
envMap, ok := credentialOverrides[credName]
if !ok {
envMap = make(map[string]string)
}
for _, env := range strings.Split(envs, ",") {
for _, env := range strings.Split(env, "|") {
key, value, found := strings.Cut(env, "=")
if !found {
// User just passed an env var name as the key, so look up the value.
value = os.Getenv(key)
}
envMap[key] = value
}
}
credentialOverrides[credName] = envMap
}

return credentialOverrides, nil
}

type withOverride struct {
target CredentialStore
credContext []string
overrides map[string]map[string]map[string]string
}

func (w withOverride) Get(ctx context.Context, toolName string) (*Credential, bool, error) {
for _, credCtx := range w.credContext {
overrides, ok := w.overrides[credCtx]
if !ok {
continue
}
override, ok := overrides[toolName]
if !ok {
continue
}

return &Credential{
Context: credCtx,
ToolName: toolName,
Type: CredentialTypeTool,
Env: maps.Clone(override),
}, true, nil
}

return w.target.Get(ctx, toolName)
}

func (w withOverride) Add(ctx context.Context, cred Credential) error {
for _, credCtx := range w.credContext {
if override, ok := w.overrides[credCtx]; ok {
if _, ok := override[cred.ToolName]; ok {
return fmt.Errorf("cannot add credential with context %q and tool %q because it is statically configure", cred.Context, cred.ToolName)
}
}
}
return w.target.Add(ctx, cred)
}

func (w withOverride) Refresh(ctx context.Context, cred Credential) error {
if override, ok := w.overrides[cred.Context]; ok {
if _, ok := override[cred.ToolName]; ok {
return nil
}
}
return w.target.Refresh(ctx, cred)
}

func (w withOverride) Remove(ctx context.Context, toolName string) error {
for _, credCtx := range w.credContext {
if override, ok := w.overrides[credCtx]; ok {
if _, ok := override[toolName]; ok {
return fmt.Errorf("cannot remove credential with context %q and tool %q because it is statically configure", credCtx, toolName)
}
}
}
return w.target.Remove(ctx, toolName)
}

func (w withOverride) List(ctx context.Context) ([]Credential, error) {
creds, err := w.target.List(ctx)
if err != nil {
return nil, err
}

added := make(map[string]map[string]bool)
for i, cred := range creds {
if override, ok := w.overrides[cred.Context]; ok {
if _, ok := override[cred.ToolName]; ok {
creds[i].Type = CredentialTypeTool
creds[i].Env = maps.Clone(override[cred.ToolName])
}
}
tools, ok := added[cred.Context]
if !ok {
tools = make(map[string]bool)
}
tools[cred.ToolName] = true
added[cred.Context] = tools
}

for _, credCtx := range w.credContext {
tools := w.overrides[credCtx]
for toolName := range tools {
if _, ok := added[credCtx][toolName]; ok {
continue
}
creds = append(creds, Credential{
Context: credCtx,
ToolName: toolName,
Type: CredentialTypeTool,
Env: maps.Clone(tools[toolName]),
})
}
}

return creds, nil
}
2 changes: 1 addition & 1 deletion pkg/gptscript/gptscript.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ func New(ctx context.Context, o ...Options) (*GPTScript, error) {
return nil, err
}

storeFactory, err := credentials.NewFactory(ctx, cliCfg, simplerRunner)
storeFactory, err := credentials.NewFactory(ctx, cliCfg, opts.Runner.CredentialOverrides, simplerRunner)
if err != nil {
return nil, err
}
Expand Down
43 changes: 0 additions & 43 deletions pkg/runner/credentials.go

This file was deleted.

3 changes: 2 additions & 1 deletion pkg/runner/credentials_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"os"
"testing"

"github.com/gptscript-ai/gptscript/pkg/credentials"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -119,7 +120,7 @@ func TestParseCredentialOverrides(t *testing.T) {
_ = os.Setenv(k, v)
}

out, err := parseCredentialOverrides(tc.in)
out, err := credentials.ParseCredentialOverrides(tc.in)
if tc.expectErr {
require.Error(t, err)
return
Expand Down
2 changes: 1 addition & 1 deletion pkg/runner/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -754,7 +754,7 @@ func (r *Runner) handleCredentials(callCtx engine.Context, monitor Monitor, env
err error
)
if r.credOverrides != nil {
credOverrides, err = parseCredentialOverrides(r.credOverrides)
credOverrides, err = credentials.ParseCredentialOverrides(r.credOverrides)
if err != nil {
return nil, fmt.Errorf("failed to parse credential overrides: %w", err)
}
Expand Down

0 comments on commit c02c4cb

Please sign in to comment.