Skip to content

feat: add bigquery_client as a parameter for read_gbq and to_gbq #878

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 41 additions & 2 deletions pandas_gbq/gbq.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,7 @@ def __init__(
client_secret=None,
user_agent=None,
rfc9110_delimiter=False,
bigquery_client=None,
):
global context
from google.api_core.exceptions import ClientError, GoogleAPIError
Expand All @@ -288,6 +289,14 @@ def __init__(
self.client_secret = client_secret
self.user_agent = user_agent
self.rfc9110_delimiter = rfc9110_delimiter
self.use_bqstorage_api = use_bqstorage_api

if bigquery_client is not None:
# If a bq client is already provided, use it to populate auth fields.
self.project_id = bigquery_client.project
self.credentials = bigquery_client._credentials
self.client = bigquery_client
return

default_project = None

Expand Down Expand Up @@ -325,8 +334,9 @@ def __init__(
if context.project is None:
context.project = self.project_id

self.client = self.get_client()
self.use_bqstorage_api = use_bqstorage_api
self.client = _get_client(
self.user_agent, self.rfc9110_delimiter, self.project_id, self.credentials
)

def _start_timer(self):
self.start = time.time()
Expand Down Expand Up @@ -702,6 +712,7 @@ def read_gbq(
client_secret=None,
*,
col_order=None,
bigquery_client=None,
):
r"""Read data from Google BigQuery to a pandas DataFrame.

Expand Down Expand Up @@ -849,6 +860,9 @@ def read_gbq(
the user is attempting to connect to.
col_order : list(str), optional
Alias for columns, retained for backwards compatibility.
bigquery_client : google.cloud.bigquery.Client, optional
A Google Cloud BigQuery Python Client instance. If provided, it will be used for reading
data, while the project and credentials parameters will be ignored.

Returns
-------
Expand Down Expand Up @@ -900,6 +914,7 @@ def read_gbq(
auth_redirect_uri=auth_redirect_uri,
client_id=client_id,
client_secret=client_secret,
bigquery_client=bigquery_client,
)

if _is_query(query_or_table):
Expand Down Expand Up @@ -971,6 +986,7 @@ def to_gbq(
client_secret=None,
user_agent=None,
rfc9110_delimiter=False,
bigquery_client=None,
):
"""Write a DataFrame to a Google BigQuery table.

Expand Down Expand Up @@ -1087,6 +1103,9 @@ def to_gbq(
rfc9110_delimiter : bool
Sets user agent delimiter to a hyphen or a slash.
Default is False, meaning a hyphen will be used.
bigquery_client : google.cloud.bigquery.Client, optional
A Google Cloud BigQuery Python Client instance. If provided, it will be used for reading
data, while the project, user_agent, and credentials parameters will be ignored.

.. versionadded:: 0.23.3
"""
Expand Down Expand Up @@ -1157,6 +1176,7 @@ def to_gbq(
client_secret=client_secret,
user_agent=user_agent,
rfc9110_delimiter=rfc9110_delimiter,
bigquery_client=bigquery_client,
)
bqclient = connector.client

Expand Down Expand Up @@ -1492,3 +1512,22 @@ def create_user_agent(
user_agent = f"{user_agent} {identity}"

return user_agent


def _get_client(user_agent, rfc9110_delimiter, project_id, credentials):
import google.api_core.client_info

bigquery = FEATURES.bigquery_try_import()

user_agent = create_user_agent(
user_agent=user_agent, rfc9110_delimiter=rfc9110_delimiter
)

client_info = google.api_core.client_info.ClientInfo(
user_agent=user_agent,
)
return bigquery.Client(
project=project_id,
credentials=credentials,
client_info=client_info,
)
14 changes: 14 additions & 0 deletions tests/system/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,13 @@ def to_gbq(credentials, project_id):
)


@pytest.fixture
def to_gbq_with_bq_client(bigquery_client):
import pandas_gbq

return functools.partial(pandas_gbq.to_gbq, bigquery_client=bigquery_client)


@pytest.fixture
def read_gbq(credentials, project_id):
import pandas_gbq
Expand All @@ -63,6 +70,13 @@ def read_gbq(credentials, project_id):
)


@pytest.fixture
def read_gbq_with_bq_client(bigquery_client):
import pandas_gbq

return functools.partial(pandas_gbq.read_gbq, bigquery_client=bigquery_client)


@pytest.fixture()
def random_dataset_id(bigquery_client: bigquery.Client, project_id: str):
dataset_id = prefixer.create_prefix()
Expand Down
10 changes: 10 additions & 0 deletions tests/system/test_gbq.py
Original file line number Diff line number Diff line change
Expand Up @@ -1398,3 +1398,13 @@ def test_to_gbq_does_not_override_mode(gbq_table, gbq_connector):
)

assert verify_schema(gbq_connector, gbq_table.dataset_id, table_id, table_schema)


def test_gbqconnector_init_with_bq_client(bigquery_client):
gbq_connector = gbq.GbqConnector(
project_id="project_id", credentials=None, bigquery_client=bigquery_client
)

assert gbq_connector.project_id == bigquery_client.project
assert gbq_connector.credentials is bigquery_client._credentials
assert gbq_connector.client is bigquery_client
11 changes: 11 additions & 0 deletions tests/system/test_read_gbq.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,3 +659,14 @@ def test_dml_query(read_gbq, writable_table: str):
"""
result = read_gbq(query)
assert result is not None


def test_read_gbq_with_bq_client(read_gbq_with_bq_client):
query = "SELECT * FROM UNNEST([1, 2, 3]) AS numbers"

actual_result = read_gbq_with_bq_client(query)

expected_result = pandas.DataFrame(
{"numbers": pandas.Series([1, 2, 3], dtype="Int64")}
)
pandas.testing.assert_frame_equal(actual_result, expected_result)
14 changes: 14 additions & 0 deletions tests/system/test_to_gbq.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,3 +615,17 @@ def test_dataframe_round_trip_with_table_schema(
pandas.testing.assert_frame_equal(
expected_df.set_index("row_num").sort_index(), round_trip
)


def test_dataframe_round_trip_with_bq_client(
to_gbq_with_bq_client, read_gbq_with_bq_client, random_dataset_id
):
table_id = (
f"{random_dataset_id}.round_trip_w_bq_client_{random.randrange(1_000_000)}"
)
df = pandas.DataFrame({"numbers": pandas.Series([1, 2, 3], dtype="Int64")})

to_gbq_with_bq_client(df, table_id)
result = read_gbq_with_bq_client(table_id)

pandas.testing.assert_frame_equal(result, df)