Skip to content

Commit bb45db8

Browse files
authored
feat: add bpd.options.bigquery.requests_transport_adapters option (#1755)
* feat: add `bpd.options.bigquery.requests_transport_adapters` option This allows for overriding requests-based settings such as the maximum connection pool size. * add unit test
1 parent bd07e05 commit bb45db8

File tree

7 files changed

+95
-13
lines changed

7 files changed

+95
-13
lines changed

bigframes/_config/bigquery_options.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,11 @@
1616

1717
from __future__ import annotations
1818

19-
from typing import Literal, Optional
19+
from typing import Literal, Optional, Sequence, Tuple
2020
import warnings
2121

2222
import google.auth.credentials
23+
import requests.adapters
2324

2425
import bigframes.enums
2526
import bigframes.exceptions as bfe
@@ -90,6 +91,9 @@ def __init__(
9091
allow_large_results: bool = False,
9192
ordering_mode: Literal["strict", "partial"] = "strict",
9293
client_endpoints_override: Optional[dict] = None,
94+
requests_transport_adapters: Sequence[
95+
Tuple[str, requests.adapters.BaseAdapter]
96+
] = (),
9397
):
9498
self._credentials = credentials
9599
self._project = project
@@ -100,6 +104,7 @@ def __init__(
100104
self._kms_key_name = kms_key_name
101105
self._skip_bq_connection_check = skip_bq_connection_check
102106
self._allow_large_results = allow_large_results
107+
self._requests_transport_adapters = requests_transport_adapters
103108
self._session_started = False
104109
# Determines the ordering strictness for the session.
105110
self._ordering_mode = _validate_ordering_mode(ordering_mode)
@@ -379,3 +384,43 @@ def client_endpoints_override(self, value: dict):
379384
)
380385

381386
self._client_endpoints_override = value
387+
388+
@property
389+
def requests_transport_adapters(
390+
self,
391+
) -> Sequence[Tuple[str, requests.adapters.BaseAdapter]]:
392+
"""Transport adapters for requests-based REST clients such as the
393+
google-cloud-bigquery package.
394+
395+
For more details, see the explanation in `requests guide to transport
396+
adapters
397+
<https://requests.readthedocs.io/en/latest/user/advanced/#transport-adapters>`_.
398+
399+
**Examples:**
400+
401+
Increase the connection pool size using the requests `HTTPAdapter
402+
<https://requests.readthedocs.io/en/latest/api/#requests.adapters.HTTPAdapter>`_.
403+
404+
>>> import bigframes.pandas as bpd
405+
>>> bpd.options.bigquery.requests_transport_adapters = (
406+
... ("http://", requests.adapters.HTTPAdapter(pool_maxsize=100)),
407+
... ("https://", requests.adapters.HTTPAdapter(pool_maxsize=100)),
408+
... ) # doctest: +SKIP
409+
410+
Returns:
411+
Sequence[Tuple[str, requests.adapters.BaseAdapter]]:
412+
Prefixes and corresponding transport adapters to `mount
413+
<https://requests.readthedocs.io/en/latest/api/#requests.Session.mount>`_
414+
in requests-based REST clients.
415+
"""
416+
return self._requests_transport_adapters
417+
418+
@requests_transport_adapters.setter
419+
def requests_transport_adapters(
420+
self, value: Sequence[Tuple[str, requests.adapters.BaseAdapter]]
421+
) -> None:
422+
if self._session_started and self._requests_transport_adapters != value:
423+
raise ValueError(
424+
SESSION_STARTED_MESSAGE.format(attribute="requests_transport_adapters")
425+
)
426+
self._requests_transport_adapters = value

bigframes/pandas/io/api.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -496,6 +496,7 @@ def _set_default_session_location_if_possible(query):
496496
application_name=config.options.bigquery.application_name,
497497
bq_kms_key_name=config.options.bigquery.kms_key_name,
498498
client_endpoints_override=config.options.bigquery.client_endpoints_override,
499+
requests_transport_adapters=config.options.bigquery.requests_transport_adapters,
499500
)
500501

501502
bqclient = clients_provider.bqclient

bigframes/session/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,7 @@ def __init__(
172172
application_name=context.application_name,
173173
bq_kms_key_name=self._bq_kms_key_name,
174174
client_endpoints_override=context.client_endpoints_override,
175+
requests_transport_adapters=context.requests_transport_adapters,
175176
)
176177

177178
# TODO(shobs): Remove this logic after https://github.com/ibis-project/ibis/issues/8494

bigframes/session/clients.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,18 +17,20 @@
1717
import os
1818
import threading
1919
import typing
20-
from typing import Optional
20+
from typing import Optional, Sequence, Tuple
2121

2222
import google.api_core.client_info
2323
import google.api_core.client_options
2424
import google.api_core.gapic_v1.client_info
2525
import google.auth.credentials
26+
import google.auth.transport.requests
2627
import google.cloud.bigquery as bigquery
2728
import google.cloud.bigquery_connection_v1
2829
import google.cloud.bigquery_storage_v1
2930
import google.cloud.functions_v2
3031
import google.cloud.resourcemanager_v3
3132
import pydata_google_auth
33+
import requests
3234

