Skip to content

Commit d6a25ad

Browse files
committed
GODRIVER-3215 Fix default auth source for auth specified via ClientOptions.
1 parent a766876 commit d6a25ad

File tree

12 files changed

+194
-185
lines changed

12 files changed

+194
-185
lines changed

mongo/client.go

Lines changed: 4 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@ import (
2626
"go.mongodb.org/mongo-driver/mongo/writeconcern"
2727
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
2828
"go.mongodb.org/mongo-driver/x/mongo/driver"
29-
"go.mongodb.org/mongo-driver/x/mongo/driver/auth"
3029
"go.mongodb.org/mongo-driver/x/mongo/driver/mongocrypt"
3130
mcopts "go.mongodb.org/mongo-driver/x/mongo/driver/mongocrypt/options"
3231
"go.mongodb.org/mongo-driver/x/mongo/driver/operation"
@@ -211,43 +210,16 @@ func NewClient(opts ...*options.ClientOptions) (*Client, error) {
211210
clientOpt.SetMaxPoolSize(defaultMaxPoolSize)
212211
}
213212

214-
if clientOpt.Auth != nil {
215-
var oidcMachineCallback auth.OIDCCallback
216-
if clientOpt.Auth.OIDCMachineCallback != nil {
217-
oidcMachineCallback = func(ctx context.Context, args *driver.OIDCArgs) (*driver.OIDCCredential, error) {
218-
cred, err := clientOpt.Auth.OIDCMachineCallback(ctx, convertOIDCArgs(args))
219-
return (*driver.OIDCCredential)(cred), err
220-
}
221-
}
222-
223-
var oidcHumanCallback auth.OIDCCallback
224-
if clientOpt.Auth.OIDCHumanCallback != nil {
225-
oidcHumanCallback = func(ctx context.Context, args *driver.OIDCArgs) (*driver.OIDCCredential, error) {
226-
cred, err := clientOpt.Auth.OIDCHumanCallback(ctx, convertOIDCArgs(args))
227-
return (*driver.OIDCCredential)(cred), err
228-
}
229-
}
230-
231-
// Create an authenticator for the client
232-
client.authenticator, err = auth.CreateAuthenticator(clientOpt.Auth.AuthMechanism, &auth.Cred{
233-
Source: clientOpt.Auth.AuthSource,
234-
Username: clientOpt.Auth.Username,
235-
Password: clientOpt.Auth.Password,
236-
PasswordSet: clientOpt.Auth.PasswordSet,
237-
Props: clientOpt.Auth.AuthMechanismProperties,
238-
OIDCMachineCallback: oidcMachineCallback,
239-
OIDCHumanCallback: oidcHumanCallback,
240-
}, clientOpt.HTTPClient)
241-
if err != nil {
242-
return nil, err
243-
}
213+
client.authenticator, err = topology.NewAuthenticator(clientOpt.Auth, clientOpt.HTTPClient)
214+
if err != nil {
215+
return nil, fmt.Errorf("error creating authenticator: %w", err)
244216
}
245217

246218
cfg, err := topology.NewConfigWithAuthenticator(clientOpt, client.clock, client.authenticator)
247-
248219
if err != nil {
249220
return nil, err
250221
}
222+
251223
client.serverAPI = topology.ServerAPIFromServerOptions(cfg.ServerOpts)
252224

253225
if client.deployment == nil {
@@ -266,19 +238,6 @@ func NewClient(opts ...*options.ClientOptions) (*Client, error) {
266238
return client, nil
267239
}
268240

269-
// convertOIDCArgs converts the internal *driver.OIDCArgs into the equivalent
270-
// public type *options.OIDCArgs.
271-
func convertOIDCArgs(args *driver.OIDCArgs) *options.OIDCArgs {
272-
if args == nil {
273-
return nil
274-
}
275-
return &options.OIDCArgs{
276-
Version: args.Version,
277-
IDPInfo: (*options.IDPInfo)(args.IDPInfo),
278-
RefreshToken: args.RefreshToken,
279-
}
280-
}
281-
282241
// Connect initializes the Client by starting background monitoring goroutines.
283242
// If the Client was created using the NewClient function, this method must be called before a Client can be used.
284243
//

mongo/client_test.go

Lines changed: 0 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -11,21 +11,18 @@ import (
1111
"errors"
1212
"math"
1313
"os"
14-
"reflect"
1514
"testing"
1615
"time"
1716

1817
"go.mongodb.org/mongo-driver/bson"
1918
"go.mongodb.org/mongo-driver/event"
2019
"go.mongodb.org/mongo-driver/internal/assert"
2120
"go.mongodb.org/mongo-driver/internal/integtest"
22-
"go.mongodb.org/mongo-driver/internal/require"
2321
"go.mongodb.org/mongo-driver/mongo/options"
2422
"go.mongodb.org/mongo-driver/mongo/readconcern"
2523
"go.mongodb.org/mongo-driver/mongo/readpref"
2624
"go.mongodb.org/mongo-driver/mongo/writeconcern"
2725
"go.mongodb.org/mongo-driver/tag"
28-
"go.mongodb.org/mongo-driver/x/mongo/driver"
2926
"go.mongodb.org/mongo-driver/x/mongo/driver/mongocrypt"
3027
"go.mongodb.org/mongo-driver/x/mongo/driver/session"
3128
"go.mongodb.org/mongo-driver/x/mongo/driver/topology"
@@ -505,76 +502,3 @@ func TestClient(t *testing.T) {
505502
}
506503
})
507504
}
508-
509-
// Test that convertOIDCArgs exhaustively copies all fields of a driver.OIDCArgs
510-
// into an options.OIDCArgs.
511-
func TestConvertOIDCArgs(t *testing.T) {
512-
refreshToken := "test refresh token"
513-
514-
testCases := []struct {
515-
desc string
516-
args *driver.OIDCArgs
517-
}{
518-
{
519-
desc: "populated args",
520-
args: &driver.OIDCArgs{
521-
Version: 9,
522-
IDPInfo: &driver.IDPInfo{
523-
Issuer: "test issuer",
524-
ClientID: "test client ID",
525-
RequestScopes: []string{"test scope 1", "test scope 2"},
526-
},
527-
RefreshToken: &refreshToken,
528-
},
529-
},
530-
{
531-
desc: "nil",
532-
args: nil,
533-
},
534-
{
535-
desc: "nil IDPInfo and RefreshToken",
536-
args: &driver.OIDCArgs{
537-
Version: 9,
538-
IDPInfo: nil,
539-
RefreshToken: nil,
540-
},
541-
},
542-
}
543-
544-
for _, tc := range testCases {
545-
tc := tc // Capture range variable.
546-
547-
t.Run(tc.desc, func(t *testing.T) {
548-
t.Parallel()
549-
550-
got := convertOIDCArgs(tc.args)
551-
552-
if tc.args == nil {
553-
assert.Nil(t, got, "expected nil when input is nil")
554-
return
555-
}
556-
557-
require.Equal(t,
558-
3,
559-
reflect.ValueOf(*tc.args).NumField(),
560-
"expected the driver.OIDCArgs struct to have exactly 3 fields")
561-
require.Equal(t,
562-
3,
563-
reflect.ValueOf(*got).NumField(),
564-
"expected the options.OIDCArgs struct to have exactly 3 fields")
565-
566-
assert.Equal(t,
567-
tc.args.Version,
568-
got.Version,
569-
"expected Version field to be equal")
570-
assert.EqualValues(t,
571-
tc.args.IDPInfo,
572-
got.IDPInfo,
573-
"expected IDPInfo field to be convertible to equal values")
574-
assert.Equal(t,
575-
tc.args.RefreshToken,
576-
got.RefreshToken,
577-
"expected RefreshToken field to be equal")
578-
})
579-
}
580-
}

x/mongo/driver/auth/mongodbaws.go

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,8 @@ func newMongoDBAWSAuthenticator(cred *Cred, httpClient *http.Client) (Authentica
2828
return nil, errors.New("httpClient must not be nil")
2929
}
3030
return &MongoDBAWSAuthenticator{
31-
source: cred.Source,
3231
credentials: &credproviders.StaticProvider{
3332
Value: credentials.Value{
34-
ProviderName: cred.Source,
3533
AccessKeyID: cred.Username,
3634
SecretAccessKey: cred.Password,
3735
SessionToken: cred.Props["AWS_SESSION_TOKEN"],
@@ -43,7 +41,6 @@ func newMongoDBAWSAuthenticator(cred *Cred, httpClient *http.Client) (Authentica
4341

4442
// MongoDBAWSAuthenticator uses AWS-IAM credentials over SASL to authenticate a connection.
4543
type MongoDBAWSAuthenticator struct {
46-
source string
4744
credentials *credproviders.StaticProvider
4845
httpClient *http.Client
4946
}
@@ -56,7 +53,7 @@ func (a *MongoDBAWSAuthenticator) Auth(ctx context.Context, cfg *Config) error {
5653
credentials: providers.Cred,
5754
},
5855
}
59-
err := ConductSaslConversation(ctx, cfg, a.source, adapter)
56+
err := ConductSaslConversation(ctx, cfg, "$external", adapter)
6057
if err != nil {
6158
return newAuthError("sasl conversation error", err)
6259
}

x/mongo/driver/auth/mongodbcr.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,12 @@ import (
3030
const MONGODBCR = "MONGODB-CR"
3131

3232
func newMongoDBCRAuthenticator(cred *Cred, _ *http.Client) (Authenticator, error) {
33+
source := cred.Source
34+
if source == "" {
35+
source = "admin"
36+
}
3337
return &MongoDBCRAuthenticator{
34-
DB: cred.Source,
38+
DB: source,
3539
Username: cred.Username,
3640
Password: cred.Password,
3741
}, nil

x/mongo/driver/auth/oidc.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,9 @@ func (oa *OIDCAuthenticator) SetAccessToken(accessToken string) {
109109
}
110110

111111
func newOIDCAuthenticator(cred *Cred, httpClient *http.Client) (Authenticator, error) {
112+
if cred.Source != "" && cred.Source != "$external" {
113+
return nil, newAuthError("MONGODB-OIDC source must be empty or $external", nil)
114+
}
112115
if cred.Password != "" {
113116
return nil, fmt.Errorf("password cannot be specified for %q", MongoDBOIDC)
114117
}

x/mongo/driver/auth/plain.go

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,21 +17,27 @@ import (
1717
const PLAIN = "PLAIN"
1818

1919
func newPlainAuthenticator(cred *Cred, _ *http.Client) (Authenticator, error) {
20+
source := cred.Source
21+
if source == "" {
22+
source = "$external"
23+
}
2024
return &PlainAuthenticator{
2125
Username: cred.Username,
2226
Password: cred.Password,
27+
Source: source,
2328
}, nil
2429
}
2530

2631
// PlainAuthenticator uses the PLAIN algorithm over SASL to authenticate a connection.
2732
type PlainAuthenticator struct {
2833
Username string
2934
Password string
35+
Source string
3036
}
3137

3238
// Auth authenticates the connection.
3339
func (a *PlainAuthenticator) Auth(ctx context.Context, cfg *Config) error {
34-
return ConductSaslConversation(ctx, cfg, "$external", &plainSaslClient{
40+
return ConductSaslConversation(ctx, cfg, a.Source, &plainSaslClient{
3541
username: a.Username,
3642
password: a.Password,
3743
})

x/mongo/driver/auth/plain_test.go

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,10 @@ package auth_test
88

99
import (
1010
"context"
11+
"encoding/base64"
1112
"strings"
1213
"testing"
1314

14-
"encoding/base64"
15-
1615
"go.mongodb.org/mongo-driver/internal/require"
1716
"go.mongodb.org/mongo-driver/mongo/description"
1817
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
@@ -26,6 +25,7 @@ func TestPlainAuthenticator_Fails(t *testing.T) {
2625
authenticator := PlainAuthenticator{
2726
Username: "user",
2827
Password: "pencil",
28+
Source: "$external",
2929
}
3030

3131
resps := make(chan []byte, 1)
@@ -65,6 +65,7 @@ func TestPlainAuthenticator_Extra_server_message(t *testing.T) {
6565
authenticator := PlainAuthenticator{
6666
Username: "user",
6767
Password: "pencil",
68+
Source: "$external",
6869
}
6970

7071
resps := make(chan []byte, 2)
@@ -108,6 +109,7 @@ func TestPlainAuthenticator_Succeeds(t *testing.T) {
108109
authenticator := PlainAuthenticator{
109110
Username: "user",
110111
Password: "pencil",
112+
Source: "$external",
111113
}
112114

113115
resps := make(chan []byte, 1)
@@ -153,6 +155,7 @@ func TestPlainAuthenticator_SucceedsBoolean(t *testing.T) {
153155
authenticator := PlainAuthenticator{
154156
Username: "user",
155157
Password: "pencil",
158+
Source: "$external",
156159
}
157160

158161
resps := make(chan []byte, 1)

x/mongo/driver/auth/scram.go

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,10 @@ var (
3838
)
3939

4040
func newScramSHA1Authenticator(cred *Cred, _ *http.Client) (Authenticator, error) {
41+
source := cred.Source
42+
if source == "" {
43+
source = "admin"
44+
}
4145
passdigest := mongoPasswordDigest(cred.Username, cred.Password)
4246
client, err := scram.SHA1.NewClientUnprepped(cred.Username, passdigest, "")
4347
if err != nil {
@@ -46,12 +50,16 @@ func newScramSHA1Authenticator(cred *Cred, _ *http.Client) (Authenticator, error
4650
client.WithMinIterations(4096)
4751
return &ScramAuthenticator{
4852
mechanism: SCRAMSHA1,
49-
source: cred.Source,
53+
source: source,
5054
client: client,
5155
}, nil
5256
}
5357

5458
func newScramSHA256Authenticator(cred *Cred, _ *http.Client) (Authenticator, error) {
59+
source := cred.Source
60+
if source == "" {
61+
source = "admin"
62+
}
5563
passprep, err := stringprep.SASLprep.Prepare(cred.Password)
5664
if err != nil {
5765
return nil, newAuthError("error SASLprepping password", err)
@@ -63,7 +71,7 @@ func newScramSHA256Authenticator(cred *Cred, _ *http.Client) (Authenticator, err
6371
client.WithMinIterations(4096)
6472
return &ScramAuthenticator{
6573
mechanism: SCRAMSHA256,
66-
source: cred.Source,
74+
source: source,
6775
client: client,
6876
}, nil
6977
}

x/mongo/driver/auth/x509.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@ import (
1919
const MongoDBX509 = "MONGODB-X509"
2020

2121
func newMongoDBX509Authenticator(cred *Cred, _ *http.Client) (Authenticator, error) {
22+
// TODO(GODRIVER-3309): Validate that cred.Source is either empty or
23+
// "$external" to make validation uniform with other auth mechanisms that
24+
// require Source to be "$external" (e.g. MONGODB-AWS, MONGODB-OIDC, etc).
2225
return &MongoDBX509Authenticator{User: cred.Username}, nil
2326
}
2427

x/mongo/driver/connstring/connstring.go

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -296,7 +296,7 @@ func (u *ConnString) setDefaultAuthParams(dbName string) error {
296296
u.AuthMechanismProperties["SERVICE_NAME"] = "mongodb"
297297
}
298298
fallthrough
299-
case "mongodb-aws", "mongodb-x509":
299+
case "mongodb-aws", "mongodb-x509", "mongodb-oidc":
300300
if u.AuthSource == "" {
301301
u.AuthSource = "$external"
302302
} else if u.AuthSource != "$external" {
@@ -313,13 +313,6 @@ func (u *ConnString) setDefaultAuthParams(dbName string) error {
313313
u.AuthSource = "admin"
314314
}
315315
}
316-
case "mongodb-oidc":
317-
if u.AuthSource == "" {
318-
u.AuthSource = dbName
319-
if u.AuthSource == "" {
320-
u.AuthSource = "$external"
321-
}
322-
}
323316
case "":
324317
// Only set auth source if there is a request for authentication via non-empty credentials.
325318
if u.AuthSource == "" && (u.AuthMechanismProperties != nil || u.Username != "" || u.PasswordSet) {

0 commit comments

Comments
 (0)