66
77from __future__ import annotations
88
9+ import io
910import logging
1011from typing import (
1112 List ,
1213 Optional ,
1314 Any ,
14- Callable ,
1515 cast ,
1616 TYPE_CHECKING ,
1717)
2020 from databricks .sql .backend .sea .result_set import SeaResultSet
2121
2222from databricks .sql .backend .types import ExecuteResponse
23+ from databricks .sql .backend .sea .models .base import ResultData
24+ from databricks .sql .backend .sea .backend import SeaDatabricksClient
25+ from databricks .sql .utils import CloudFetchQueue , ArrowQueue
26+
27+ try :
28+ import pyarrow
29+ import pyarrow .compute as pc
30+ except ImportError :
31+ pyarrow = None
32+ pc = None
2333
2434logger = logging .getLogger (__name__ )
2535
@@ -30,32 +40,18 @@ class ResultSetFilter:
3040 """
3141
3242 @staticmethod
33- def _filter_sea_result_set (
34- result_set : SeaResultSet , filter_func : Callable [[List [Any ]], bool ]
35- ) -> SeaResultSet :
43+ def _create_execute_response (result_set : SeaResultSet ) -> ExecuteResponse :
3644 """
37- Filter a SEA result set using the provided filter function .
45+ Create an ExecuteResponse with parameters from the original result set .
3846
3947 Args:
40- result_set: The SEA result set to filter
41- filter_func: Function that takes a row and returns True if the row should be included
48+ result_set: Original result set to copy parameters from
4249
4350 Returns:
44- A filtered SEA result set
51+ ExecuteResponse: New execute response object
4552 """
46-
47- # Get all remaining rows
48- all_rows = result_set .results .remaining_rows ()
49-
50- # Filter rows
51- filtered_rows = [row for row in all_rows if filter_func (row )]
52-
53- # Reuse the command_id from the original result set
54- command_id = result_set .command_id
55-
56- # Create an ExecuteResponse for the filtered data
57- execute_response = ExecuteResponse (
58- command_id = command_id ,
53+ return ExecuteResponse (
54+ command_id = result_set .command_id ,
5955 status = result_set .status ,
6056 description = result_set .description ,
6157 has_been_closed_server_side = result_set .has_been_closed_server_side ,
@@ -64,32 +60,145 @@ def _filter_sea_result_set(
6460 is_staging_operation = False ,
6561 )
6662
67- # Create a new ResultData object with filtered data
68- from databricks .sql .backend .sea .models .base import ResultData
63+ @staticmethod
64+ def _update_manifest (result_set : SeaResultSet , new_row_count : int ):
65+ """
66+ Create a copy of the manifest with updated row count.
67+
68+ Args:
69+ result_set: Original result set to copy manifest from
70+ new_row_count: New total row count for filtered data
6971
70- result_data = ResultData (data = filtered_rows , external_links = None )
72+ Returns:
73+ Updated manifest copy
74+ """
75+ filtered_manifest = result_set .manifest
76+ filtered_manifest .total_row_count = new_row_count
77+ return filtered_manifest
7178
72- from databricks .sql .backend .sea .backend import SeaDatabricksClient
79+ @staticmethod
80+ def _create_filtered_result_set (
81+ result_set : SeaResultSet ,
82+ result_data : ResultData ,
83+ row_count : int ,
84+ ) -> "SeaResultSet" :
85+ """
86+ Create a new filtered SeaResultSet with the provided data.
87+
88+ Args:
89+ result_set: Original result set to copy parameters from
90+ result_data: New result data for the filtered set
91+ row_count: Number of rows in the filtered data
92+
93+ Returns:
94+ New filtered SeaResultSet
95+ """
7396 from databricks .sql .backend .sea .result_set import SeaResultSet
7497
75- # Create a new SeaResultSet with the filtered data
76- manifest = result_set .manifest
77- manifest .total_row_count = len (filtered_rows )
98+ execute_response = ResultSetFilter ._create_execute_response (result_set )
99+ filtered_manifest = ResultSetFilter ._update_manifest (result_set , row_count )
78100
79- filtered_result_set = SeaResultSet (
101+ return SeaResultSet (
80102 connection = result_set .connection ,
81103 execute_response = execute_response ,
82104 sea_client = cast (SeaDatabricksClient , result_set .backend ),
83105 result_data = result_data ,
84- manifest = manifest ,
106+ manifest = filtered_manifest ,
85107 buffer_size_bytes = result_set .buffer_size_bytes ,
86108 arraysize = result_set .arraysize ,
87109 )
88110
89- return filtered_result_set
111+ @staticmethod
112+ def _filter_arrow_table (
113+ table : Any , # pyarrow.Table
114+ column_name : str ,
115+ allowed_values : List [str ],
116+ case_sensitive : bool = True ,
117+ ) -> Any : # returns pyarrow.Table
118+ """
119+ Filter a PyArrow table by column values.
120+
121+ Args:
122+ table: The PyArrow table to filter
123+ column_name: The name of the column to filter on
124+ allowed_values: List of allowed values for the column
125+ case_sensitive: Whether to perform case-sensitive comparison
126+
127+ Returns:
128+ A filtered PyArrow table
129+ """
130+ if not pyarrow :
131+ raise ImportError ("PyArrow is required for Arrow table filtering" )
132+
133+ if table .num_rows == 0 :
134+ return table
135+
136+ # Handle case-insensitive filtering by normalizing both column and allowed values
137+ if not case_sensitive :
138+ # Convert allowed values to uppercase
139+ allowed_values = [v .upper () for v in allowed_values ]
140+ # Get column values as uppercase
141+ column = pc .utf8_upper (table [column_name ])
142+ else :
143+ # Use column as-is
144+ column = table [column_name ]
145+
146+ # Convert allowed_values to PyArrow Array
147+ allowed_array = pyarrow .array (allowed_values )
148+
149+ # Construct a boolean mask: True where column is in allowed_list
150+ mask = pc .is_in (column , value_set = allowed_array )
151+ return table .filter (mask )
152+
153+ @staticmethod
154+ def _filter_arrow_result_set (
155+ result_set : SeaResultSet ,
156+ column_index : int ,
157+ allowed_values : List [str ],
158+ case_sensitive : bool = True ,
159+ ) -> SeaResultSet :
160+ """
161+ Filter a SEA result set that contains Arrow tables.
162+
163+ Args:
164+ result_set: The SEA result set to filter (containing Arrow data)
165+ column_index: The index of the column to filter on
166+ allowed_values: List of allowed values for the column
167+ case_sensitive: Whether to perform case-sensitive comparison
168+
169+ Returns:
170+ A filtered SEA result set
171+ """
172+ # Validate column index and get column name
173+ if column_index >= len (result_set .description ):
174+ raise ValueError (f"Column index { column_index } is out of bounds" )
175+ column_name = result_set .description [column_index ][0 ]
176+
177+ # Get all remaining rows as Arrow table and filter it
178+ arrow_table = result_set .results .remaining_rows ()
179+ filtered_table = ResultSetFilter ._filter_arrow_table (
180+ arrow_table , column_name , allowed_values , case_sensitive
181+ )
182+
183+ # Convert the filtered table to Arrow stream format for ResultData
184+ sink = io .BytesIO ()
185+ with pyarrow .ipc .new_stream (sink , filtered_table .schema ) as writer :
186+ writer .write_table (filtered_table )
187+ arrow_stream_bytes = sink .getvalue ()
188+
189+ # Create ResultData with attachment containing the filtered data
190+ result_data = ResultData (
191+ data = None , # No JSON data
192+ external_links = None , # No external links
193+ attachment = arrow_stream_bytes , # Arrow data as attachment
194+ )
195+
196+ return ResultSetFilter ._create_filtered_result_set (
197+ result_set , result_data , filtered_table .num_rows
198+ )
90199
91200 @staticmethod
92- def filter_by_column_values (
201+ def _filter_json_result_set (
93202 result_set : SeaResultSet ,
94203 column_index : int ,
95204 allowed_values : List [str ],
@@ -107,22 +216,35 @@ def filter_by_column_values(
107216 Returns:
108217 A filtered result set
109218 """
219+ # Validate column index (optional - not in arrow version but good practice)
220+ if column_index >= len (result_set .description ):
221+ raise ValueError (f"Column index { column_index } is out of bounds" )
110222
111- # Convert to uppercase for case-insensitive comparison if needed
223+ # Extract rows
224+ all_rows = result_set .results .remaining_rows ()
225+
226+ # Convert allowed values if case-insensitive
112227 if not case_sensitive :
113228 allowed_values = [v .upper () for v in allowed_values ]
229+ # Helper lambda to get column value based on case sensitivity
230+ get_column_value = (
231+ lambda row : row [column_index ].upper ()
232+ if not case_sensitive
233+ else row [column_index ]
234+ )
235+
236+ # Filter rows based on allowed values
237+ filtered_rows = [
238+ row
239+ for row in all_rows
240+ if len (row ) > column_index and get_column_value (row ) in allowed_values
241+ ]
242+
243+ # Create filtered result set
244+ result_data = ResultData (data = filtered_rows , external_links = None )
114245
115- return ResultSetFilter ._filter_sea_result_set (
116- result_set ,
117- lambda row : (
118- len (row ) > column_index
119- and (
120- row [column_index ].upper ()
121- if not case_sensitive
122- else row [column_index ]
123- )
124- in allowed_values
125- ),
246+ return ResultSetFilter ._create_filtered_result_set (
247+ result_set , result_data , len (filtered_rows )
126248 )
127249
128250 @staticmethod
@@ -143,14 +265,25 @@ def filter_tables_by_type(
143265 Returns:
144266 A filtered result set containing only tables of the specified types
145267 """
146-
147268 # Default table types if none specified
148269 DEFAULT_TABLE_TYPES = ["TABLE" , "VIEW" , "SYSTEM TABLE" ]
149- valid_types = (
150- table_types if table_types and len (table_types ) > 0 else DEFAULT_TABLE_TYPES
151- )
270+ valid_types = table_types if table_types else DEFAULT_TABLE_TYPES
152271
272+ # Check if we have an Arrow table (cloud fetch) or JSON data
153273 # Table type is the 6th column (index 5)
154- return ResultSetFilter .filter_by_column_values (
155- result_set , 5 , valid_types , case_sensitive = True
156- )
274+ if isinstance (result_set .results , (CloudFetchQueue , ArrowQueue )):
275+ # For Arrow tables, we need to handle filtering differently
276+ return ResultSetFilter ._filter_arrow_result_set (
277+ result_set ,
278+ column_index = 5 ,
279+ allowed_values = valid_types ,
280+ case_sensitive = True ,
281+ )
282+ else :
283+ # For JSON data, use the existing filter method
284+ return ResultSetFilter ._filter_json_result_set (
285+ result_set ,
286+ column_index = 5 ,
287+ allowed_values = valid_types ,
288+ case_sensitive = True ,
289+ )
0 commit comments