Skip to content

Commit 6b872e8

Browse files
committed
Do not fail creating the provisioner HTTP client
This commit avoids an error starting the CA if the `http.DefaultTransport` is not an `*http.Transport`. If the DefaultTransport is overwritten, the newHTTPClient method will return a simple *http.Client. With an *http.Transport, it will return a client that trusts the system certificate pool and the CA roots.
1 parent 03c4b18 commit 6b872e8

File tree

2 files changed

+35
-20
lines changed

2 files changed

+35
-20
lines changed

authority/http_client.go

+20-20
Original file line numberDiff line numberDiff line change
@@ -7,28 +7,28 @@ import (
77
"net/http"
88
)
99

10-
// newHTTPClient returns an HTTP client that trusts the system cert pool and the
11-
// given roots.
10+
// newHTTPClient will return an HTTP client that trusts the system cert pool and
11+
// the given roots, but only if the http.DefaultTransport is an *http.Transport.
12+
// If not, it will return the default HTTP client.
1213
func newHTTPClient(roots ...*x509.Certificate) (*http.Client, error) {
13-
pool, err := x509.SystemCertPool()
14-
if err != nil {
15-
return nil, fmt.Errorf("error initializing http client: %w", err)
16-
}
17-
for _, crt := range roots {
18-
pool.AddCert(crt)
19-
}
14+
if tr, ok := http.DefaultTransport.(*http.Transport); ok {
15+
pool, err := x509.SystemCertPool()
16+
if err != nil {
17+
return nil, fmt.Errorf("error initializing http client: %w", err)
18+
}
19+
for _, crt := range roots {
20+
pool.AddCert(crt)
21+
}
2022

21-
tr, ok := http.DefaultTransport.(*http.Transport)
22-
if !ok {
23-
return nil, fmt.Errorf("error initializing http client: type is not *http.Transport")
24-
}
25-
tr = tr.Clone()
26-
tr.TLSClientConfig = &tls.Config{
27-
MinVersion: tls.VersionTLS12,
28-
RootCAs: pool,
23+
tr = tr.Clone()
24+
tr.TLSClientConfig = &tls.Config{
25+
MinVersion: tls.VersionTLS12,
26+
RootCAs: pool,
27+
}
28+
return &http.Client{
29+
Transport: tr,
30+
}, nil
2931
}
3032

31-
return &http.Client{
32-
Transport: tr,
33-
}, nil
33+
return &http.Client{}, nil
3434
}

authority/http_client_test.go

+15
Original file line numberDiff line numberDiff line change
@@ -102,4 +102,19 @@ func Test_newHTTPClient(t *testing.T) {
102102
assert.Error(t, err)
103103
})
104104
})
105+
106+
t.Run("custom transport", func(t *testing.T) {
107+
tmp := http.DefaultTransport
108+
t.Cleanup(func() {
109+
http.DefaultTransport = tmp
110+
})
111+
transport := struct {
112+
http.RoundTripper
113+
}{http.DefaultTransport}
114+
http.DefaultTransport = transport
115+
116+
client, err := newHTTPClient(auth.rootX509Certs...)
117+
assert.NoError(t, err)
118+
assert.Equal(t, &http.Client{}, client)
119+
})
105120
}

0 commit comments

Comments
 (0)