Skip to content

Commit 1319cb9

Browse files
committed
Change IdentityProviderResponse interface
1 parent 1fca220 commit 1319cb9

7 files changed

+40
-50
lines changed

identity/azure_default_identity_provider_test.go

+13-13
Original file line numberDiff line numberDiff line change
@@ -52,11 +52,11 @@ func TestAzureDefaultIdentityProvider_RequestToken(t *testing.T) {
5252
mCredFactory := &mockCredFactory{}
5353
mCredFactory.On("NewDefaultAzureCredential", mock.Anything).Return(mCreds, nil)
5454
provider.credFactory = mCredFactory
55-
token, err = provider.RequestToken(context.Background())
56-
assert.NotNil(t, token, "token should not be nil")
57-
assert.NoError(t, err, "failed to request token")
58-
assert.Equal(t, shared.ResponseTypeAccessToken, token.Type(), "token type should be access token")
59-
assert.Equal(t, mToken, token.AccessToken(), "access token should be equal to testJWTToken")
55+
resp, err := provider.RequestToken(context.Background())
56+
assert.NotNil(t, resp, "resp should not be nil")
57+
assert.NoError(t, err, "failed to request resp")
58+
assert.Equal(t, shared.ResponseTypeAccessToken, resp.Type(), "resp type should be access resp")
59+
assert.Equal(t, mToken, resp.(shared.AccessTokenIDPResponse).AccessToken(), "access token should be equal to testJWTToken")
6060
}
6161

6262
func TestAzureDefaultIdentityProvider_RequestTokenWithScopes(t *testing.T) {
@@ -84,19 +84,19 @@ func TestAzureDefaultIdentityProvider_RequestTokenWithScopes(t *testing.T) {
8484
mCredFactory := &mockCredFactory{}
8585
mCredFactory.On("NewDefaultAzureCredential", mock.Anything).Return(mCreds, nil)
8686
provider.credFactory = mCredFactory
87-
token, err = provider.RequestToken(context.Background())
88-
assert.NotNil(t, token, "token should not be nil")
89-
assert.NoError(t, err, "failed to request token")
90-
assert.Equal(t, shared.ResponseTypeAccessToken, token.Type(), "token type should be access token")
91-
assert.Equal(t, mToken, token.AccessToken(), "access token should be equal to testJWTToken")
87+
resp, err := provider.RequestToken(context.Background())
88+
assert.NotNil(t, resp, "resp should not be nil")
89+
assert.NoError(t, err, "failed to request resp")
90+
assert.Equal(t, shared.ResponseTypeAccessToken, resp.Type(), "resp type should be access resp")
91+
assert.Equal(t, mToken, resp.(shared.AccessTokenIDPResponse).AccessToken(), "access resp should be equal to testJWTToken")
9292
})
9393
t.Run("RequestToken with error from credFactory", func(t *testing.T) {
9494
// use mockAzureCredential to simulate the environment
9595
mCredFactory := &mockCredFactory{}
9696
mCredFactory.On("NewDefaultAzureCredential", mock.Anything).Return(nil, assert.AnError)
9797
provider.credFactory = mCredFactory
98-
token, err := provider.RequestToken(context.Background())
99-
assert.Nil(t, token, "token should be nil")
100-
assert.Error(t, err, "failed to request token")
98+
resp, err := provider.RequestToken(context.Background())
99+
assert.Nil(t, resp, "resp should be nil")
100+
assert.Error(t, err, "failed to request resp")
101101
})
102102
}

