@@ -12,16 +12,26 @@ import (
12
12
"net/http/httptest"
13
13
"net/url"
14
14
"testing"
15
+
16
+ "golang.org/x/oauth2"
15
17
)
16
18
17
- func newConf (serverURL string ) * Config {
18
- return & Config {
19
+ func newConf (serverURL string , assertion bool ) * Config {
20
+ conf := & Config {
19
21
ClientID : "CLIENT_ID" ,
20
- ClientSecret : "CLIENT_SECRET" ,
21
22
Scopes : []string {"scope1" , "scope2" },
22
23
TokenURL : serverURL + "/token" ,
23
24
EndpointParams : url.Values {"audience" : {"audience1" }},
25
+ AuthStyle : oauth2 .AuthStyleInParams ,
26
+ }
27
+ if assertion {
28
+ conf .ClientAssertionFn = func (ctx context.Context ) (string , error ) {
29
+ return "CLIENT_ASSERTION" , nil
30
+ }
31
+ } else {
32
+ conf .ClientSecret = "CLIENT_SECRET"
24
33
}
34
+ return conf
25
35
}
26
36
27
37
type mockTransport struct {
@@ -69,45 +79,70 @@ func TestTokenSourceGrantTypeOverride(t *testing.T) {
69
79
}
70
80
}
71
81
82
+ func assert (t * testing.T , want , got string ) {
83
+ t .Helper ()
84
+ if got != want {
85
+ t .Errorf ("got %q; want %q" , got , want )
86
+ }
87
+ }
88
+
72
89
func TestTokenRequest (t * testing.T ) {
73
90
ts := httptest .NewServer (http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
74
91
if r .URL .String () != "/token" {
75
92
t .Errorf ("authenticate client request URL = %q; want %q" , r .URL , "/token" )
76
93
}
77
- headerAuth := r .Header .Get ("Authorization" )
78
- if headerAuth != "Basic Q0xJRU5UX0lEOkNMSUVOVF9TRUNSRVQ=" {
79
- t .Errorf ("Unexpected authorization header, %v is found." , headerAuth )
80
- }
94
+
81
95
if got , want := r .Header .Get ("Content-Type" ), "application/x-www-form-urlencoded" ; got != want {
82
96
t .Errorf ("Content-Type header = %q; want %q" , got , want )
83
97
}
84
- body , err := ioutil .ReadAll (r .Body )
85
- if err != nil {
86
- r .Body .Close ()
87
- }
88
- if err != nil {
89
- t .Errorf ("failed reading request body: %s." , err )
90
- }
91
- if string (body ) != "audience=audience1&grant_type=client_credentials&scope=scope1+scope2" {
92
- t .Errorf ("payload = %q; want %q" , string (body ), "grant_type=client_credentials&scope=scope1+scope2" )
98
+
99
+ assert (t , "audience1" , r .FormValue ("audience" ))
100
+ assert (t , "CLIENT_ID" , r .FormValue ("client_id" ))
101
+ assert (t , "client_credentials" , r .FormValue ("grant_type" ))
102
+ assert (t , "scope1 scope2" , r .FormValue ("scope" ))
103
+ if r .FormValue ("client_secret" ) != "" {
104
+ assert (t , "CLIENT_SECRET" , r .FormValue ("client_secret" ))
105
+ } else {
106
+ assert (t , "CLIENT_ASSERTION" , r .FormValue ("client_assertion" ))
107
+ assert (t , "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" , r .FormValue ("client_assertion_type" ))
93
108
}
94
109
w .Header ().Set ("Content-Type" , "application/x-www-form-urlencoded" )
95
110
w .Write ([]byte ("access_token=90d64460d14870c08c81352a05dedd3465940a7c&token_type=bearer" ))
96
111
}))
97
112
defer ts .Close ()
98
- conf := newConf (ts .URL )
99
- tok , err := conf .Token (context .Background ())
100
- if err != nil {
101
- t .Error (err )
102
- }
103
- if ! tok .Valid () {
104
- t .Fatalf ("token invalid. got: %#v" , tok )
113
+
114
+ type testCase struct {
115
+ name string
116
+ conf * Config
105
117
}
106
- if tok .AccessToken != "90d64460d14870c08c81352a05dedd3465940a7c" {
107
- t .Errorf ("Access token = %q; want %q" , tok .AccessToken , "90d64460d14870c08c81352a05dedd3465940a7c" )
118
+
119
+ tests := []testCase {
120
+ {
121
+ name : "client id and client_secret" ,
122
+ conf : newConf (ts .URL , false ),
123
+ },
124
+ {
125
+ name : "client id and client_assertion" ,
126
+ conf : newConf (ts .URL , true ),
127
+ },
108
128
}
109
- if tok .TokenType != "bearer" {
110
- t .Errorf ("token type = %q; want %q" , tok .TokenType , "bearer" )
129
+
130
+ for _ , tc := range tests {
131
+ t .Run (tc .name , func (t * testing.T ) {
132
+ tok , err := tc .conf .Token (context .Background ())
133
+ if err != nil {
134
+ t .Error (err )
135
+ }
136
+ if ! tok .Valid () {
137
+ t .Fatalf ("token invalid. got: %#v" , tok )
138
+ }
139
+ if tok .AccessToken != "90d64460d14870c08c81352a05dedd3465940a7c" {
140
+ t .Errorf ("Access token = %q; want %q" , tok .AccessToken , "90d64460d14870c08c81352a05dedd3465940a7c" )
141
+ }
142
+ if tok .TokenType != "bearer" {
143
+ t .Errorf ("token type = %q; want %q" , tok .TokenType , "bearer" )
144
+ }
145
+ })
111
146
}
112
147
}
113
148
@@ -132,7 +167,7 @@ func TestTokenRefreshRequest(t *testing.T) {
132
167
io .WriteString (w , `{"access_token": "foo", "refresh_token": "bar"}` )
133
168
}))
134
169
defer ts .Close ()
135
- conf := newConf (ts .URL )
170
+ conf := newConf (ts .URL , false )
136
171
c := conf .Client (context .Background ())
137
172
c .Get (ts .URL + "/somethingelse" )
138
173
}
0 commit comments