3335
import bigframes.constants
3436
import bigframes.version
@@ -79,6 +81,10 @@ def __init__(
7981
application_name: Optional[str] = None,
8082
bq_kms_key_name: Optional[str] = None,
8183
client_endpoints_override: dict = {},
84+
*,
85+
requests_transport_adapters: Sequence[
86+
Tuple[str, requests.adapters.BaseAdapter]
87+
] = (),
8288
):
8389
credentials_project = None
8490
if credentials is None:
@@ -124,6 +130,7 @@ def __init__(
124130
)
125131
self._location = location
126132
self._use_regional_endpoints = use_regional_endpoints
133+
self._requests_transport_adapters = requests_transport_adapters
127134

128135
self._credentials = credentials
129136
self._bq_kms_key_name = bq_kms_key_name
@@ -173,12 +180,21 @@ def _create_bigquery_client(self):
173180
user_agent=self._application_name
174181
)
175182

183+
requests_session = google.auth.transport.requests.AuthorizedSession(
184+
self._credentials
185+
)
186+
for prefix, adapter in self._requests_transport_adapters:
187+
requests_session.mount(prefix, adapter)
188+
176189
bq_client = bigquery.Client(
177190
client_info=bq_info,
178191
client_options=bq_options,
179-
credentials=self._credentials,
180192
project=self._project,
181193
location=self._location,
194+
# Instead of credentials, use _http so that users can override
195+
# requests options with transport adapters. See internal issue
196+
# b/419106112.
197+
_http=requests_session,
182198
)
183199

184200
# If a new enough client library is available, we opt-in to the faster

tests/system/small/test_pandas_options.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -279,16 +279,18 @@ def test_credentials_need_reauthentication(
279279
# Call get_global_session() *after* read_gbq so that our location detection
280280
# has a chance to work.
281281
session = bpd.get_global_session()
282-
assert session.bqclient._credentials.valid
282+
assert session.bqclient._http.credentials.valid
283283

284284
with monkeypatch.context() as m:
285285
# Simulate expired credentials to trigger the credential refresh flow
286-
m.setattr(session.bqclient._credentials, "expiry", datetime.datetime.utcnow())
287-
assert not session.bqclient._credentials.valid
286+
m.setattr(
287+
session.bqclient._http.credentials, "expiry", datetime.datetime.utcnow()
288+
)
289+
assert not session.bqclient._http.credentials.valid
288290

289291
# Simulate an exception during the credential refresh flow
290292
m.setattr(
291-
session.bqclient._credentials,
293+
session.bqclient._http.credentials,
292294
"refresh",
293295
mock.Mock(side_effect=google.auth.exceptions.RefreshError()),
294296
)

tests/unit/_config/test_bigquery_options.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
("skip_bq_connection_check", False, True),
3939
("client_endpoints_override", {}, {"bqclient": "endpoint_address"}),
4040
("ordering_mode", "strict", "partial"),
41+
("requests_transport_adapters", object(), object()),
4142
],
4243
)
4344
def test_setter_raises_if_session_started(attribute, original_value, new_value):

tests/unit/session/test_clients.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,25 +15,22 @@
1515
import os
1616
import pathlib
1717
import tempfile
18-
from typing import Optional
18+
from typing import cast, Optional
1919
import unittest.mock as mock
2020

21-
import google.api_core.client_info
22-
import google.api_core.client_options
23-
import google.api_core.exceptions
24-
import google.api_core.gapic_v1.client_info
2521
import google.auth.credentials
2622
import google.cloud.bigquery
2723
import google.cloud.bigquery_connection_v1
2824
import google.cloud.bigquery_storage_v1
2925
import google.cloud.functions_v2
3026
import google.cloud.resourcemanager_v3
27+
import requests.adapters
3128

3229
import bigframes.session.clients as clients
3330
import bigframes.version
3431

3532

36-
def create_clients_provider(application_name: Optional[str] = None):
33+
def create_clients_provider(application_name: Optional[str] = None, **kwargs):
3734
credentials = mock.create_autospec(google.auth.credentials.Credentials)
3835
return clients.ClientsProvider(
3936
project="test-project",
@@ -42,6 +39,7 @@ def create_clients_provider(application_name: Optional[str] = None):
4239
credentials=credentials,
4340
application_name=application_name,
4441
bq_kms_key_name="projects/my-project/locations/us/keyRings/myKeyRing/cryptoKeys/myKey",
42+
**kwargs,
4543
)
4644

4745

@@ -136,6 +134,24 @@ def assert_clients_wo_user_agent(
136134
)
137135

138136

137+
def test_requests_transport_adapters_pool_maxsize(monkeypatch):
138+
monkeypatch_client_constructors(monkeypatch)
139+
requests_transport_adapters = (
140+
("http://", requests.adapters.HTTPAdapter(pool_maxsize=123)),
141+
("https://", requests.adapters.HTTPAdapter(pool_maxsize=123)),
142+
) # doctest: +SKIP
143+
provider = create_clients_provider(
144+
requests_transport_adapters=requests_transport_adapters
145+
)
146+
147+
_, kwargs = cast(mock.Mock, provider.bqclient).call_args
148+
requests_session = kwargs.get("_http")
149+
adapter: requests.adapters.HTTPAdapter = requests_session.get_adapter(
150+
"https://bigquery.googleapis.com/"
151+
)
152+
assert adapter._pool_maxsize == 123 # type: ignore
153+
154+
139155
def test_user_agent_default(monkeypatch):
140156
monkeypatch_client_constructors(monkeypatch)
141157
provider = create_clients_provider(application_name=None)

0 commit comments

Comments
 (0)