Skip to content

Commit e9b2eb7

Browse files
committed
support client assertion
1 parent 3e64809 commit e9b2eb7

File tree

2 files changed

+80
-28
lines changed

2 files changed

+80
-28
lines changed

clientcredentials/clientcredentials.go

+17
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,10 @@ import (
2424
"golang.org/x/oauth2/internal"
2525
)
2626

27+
const (
28+
ClientJWTAssertionType = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"
29+
)
30+
2731
// Config describes a 2-legged OAuth2 flow, with both the
2832
// client application information and the server's endpoint URLs.
2933
type Config struct {
@@ -33,6 +37,9 @@ type Config struct {
3337
// ClientSecret is the application's secret.
3438
ClientSecret string
3539

40+
// ClientAssertionFn is a function to generate a client assertion value.
41+
ClientAssertionFn func(ctx context.Context) (string, error)
42+
3643
// TokenURL is the resource server's token endpoint
3744
// URL. This is a constant specific to each server.
3845
TokenURL string
@@ -107,6 +114,16 @@ func (c *tokenSource) Token() (*oauth2.Token, error) {
107114
v[k] = p
108115
}
109116

117+
if c.conf.ClientAssertionFn != nil {
118+
clientAssertion, err := c.conf.ClientAssertionFn(c.ctx)
119+
if err != nil {
120+
return nil, err
121+
}
122+
123+
v.Set("client_assertion", clientAssertion)
124+
v.Set("client_assertion_type", ClientJWTAssertionType)
125+
}
126+
110127
tk, err := internal.RetrieveToken(c.ctx, c.conf.ClientID, c.conf.ClientSecret, c.conf.TokenURL, v, internal.AuthStyle(c.conf.AuthStyle), c.conf.authStyleCache.Get())
111128
if err != nil {
112129
if rErr, ok := err.(*internal.RetrieveError); ok {

clientcredentials/clientcredentials_test.go

+63-28
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,26 @@ import (
1212
"net/http/httptest"
1313
"net/url"
1414
"testing"
15+
16+
"golang.org/x/oauth2"
1517
)
1618

17-
func newConf(serverURL string) *Config {
18-
return &Config{
19+
func newConf(serverURL string, assertion bool) *Config {
20+
conf := &Config{
1921
ClientID: "CLIENT_ID",
20-
ClientSecret: "CLIENT_SECRET",
2122
Scopes: []string{"scope1", "scope2"},
2223
TokenURL: serverURL + "/token",
2324
EndpointParams: url.Values{"audience": {"audience1"}},
25+
AuthStyle: oauth2.AuthStyleInParams,
26+
}
27+
if assertion {
28+
conf.ClientAssertionFn = func(ctx context.Context) (string, error) {
29+
return "CLIENT_ASSERTION", nil
30+
}
31+
} else {
32+
conf.ClientSecret = "CLIENT_SECRET"
2433
}
34+
return conf
2535
}
2636

2737
type mockTransport struct {
@@ -69,45 +79,70 @@ func TestTokenSourceGrantTypeOverride(t *testing.T) {
6979
}
7080
}
7181

82+
func assert(t *testing.T, want, got string) {
83+
t.Helper()
84+
if got != want {
85+
t.Errorf("got %q; want %q", got, want)
86+
}
87+
}
88+
7289
func TestTokenRequest(t *testing.T) {
7390
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
7491
if r.URL.String() != "/token" {
7592
t.Errorf("authenticate client request URL = %q; want %q", r.URL, "/token")
7693
}
77-
headerAuth := r.Header.Get("Authorization")
78-
if headerAuth != "Basic Q0xJRU5UX0lEOkNMSUVOVF9TRUNSRVQ=" {
79-
t.Errorf("Unexpected authorization header, %v is found.", headerAuth)
80-
}
94+
8195
if got, want := r.Header.Get("Content-Type"), "application/x-www-form-urlencoded"; got != want {
8296
t.Errorf("Content-Type header = %q; want %q", got, want)
8397
}
84-
body, err := ioutil.ReadAll(r.Body)
85-
if err != nil {
86-
r.Body.Close()
87-
}
88-
if err != nil {
89-
t.Errorf("failed reading request body: %s.", err)
90-
}
91-
if string(body) != "audience=audience1&grant_type=client_credentials&scope=scope1+scope2" {
92-
t.Errorf("payload = %q; want %q", string(body), "grant_type=client_credentials&scope=scope1+scope2")
98+
99+
assert(t, "audience1", r.FormValue("audience"))
100+
assert(t, "CLIENT_ID", r.FormValue("client_id"))
101+
assert(t, "client_credentials", r.FormValue("grant_type"))
102+
assert(t, "scope1 scope2", r.FormValue("scope"))
103+
if r.FormValue("client_secret") != "" {
104+
assert(t, "CLIENT_SECRET", r.FormValue("client_secret"))
105+
} else {
106+
assert(t, "CLIENT_ASSERTION", r.FormValue("client_assertion"))
107+
assert(t, "urn:ietf:params:oauth:client-assertion-type:jwt-bearer", r.FormValue("client_assertion_type"))
93108
}
94109
w.Header().Set("Content-Type", "application/x-www-form-urlencoded")
95110
w.Write([]byte("access_token=90d64460d14870c08c81352a05dedd3465940a7c&token_type=bearer"))
96111
}))
97112
defer ts.Close()
98-
conf := newConf(ts.URL)
99-
tok, err := conf.Token(context.Background())
100-
if err != nil {
101-
t.Error(err)
102-
}
103-
if !tok.Valid() {
104-
t.Fatalf("token invalid. got: %#v", tok)
113+
114+
type testCase struct {
115+
name string
116+
conf *Config
105117
}
106-
if tok.AccessToken != "90d64460d14870c08c81352a05dedd3465940a7c" {
107-
t.Errorf("Access token = %q; want %q", tok.AccessToken, "90d64460d14870c08c81352a05dedd3465940a7c")
118+
119+
tests := []testCase{
120+
{
121+
name: "client id and client_secret",
122+
conf: newConf(ts.URL, false),
123+
},
124+
{
125+
name: "client id and client_assertion",
126+
conf: newConf(ts.URL, true),
127+
},
108128
}
109-
if tok.TokenType != "bearer" {
110-
t.Errorf("token type = %q; want %q", tok.TokenType, "bearer")
129+
130+
for _, tc := range tests {
131+
t.Run(tc.name, func(t *testing.T) {
132+
tok, err := tc.conf.Token(context.Background())
133+
if err != nil {
134+
t.Error(err)
135+
}
136+
if !tok.Valid() {
137+
t.Fatalf("token invalid. got: %#v", tok)
138+
}
139+
if tok.AccessToken != "90d64460d14870c08c81352a05dedd3465940a7c" {
140+
t.Errorf("Access token = %q; want %q", tok.AccessToken, "90d64460d14870c08c81352a05dedd3465940a7c")
141+
}
142+
if tok.TokenType != "bearer" {
143+
t.Errorf("token type = %q; want %q", tok.TokenType, "bearer")
144+
}
145+
})
111146
}
112147
}
113148

@@ -132,7 +167,7 @@ func TestTokenRefreshRequest(t *testing.T) {
132167
io.WriteString(w, `{"access_token": "foo", "refresh_token": "bar"}`)
133168
}))
134169
defer ts.Close()
135-
conf := newConf(ts.URL)
170+
conf := newConf(ts.URL, false)
136171
c := conf.Client(context.Background())
137172
c.Get(ts.URL + "/somethingelse")
138173
}

0 commit comments

Comments
 (0)