Skip to content

chore: add credential checkParam field #964

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 11 additions & 7 deletions pkg/credentials/credential.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,16 @@ const (
)

type Credential struct {
Context string `json:"context"`
ToolName string `json:"toolName"`
Type CredentialType `json:"type"`
Env map[string]string `json:"env"`
Ephemeral bool `json:"ephemeral,omitempty"`
ExpiresAt *time.Time `json:"expiresAt"`
RefreshToken string `json:"refreshToken"`
Context string `json:"context"`
ToolName string `json:"toolName"`
Type CredentialType `json:"type"`
Env map[string]string `json:"env"`
// If the CheckParam that is stored is different from the param on the tool,
// then the credential will be re-authed as if it does not exist.
CheckParam string `json:"checkParam"`
Ephemeral bool `json:"ephemeral,omitempty"`
ExpiresAt *time.Time `json:"expiresAt"`
RefreshToken string `json:"refreshToken"`
}

func (c Credential) IsExpired() bool {
Expand Down Expand Up @@ -82,6 +85,7 @@ func credentialFromDockerAuthConfig(authCfg types.AuthConfig) (Credential, error
Context: ctx,
ToolName: tool,
Type: CredentialType(credType),
CheckParam: cred.CheckParam,
Env: cred.Env,
ExpiresAt: cred.ExpiresAt,
RefreshToken: cred.RefreshToken,
Expand Down
7 changes: 4 additions & 3 deletions pkg/runner/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -780,7 +780,7 @@ func (r *Runner) handleCredentials(callCtx engine.Context, monitor Monitor, env

var nearestExpiration *time.Time
for _, ref := range credToolRefs {
toolName, credentialAlias, args, err := types.ParseCredentialArgs(ref.Reference, callCtx.Input)
toolName, credentialAlias, checkParam, args, err := types.ParseCredentialArgs(ref.Reference, callCtx.Input)
if err != nil {
return nil, fmt.Errorf("failed to parse credential tool %q: %w", ref.Reference, err)
}
Expand Down Expand Up @@ -830,9 +830,10 @@ func (r *Runner) handleCredentials(callCtx engine.Context, monitor Monitor, env

// If the credential doesn't already exist in the store, run the credential tool in order to get the value,
// and save it in the store.
if !exists || c.IsExpired() {
if !exists || c.IsExpired() || checkParam != c.CheckParam {
// If the existing credential is expired, we need to provide it to the cred tool through the environment.
if exists && c.IsExpired() {
// If the check parameter is different, then we don't refresh. We should re-auth below.
if exists && c.IsExpired() && checkParam == c.CheckParam {
refresh = true
credJSON, err := json.Marshal(c)
if err != nil {
Expand Down
52 changes: 44 additions & 8 deletions pkg/types/credential_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,14 @@ import (

func TestParseCredentialArgs(t *testing.T) {
tests := []struct {
name string
toolName string
input string
expectedName string
expectedAlias string
expectedArgs map[string]string
wantErr bool
name string
toolName string
input string
expectedName string
expectedAlias string
expectedCheckParam string
expectedArgs map[string]string
wantErr bool
}{
{
name: "empty",
Expand Down Expand Up @@ -94,6 +95,40 @@ func TestParseCredentialArgs(t *testing.T) {
"arg2": "value2",
},
},
{
name: "tool name with check parameter",
toolName: `myCredentialTool checked with myCheckParam`,
expectedName: "myCredentialTool",
expectedCheckParam: "myCheckParam",
},
{
name: "tool name with alias and check parameter",
toolName: `myCredentialTool as myAlias checked with myCheckParam`,
expectedName: "myCredentialTool",
expectedAlias: "myAlias",
expectedCheckParam: "myCheckParam",
},
{
name: "tool name with alias, check parameter, and args",
toolName: `myCredentialTool as myAlias checked with myCheckParam with value1 as arg1 and value2 as arg2`,
expectedName: "myCredentialTool",
expectedAlias: "myAlias",
expectedCheckParam: "myCheckParam",
expectedArgs: map[string]string{
"arg1": "value1",
"arg2": "value2",
},
},
{
name: "check parameter without with",
toolName: `myCredentialTool checked myCheckParam`,
wantErr: true,
},
{
name: "invalid check parameter",
toolName: `myCredentialTool checked with`,
wantErr: true,
},
{
name: "tool name with alias but no 'as' (invalid)",
toolName: "myCredentialTool myAlias",
Expand Down Expand Up @@ -136,7 +171,7 @@ func TestParseCredentialArgs(t *testing.T) {

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
originalName, alias, args, err := ParseCredentialArgs(tt.toolName, tt.input)
originalName, alias, checkParam, args, err := ParseCredentialArgs(tt.toolName, tt.input)
if tt.wantErr {
require.Error(t, err, "expected an error but got none")
return
Expand All @@ -145,6 +180,7 @@ func TestParseCredentialArgs(t *testing.T) {
require.NoError(t, err, "did not expect an error but got one")
require.Equal(t, tt.expectedName, originalName, "unexpected original name")
require.Equal(t, tt.expectedAlias, alias, "unexpected alias")
require.Equal(t, tt.expectedCheckParam, checkParam, "unexpected checkParam")
require.Equal(t, len(tt.expectedArgs), len(args), "unexpected number of args")

for k, v := range tt.expectedArgs {
Expand Down
38 changes: 26 additions & 12 deletions pkg/types/tool.go
Original file line number Diff line number Diff line change
Expand Up @@ -272,9 +272,9 @@ func SplitArg(hasArg string) (prefix, arg string) {
// - toolName: "toolName with ${var1} as arg1 and ${var2} as arg2"
// - input: `{"var1": "value1", "var2": "value2"}`
// result: toolName, "", map[string]any{"arg1": "value1", "arg2": "value2"}, nil
func ParseCredentialArgs(toolName string, input string) (string, string, map[string]any, error) {
func ParseCredentialArgs(toolName string, input string) (string, string, string, map[string]any, error) {
if toolName == "" {
return "", "", nil, nil
return "", "", "", nil, nil
}

inputMap := make(map[string]any)
Expand All @@ -287,12 +287,12 @@ func ParseCredentialArgs(toolName string, input string) (string, string, map[str

fields, err := shlex.Split(toolName)
if err != nil {
return "", "", nil, err
return "", "", "", nil, err
}

// If it's just the tool name, return it
if len(fields) == 1 {
return toolName, "", nil, nil
return toolName, "", "", nil, nil
}

// Next field is "as" if there is an alias, otherwise it should be "with"
Expand All @@ -301,25 +301,39 @@ func ParseCredentialArgs(toolName string, input string) (string, string, map[str
fields = fields[1:]
if fields[0] == "as" {
if len(fields) < 2 {
return "", "", nil, fmt.Errorf("expected alias after 'as'")
return "", "", "", nil, fmt.Errorf("expected alias after 'as'")
}
alias = fields[1]
fields = fields[2:]
}

if len(fields) == 0 { // Nothing left, so just return
return originalName, alias, nil, nil
return originalName, alias, "", nil, nil
}

var checkParam string
if fields[0] == "checked" {
if len(fields) < 3 || fields[1] != "with" {
return "", "", "", nil, fmt.Errorf("expected 'checked with some_value' but got %v", fields)
}

checkParam = fields[2]
fields = fields[3:]
}

if len(fields) == 0 { // Nothing left, so just return
return originalName, alias, checkParam, nil, nil
}

// Next we should have "with" followed by the args
if fields[0] != "with" {
return "", "", nil, fmt.Errorf("expected 'with' but got %s", fields[0])
return "", "", "", nil, fmt.Errorf("expected 'with' but got %s", fields[0])
}
fields = fields[1:]

// If there are no args, return an error
if len(fields) == 0 {
return "", "", nil, fmt.Errorf("expected args after 'with'")
return "", "", "", nil, fmt.Errorf("expected args after 'with'")
}

args := make(map[string]any)
Expand All @@ -332,22 +346,22 @@ func ParseCredentialArgs(toolName string, input string) (string, string, map[str
prev = "value"
case "value":
if field != "as" {
return "", "", nil, fmt.Errorf("expected 'as' but got %s", field)
return "", "", "", nil, fmt.Errorf("expected 'as' but got %s", field)
}
prev = "as"
case "as":
args[field] = argValue
prev = "name"
case "name":
if field != "and" {
return "", "", nil, fmt.Errorf("expected 'and' but got %s", field)
return "", "", "", nil, fmt.Errorf("expected 'and' but got %s", field)
}
prev = "and"
}
}

if prev == "and" {
return "", "", nil, fmt.Errorf("expected arg name after 'and'")
return "", "", "", nil, fmt.Errorf("expected arg name after 'and'")
}

// Check and see if any of the arg values are references to an input
Expand All @@ -360,7 +374,7 @@ func ParseCredentialArgs(toolName string, input string) (string, string, map[str
}
}

return originalName, alias, args, nil
return originalName, alias, checkParam, args, nil
}

func (t Tool) GetToolRefsFromNames(names []string) (result []ToolReference, _ error) {
Expand Down