17
17
import dataclasses
18
18
import math
19
19
import os
20
+ import threading
20
21
from typing import cast , Literal , Mapping , Optional , Sequence , Tuple , Union
21
22
import warnings
22
23
import weakref
27
28
import google .cloud .bigquery .table as bq_table
28
29
import google .cloud .bigquery_storage_v1
29
30
31
+ import bigframes .constants
30
32
import bigframes .core
31
- from bigframes .core import compile , rewrite
33
+ from bigframes .core import compile , local_data , rewrite
32
34
import bigframes .core .compile .sqlglot .sqlglot_ir as sqlglot_ir
33
35
import bigframes .core .guid
34
36
import bigframes .core .nodes as nodes
38
40
import bigframes .dtypes
39
41
import bigframes .exceptions as bfe
40
42
import bigframes .features
41
- from bigframes .session import executor , local_scan_executor , read_api_execution
43
+ from bigframes .session import executor , loader , local_scan_executor , read_api_execution
42
44
import bigframes .session ._io .bigquery as bq_io
43
45
import bigframes .session .metrics
44
46
import bigframes .session .planner
@@ -67,12 +69,19 @@ def _get_default_output_spec() -> OutputSpec:
67
69
)
68
70
69
71
72
+ SourceIdMapping = Mapping [str , str ]
73
+
74
+
70
75
class ExecutionCache :
71
76
def __init__ (self ):
72
77
# current assumption is only 1 cache of a given node
73
78
# in future, might have multiple caches, with different layout, localities
74
79
self ._cached_executions : weakref .WeakKeyDictionary [
75
- nodes .BigFrameNode , nodes .BigFrameNode
80
+ nodes .BigFrameNode , nodes .CachedTableNode
81
+ ] = weakref .WeakKeyDictionary ()
82
+ self ._uploaded_local_data : weakref .WeakKeyDictionary [
83
+ local_data .ManagedArrowTable ,
84
+ tuple [nodes .BigqueryDataSource , SourceIdMapping ],
76
85
] = weakref .WeakKeyDictionary ()
77
86
78
87
@property
@@ -105,6 +114,19 @@ def cache_results_table(
105
114
assert original_root .schema == cached_replacement .schema
106
115
self ._cached_executions [original_root ] = cached_replacement
107
116
117
+ def cache_remote_replacement (
118
+ self ,
119
+ local_data : local_data .ManagedArrowTable ,
120
+ bq_data : nodes .BigqueryDataSource ,
121
+ ):
122
+ # bq table has one extra column for offsets, those are implicit for local data
123
+ assert len (local_data .schema .items ) + 1 == len (bq_data .table .physical_schema )
124
+ mapping = {
125
+ local_data .schema .items [i ].column : bq_data .table .physical_schema [i ].name
126
+ for i in range (len (local_data .schema ))
127
+ }
128
+ self ._uploaded_local_data [local_data ] = (bq_data , mapping )
129
+
108
130
109
131
class BigQueryCachingExecutor (executor .Executor ):
110
132
"""Computes BigFrames values using BigQuery Engine.
@@ -120,6 +142,7 @@ def __init__(
120
142
bqclient : bigquery .Client ,
121
143
storage_manager : bigframes .session .temporary_storage .TemporaryStorageManager ,
122
144
bqstoragereadclient : google .cloud .bigquery_storage_v1 .BigQueryReadClient ,
145
+ loader : loader .GbqDataLoader ,
123
146
* ,
124
147
strictly_ordered : bool = True ,
125
148
metrics : Optional [bigframes .session .metrics .ExecutionMetrics ] = None ,
@@ -129,6 +152,7 @@ def __init__(
129
152
self .strictly_ordered : bool = strictly_ordered
130
153
self .cache : ExecutionCache = ExecutionCache ()
131
154
self .metrics = metrics
155
+ self .loader = loader
132
156
self .bqstoragereadclient = bqstoragereadclient
133
157
# Simple left-to-right precedence for now
134
158
self ._semi_executors = (
@@ -138,6 +162,7 @@ def __init__(
138
162
),
139
163
local_scan_executor .LocalScanExecutor (),
140
164
)
165
+ self ._upload_lock = threading .Lock ()
141
166
142
167
def to_sql (
143
168
self ,
@@ -149,6 +174,7 @@ def to_sql(
149
174
if offset_column :
150
175
array_value , _ = array_value .promote_offsets ()
151
176
node = self .logical_plan (array_value .node ) if enable_cache else array_value .node
177
+ node = self ._substitute_large_local_sources (node )
152
178
compiled = compile .compile_sql (compile .CompileRequest (node , sort_rows = ordered ))
153
179
return compiled .sql
154
180
@@ -402,6 +428,7 @@ def _cache_with_cluster_cols(
402
428
):
403
429
"""Executes the query and uses the resulting table to rewrite future executions."""
404
430
plan = self .logical_plan (array_value .node )
431
+ plan = self ._substitute_large_local_sources (plan )
405
432
compiled = compile .compile_sql (
406
433
compile .CompileRequest (
407
434
plan , sort_rows = False , materialize_all_order_keys = True
@@ -422,7 +449,7 @@ def _cache_with_offsets(self, array_value: bigframes.core.ArrayValue):
422
449
w_offsets , offset_column = array_value .promote_offsets ()
423
450
compiled = compile .compile_sql (
424
451
compile .CompileRequest (
425
- self .logical_plan (w_offsets .node ),
452
+ self .logical_plan (self . _substitute_large_local_sources ( w_offsets .node ) ),
426
453
sort_rows = False ,
427
454
)
428
455
)
@@ -532,6 +559,54 @@ def _validate_result_schema(
532
559
f"This error should only occur while testing. Ibis schema: { ibis_schema } does not match actual schema: { actual_schema } "
533
560
)
534
561
562
+ def _substitute_large_local_sources (self , original_root : nodes .BigFrameNode ):
563
+ """
564
+ Replace large local sources with the uploaded version of those datasources.
565
+ """
566
+ # Step 1: Upload all previously un-uploaded data
567
+ for leaf in original_root .unique_nodes ():
568
+ if isinstance (leaf , nodes .ReadLocalNode ):
569
+ if (
570
+ leaf .local_data_source .metadata .total_bytes
571
+ > bigframes .constants .MAX_INLINE_BYTES
572
+ ):
573
+ self ._upload_local_data (leaf .local_data_source )
574
+
575
+ # Step 2: Replace local scans with remote scans
576
+ def map_local_scans (node : nodes .BigFrameNode ):
577
+ if not isinstance (node , nodes .ReadLocalNode ):
578
+ return node
579
+ if node .local_data_source not in self .cache ._uploaded_local_data :
580
+ return node
581
+ bq_source , source_mapping = self .cache ._uploaded_local_data [
582
+ node .local_data_source
583
+ ]
584
+ scan_list = node .scan_list .remap_source_ids (source_mapping )
585
+ # offsets_col isn't part of ReadTableNode, so emulate by adding to end of scan_list
586
+ if node .offsets_col is not None :
587
+ # Offsets are always implicitly the final column of uploaded data
588
+ # See: Loader.load_data
589
+ scan_list = scan_list .append (
590
+ bq_source .table .physical_schema [- 1 ].name ,
591
+ bigframes .dtypes .INT_DTYPE ,
592
+ node .offsets_col ,
593
+ )
594
+ return nodes .ReadTableNode (bq_source , scan_list , node .session )
595
+
596
+ return original_root .bottom_up (map_local_scans )
597
+
598
+ def _upload_local_data (self , local_table : local_data .ManagedArrowTable ):
599
+ if local_table in self .cache ._uploaded_local_data :
600
+ return
601
+ # Lock prevents concurrent repeated work, but slows things down.
602
+ # Might be better as a queue and a worker thread
603
+ with self ._upload_lock :
604
+ if local_table not in self .cache ._uploaded_local_data :
605
+ uploaded = self .loader .load_data (
606
+ local_table , bigframes .core .guid .generate_guid ()
607
+ )
608
+ self .cache .cache_remote_replacement (local_table , uploaded )
609
+
535
610
def _execute_plan (
536
611
self ,
537
612
plan : nodes .BigFrameNode ,
@@ -562,6 +637,8 @@ def _execute_plan(
562
637
# Use explicit destination to avoid 10GB limit of temporary table
563
638
if destination_table is not None :
564
639
job_config .destination = destination_table
640
+
641
+ plan = self ._substitute_large_local_sources (plan )
565
642
compiled = compile .compile_sql (
566
643
compile .CompileRequest (plan , sort_rows = ordered , peek_count = peek )
567
644
)
0 commit comments