diff --git a/bigframes/core/array_value.py b/bigframes/core/array_value.py index a89d224ad1..4cd8b3cde4 100644 --- a/bigframes/core/array_value.py +++ b/bigframes/core/array_value.py @@ -23,7 +23,7 @@ import pandas import pyarrow as pa -from bigframes.core import agg_expressions +from bigframes.core import agg_expressions, bq_data import bigframes.core.expression as ex import bigframes.core.guid import bigframes.core.identifiers as ids @@ -63,7 +63,7 @@ def from_pyarrow(cls, arrow_table: pa.Table, session: Session): def from_managed(cls, source: local_data.ManagedArrowTable, session: Session): scan_list = nodes.ScanList( tuple( - nodes.ScanItem(ids.ColumnId(item.column), item.dtype, item.column) + nodes.ScanItem(ids.ColumnId(item.column), item.column) for item in source.schema.items ) ) @@ -100,7 +100,7 @@ def from_table( if offsets_col and primary_key: raise ValueError("must set at most one of 'offests', 'primary_key'") # define data source only for needed columns, this makes row-hashing cheaper - table_def = nodes.GbqTable.from_table(table, columns=schema.names) + table_def = bq_data.GbqTable.from_table(table, columns=schema.names) # create ordering from info ordering = None @@ -114,12 +114,13 @@ def from_table( # Scan all columns by default, we define this list as it can be pruned while preserving source_def scan_list = nodes.ScanList( tuple( - nodes.ScanItem(ids.ColumnId(item.column), item.dtype, item.column) + nodes.ScanItem(ids.ColumnId(item.column), item.column) for item in schema.items ) ) - source_def = nodes.BigqueryDataSource( + source_def = bq_data.BigqueryDataSource( table=table_def, + schema=schema, at_time=at_time, sql_predicate=predicate, ordering=ordering, @@ -130,7 +131,7 @@ def from_table( @classmethod def from_bq_data_source( cls, - source: nodes.BigqueryDataSource, + source: bq_data.BigqueryDataSource, scan_list: nodes.ScanList, session: Session, ): diff --git a/bigframes/core/blocks.py b/bigframes/core/blocks.py index f9896784bb..1c36922e6e 100644 --- a/bigframes/core/blocks.py +++ b/bigframes/core/blocks.py @@ -37,7 +37,6 @@ Optional, Sequence, Tuple, - TYPE_CHECKING, Union, ) import warnings @@ -70,9 +69,6 @@ from bigframes.session import dry_runs, execution_spec from bigframes.session import executor as executors -if TYPE_CHECKING: - from bigframes.session.executor import ExecuteResult - # Type constraint for wherever column labels are used Label = typing.Hashable @@ -98,7 +94,6 @@ LevelsType = typing.Union[LevelType, typing.Sequence[LevelType]] -@dataclasses.dataclass class PandasBatches(Iterator[pd.DataFrame]): """Interface for mutable objects with state represented by a block value object.""" @@ -271,10 +266,14 @@ def shape(self) -> typing.Tuple[int, int]: except Exception: pass - row_count = self.session._executor.execute( - self.expr.row_count(), - execution_spec.ExecutionSpec(promise_under_10gb=True, ordered=False), - ).to_py_scalar() + row_count = ( + self.session._executor.execute( + self.expr.row_count(), + execution_spec.ExecutionSpec(promise_under_10gb=True, ordered=False), + ) + .batches() + .to_py_scalar() + ) return (row_count, len(self.value_columns)) @property @@ -584,7 +583,7 @@ def to_arrow( ordered=ordered, ), ) - pa_table = execute_result.to_arrow_table() + pa_table = execute_result.batches().to_arrow_table() pa_index_labels = [] for index_level, index_label in enumerate(self._index_labels): @@ -636,17 +635,13 @@ def to_pandas( max_download_size, sampling_method, random_state ) - ex_result = self._materialize_local( + return self._materialize_local( materialize_options=MaterializationOptions( downsampling=sampling, allow_large_results=allow_large_results, ordered=ordered, ) ) - df = ex_result.to_pandas() - df = self._copy_index_to_pandas(df) - df.set_axis(self.column_labels, axis=1, copy=False) - return df, ex_result.query_job def _get_sampling_option( self, @@ -683,7 +678,7 @@ def try_peek( self.expr, execution_spec.ExecutionSpec(promise_under_10gb=under_10gb, peek=n), ) - df = result.to_pandas() + df = result.batches().to_pandas() return self._copy_index_to_pandas(df) else: return None @@ -704,13 +699,14 @@ def to_pandas_batches( if (allow_large_results is not None) else not bigframes.options._allow_large_results ) - execute_result = self.session._executor.execute( + execution_result = self.session._executor.execute( self.expr, execution_spec.ExecutionSpec( promise_under_10gb=under_10gb, ordered=True, ), ) + result_batches = execution_result.batches() # To reduce the number of edge cases to consider when working with the # results of this, always return at least one DataFrame. See: @@ -724,19 +720,21 @@ def to_pandas_batches( dfs = map( lambda a: a[0], itertools.zip_longest( - execute_result.to_pandas_batches(page_size, max_results), + result_batches.to_pandas_batches(page_size, max_results), [0], fillvalue=empty_val, ), ) dfs = iter(map(self._copy_index_to_pandas, dfs)) - total_rows = execute_result.total_rows + total_rows = result_batches.approx_total_rows if (total_rows is not None) and (max_results is not None): total_rows = min(total_rows, max_results) return PandasBatches( - dfs, total_rows, total_bytes_processed=execute_result.total_bytes_processed + dfs, + total_rows, + total_bytes_processed=execution_result.total_bytes_processed, ) def _copy_index_to_pandas(self, df: pd.DataFrame) -> pd.DataFrame: @@ -754,7 +752,7 @@ def _copy_index_to_pandas(self, df: pd.DataFrame) -> pd.DataFrame: def _materialize_local( self, materialize_options: MaterializationOptions = MaterializationOptions() - ) -> ExecuteResult: + ) -> tuple[pd.DataFrame, Optional[bigquery.QueryJob]]: """Run query and download results as a pandas DataFrame. Return the total number of results as well.""" # TODO(swast): Allow for dry run and timeout. under_10gb = ( @@ -769,9 +767,11 @@ def _materialize_local( ordered=materialize_options.ordered, ), ) + result_batches = execute_result.batches() + sample_config = materialize_options.downsampling - if execute_result.total_bytes is not None: - table_mb = execute_result.total_bytes / _BYTES_TO_MEGABYTES + if result_batches.approx_total_bytes is not None: + table_mb = result_batches.approx_total_bytes / _BYTES_TO_MEGABYTES max_download_size = sample_config.max_download_size fraction = ( max_download_size / table_mb @@ -792,7 +792,7 @@ def _materialize_local( # TODO: Maybe materialize before downsampling # Some downsampling methods - if fraction < 1 and (execute_result.total_rows is not None): + if fraction < 1 and (result_batches.approx_total_rows is not None): if not sample_config.enable_downsampling: raise RuntimeError( f"The data size ({table_mb:.2f} MB) exceeds the maximum download limit of " @@ -811,7 +811,7 @@ def _materialize_local( "the downloading limit." ) warnings.warn(msg, category=UserWarning) - total_rows = execute_result.total_rows + total_rows = result_batches.approx_total_rows # Remove downsampling config from subsequent invocations, as otherwise could result in many # iterations if downsampling undershoots return self._downsample( @@ -823,7 +823,10 @@ def _materialize_local( MaterializationOptions(ordered=materialize_options.ordered) ) else: - return execute_result + df = result_batches.to_pandas() + df = self._copy_index_to_pandas(df) + df.set_axis(self.column_labels, axis=1, copy=False) + return df, execute_result.query_job def _downsample( self, total_rows: int, sampling_method: str, fraction: float, random_state @@ -1662,15 +1665,19 @@ def retrieve_repr_request_results( ordered=True, ), ) - row_count = self.session._executor.execute( - self.expr.row_count(), - execution_spec.ExecutionSpec( - promise_under_10gb=True, - ordered=False, - ), - ).to_py_scalar() + row_count = ( + self.session._executor.execute( + self.expr.row_count(), + execution_spec.ExecutionSpec( + promise_under_10gb=True, + ordered=False, + ), + ) + .batches() + .to_py_scalar() + ) - head_df = head_result.to_pandas() + head_df = head_result.batches().to_pandas() return self._copy_index_to_pandas(head_df), row_count, head_result.query_job def promote_offsets(self, label: Label = None) -> typing.Tuple[Block, str]: diff --git a/bigframes/core/bq_data.py b/bigframes/core/bq_data.py new file mode 100644 index 0000000000..7c65c6b46a --- /dev/null +++ b/bigframes/core/bq_data.py @@ -0,0 +1,218 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import concurrent.futures +import dataclasses +import datetime +import functools +import os +import queue +import threading +import typing +from typing import Any, Iterator, Optional, Sequence, Tuple + +from google.cloud import bigquery_storage_v1 +import google.cloud.bigquery as bq +import google.cloud.bigquery_storage_v1.types as bq_storage_types +from google.protobuf import timestamp_pb2 +import pyarrow as pa + +from bigframes.core import pyarrow_utils +import bigframes.core.schema + +if typing.TYPE_CHECKING: + import bigframes.core.ordering as orderings + + +@dataclasses.dataclass(frozen=True) +class GbqTable: + project_id: str = dataclasses.field() + dataset_id: str = dataclasses.field() + table_id: str = dataclasses.field() + physical_schema: Tuple[bq.SchemaField, ...] = dataclasses.field() + is_physically_stored: bool = dataclasses.field() + cluster_cols: typing.Optional[Tuple[str, ...]] + + @staticmethod + def from_table(table: bq.Table, columns: Sequence[str] = ()) -> GbqTable: + # Subsetting fields with columns can reduce cost of row-hash default ordering + if columns: + schema = tuple(item for item in table.schema if item.name in columns) + else: + schema = tuple(table.schema) + return GbqTable( + project_id=table.project, + dataset_id=table.dataset_id, + table_id=table.table_id, + physical_schema=schema, + is_physically_stored=(table.table_type in ["TABLE", "MATERIALIZED_VIEW"]), + cluster_cols=None + if table.clustering_fields is None + else tuple(table.clustering_fields), + ) + + def get_table_ref(self) -> bq.TableReference: + return bq.TableReference( + bq.DatasetReference(self.project_id, self.dataset_id), self.table_id + ) + + @property + @functools.cache + def schema_by_id(self): + return {col.name: col for col in self.physical_schema} + + +@dataclasses.dataclass(frozen=True) +class BigqueryDataSource: + """ + Google BigQuery Data source. + + This should not be modified once defined, as all attributes contribute to the default ordering. + """ + + def __post_init__(self): + assert [field.name for field in self.table.physical_schema] == list( + self.schema.names + ) + + table: GbqTable + schema: bigframes.core.schema.ArraySchema + at_time: typing.Optional[datetime.datetime] = None + # Added for backwards compatibility, not validated + sql_predicate: typing.Optional[str] = None + ordering: typing.Optional[orderings.RowOrdering] = None + # Optimization field + n_rows: Optional[int] = None + + +_WORKER_TIME_INCREMENT = 0.05 + + +def _iter_stream( + stream_name: str, + storage_read_client: bigquery_storage_v1.BigQueryReadClient, + result_queue: queue.Queue, + stop_event: threading.Event, +): + reader = storage_read_client.read_rows(stream_name) + for page in reader.rows().pages: + try: + result_queue.put(page.to_arrow(), timeout=_WORKER_TIME_INCREMENT) + except queue.Full: + continue + if stop_event.is_set(): + return + + +def _iter_streams( + streams: Sequence[bq_storage_types.ReadStream], + storage_read_client: bigquery_storage_v1.BigQueryReadClient, +) -> Iterator[pa.RecordBatch]: + stop_event = threading.Event() + result_queue: queue.Queue = queue.Queue( + len(streams) + ) # each response is large, so small queue is appropriate + + in_progress: list[concurrent.futures.Future] = [] + with concurrent.futures.ThreadPoolExecutor(max_workers=len(streams)) as pool: + for stream in streams: + in_progress.append( + pool.submit( + _iter_stream, + stream.name, + storage_read_client, + result_queue, + stop_event, + ) + ) + + while in_progress: + try: + yield result_queue.get(timeout=0.1) + except queue.Empty: + new_in_progress = [] + for future in in_progress: + if future.done(): + try: + future.result() + except Exception: + stop_event.set() + raise + else: + new_in_progress.append(future) + in_progress = new_in_progress + + +@dataclasses.dataclass +class ReadResult: + iter: Iterator[pa.RecordBatch] + approx_rows: int + approx_bytes: int + + +def get_arrow_batches( + data: BigqueryDataSource, + columns: Sequence[str], + storage_read_client: bigquery_storage_v1.BigQueryReadClient, + project_id: str, +) -> ReadResult: + table_mod_options = {} + read_options_dict: dict[str, Any] = {"selected_fields": list(columns)} + if data.sql_predicate: + read_options_dict["row_restriction"] = data.sql_predicate + read_options = bq_storage_types.ReadSession.TableReadOptions(**read_options_dict) + + if data.at_time: + snapshot_time = timestamp_pb2.Timestamp() + snapshot_time.FromDatetime(data.at_time) + table_mod_options["snapshot_time"] = snapshot_time + table_mods = bq_storage_types.ReadSession.TableModifiers(**table_mod_options) + + requested_session = bq_storage_types.stream.ReadSession( + table=data.table.get_table_ref().to_bqstorage(), + data_format=bq_storage_types.DataFormat.ARROW, + read_options=read_options, + table_modifiers=table_mods, + ) + if data.ordering is not None: + max_streams = 1 + else: + max_streams = os.cpu_count() or 8 + + # Single stream to maintain ordering + request = bq_storage_types.CreateReadSessionRequest( + parent=f"projects/{project_id}", + read_session=requested_session, + max_stream_count=max_streams, + ) + + session = storage_read_client.create_read_session(request=request) + + if not session.streams: + batches: Iterator[pa.RecordBatch] = iter([]) + else: + batches = _iter_streams(session.streams, storage_read_client) + + def process_batch(pa_batch): + return pyarrow_utils.cast_batch( + pa_batch.select(columns), data.schema.select(columns).to_pyarrow() + ) + + batches = map(process_batch, batches) + + return ReadResult( + batches, session.estimated_row_count, session.estimated_total_bytes_scanned + ) diff --git a/bigframes/core/compile/ibis_compiler/ibis_compiler.py b/bigframes/core/compile/ibis_compiler/ibis_compiler.py index ff0441ea22..0436e05559 100644 --- a/bigframes/core/compile/ibis_compiler/ibis_compiler.py +++ b/bigframes/core/compile/ibis_compiler/ibis_compiler.py @@ -24,7 +24,7 @@ import bigframes_vendored.ibis.expr.types as ibis_types from bigframes import dtypes, operations -from bigframes.core import expression, pyarrow_utils +from bigframes.core import bq_data, expression, pyarrow_utils import bigframes.core.compile.compiled as compiled import bigframes.core.compile.concat as concat_impl import bigframes.core.compile.configs as configs @@ -186,7 +186,7 @@ def compile_readtable(node: nodes.ReadTableNode, *args): # TODO(b/395912450): Remove workaround solution once b/374784249 got resolved. for scan_item in node.scan_list.items: if ( - scan_item.dtype == dtypes.JSON_DTYPE + node.source.schema.get_type(scan_item.source_id) == dtypes.JSON_DTYPE and ibis_table[scan_item.source_id].type() == ibis_dtypes.string ): json_column = scalar_op_registry.parse_json( @@ -204,7 +204,7 @@ def compile_readtable(node: nodes.ReadTableNode, *args): def _table_to_ibis( - source: nodes.BigqueryDataSource, + source: bq_data.BigqueryDataSource, scan_cols: typing.Sequence[str], ) -> ibis_types.Table: full_table_name = ( diff --git a/bigframes/core/indexes/base.py b/bigframes/core/indexes/base.py index a258c01195..48186cc5ce 100644 --- a/bigframes/core/indexes/base.py +++ b/bigframes/core/indexes/base.py @@ -290,9 +290,13 @@ def get_loc(self, key) -> typing.Union[int, slice, "bigframes.series.Series"]: count_agg = ex_types.UnaryAggregation(agg_ops.count_op, ex.deref(offsets_id)) count_result = filtered_block._expr.aggregate([(count_agg, "count")]) - count_scalar = self._block.session._executor.execute( - count_result, ex_spec.ExecutionSpec(promise_under_10gb=True) - ).to_py_scalar() + count_scalar = ( + self._block.session._executor.execute( + count_result, ex_spec.ExecutionSpec(promise_under_10gb=True) + ) + .batches() + .to_py_scalar() + ) if count_scalar == 0: raise KeyError(f"'{key}' is not in index") @@ -301,9 +305,13 @@ def get_loc(self, key) -> typing.Union[int, slice, "bigframes.series.Series"]: if count_scalar == 1: min_agg = ex_types.UnaryAggregation(agg_ops.min_op, ex.deref(offsets_id)) position_result = filtered_block._expr.aggregate([(min_agg, "position")]) - position_scalar = self._block.session._executor.execute( - position_result, ex_spec.ExecutionSpec(promise_under_10gb=True) - ).to_py_scalar() + position_scalar = ( + self._block.session._executor.execute( + position_result, ex_spec.ExecutionSpec(promise_under_10gb=True) + ) + .batches() + .to_py_scalar() + ) return int(position_scalar) # Handle multiple matches based on index monotonicity @@ -333,10 +341,14 @@ def _get_monotonic_slice(self, filtered_block, offsets_id: str) -> slice: combined_result = filtered_block._expr.aggregate(min_max_aggs) # Execute query and extract positions - result_df = self._block.session._executor.execute( - combined_result, - execution_spec=ex_spec.ExecutionSpec(promise_under_10gb=True), - ).to_pandas() + result_df = ( + self._block.session._executor.execute( + combined_result, + execution_spec=ex_spec.ExecutionSpec(promise_under_10gb=True), + ) + .batches() + .to_pandas() + ) min_pos = int(result_df["min_pos"].iloc[0]) max_pos = int(result_df["max_pos"].iloc[0]) diff --git a/bigframes/core/local_data.py b/bigframes/core/local_data.py index c214d0bb7e..fa18f00483 100644 --- a/bigframes/core/local_data.py +++ b/bigframes/core/local_data.py @@ -83,20 +83,39 @@ def from_pandas(cls, dataframe: pd.DataFrame) -> ManagedArrowTable: return mat @classmethod - def from_pyarrow(self, table: pa.Table) -> ManagedArrowTable: - columns: list[pa.ChunkedArray] = [] - fields: list[schemata.SchemaItem] = [] - for name, arr in zip(table.column_names, table.columns): - new_arr, bf_type = _adapt_chunked_array(arr) - columns.append(new_arr) - fields.append(schemata.SchemaItem(name, bf_type)) - - mat = ManagedArrowTable( - pa.table(columns, names=table.column_names), - schemata.ArraySchema(tuple(fields)), - ) - mat.validate() - return mat + def from_pyarrow( + cls, table: pa.Table, schema: Optional[schemata.ArraySchema] = None + ) -> ManagedArrowTable: + if schema is not None: + pa_fields = [] + for item in schema.items: + pa_type = _get_managed_storage_type(item.dtype) + pa_fields.append( + pyarrow.field( + item.column, + pa_type, + nullable=not pyarrow.types.is_list(pa_type), + ) + ) + pa_schema = pyarrow.schema(pa_fields) + # assumption: needed transformations can be handled by simple cast. + mat = ManagedArrowTable(table.cast(pa_schema), schema) + mat.validate() + return mat + else: # infer bigframes schema + columns: list[pa.ChunkedArray] = [] + fields: list[schemata.SchemaItem] = [] + for name, arr in zip(table.column_names, table.columns): + new_arr, bf_type = _adapt_chunked_array(arr) + columns.append(new_arr) + fields.append(schemata.SchemaItem(name, bf_type)) + + mat = ManagedArrowTable( + pa.table(columns, names=table.column_names), + schemata.ArraySchema(tuple(fields)), + ) + mat.validate() + return mat def to_arrow( self, diff --git a/bigframes/core/nodes.py b/bigframes/core/nodes.py index 0d20509877..9e0fcb3ace 100644 --- a/bigframes/core/nodes.py +++ b/bigframes/core/nodes.py @@ -16,7 +16,6 @@ import abc import dataclasses -import datetime import functools import itertools import typing @@ -31,9 +30,7 @@ Tuple, ) -import google.cloud.bigquery as bq - -from bigframes.core import agg_expressions, identifiers, local_data, sequences +from bigframes.core import agg_expressions, bq_data, identifiers, local_data, sequences from bigframes.core.bigframe_node import BigFrameNode, COLUMN_SET import bigframes.core.expression as ex from bigframes.core.field import Field @@ -599,14 +596,13 @@ def transform_children(self, t: Callable[[BigFrameNode], BigFrameNode]) -> LeafN class ScanItem(typing.NamedTuple): id: identifiers.ColumnId - dtype: bigframes.dtypes.Dtype # Might be multiple logical types for a given physical source type source_id: str # Flexible enough for both local data and bq data def with_id(self, id: identifiers.ColumnId) -> ScanItem: - return ScanItem(id, self.dtype, self.source_id) + return ScanItem(id, self.source_id) def with_source_id(self, source_id: str) -> ScanItem: - return ScanItem(self.id, self.dtype, source_id) + return ScanItem(self.id, source_id) @dataclasses.dataclass(frozen=True) @@ -661,7 +657,7 @@ def remap_source_ids( def append( self, source_id: str, dtype: bigframes.dtypes.Dtype, id: identifiers.ColumnId ) -> ScanList: - return ScanList((*self.items, ScanItem(id, dtype, source_id))) + return ScanList((*self.items, ScanItem(id, source_id))) @dataclasses.dataclass(frozen=True, eq=False) @@ -677,8 +673,10 @@ class ReadLocalNode(LeafNode): @property def fields(self) -> Sequence[Field]: fields = tuple( - Field(col_id, dtype) for col_id, dtype, _ in self.scan_list.items + Field(col_id, self.local_data_source.schema.get_type(source_id)) + for col_id, source_id in self.scan_list.items ) + if self.offsets_col is not None: return tuple( itertools.chain( @@ -726,7 +724,7 @@ def remap_vars( ) -> ReadLocalNode: new_scan_list = ScanList( tuple( - ScanItem(mappings.get(item.id, item.id), item.dtype, item.source_id) + ScanItem(mappings.get(item.id, item.id), item.source_id) for item in self.scan_list.items ) ) @@ -745,64 +743,9 @@ def remap_refs( return self -@dataclasses.dataclass(frozen=True) -class GbqTable: - project_id: str = dataclasses.field() - dataset_id: str = dataclasses.field() - table_id: str = dataclasses.field() - physical_schema: Tuple[bq.SchemaField, ...] = dataclasses.field() - is_physically_stored: bool = dataclasses.field() - cluster_cols: typing.Optional[Tuple[str, ...]] - - @staticmethod - def from_table(table: bq.Table, columns: Sequence[str] = ()) -> GbqTable: - # Subsetting fields with columns can reduce cost of row-hash default ordering - if columns: - schema = tuple(item for item in table.schema if item.name in columns) - else: - schema = tuple(table.schema) - return GbqTable( - project_id=table.project, - dataset_id=table.dataset_id, - table_id=table.table_id, - physical_schema=schema, - is_physically_stored=(table.table_type in ["TABLE", "MATERIALIZED_VIEW"]), - cluster_cols=None - if table.clustering_fields is None - else tuple(table.clustering_fields), - ) - - def get_table_ref(self) -> bq.TableReference: - return bq.TableReference( - bq.DatasetReference(self.project_id, self.dataset_id), self.table_id - ) - - @property - @functools.cache - def schema_by_id(self): - return {col.name: col for col in self.physical_schema} - - -@dataclasses.dataclass(frozen=True) -class BigqueryDataSource: - """ - Google BigQuery Data source. - - This should not be modified once defined, as all attributes contribute to the default ordering. - """ - - table: GbqTable - at_time: typing.Optional[datetime.datetime] = None - # Added for backwards compatibility, not validated - sql_predicate: typing.Optional[str] = None - ordering: typing.Optional[orderings.RowOrdering] = None - n_rows: Optional[int] = None - - -## Put ordering in here or just add order_by node above? @dataclasses.dataclass(frozen=True, eq=False) class ReadTableNode(LeafNode): - source: BigqueryDataSource + source: bq_data.BigqueryDataSource # Subset of physical schema column # Mapping of table schema ids to bfet id. scan_list: ScanList @@ -826,8 +769,12 @@ def session(self): @property def fields(self) -> Sequence[Field]: return tuple( - Field(col_id, dtype, self.source.table.schema_by_id[source_id].is_nullable) - for col_id, dtype, source_id in self.scan_list.items + Field( + col_id, + self.source.schema.get_type(source_id), + self.source.table.schema_by_id[source_id].is_nullable, + ) + for col_id, source_id in self.scan_list.items ) @property @@ -886,7 +833,7 @@ def remap_vars( ) -> ReadTableNode: new_scan_list = ScanList( tuple( - ScanItem(mappings.get(item.id, item.id), item.dtype, item.source_id) + ScanItem(mappings.get(item.id, item.id), item.source_id) for item in self.scan_list.items ) ) @@ -907,7 +854,6 @@ def with_order_cols(self): new_scan_cols = [ ScanItem( identifiers.ColumnId.unique(), - dtype=bigframes.dtypes.convert_schema_field(field)[1], source_id=field.name, ) for field in self.source.table.physical_schema diff --git a/bigframes/core/pyarrow_utils.py b/bigframes/core/pyarrow_utils.py index b9dc2ea2b3..bdbb220b95 100644 --- a/bigframes/core/pyarrow_utils.py +++ b/bigframes/core/pyarrow_utils.py @@ -84,6 +84,13 @@ def cast_batch(batch: pa.RecordBatch, schema: pa.Schema) -> pa.RecordBatch: ) +def rename_batch(batch: pa.RecordBatch, names: list[str]) -> pa.RecordBatch: + if batch.schema.names == names: + return batch + # TODO: Use RecordBatch.rename_columns once min pyarrow>=16.0 + return pa.RecordBatch.from_arrays(batch.columns, names) + + def truncate_pyarrow_iterable( batches: Iterable[pa.RecordBatch], max_results: int ) -> Iterator[pa.RecordBatch]: diff --git a/bigframes/core/rewrite/fold_row_count.py b/bigframes/core/rewrite/fold_row_count.py index 583343d68a..cc0b818fb9 100644 --- a/bigframes/core/rewrite/fold_row_count.py +++ b/bigframes/core/rewrite/fold_row_count.py @@ -15,7 +15,6 @@ import pyarrow as pa -from bigframes import dtypes from bigframes.core import local_data, nodes from bigframes.operations import aggregations @@ -34,10 +33,7 @@ def fold_row_counts(node: nodes.BigFrameNode) -> nodes.BigFrameNode: pa.table({"count": pa.array([node.child.row_count], type=pa.int64())}) ) scan_list = nodes.ScanList( - tuple( - nodes.ScanItem(out_id, dtypes.INT_DTYPE, "count") - for _, out_id in node.aggregations - ) + tuple(nodes.ScanItem(out_id, "count") for _, out_id in node.aggregations) ) return nodes.ReadLocalNode( local_data_source=local_data_source, scan_list=scan_list, session=node.session diff --git a/bigframes/core/schema.py b/bigframes/core/schema.py index b1a77d1259..4225725e72 100644 --- a/bigframes/core/schema.py +++ b/bigframes/core/schema.py @@ -17,7 +17,7 @@ from dataclasses import dataclass import functools import typing -from typing import Dict, List, Sequence +from typing import Dict, List import google.cloud.bigquery import pyarrow @@ -35,7 +35,7 @@ class SchemaItem: @dataclass(frozen=True) class ArraySchema: - items: Sequence[SchemaItem] + items: tuple[SchemaItem, ...] def __iter__(self): yield from self.items diff --git a/bigframes/session/bq_caching_executor.py b/bigframes/session/bq_caching_executor.py index c830ca1e29..c5daf66d92 100644 --- a/bigframes/session/bq_caching_executor.py +++ b/bigframes/session/bq_caching_executor.py @@ -15,10 +15,8 @@ from __future__ import annotations import math -import os import threading from typing import Literal, Mapping, Optional, Sequence, Tuple -import warnings import weakref import google.api_core.exceptions @@ -31,13 +29,12 @@ from bigframes import exceptions as bfe import bigframes.constants import bigframes.core -from bigframes.core import compile, local_data, rewrite +from bigframes.core import bq_data, compile, local_data, rewrite import bigframes.core.compile.sqlglot.sqlglot_ir as sqlglot_ir import bigframes.core.events import bigframes.core.guid import bigframes.core.identifiers import bigframes.core.nodes as nodes -import bigframes.core.ordering as order import bigframes.core.schema as schemata import bigframes.core.tree_properties as tree_properties import bigframes.dtypes @@ -60,7 +57,6 @@ MAX_SUBTREE_FACTORINGS = 5 _MAX_CLUSTER_COLUMNS = 4 MAX_SMALL_RESULT_BYTES = 10 * 1024 * 1024 * 1024 # 10G -_MAX_READ_STREAMS = os.cpu_count() SourceIdMapping = Mapping[str, str] @@ -74,7 +70,7 @@ def __init__(self): ] = weakref.WeakKeyDictionary() self._uploaded_local_data: weakref.WeakKeyDictionary[ local_data.ManagedArrowTable, - tuple[nodes.BigqueryDataSource, SourceIdMapping], + tuple[bq_data.BigqueryDataSource, SourceIdMapping], ] = weakref.WeakKeyDictionary() @property @@ -84,23 +80,16 @@ def mapping(self) -> Mapping[nodes.BigFrameNode, nodes.BigFrameNode]: def cache_results_table( self, original_root: nodes.BigFrameNode, - table: bigquery.Table, - ordering: order.RowOrdering, - num_rows: Optional[int] = None, + data: bq_data.BigqueryDataSource, ): # Assumption: GBQ cached table uses field name as bq column name scan_list = nodes.ScanList( tuple( - nodes.ScanItem(field.id, field.dtype, field.id.sql) - for field in original_root.fields + nodes.ScanItem(field.id, field.id.sql) for field in original_root.fields ) ) cached_replacement = nodes.CachedTableNode( - source=nodes.BigqueryDataSource( - nodes.GbqTable.from_table(table), - ordering=ordering, - n_rows=num_rows, - ), + source=data, scan_list=scan_list, table_session=original_root.session, original_node=original_root, @@ -111,7 +100,7 @@ def cache_results_table( def cache_remote_replacement( self, local_data: local_data.ManagedArrowTable, - bq_data: nodes.BigqueryDataSource, + bq_data: bq_data.BigqueryDataSource, ): # bq table has one extra column for offsets, those are implicit for local data assert len(local_data.schema.items) + 1 == len(bq_data.table.physical_schema) @@ -331,7 +320,7 @@ def _export_gbq( # TODO(swast): plumb through the api_name of the user-facing api that # caused this query. - row_iter, query_job = self._run_execute_query( + iterator, job = self._run_execute_query( sql=sql, job_config=job_config, ) @@ -347,14 +336,11 @@ def _export_gbq( table.schema = array_value.schema.to_bigquery() self.bqclient.update_table(table, ["schema"]) - return executor.ExecuteResult( - row_iter.to_arrow_iterable( - bqstorage_client=self.bqstoragereadclient, - max_stream_count=_MAX_READ_STREAMS, + return executor.EmptyExecuteResult( + bf_schema=array_value.schema, + execution_metadata=executor.ExecutionMetadata.from_iterator_and_job( + iterator, job ), - array_value.schema, - query_job, - total_bytes_processed=row_iter.total_bytes_processed, ) def dry_run( @@ -637,18 +623,14 @@ def _execute_plan_gbq( create_table = True if not cache_spec.cluster_cols: + assert len(cache_spec.cluster_cols) <= _MAX_CLUSTER_COLUMNS offsets_id = bigframes.core.identifiers.ColumnId( bigframes.core.guid.generate_guid() ) plan = nodes.PromoteOffsetsNode(plan, offsets_id) cluster_cols = [offsets_id.sql] else: - cluster_cols = [ - col - for col in cache_spec.cluster_cols - if bigframes.dtypes.is_clusterable(plan.schema.get_type(col)) - ] - cluster_cols = cluster_cols[:_MAX_CLUSTER_COLUMNS] + cluster_cols = cache_spec.cluster_cols compiled = compile.compile_sql( compile.CompileRequest( @@ -676,41 +658,62 @@ def _execute_plan_gbq( query_with_job=(destination_table is not None), ) - table_info: Optional[bigquery.Table] = None + # we could actually cache even when caching is not explicitly requested, but being conservative for now + result_bq_data = None if query_job and query_job.destination: - table_info = self.bqclient.get_table(query_job.destination) - size_bytes = table_info.num_bytes - else: - size_bytes = None + # we might add extra sql columns in compilation, esp if caching w ordering, infer a bigframes type for them + result_bf_schema = _result_schema(og_schema, list(compiled.sql_schema)) + dst = query_job.destination + result_bq_data = bq_data.BigqueryDataSource( + table=bq_data.GbqTable( + dst.project, + dst.dataset_id, + dst.table_id, + tuple(compiled_schema), + is_physically_stored=True, + cluster_cols=tuple(cluster_cols), + ), + schema=result_bf_schema, + ordering=compiled.row_order, + n_rows=iterator.total_rows, + ) - # we could actually cache even when caching is not explicitly requested, but being conservative for now if cache_spec is not None: - assert table_info is not None + assert result_bq_data is not None assert compiled.row_order is not None - self.cache.cache_results_table( - og_plan, table_info, compiled.row_order, num_rows=table_info.num_rows - ) + self.cache.cache_results_table(og_plan, result_bq_data) - if size_bytes is not None and size_bytes >= MAX_SMALL_RESULT_BYTES: - msg = bfe.format_message( - "The query result size has exceeded 10 GB. In BigFrames 2.0 and " - "later, you might need to manually set `allow_large_results=True` in " - "the IO method or adjust the BigFrames option: " - "`bigframes.options.compute.allow_large_results=True`." + execution_metadata = executor.ExecutionMetadata.from_iterator_and_job( + iterator, query_job + ) + result_mostly_cached = ( + hasattr(iterator, "_is_almost_completely_cached") + and iterator._is_almost_completely_cached() + ) + if result_bq_data is not None and not result_mostly_cached: + return executor.BQTableExecuteResult( + data=result_bq_data, + project_id=self.bqclient.project, + storage_client=self.bqstoragereadclient, + execution_metadata=execution_metadata, + selected_fields=tuple((col, col) for col in og_schema.names), + ) + else: + return executor.LocalExecuteResult( + data=iterator.to_arrow().select(og_schema.names), + bf_schema=plan.schema, + execution_metadata=execution_metadata, ) - warnings.warn(msg, FutureWarning) - return executor.ExecuteResult( - _arrow_batches=iterator.to_arrow_iterable( - bqstorage_client=self.bqstoragereadclient, - max_stream_count=_MAX_READ_STREAMS, - ), - schema=og_schema, - query_job=query_job, - total_bytes=size_bytes, - total_rows=iterator.total_rows, - total_bytes_processed=iterator.total_bytes_processed, - ) + +def _result_schema( + logical_schema: schemata.ArraySchema, sql_schema: list[bigquery.SchemaField] +) -> schemata.ArraySchema: + inferred_schema = bigframes.dtypes.bf_type_from_type_kind(sql_schema) + inferred_schema.update(logical_schema._mapping) + return schemata.ArraySchema( + tuple(schemata.SchemaItem(col, dtype) for col, dtype in inferred_schema.items()) + ) def _if_schema_match( diff --git a/bigframes/session/direct_gbq_execution.py b/bigframes/session/direct_gbq_execution.py index 9e7db87301..d76a1a7630 100644 --- a/bigframes/session/direct_gbq_execution.py +++ b/bigframes/session/direct_gbq_execution.py @@ -64,12 +64,13 @@ def execute( sql=compiled.sql, ) - return executor.ExecuteResult( - _arrow_batches=iterator.to_arrow_iterable(), - schema=plan.schema, - query_job=query_job, - total_rows=iterator.total_rows, - total_bytes_processed=iterator.total_bytes_processed, + # just immediately downlaod everything for simplicity + return executor.LocalExecuteResult( + data=iterator.to_arrow(), + bf_schema=plan.schema, + execution_metadata=executor.ExecutionMetadata.from_iterator_and_job( + iterator, query_job + ), ) def _run_execute_query( diff --git a/bigframes/session/executor.py b/bigframes/session/executor.py index d0cfe5f4f7..a63bb962ab 100644 --- a/bigframes/session/executor.py +++ b/bigframes/session/executor.py @@ -18,15 +18,17 @@ import dataclasses import functools import itertools -from typing import Iterator, Literal, Optional, Union +from typing import Iterator, Literal, Optional, Sequence, Union -from google.cloud import bigquery +from google.cloud import bigquery, bigquery_storage_v1 +import google.cloud.bigquery.table as bq_table import pandas as pd import pyarrow +import pyarrow as pa import bigframes import bigframes.core -from bigframes.core import pyarrow_utils +from bigframes.core import bq_data, local_data, pyarrow_utils import bigframes.core.schema import bigframes.session._io.pandas as io_pandas import bigframes.session.execution_spec as ex_spec @@ -38,21 +40,39 @@ ) -@dataclasses.dataclass(frozen=True) -class ExecuteResult: - _arrow_batches: Iterator[pyarrow.RecordBatch] - schema: bigframes.core.schema.ArraySchema - query_job: Optional[bigquery.QueryJob] = None - total_bytes: Optional[int] = None - total_rows: Optional[int] = None - total_bytes_processed: Optional[int] = None +class ResultsIterator(Iterator[pa.RecordBatch]): + """ + Iterator for query results, with some extra metadata attached. + """ + + def __init__( + self, + batches: Iterator[pa.RecordBatch], + schema: bigframes.core.schema.ArraySchema, + total_rows: Optional[int] = 0, + total_bytes: Optional[int] = 0, + ): + self._batches = batches + self._schema = schema + self._total_rows = total_rows + self._total_bytes = total_bytes + + @property + def approx_total_rows(self) -> Optional[int]: + return self._total_rows + + @property + def approx_total_bytes(self) -> Optional[int]: + return self._total_bytes + + def __next__(self) -> pa.RecordBatch: + return next(self._batches) @property def arrow_batches(self) -> Iterator[pyarrow.RecordBatch]: result_rows = 0 - for batch in self._arrow_batches: - batch = pyarrow_utils.cast_batch(batch, self.schema.to_pyarrow()) + for batch in self._batches: result_rows += batch.num_rows maximum_result_rows = bigframes.options.compute.maximum_result_rows @@ -80,10 +100,10 @@ def to_arrow_table(self) -> pyarrow.Table: itertools.chain(peek_value, batches), # reconstruct ) else: - return self.schema.to_pyarrow().empty_table() + return self._schema.to_pyarrow().empty_table() def to_pandas(self) -> pd.DataFrame: - return io_pandas.arrow_to_pandas(self.to_arrow_table(), self.schema) + return io_pandas.arrow_to_pandas(self.to_arrow_table(), self._schema) def to_pandas_batches( self, page_size: Optional[int] = None, max_results: Optional[int] = None @@ -105,7 +125,7 @@ def to_pandas_batches( ) yield from map( - functools.partial(io_pandas.arrow_to_pandas, schema=self.schema), + functools.partial(io_pandas.arrow_to_pandas, schema=self._schema), batch_iter, ) @@ -121,6 +141,150 @@ def to_py_scalar(self): return column[0] +class ExecuteResult(abc.ABC): + @property + @abc.abstractmethod + def execution_metadata(self) -> ExecutionMetadata: + ... + + @property + @abc.abstractmethod + def schema(self) -> bigframes.core.schema.ArraySchema: + ... + + @abc.abstractmethod + def batches(self) -> ResultsIterator: + ... + + @property + def query_job(self) -> Optional[bigquery.QueryJob]: + return self.execution_metadata.query_job + + @property + def total_bytes_processed(self) -> Optional[int]: + return self.execution_metadata.bytes_processed + + +@dataclasses.dataclass(frozen=True) +class ExecutionMetadata: + query_job: Optional[bigquery.QueryJob] = None + bytes_processed: Optional[int] = None + + @classmethod + def from_iterator_and_job( + cls, iterator: bq_table.RowIterator, job: Optional[bigquery.QueryJob] + ) -> ExecutionMetadata: + return cls(query_job=job, bytes_processed=iterator.total_bytes_processed) + + +class LocalExecuteResult(ExecuteResult): + def __init__( + self, + data: pa.Table, + bf_schema: bigframes.core.schema.ArraySchema, + execution_metadata: ExecutionMetadata = ExecutionMetadata(), + ): + self._data = local_data.ManagedArrowTable.from_pyarrow(data, bf_schema) + self._execution_metadata = execution_metadata + + @property + def execution_metadata(self) -> ExecutionMetadata: + return self._execution_metadata + + @property + def schema(self) -> bigframes.core.schema.ArraySchema: + return self._data.schema + + def batches(self) -> ResultsIterator: + return ResultsIterator( + iter(self._data.to_arrow()[1]), + self.schema, + self._data.metadata.row_count, + self._data.metadata.total_bytes, + ) + + +class EmptyExecuteResult(ExecuteResult): + def __init__( + self, + bf_schema: bigframes.core.schema.ArraySchema, + execution_metadata: ExecutionMetadata = ExecutionMetadata(), + ): + self._schema = bf_schema + self._execution_metadata = execution_metadata + + @property + def execution_metadata(self) -> ExecutionMetadata: + return self._execution_metadata + + @property + def schema(self) -> bigframes.core.schema.ArraySchema: + return self._schema + + def batches(self) -> ResultsIterator: + return ResultsIterator(iter([]), self.schema, 0, 0) + + +class BQTableExecuteResult(ExecuteResult): + def __init__( + self, + data: bq_data.BigqueryDataSource, + storage_client: bigquery_storage_v1.BigQueryReadClient, + project_id: str, + *, + execution_metadata: ExecutionMetadata = ExecutionMetadata(), + limit: Optional[int] = None, + selected_fields: Optional[Sequence[tuple[str, str]]] = None, + ): + self._data = data + self._project_id = project_id + self._execution_metadata = execution_metadata + self._storage_client = storage_client + self._limit = limit + self._selected_fields = selected_fields or [ + (name, name) for name in data.schema.names + ] + + @property + def execution_metadata(self) -> ExecutionMetadata: + return self._execution_metadata + + @property + @functools.cache + def schema(self) -> bigframes.core.schema.ArraySchema: + source_ids = [selection[0] for selection in self._selected_fields] + return self._data.schema.select(source_ids).rename(dict(self._selected_fields)) + + def batches(self) -> ResultsIterator: + read_batches = bq_data.get_arrow_batches( + self._data, + [x[0] for x in self._selected_fields], + self._storage_client, + self._project_id, + ) + arrow_batches: Iterator[pa.RecordBatch] = map( + functools.partial( + pyarrow_utils.rename_batch, names=list(self.schema.names) + ), + read_batches.iter, + ) + approx_bytes: Optional[int] = read_batches.approx_bytes + approx_rows: Optional[int] = self._data.n_rows or read_batches.approx_rows + + if self._limit is not None: + if approx_rows is not None: + approx_rows = min(approx_rows, self._limit) + arrow_batches = pyarrow_utils.truncate_pyarrow_iterable( + arrow_batches, self._limit + ) + + if self._data.sql_predicate: + approx_bytes = None + approx_rows = None + + return ResultsIterator(arrow_batches, self.schema, approx_rows, approx_bytes) + + @dataclasses.dataclass(frozen=True) class HierarchicalKey: columns: tuple[str, ...] diff --git a/bigframes/session/loader.py b/bigframes/session/loader.py index 940fdc1352..7d549809bb 100644 --- a/bigframes/session/loader.py +++ b/bigframes/session/loader.py @@ -47,7 +47,15 @@ import pandas import pyarrow as pa -from bigframes.core import guid, identifiers, local_data, nodes, ordering, utils +from bigframes.core import ( + bq_data, + guid, + identifiers, + local_data, + nodes, + ordering, + utils, +) import bigframes.core as core import bigframes.core.blocks as blocks import bigframes.core.events @@ -324,9 +332,7 @@ def read_managed_data( source=gbq_source, scan_list=nodes.ScanList( tuple( - nodes.ScanItem( - identifiers.ColumnId(item.column), item.dtype, item.column - ) + nodes.ScanItem(identifiers.ColumnId(item.column), item.column) for item in data.schema.items ) ), @@ -337,7 +343,7 @@ def load_data( self, data: local_data.ManagedArrowTable, offsets_col: str, - ) -> nodes.BigqueryDataSource: + ) -> bq_data.BigqueryDataSource: """Load managed data into bigquery""" # JSON support incomplete @@ -379,8 +385,9 @@ def load_data( self._start_generic_job(load_job) # must get table metadata after load job for accurate metadata destination_table = self._bqclient.get_table(load_table_destination) - return nodes.BigqueryDataSource( - nodes.GbqTable.from_table(destination_table), + return bq_data.BigqueryDataSource( + bq_data.GbqTable.from_table(destination_table), + schema=schema_w_offsets, ordering=ordering.TotalOrdering.from_offset_col(offsets_col), n_rows=data.metadata.row_count, ) @@ -389,7 +396,7 @@ def stream_data( self, data: local_data.ManagedArrowTable, offsets_col: str, - ) -> nodes.BigqueryDataSource: + ) -> bq_data.BigqueryDataSource: """Load managed data into bigquery""" schema_w_offsets = data.schema.append( schemata.SchemaItem(offsets_col, bigframes.dtypes.INT_DTYPE) @@ -415,8 +422,9 @@ def stream_data( f"Problem loading at least one row from DataFrame: {errors}. {constants.FEEDBACK_LINK}" ) destination_table = self._bqclient.get_table(load_table_destination) - return nodes.BigqueryDataSource( - nodes.GbqTable.from_table(destination_table), + return bq_data.BigqueryDataSource( + bq_data.GbqTable.from_table(destination_table), + schema=schema_w_offsets, ordering=ordering.TotalOrdering.from_offset_col(offsets_col), n_rows=data.metadata.row_count, ) @@ -425,7 +433,7 @@ def write_data( self, data: local_data.ManagedArrowTable, offsets_col: str, - ) -> nodes.BigqueryDataSource: + ) -> bq_data.BigqueryDataSource: """Load managed data into bigquery""" schema_w_offsets = data.schema.append( schemata.SchemaItem(offsets_col, bigframes.dtypes.INT_DTYPE) @@ -469,8 +477,9 @@ def request_gen() -> Generator[bq_storage_types.AppendRowsRequest, None, None]: assert response.row_count == data.data.num_rows destination_table = self._bqclient.get_table(bq_table_ref) - return nodes.BigqueryDataSource( - nodes.GbqTable.from_table(destination_table), + return bq_data.BigqueryDataSource( + bq_data.GbqTable.from_table(destination_table), + schema=schema_w_offsets, ordering=ordering.TotalOrdering.from_offset_col(offsets_col), n_rows=data.metadata.row_count, ) diff --git a/bigframes/session/local_scan_executor.py b/bigframes/session/local_scan_executor.py index 65f088e8a1..fee0f557ea 100644 --- a/bigframes/session/local_scan_executor.py +++ b/bigframes/session/local_scan_executor.py @@ -57,10 +57,7 @@ def execute( if (peek is not None) and (total_rows is not None): total_rows = min(peek, total_rows) - return executor.ExecuteResult( - _arrow_batches=arrow_table.to_batches(), - schema=plan.schema, - query_job=None, - total_bytes=None, - total_rows=total_rows, + return executor.LocalExecuteResult( + data=arrow_table, + bf_schema=plan.schema, ) diff --git a/bigframes/session/polars_executor.py b/bigframes/session/polars_executor.py index a1e1d436e1..00f8f37934 100644 --- a/bigframes/session/polars_executor.py +++ b/bigframes/session/polars_executor.py @@ -16,14 +16,11 @@ import itertools from typing import Optional, TYPE_CHECKING -import pyarrow as pa - from bigframes.core import ( agg_expressions, array_value, bigframe_node, expression, - local_data, nodes, ) import bigframes.operations @@ -153,23 +150,10 @@ def execute( if peek is not None: lazy_frame = lazy_frame.limit(peek) pa_table = lazy_frame.collect().to_arrow() - return executor.ExecuteResult( - _arrow_batches=iter(map(self._adapt_batch, pa_table.to_batches())), - schema=plan.schema, - total_bytes=pa_table.nbytes, - total_rows=pa_table.num_rows, + return executor.LocalExecuteResult( + data=pa_table, + bf_schema=plan.schema, ) def _can_execute(self, plan: bigframe_node.BigFrameNode): return all(_is_node_polars_executable(node) for node in plan.unique_nodes()) - - def _adapt_array(self, array: pa.Array) -> pa.Array: - target_type = local_data.logical_type_replacements(array.type) - if target_type != array.type: - # Safe is false to handle weird polars decimal scaling - return array.cast(target_type, safe=False) - return array - - def _adapt_batch(self, batch: pa.RecordBatch) -> pa.RecordBatch: - new_arrays = [self._adapt_array(arr) for arr in batch.columns] - return pa.RecordBatch.from_arrays(new_arrays, names=batch.column_names) diff --git a/bigframes/session/read_api_execution.py b/bigframes/session/read_api_execution.py index 2530a1dc8d..0535bec0c4 100644 --- a/bigframes/session/read_api_execution.py +++ b/bigframes/session/read_api_execution.py @@ -13,12 +13,11 @@ # limitations under the License. from __future__ import annotations -from typing import Any, Iterator, Optional +from typing import Optional from google.cloud import bigquery_storage_v1 -import pyarrow as pa -from bigframes.core import bigframe_node, nodes, pyarrow_utils, rewrite +from bigframes.core import bigframe_node, nodes, rewrite from bigframes.session import executor, semi_executor @@ -28,7 +27,9 @@ class ReadApiSemiExecutor(semi_executor.SemiExecutor): """ def __init__( - self, bqstoragereadclient: bigquery_storage_v1.BigQueryReadClient, project: str + self, + bqstoragereadclient: bigquery_storage_v1.BigQueryReadClient, + project: str, ): self.bqstoragereadclient = bqstoragereadclient self.project = project @@ -50,68 +51,14 @@ def execute( if peek is None or limit < peek: peek = limit - import google.cloud.bigquery_storage_v1.types as bq_storage_types - from google.protobuf import timestamp_pb2 - - bq_table = node.source.table.get_table_ref() - read_options: dict[str, Any] = { - "selected_fields": [item.source_id for item in node.scan_list.items] - } - if node.source.sql_predicate: - read_options["row_restriction"] = node.source.sql_predicate - read_options = bq_storage_types.ReadSession.TableReadOptions(**read_options) - - table_mod_options = {} - if node.source.at_time: - snapshot_time = timestamp_pb2.Timestamp() - snapshot_time.FromDatetime(node.source.at_time) - table_mod_options["snapshot_time"] = snapshot_time = snapshot_time - table_mods = bq_storage_types.ReadSession.TableModifiers(**table_mod_options) - - requested_session = bq_storage_types.stream.ReadSession( - table=bq_table.to_bqstorage(), - data_format=bq_storage_types.DataFormat.ARROW, - read_options=read_options, - table_modifiers=table_mods, - ) - # Single stream to maintain ordering - request = bq_storage_types.CreateReadSessionRequest( - parent=f"projects/{self.project}", - read_session=requested_session, - max_stream_count=1, - ) - session = self.bqstoragereadclient.create_read_session(request=request) - - if not session.streams: - batches: Iterator[pa.RecordBatch] = iter([]) - else: - reader = self.bqstoragereadclient.read_rows(session.streams[0].name) - rowstream = reader.rows() - - def process_page(page): - pa_batch = page.to_arrow() - pa_batch = pa_batch.select( - [item.source_id for item in node.scan_list.items] - ) - return pa.RecordBatch.from_arrays( - pa_batch.columns, names=[id.sql for id in node.ids] - ) - - batches = map(process_page, rowstream.pages) - - if peek: - batches = pyarrow_utils.truncate_pyarrow_iterable(batches, max_results=peek) - - rows = node.source.n_rows or session.estimated_row_count - if peek and rows: - rows = min(peek, rows) - - return executor.ExecuteResult( - _arrow_batches=batches, - schema=plan.schema, - query_job=None, - total_bytes=None, - total_rows=rows, + return executor.BQTableExecuteResult( + data=node.source, + project_id=self.project, + storage_client=self.bqstoragereadclient, + limit=peek, + selected_fields=[ + (item.source_id, item.id.sql) for item in node.scan_list.items + ], ) def _try_adapt_plan( diff --git a/bigframes/testing/engine_utils.py b/bigframes/testing/engine_utils.py index 625d1727ee..edb68c3a9b 100644 --- a/bigframes/testing/engine_utils.py +++ b/bigframes/testing/engine_utils.py @@ -29,6 +29,6 @@ def assert_equivalence_execution( assert e2_result is not None # Convert to pandas, as pandas has better comparison utils than arrow assert e1_result.schema == e2_result.schema - e1_table = e1_result.to_pandas() - e2_table = e2_result.to_pandas() + e1_table = e1_result.batches().to_pandas() + e2_table = e2_result.batches().to_pandas() pandas.testing.assert_frame_equal(e1_table, e2_table, rtol=1e-5) diff --git a/bigframes/testing/polars_session.py b/bigframes/testing/polars_session.py index ba6d502fcc..ca1fa329a2 100644 --- a/bigframes/testing/polars_session.py +++ b/bigframes/testing/polars_session.py @@ -51,11 +51,9 @@ def execute( pa_table = lazy_frame.collect().to_arrow() # Currently, pyarrow types might not quite be exactly the ones in the bigframes schema. # Nullability may be different, and might use large versions of list, string datatypes. - return bigframes.session.executor.ExecuteResult( - _arrow_batches=pa_table.to_batches(), - schema=array_value.schema, - total_bytes=pa_table.nbytes, - total_rows=pa_table.num_rows, + return bigframes.session.executor.LocalExecuteResult( + data=pa_table, + bf_schema=array_value.schema, ) def cached( diff --git a/tests/system/small/engines/test_read_local.py b/tests/system/small/engines/test_read_local.py index bf1a10beec..abdd29c4ac 100644 --- a/tests/system/small/engines/test_read_local.py +++ b/tests/system/small/engines/test_read_local.py @@ -31,7 +31,7 @@ def test_engines_read_local( engine, ): scan_list = nodes.ScanList.from_items( - nodes.ScanItem(identifiers.ColumnId(item.column), item.dtype, item.column) + nodes.ScanItem(identifiers.ColumnId(item.column), item.column) for item in managed_data_source.schema.items ) local_node = nodes.ReadLocalNode( @@ -46,7 +46,7 @@ def test_engines_read_local_w_offsets( engine, ): scan_list = nodes.ScanList.from_items( - nodes.ScanItem(identifiers.ColumnId(item.column), item.dtype, item.column) + nodes.ScanItem(identifiers.ColumnId(item.column), item.column) for item in managed_data_source.schema.items ) local_node = nodes.ReadLocalNode( @@ -64,7 +64,7 @@ def test_engines_read_local_w_col_subset( engine, ): scan_list = nodes.ScanList.from_items( - nodes.ScanItem(identifiers.ColumnId(item.column), item.dtype, item.column) + nodes.ScanItem(identifiers.ColumnId(item.column), item.column) for item in managed_data_source.schema.items[::-2] ) local_node = nodes.ReadLocalNode( @@ -79,7 +79,7 @@ def test_engines_read_local_w_zero_row_source( engine, ): scan_list = nodes.ScanList.from_items( - nodes.ScanItem(identifiers.ColumnId(item.column), item.dtype, item.column) + nodes.ScanItem(identifiers.ColumnId(item.column), item.column) for item in zero_row_source.schema.items ) local_node = nodes.ReadLocalNode( @@ -96,7 +96,7 @@ def test_engines_read_local_w_nested_source( engine, ): scan_list = nodes.ScanList.from_items( - nodes.ScanItem(identifiers.ColumnId(item.column), item.dtype, item.column) + nodes.ScanItem(identifiers.ColumnId(item.column), item.column) for item in nested_data_source.schema.items ) local_node = nodes.ReadLocalNode( @@ -111,7 +111,7 @@ def test_engines_read_local_w_repeated_source( engine, ): scan_list = nodes.ScanList.from_items( - nodes.ScanItem(identifiers.ColumnId(item.column), item.dtype, item.column) + nodes.ScanItem(identifiers.ColumnId(item.column), item.column) for item in repeated_data_source.schema.items ) local_node = nodes.ReadLocalNode( diff --git a/tests/system/small/test_session.py b/tests/system/small/test_session.py index 001e02c2fa..52f317ae25 100644 --- a/tests/system/small/test_session.py +++ b/tests/system/small/test_session.py @@ -122,7 +122,7 @@ def test_read_gbq_tokyo( assert exec_result.query_job is not None assert exec_result.query_job.location == tokyo_location - assert len(expected) == exec_result.total_rows + assert len(expected) == exec_result.batches().approx_total_rows @pytest.mark.parametrize( @@ -951,7 +951,7 @@ def test_read_pandas_tokyo( assert result.query_job is not None assert result.query_job.location == tokyo_location - assert len(expected) == result.total_rows + assert len(expected) == result.batches().approx_total_rows @all_write_engines diff --git a/tests/unit/session/test_local_scan_executor.py b/tests/unit/session/test_local_scan_executor.py index 30b1b5f78d..fc59253153 100644 --- a/tests/unit/session/test_local_scan_executor.py +++ b/tests/unit/session/test_local_scan_executor.py @@ -16,7 +16,6 @@ import pyarrow import pytest -from bigframes import dtypes from bigframes.core import identifiers, local_data, nodes from bigframes.session import local_scan_executor from bigframes.testing import mocks @@ -37,9 +36,6 @@ def create_read_local_node(arrow_table: pyarrow.Table): items=tuple( nodes.ScanItem( id=identifiers.ColumnId(column_name), - dtype=dtypes.arrow_dtype_to_bigframes_dtype( - arrow_table.field(column_name).type - ), source_id=column_name, ) for column_name in arrow_table.column_names @@ -77,7 +73,7 @@ def test_local_scan_executor_with_slice(start, stop, expected_rows, object_under ) result = object_under_test.execute(plan, ordered=True) - result_table = pyarrow.Table.from_batches(result.arrow_batches) + result_table = pyarrow.Table.from_batches(result.batches().arrow_batches) assert result_table.num_rows == expected_rows