identity/confidential_identity_provider_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,7 @@ func TestConfidentialIdentityProvider_RequestToken(t *testing.T) {
268268
}
269269
assert.NotEmpty(t, token, "RequestToken() token should not be empty")
270270
assert.Equal(t, token.Type(), shared.ResponseTypeAuthResult, "RequestToken() token type should be AuthResult")
271-
assert.Equal(t, token.AuthResult().ExpiresOn, expiresOn, "RequestToken() token expiration should match")
271+
assert.Equal(t, token.(shared.AuthResultIDPResponse).AuthResult().ExpiresOn, expiresOn, "RequestToken() token expiration should match")
272272
})
273273
t.Run("with error", func(t *testing.T) {
274274
t.Parallel()

internal/idp_response.go

-15
Original file line numberDiff line numberDiff line change
@@ -80,11 +80,6 @@ func (a *IDPResp) AuthResult() public.AuthResult {
8080
return *a.authResultVal
8181
}
8282

83-
// HasAuthResult returns true if an AuthResult is set
84-
func (a *IDPResp) HasAuthResult() bool {
85-
return a.authResultVal != nil
86-
}
87-
8883
// AccessToken returns the AccessToken if present, or an empty AccessToken if not set
8984
// Use HasAccessToken() to check if the value is actually set
9085
func (a *IDPResp) AccessToken() azcore.AccessToken {
@@ -94,17 +89,7 @@ func (a *IDPResp) AccessToken() azcore.AccessToken {
9489
return *a.accessTokenVal
9590
}
9691

97-
// HasAccessToken returns true if an AccessToken is set
98-
func (a *IDPResp) HasAccessToken() bool {
99-
return a.accessTokenVal != nil
100-
}
101-
10292
// RawToken returns the raw token string
10393
func (a *IDPResp) RawToken() string {
10494
return a.rawTokenVal
10595
}
106-
107-
// HasRawToken returns true if a raw token is set
108-
func (a *IDPResp) HasRawToken() bool {
109-
return a.rawTokenVal != ""
110-
}

internal/idp_response_test.go

-10
Original file line numberDiff line numberDiff line change
@@ -171,10 +171,7 @@ func TestNewIDPResp(t *testing.T) {
171171
},
172172
wantErr: false,
173173
checkResult: func(t *testing.T, resp *IDPResp) {
174-
assert.True(t, resp.HasAuthResult())
175174
assert.Equal(t, "test-token", resp.AuthResult().AccessToken)
176-
assert.False(t, resp.HasAccessToken())
177-
assert.False(t, resp.HasRawToken())
178175
},
179176
},
180177
{
@@ -185,7 +182,6 @@ func TestNewIDPResp(t *testing.T) {
185182
},
186183
wantErr: false,
187184
checkResult: func(t *testing.T, resp *IDPResp) {
188-
assert.True(t, resp.HasAuthResult())
189185
assert.Equal(t, "test-token", resp.AuthResult().AccessToken)
190186
},
191187
},
@@ -198,7 +194,6 @@ func TestNewIDPResp(t *testing.T) {
198194
},
199195
wantErr: false,
200196
checkResult: func(t *testing.T, resp *IDPResp) {
201-
assert.True(t, resp.HasAccessToken())
202197
assert.Equal(t, "test-token", resp.AccessToken().Token)
203198
assert.Equal(t, "test-token", resp.RawToken())
204199
},
@@ -212,7 +207,6 @@ func TestNewIDPResp(t *testing.T) {
212207
},
213208
wantErr: false,
214209
checkResult: func(t *testing.T, resp *IDPResp) {
215-
assert.True(t, resp.HasAccessToken())
216210
assert.Equal(t, "test-token", resp.AccessToken().Token)
217211
assert.Equal(t, "test-token", resp.RawToken())
218212
},
@@ -223,10 +217,7 @@ func TestNewIDPResp(t *testing.T) {
223217
result: "test-token",
224218
wantErr: false,
225219
checkResult: func(t *testing.T, resp *IDPResp) {
226-
assert.True(t, resp.HasRawToken())
227220
assert.Equal(t, "test-token", resp.RawToken())
228-
assert.False(t, resp.HasAuthResult())
229-
assert.False(t, resp.HasAccessToken())
230221
},
231222
},
232223
{
@@ -235,7 +226,6 @@ func TestNewIDPResp(t *testing.T) {
235226
result: stringPtr("test-token"),
236227
wantErr: false,
237228
checkResult: func(t *testing.T, resp *IDPResp) {
238-
assert.True(t, resp.HasRawToken())
239229
assert.Equal(t, "test-token", resp.RawToken())
240230
},
241231
},

manager/defaults.go

+3-3
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ func (*defaultIdentityProviderResponseParser) ParseResponse(response shared.Iden
104104

105105
switch response.Type() {
106106
case shared.ResponseTypeAuthResult:
107-
authResult := response.AuthResult()
107+
authResult := response.(shared.AuthResultIDPResponse).AuthResult()
108108
if authResult.ExpiresOn.IsZero() {
109109
return nil, fmt.Errorf("auth result expiration time is not set")
110110
}
@@ -117,10 +117,10 @@ func (*defaultIdentityProviderResponseParser) ParseResponse(response shared.Iden
117117
expiresOn = authResult.ExpiresOn.UTC()
118118

119119
case shared.ResponseTypeRawToken, shared.ResponseTypeAccessToken:
120-
tokenStr := response.RawToken()
120+
tokenStr := response.(shared.RawTokenIDPResponse).RawToken()
121121

122122
if response.Type() == shared.ResponseTypeAccessToken {
123-
accessToken := response.AccessToken()
123+
accessToken := response.(shared.AccessTokenIDPResponse).AccessToken()
124124
if accessToken.Token == "" {
125125
return nil, fmt.Errorf("access token value is empty")
126126
}

shared/identity_provider_response.go

+18-3
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,28 @@ type IdentityProviderResponseParser interface {
2525
ParseResponse(response IdentityProviderResponse) (*token.Token, error)
2626
}
2727

28-
// IdentityProviderResponse is an interface that defines the methods for an identity provider authentication result.
29-
// It is used to get the type of the authentication result, the authentication result itself (can be AuthResult or AccessToken),
28+
// IdentityProviderResponse is an interface that defines the
29+
// type method for the identity provider response. It is used to
30+
// identify the type of response returned by the identity provider.
31+
// The type can be either AuthResult, AccessToken, or RawToken. You can
32+
// use this interface to check the type of the response and handle it accordingly.
3033
type IdentityProviderResponse interface {
31-
// Type returns the type of the auth result
34+
// Type returns the type of identity provider response
3235
Type() string
36+
}
37+
38+
// AuthResultIDPResponse is an interface that defines the method for getting the auth result.
39+
type AuthResultIDPResponse interface {
3340
AuthResult() public.AuthResult
41+
}
42+
43+
// AccessTokenIDPResponse is an interface that defines the method for getting the access token.
44+
type AccessTokenIDPResponse interface {
3445
AccessToken() azcore.AccessToken
46+
}
47+
48+
// RawTokenIDPResponse is an interface that defines the method for getting the raw token.
49+
type RawTokenIDPResponse interface {
3550
RawToken() string
3651
}
3752

shared/identity_provider_response_test.go

+5-5
Original file line numberDiff line numberDiff line change
@@ -156,12 +156,12 @@ func TestNewIDPResponse(t *testing.T) {
156156

157157
switch tt.responseType {
158158
case ResponseTypeAuthResult:
159-
assert.NotNil(t, resp.AuthResult())
159+
assert.NotNil(t, resp.(AuthResultIDPResponse).AuthResult())
160160
case ResponseTypeAccessToken:
161-
assert.NotNil(t, resp.AccessToken())
162-
assert.NotEmpty(t, resp.AccessToken().Token)
161+
assert.NotNil(t, resp.(AccessTokenIDPResponse).AccessToken())
162+
assert.NotEmpty(t, resp.(AccessTokenIDPResponse).AccessToken().Token)
163163
case ResponseTypeRawToken:
164-
assert.NotEmpty(t, resp.RawToken())
164+
assert.NotEmpty(t, resp.(RawTokenIDPResponse).RawToken())
165165
}
166166
})
167167
}
@@ -271,7 +271,7 @@ func TestIdentityProvider(t *testing.T) {
271271
assert.NoError(t, err)
272272
assert.NotNil(t, response)
273273
assert.Equal(t, ResponseTypeRawToken, response.Type())
274-
assert.Equal(t, "test-token", response.RawToken())
274+
assert.Equal(t, "test-token", response.(RawTokenIDPResponse).RawToken())
275275
}
276276
})
277277
}

0 commit comments

Comments
 (0)