From c98e5ce18117b83b9999eeb819d2d2cb071f943f Mon Sep 17 00:00:00 2001 From: BohuTANG Date: Tue, 14 Oct 2025 10:31:43 +0800 Subject: [PATCH] Improve stage mapping integration --- python/README.md | 9 + python/README_CLIENT.md | 121 ------ python/databend_udf/client.py | 92 +++- python/databend_udf/udf.py | 579 +++++++++++++++++++++++-- python/tests/README.md | 46 +- python/tests/conftest.py | 26 ++ python/tests/servers/demo_server.py | 145 +++++++ python/tests/servers/stage_server.py | 49 +++ python/tests/test_demo_server.py | 73 ++++ python/tests/test_stage_integration.py | 164 +++++++ python/tests/test_stage_location.py | 208 +++++++++ 11 files changed, 1298 insertions(+), 214 deletions(-) delete mode 100644 python/README_CLIENT.md create mode 100644 python/tests/servers/demo_server.py create mode 100644 python/tests/servers/stage_server.py create mode 100644 python/tests/test_demo_server.py create mode 100644 python/tests/test_stage_integration.py create mode 100644 python/tests/test_stage_location.py diff --git a/python/README.md b/python/README.md index 772b9d0..5a6b4e1 100644 --- a/python/README.md +++ b/python/README.md @@ -147,3 +147,12 @@ python3 examples/server.py ### Acknowledgement Databend Python UDF Server API is inspired by [RisingWave Python API](https://pypi.org/project/risingwave/). + +### Code Formatting + +Use Ruff to keep the Python sources consistent: + +```bash +python -m pip install ruff # once +python -m ruff format python/databend_udf python/tests +``` diff --git a/python/README_CLIENT.md b/python/README_CLIENT.md deleted file mode 100644 index 6ee7a2a..0000000 --- a/python/README_CLIENT.md +++ /dev/null @@ -1,121 +0,0 @@ -# Databend UDF Client Library - -Simple Python client library for testing Databend UDF servers. - -## Installation - -```bash -pip install databend-udf -``` - -## Quick Start - -```python -from databend_udf import UDFClient, create_client - -# Create client (default: localhost:8815) -client = create_client() - -# Or specify host/port -client = UDFClient(host="localhost", port=8815) - -# Health check -if client.health_check(): - print("Server is running!") - -# Echo test -result = client.echo("Hello, Databend!") -print(result) # "Hello, Databend!" - -# Call UDF function -result = client.call_function("gcd", 48, 18) -print(result[0]) # 6 -``` - -## API Reference - -### UDFClient - -#### Methods - -- `__init__(host="localhost", port=8815)` - Create client -- `health_check() -> bool` - Check server health -- `echo(message: str) -> str` - Echo test message -- `call_function(name, *args) -> List[Any]` - Call UDF with arguments -- `call_function_batch(name, **kwargs) -> List[Any]` - Call UDF with batch data -- `get_function_info(name) -> FlightInfo` - Get function schema -- `list_functions() -> List[str]` - List available functions - -### Examples - -#### Single Value Calls - -```python -client = create_client() - -# Numeric functions -result = client.call_function("add_signed", 1, 2, 3, 4) -print(result[0]) # 10 - -# String functions -result = client.call_function("split_and_join", "a,b,c", ",", "-") -print(result[0]) # "a-b-c" - -# Array functions -result = client.call_function("array_access", ["hello", "world"], 1) -print(result[0]) # "hello" -``` - -#### Batch Calls - -```python -client = create_client() - -# Process multiple values at once -x_values = [48, 56, 72] -y_values = [18, 21, 24] -results = client.call_function_batch("gcd_batch", a=x_values, b=y_values) -print(results) # [6, 7, 24] -``` - -## Testing - -Run the example server: - -```bash -cd python/example -python server.py -``` - -In another terminal, run tests: - -```bash -# Simple test -python simple_test.py - -# Comprehensive test suite -python client_test.py -``` - -## Error Handling - -```python -try: - result = client.call_function("my_function", arg1, arg2) - print(result[0]) -except Exception as e: - print(f"Function call failed: {e}") -``` - -## Performance Testing - -The client supports testing concurrent I/O operations: - -```python -# Sequential calls (slow) -for i in range(10): - client.call_function("slow_function", i) - -# Batch call (fast with io_threads) -client.call_function_batch("slow_function", a=list(range(10))) -``` \ No newline at end of file diff --git a/python/databend_udf/client.py b/python/databend_udf/client.py index 74cb4c5..51015dd 100644 --- a/python/databend_udf/client.py +++ b/python/databend_udf/client.py @@ -1,10 +1,10 @@ -""" -Simple client library for testing Databend UDF servers. -""" +"""Simple client library for testing Databend UDF servers.""" + +import json +from typing import Any, Dict, Iterable, List, Sequence, Tuple import pyarrow as pa import pyarrow.flight as fl -from typing import List, Any class UDFClient: @@ -135,7 +135,69 @@ def get_function_info(self, function_name: str) -> fl.FlightInfo: descriptor = fl.FlightDescriptor.for_path(function_name) return self.client.get_flight_info(descriptor) - def call_function(self, function_name: str, *args) -> List[Any]: + @staticmethod + def format_stage_mapping(stage_locations: Iterable[Dict[str, Any]]) -> str: + """Serialize stage mapping entries to the Databend header payload.""" + + serialized_entries: List[Dict[str, Any]] = [] + for entry in stage_locations: + if not isinstance(entry, dict): + raise ValueError("stage_locations entries must be dictionaries") + if "param_name" not in entry: + raise ValueError("stage_locations entry requires 'param_name'") + serialized_entries.append(entry) + + return json.dumps(serialized_entries) + + @staticmethod + def _build_flight_headers( + headers: Dict[str, Any] = None, + stage_locations: Iterable[Dict[str, Any]] = None, + ) -> Sequence[Tuple[str, str]]: + """Construct Flight headers for a UDF call. + + ``stage_locations`` becomes a single header named ``databend-stage-mapping`` + whose value is a JSON array. This mirrors what Databend Query sends to + external UDF servers. Example HTTP-style representation:: + + databend-stage-mapping: [ + { + "param_name": "stage_loc", + "relative_path": "input/2024/", + "stage_info": { ... StageInfo JSON ... } + } + ] + + Multiple stage parameters simply append more objects to the array. + Additional custom headers can be supplied through ``headers``. + """ + headers = headers or {} + flight_headers: List[Tuple[bytes, bytes]] = [] + + for key, value in headers.items(): + if isinstance(value, (list, tuple)): + for item in value: + flight_headers.append( + (str(key).encode("utf-8"), str(item).encode("utf-8")) + ) + else: + flight_headers.append( + (str(key).encode("utf-8"), str(value).encode("utf-8")) + ) + + if stage_locations: + payload = UDFClient.format_stage_mapping(stage_locations) + flight_headers.append((b"databend-stage-mapping", payload.encode("utf-8"))) + + return flight_headers + + def call_function( + self, + function_name: str, + *args, + headers: Dict[str, Any] = None, + stage_locations: Iterable[Dict[str, Any]] = None, + ) -> List[Any]: """ Call a UDF function with given arguments. @@ -150,7 +212,11 @@ def call_function(self, function_name: str, *args) -> List[Any]: # Call function descriptor = fl.FlightDescriptor.for_path(function_name) - writer, reader = self.client.do_exchange(descriptor=descriptor) + flight_headers = self._build_flight_headers(headers, stage_locations) + options = ( + fl.FlightCallOptions(headers=flight_headers) if flight_headers else None + ) + writer, reader = self.client.do_exchange(descriptor=descriptor, options=options) with writer: writer.begin(input_schema) @@ -166,7 +232,13 @@ def call_function(self, function_name: str, *args) -> List[Any]: return results - def call_function_batch(self, function_name: str, **kwargs) -> List[Any]: + def call_function_batch( + self, + function_name: str, + headers: Dict[str, Any] = None, + stage_locations: Iterable[Dict[str, Any]] = None, + **kwargs, + ) -> List[Any]: """ Call a UDF function with batch data. @@ -181,7 +253,11 @@ def call_function_batch(self, function_name: str, **kwargs) -> List[Any]: # Call function descriptor = fl.FlightDescriptor.for_path(function_name) - writer, reader = self.client.do_exchange(descriptor=descriptor) + flight_headers = self._build_flight_headers(headers, stage_locations) + options = ( + fl.FlightCallOptions(headers=flight_headers) if flight_headers else None + ) + writer, reader = self.client.do_exchange(descriptor=descriptor, options=options) with writer: writer.begin(input_schema) diff --git a/python/databend_udf/udf.py b/python/databend_udf/udf.py index ee13ac8..a8c839b 100644 --- a/python/databend_udf/udf.py +++ b/python/databend_udf/udf.py @@ -15,14 +15,31 @@ import json import logging import inspect +import time +from dataclasses import dataclass from concurrent.futures import ThreadPoolExecutor -from typing import Iterator, Callable, Optional, Union, List, Dict +from typing import ( + Iterator, + Callable, + Optional, + Union, + List, + Dict, + Any, + Tuple, +) +from typing import get_args, get_origin from prometheus_client import Counter, Gauge, Histogram from prometheus_client import start_http_server import threading import pyarrow as pa -from pyarrow.flight import FlightServerBase, FlightInfo +from pyarrow.flight import ( + FlightServerBase, + FlightInfo, + ServerMiddleware, + ServerMiddlewareFactory, +) # comes from Databend MAX_DECIMAL128_PRECISION = 38 @@ -35,6 +52,314 @@ logger = logging.getLogger(__name__) +class QueryState: + """Represents the lifecycle state of a query request.""" + + def __init__(self) -> None: + self._cancelled = False + self._start_time = time.time() + + def is_cancelled(self) -> bool: + return self._cancelled + + def cancel(self) -> None: + self._cancelled = True + logger.warning("Query cancelled") + + +class HeadersMiddleware(ServerMiddleware): + """Flight middleware used to capture request headers for each call.""" + + def __init__(self, headers) -> None: + self.headers = headers + + def call_completed(self, exception): # pragma: no cover - thin wrapper + if exception: + logger.error("Call failed", exc_info=exception) + + +class HeadersMiddlewareFactory(ServerMiddlewareFactory): + """Creates `HeadersMiddleware` instances for each Flight call.""" + + def start_call(self, info, headers) -> HeadersMiddleware: + return HeadersMiddleware(headers) + + +def _safe_json_loads(payload: Union[str, bytes, None]) -> Optional[Any]: + if payload is None: + return None + if isinstance(payload, bytes): + payload = payload.decode("utf-8") + if not isinstance(payload, str): + return None + payload = payload.strip() + if not payload: + return None + try: + return json.loads(payload) + except json.JSONDecodeError: + logger.debug("Failed to decode JSON payload: %s", payload) + return None + + +def _ensure_dict(value: Any) -> Dict[str, Any]: + if isinstance(value, dict): + return value + decoded = _safe_json_loads(value) if isinstance(value, (str, bytes)) else None + return decoded if isinstance(decoded, dict) else {} + + +def _extract_param_name(entry: Dict[str, Any]) -> Optional[str]: + for key in ("param_name", "name", "arg_name", "stage_param", "parameter", "param"): + value = entry.get(key) + if value: + return str(value) + return None + + +@dataclass +class StageLocation: + """Structured representation of a stage argument resolved by Databend.""" + + name: str + stage_name: str + stage_type: str + storage: Dict[str, Any] + relative_path: str + raw_info: Dict[str, Any] + + @classmethod + def from_header_entry( + cls, param_name: str, entry: Dict[str, Any] + ) -> "StageLocation": + entry = entry or {} + if not isinstance(entry, dict): + entry = {} + + stage_info = entry.get("stage_info") or entry.get("stage") or entry.get("info") + if not isinstance(stage_info, dict): + stage_info = _ensure_dict(stage_info) + raw_info = stage_info if stage_info else entry + + stage_name = ( + entry.get("stage_name") or stage_info.get("stage_name") + if isinstance(stage_info, dict) + else None + ) + if not stage_name: + stage_name = entry.get("name") or param_name + + stage_type_raw = entry.get("stage_type") + if stage_type_raw is None and isinstance(stage_info, dict): + stage_type_raw = stage_info.get("stage_type") + + stage_type = "" + if isinstance(stage_type_raw, str): + stage_type = stage_type_raw + elif isinstance(stage_type_raw, dict): + for candidate in ("type", "stage_type"): + candidate_value = stage_type_raw.get(candidate) + if candidate_value: + stage_type = str(candidate_value) + break + if not stage_type and stage_type_raw: + first_value = next(iter(stage_type_raw.values()), None) + if first_value: + stage_type = str(first_value) + elif stage_type_raw is not None: + stage_type = str(stage_type_raw) + + stage_params = ( + stage_info.get("stage_params") if isinstance(stage_info, dict) else {} + ) + storage = entry.get("storage") + if not isinstance(storage, dict): + storage = _ensure_dict(storage) + if not storage and isinstance(stage_params, dict): + storage = _ensure_dict(stage_params.get("storage")) + if not isinstance(storage, dict): + storage = {} + + relative_path = ( + entry.get("relative_path") + or entry.get("path") + or entry.get("stage_path") + or entry.get("prefix") + or entry.get("pattern") + or entry.get("file_path") + ) + if isinstance(relative_path, dict): + relative_path = relative_path.get("path") or relative_path.get("value") + if relative_path is None: + relative_path = "" + + return cls( + name=str(param_name), + stage_name=str(stage_name) if stage_name is not None else "", + stage_type=stage_type, + storage=storage, + relative_path=str(relative_path), + raw_info=raw_info if isinstance(raw_info, dict) else {}, + ) + + +def _annotation_matches_stage_location(annotation: Any) -> bool: + if annotation is None: + return False + if annotation is StageLocation: + return True + if isinstance(annotation, str): + return annotation == "StageLocation" + + origin = get_origin(annotation) + if origin is Union: + return any( + _annotation_matches_stage_location(arg) for arg in get_args(annotation) + ) + + return False + + +def _parse_stage_mapping_payload(payload: Any) -> Dict[str, StageLocation]: + mapping: Dict[str, StageLocation] = {} + + def add_entry(param: str, value: Dict[str, Any]) -> None: + try: + mapping[param] = StageLocation.from_header_entry(param, value) + except Exception as exc: # pragma: no cover - defensive + logger.warning("Failed to parse stage mapping for %s: %s", param, exc) + + if isinstance(payload, list): + for entry in payload: + if not isinstance(entry, dict): + continue + param_name = _extract_param_name(entry) + if not param_name: + continue + add_entry(param_name, entry) + elif isinstance(payload, dict): + param_name = _extract_param_name(payload) + if param_name: + add_entry(param_name, payload) + else: + for key, value in payload.items(): + if isinstance(value, dict): + add_entry(str(key), value) + return mapping + + +def _load_stage_mapping(header_value: Any) -> Dict[str, StageLocation]: + """Parse the ``databend-stage-mapping`` header into StageLocation objects. + + The Flight client sends a single header whose *key* is ``databend-stage-mapping`` + (case-insensitive). The *value* is a JSON array describing every + ``STAGE_LOCATION`` argument. For example:: + + databend-stage-mapping: [ + { + "param_name": "input_stage", + "relative_path": "data/input/", + "stage_info": { + "stage_name": "input_stage", + "stage_type": "External", + "stage_params": {"storage": {"type": "s3", ...}} + } + }, + { + "param_name": "output_stage", + "relative_path": "data/output/", + "stage_info": { ... } + } + ] + + ``stage_info`` is the JSON form of Databend's ``StageInfo`` structure and is + forwarded verbatim so UDF handlers can access any extended metadata. + """ + if header_value is None: + return {} + if isinstance(header_value, (list, tuple)): + header_value = header_value[0] if header_value else None + if header_value is None: + return {} + if isinstance(header_value, bytes): + header_value = header_value.decode("utf-8") + payload: Any = header_value + if isinstance(header_value, str): + header_value = header_value.strip() + if not header_value: + return {} + try: + payload = json.loads(header_value) + except json.JSONDecodeError: + logger.warning("Failed to decode Databend-Stage-Mapping header") + return {} + if not isinstance(payload, (list, dict)): + return {} + return _parse_stage_mapping_payload(payload) + + +class Headers: + """Wrapper providing convenient accessors for Databend request headers.""" + + def __init__(self, headers: Optional[Dict[str, Any]] = None) -> None: + self.raw_headers: Dict[str, Any] = headers or {} + self.query_state = QueryState() + self._normalized: Dict[str, List[str]] = {} + if headers: + for key, value in headers.items(): + values: List[str] = [] + if isinstance(value, (list, tuple)): + iterable = value + else: + iterable = [value] + for item in iterable: + if isinstance(item, bytes): + values.append(item.decode("utf-8")) + else: + values.append(str(item) if not isinstance(item, str) else item) + self._normalized[key.lower()] = values + self.tenant = self._get_first("x-databend-tenant", "") or "" + self.queryid = self._get_first("x-databend-query-id", "") or "" + self._stage_mapping_cache: Optional[Dict[str, StageLocation]] = None + + def _get_first(self, key: str, default: Optional[str] = None) -> Optional[str]: + values = self._normalized.get(key.lower()) + if not values: + return default + return values[0] + + def get(self, key: str, default: Optional[str] = None) -> Optional[str]: + result = self._get_first(key, default) + return result if result is not None else default + + def get_all(self, key: str) -> List[str]: + return list(self._normalized.get(key.lower(), [])) + + def _stage_mapping(self) -> Dict[str, StageLocation]: + if self._stage_mapping_cache is None: + raw_value = self.get("databend-stage-mapping") + if raw_value is None: + raw_value = self.get("databend_stage_mapping") + self._stage_mapping_cache = _load_stage_mapping(raw_value) + return self._stage_mapping_cache + + @property + def stage_locations(self) -> Dict[str, StageLocation]: + return dict(self._stage_mapping()) + + def get_stage_location(self, name: str) -> Optional[StageLocation]: + return self._stage_mapping().get(name) + + def require_stage_locations(self, names: List[str]) -> Dict[str, StageLocation]: + mapping = self._stage_mapping() + missing = [name for name in names if name not in mapping] + if missing: + raise ValueError( + "Missing stage mapping for parameter(s): " + ", ".join(sorted(missing)) + ) + return {name: mapping[name] for name in names} + + class UserDefinedFunction: """ Base interface for user-defined function. @@ -44,7 +369,9 @@ class UserDefinedFunction: _input_schema: pa.Schema _result_schema: pa.Schema - def eval_batch(self, batch: pa.RecordBatch) -> Iterator[pa.RecordBatch]: + def eval_batch( + self, batch: pa.RecordBatch, headers: Optional[Headers] = None + ) -> Iterator[pa.RecordBatch]: """ Apply the function on a batch of inputs. """ @@ -70,18 +397,78 @@ def __init__( input_types, result_type, name=None, + stage_refs: Optional[List[str]] = None, io_threads=None, skip_null=None, batch_mode=False, ): self._func = func - self._input_schema = pa.schema( - field.with_name(arg_name) - for arg_name, field in zip( - inspect.getfullargspec(func)[0], - [_to_arrow_field(t) for t in _to_list(input_types)], + self._stage_ref_names = list(stage_refs or []) + if len(self._stage_ref_names) != len(set(self._stage_ref_names)): + raise ValueError("stage_refs contains duplicate parameter names") + + spec = inspect.getfullargspec(func) + arg_names = list(spec.args) + self._headers_param: Optional[str] = None + if arg_names and arg_names[-1] == "headers": + self._headers_param = arg_names.pop() + + annotations = getattr(spec, "annotations", {}) + annotated_stage_params = [ + name + for name in arg_names + if _annotation_matches_stage_location(annotations.get(name)) + ] + + stage_param_python_names: List[str] + if self._stage_ref_names: + if all(ref in arg_names for ref in self._stage_ref_names): + stage_param_python_names = list(self._stage_ref_names) + else: + stage_param_python_names = annotated_stage_params + if len(stage_param_python_names) != len(self._stage_ref_names): + raise ValueError( + f"Unable to map stage_refs to function parameters for {func.__name__}" + ) + else: + stage_param_python_names = annotated_stage_params + self._stage_ref_names = list(stage_param_python_names) + + if self._stage_ref_names and not stage_param_python_names: + raise ValueError( + f"stage_refs specified for function {func.__name__} but no StageLocation parameters found" ) - ) + + if len(stage_param_python_names) != len(set(stage_param_python_names)): + raise ValueError("Stage parameters must be unique in function signature") + + self._stage_param_to_ref = { + param: ref + for param, ref in zip(stage_param_python_names, self._stage_ref_names) + } + self._stage_param_names = stage_param_python_names + self._stage_param_set = set(self._stage_param_names) + + self._arg_order = list(arg_names) + self._data_arg_names = [ + name for name in self._arg_order if name not in self._stage_param_set + ] + input_type_list = _to_list(input_types) + if len(self._data_arg_names) != len(input_type_list): + raise ValueError( + f"Function {func.__name__} expects {len(self._data_arg_names)} data argument(s) " + f"but {len(input_type_list)} input type(s) were provided" + ) + + data_fields = [ + _to_arrow_field(type_def).with_name(arg_name) + for arg_name, type_def in zip(self._data_arg_names, input_type_list) + ] + self._input_schema = pa.schema(data_fields) + self._data_arg_indices = { + name: idx for idx, name in enumerate(self._data_arg_names) + } + self._result_schema = pa.schema( [_to_arrow_field(result_type).with_name("output")] ) @@ -96,6 +483,27 @@ def __init__( else None ) + self._call_arg_layout: List[Tuple[str, str]] = [] + for parameter in self._arg_order: + if parameter in self._stage_param_set: + self._call_arg_layout.append(("stage", parameter)) + else: + self._call_arg_layout.append(("data", parameter)) + if self._headers_param: + self._call_arg_layout.append(("headers", self._headers_param)) + + data_field_map = {field.name: field for field in self._input_schema} + self._sql_parameter_defs: List[str] = [] + for kind, identifier in self._call_arg_layout: + if kind == "stage": + stage_ref_name = self._stage_param_to_ref.get(identifier, identifier) + self._sql_parameter_defs.append(f"STAGE_LOCATION {stage_ref_name}") + elif kind == "data": + field = data_field_map[identifier] + self._sql_parameter_defs.append( + f"{field.name} {_arrow_field_to_string(field)}" + ) + if skip_null and not self._result_schema.field(0).nullable: raise ValueError( f"Return type of function {self._name} must be nullable when skip_null is True" @@ -104,47 +512,109 @@ def __init__( self._skip_null = skip_null or False super().__init__() - def eval_batch(self, batch: pa.RecordBatch) -> Iterator[pa.RecordBatch]: - inputs = [[v.as_py() for v in array] for array in batch] - inputs = [ - _input_process_func(_list_field(field))(array) - for array, field in zip(inputs, self._input_schema) - ] + def eval_batch( + self, batch: pa.RecordBatch, headers: Optional[Headers] = None + ) -> Iterator[pa.RecordBatch]: + headers_obj = headers if isinstance(headers, Headers) else Headers(headers) + + stage_locations: Dict[str, StageLocation] = {} + if self._stage_ref_names: + stage_locations_by_ref = headers_obj.require_stage_locations( + self._stage_ref_names + ) + stage_locations = { + param: stage_locations_by_ref[ref] + for param, ref in self._stage_param_to_ref.items() + } + for name, location in stage_locations.items(): + stage_type = location.stage_type.lower() if location.stage_type else "" + if stage_type and stage_type != "external": + raise ValueError( + f"Stage parameter '{name}' must reference an External stage" + ) + if not location.storage: + raise ValueError( + f"Stage parameter '{name}' is missing storage configuration" + ) + storage_type = str(location.storage.get("type", "")).lower() + if storage_type == "fs": + raise ValueError( + f"Stage parameter '{name}' must not use 'fs' storage" + ) + + processed_inputs: List[List[Any]] = [] + for array, field in zip(batch, self._input_schema): + python_values = [value.as_py() for value in array] + processed_inputs.append( + _input_process_func(_list_field(field))(python_values) + ) - # evaluate the function for each row if self._batch_mode: - column = self._func(*inputs) + call_args = self._assemble_args( + processed_inputs, stage_locations, headers_obj, row_idx=None + ) + column = self._func(*call_args) elif self._executor is not None: - # concurrently evaluate the function for each row - if self._skip_null: - tasks = [] - for row in range(batch.num_rows): - args = [col[row] for col in inputs] - func = _null_func if None in args else self._func - tasks.append(self._executor.submit(func, *args)) - else: - tasks = [ - self._executor.submit(self._func, *[col[row] for col in inputs]) - for row in range(batch.num_rows) - ] - column = [future.result() for future in tasks] + row_count = batch.num_rows + column = [None] * row_count + futures = [] + future_rows: List[int] = [] + for row in range(row_count): + if self._skip_null and self._row_has_null(processed_inputs, row): + column[row] = None + continue + call_args = self._assemble_args( + processed_inputs, stage_locations, headers_obj, row + ) + futures.append(self._executor.submit(self._func, *call_args)) + future_rows.append(row) + for row, future in zip(future_rows, futures): + column[row] = future.result() else: - if self._skip_null: - column = [] - for row in range(batch.num_rows): - args = [col[row] for col in inputs] - column.append(None if None in args else self._func(*args)) - else: - column = [ - self._func(*[col[row] for col in inputs]) - for row in range(batch.num_rows) - ] + column = [] + for row in range(batch.num_rows): + if self._skip_null and self._row_has_null(processed_inputs, row): + column.append(None) + continue + call_args = self._assemble_args( + processed_inputs, stage_locations, headers_obj, row + ) + column.append(self._func(*call_args)) column = _output_process_func(_list_field(self._result_schema.field(0)))(column) array = pa.array(column, type=self._result_schema.types[0]) yield pa.RecordBatch.from_arrays([array], schema=self._result_schema) + def _assemble_args( + self, + data_inputs: List[List[Any]], + stage_locations: Dict[str, StageLocation], + headers: Headers, + row_idx: Optional[int], + ) -> List[Any]: + args: List[Any] = [] + for kind, identifier in self._call_arg_layout: + if kind == "data": + data_index = self._data_arg_indices[identifier] + values = data_inputs[data_index] + args.append(values if row_idx is None else values[row_idx]) + elif kind == "stage": + if identifier not in stage_locations: + raise ValueError( + f"Missing stage mapping for parameter '{identifier}'" + ) + args.append(stage_locations[identifier]) + elif kind == "headers": + args.append(headers) + return args + + def _row_has_null(self, data_inputs: List[List[Any]], row_idx: int) -> bool: + for values in data_inputs: + if values[row_idx] is None: + return True + return False + def __call__(self, *args): return self._func(*args) @@ -153,6 +623,7 @@ def udf( input_types: Union[List[Union[str, pa.DataType]], Union[str, pa.DataType]], result_type: Union[str, pa.DataType], name: Optional[str] = None, + stage_refs: Optional[List[str]] = None, io_threads: Optional[int] = 32, skip_null: Optional[bool] = False, batch_mode: Optional[bool] = False, @@ -165,6 +636,9 @@ def udf( - result_type: A string or an Arrow data type that specifies the return value type. - name: An optional string specifying the function name. If not provided, the original name will be used. + - stage_refs: Optional list of parameter names that should be resolved as stage + locations. These parameters will be injected as `StageLocation` + objects when the function executes. - io_threads: Number of I/O threads used per data chunk for I/O bound functions. - skip_null: A boolean value specifying whether to skip NULL value. If it is set to True, NULL values will not be passed to the function, @@ -202,6 +676,7 @@ def gcd(x, y): input_types, result_type, name, + stage_refs=stage_refs, io_threads=io_threads, skip_null=skip_null, batch_mode=batch_mode, @@ -212,6 +687,7 @@ def gcd(x, y): input_types, result_type, name, + stage_refs=stage_refs, skip_null=skip_null, batch_mode=batch_mode, ) @@ -233,7 +709,11 @@ class UDFServer(FlightServerBase): _functions: Dict[str, UserDefinedFunction] def __init__(self, location="0.0.0.0:8815", metric_location=None, **kwargs): - super(UDFServer, self).__init__("grpc://" + location, **kwargs) + middleware = dict(kwargs.pop("middleware", {})) + middleware.setdefault("log_headers", HeadersMiddlewareFactory()) + super(UDFServer, self).__init__( + "grpc://" + location, middleware=middleware, **kwargs + ) self._location = location self._metric_location = metric_location self._functions = {} @@ -330,6 +810,11 @@ def do_exchange(self, context, descriptor, reader, writer): udf = self._functions[func_name] writer.begin(udf._result_schema) + headers_middleware = context.get_middleware("log_headers") + request_headers = Headers( + headers_middleware.headers if headers_middleware else None + ) + # Increment request counter self.requests_count.labels(function_name=func_name).inc() # Increment running requests gauge @@ -344,7 +829,7 @@ def do_exchange(self, context, descriptor, reader, writer): self.running_rows.labels(function_name=func_name).inc(batch_rows) try: - for output_batch in udf.eval_batch(batch.data): + for output_batch in udf.eval_batch(batch.data, request_headers): writer.write_batch(output_batch) finally: # Decrease running rows gauge after processing @@ -368,9 +853,13 @@ def add_function(self, udf: UserDefinedFunction): if name in self._functions: raise ValueError("Function already exists: " + name) self._functions[name] = udf - input_types = ", ".join( - _arrow_field_to_string(field) for field in udf._input_schema - ) + parameter_defs = getattr(udf, "_sql_parameter_defs", None) + if parameter_defs is None: + parameter_defs = [ + f"{field.name} {_arrow_field_to_string(field)}" + for field in udf._input_schema + ] + input_types = ", ".join(parameter_defs) output_type = _arrow_field_to_string(udf._result_schema[0]) sql = ( f"CREATE FUNCTION {name} ({input_types}) " diff --git a/python/tests/README.md b/python/tests/README.md index 6d6a868..94a95dc 100644 --- a/python/tests/README.md +++ b/python/tests/README.md @@ -1,47 +1,13 @@ -# Databend UDF Tests +# Running the Test Suite -Professional UDF test suite with isolated server instances for each test. +Use the project virtualenv and Pytest: -## Test Structure - -``` -tests/ -├── conftest.py # Pytest fixtures and server management -├── test_connectivity.py # Basic connectivity tests -├── test_simple_udf.py # Simple UDF function tests -├── servers/ -│ ├── minimal_server.py # Minimal server (built-in functions only) -│ └── basic_server.py # Basic server (simple arithmetic functions) +```bash +python/venv/bin/python -m pytest ``` -## Test Principles - -1. **Isolation** - Each test runs with its own server instance -2. **Progressive** - Test basic connectivity first, then functionality -3. **Focused** - Test core capabilities only, avoid complex types - -## Running Tests +For targeted runs, point to the desired file or test case, e.g.: ```bash -# Run all tests -pytest - -# Run specific test file -pytest test_connectivity.py - -# Run specific test function -pytest test_connectivity.py::test_health_check +python/venv/bin/python -m pytest python/tests/test_stage_integration.py::test_stage_integration_single_stage ``` - -## Test Content - -### Connectivity Tests -- Server startup -- Health check -- Built-in functions (builtin_healthy, builtin_echo) - -### Simple UDF Tests -- Integer addition -- GCD algorithm - -Each test uses an isolated server instance to ensure complete isolation between tests. \ No newline at end of file diff --git a/python/tests/conftest.py b/python/tests/conftest.py index 4c68f93..0b9b9d1 100644 --- a/python/tests/conftest.py +++ b/python/tests/conftest.py @@ -131,6 +131,32 @@ def basic_server(): manager.stop() +@pytest.fixture +def stage_server(): + """Server exposing stage-aware UDFs.""" + script_path = os.path.join(os.path.dirname(__file__), "servers", "stage_server.py") + manager = ServerManager(script_path) + + if not manager.start(): + pytest.fail("Failed to start stage server") + + yield manager + manager.stop() + + +@pytest.fixture +def demo_server(): + """Self-contained demo server for example UDFs.""" + script_path = os.path.join(os.path.dirname(__file__), "servers", "demo_server.py") + manager = ServerManager(script_path) + + if not manager.start(): + pytest.fail("Failed to start demo server") + + yield manager + manager.stop() + + @pytest.fixture def full_server(): """Full server with all example functions.""" diff --git a/python/tests/servers/demo_server.py b/python/tests/servers/demo_server.py new file mode 100644 index 0000000..ed1bcd4 --- /dev/null +++ b/python/tests/servers/demo_server.py @@ -0,0 +1,145 @@ +"""Standalone demo server for integration tests.""" + +import datetime +import json +import time +from decimal import Decimal +from typing import Any, Dict, List, Optional + +from databend_udf import UDFServer, udf + + +@udf(input_types=["INT", "INT", "INT", "INT"], result_type="INT") +def add_four(a: int, b: int, c: int, d: int) -> int: + return a + b + c + d + + +@udf(input_types=["BOOLEAN", "INT", "INT"], result_type="INT") +def select_if(flag: bool, lhs: int, rhs: int) -> int: + return lhs if flag else rhs + + +@udf(input_types=["VARCHAR", "VARCHAR", "VARCHAR"], result_type="VARCHAR") +def split_and_join(text: str, delimiter: str, glue: str) -> str: + return glue.join(text.split(delimiter)) + + +@udf(input_types=["BINARY"], result_type="BINARY") +def reverse_bytes(payload: bytes) -> bytes: + return payload[::-1] + + +@udf(input_types=["DECIMAL(36, 18)", "DECIMAL(36, 18)"], result_type="DECIMAL(72, 28)") +def decimal_div(lhs: Decimal, rhs: Decimal) -> Decimal: + result = lhs / rhs + return result.quantize(Decimal("0." + "0" * 28)) + + +@udf(input_types=["DATE", "INT"], result_type="DATE") +def add_days(base: datetime.date, days: int) -> datetime.date: + return base + datetime.timedelta(days=days) + + +@udf(input_types=["ARRAY(VARCHAR)", "INT"], result_type="VARCHAR") +def array_access(values: List[str], index: int) -> Optional[str]: + if index <= 0 or index > len(values): + return None + return values[index - 1] + + +@udf(input_types=["MAP(VARCHAR,VARCHAR)", "VARCHAR"], result_type="VARCHAR") +def map_lookup(mapping: Dict[str, str], key: str) -> Optional[str]: + return mapping.get(key) + + +@udf(input_types=["INT"], result_type="INT") +def wait(value: int) -> int: + time.sleep(0.05) + return value + + +@udf( + input_types=["ARRAY(VARCHAR)", "INT", "INT"], + result_type="ARRAY(VARCHAR)", +) +def tuple_slice(values: List[Any], i: int, j: int) -> List[Any]: + first = values[i] if 0 <= i < len(values) else None + second = values[j] if 0 <= j < len(values) else None + return [first, second] + + +@udf( + input_types=[ + "BOOLEAN", + "TINYINT", + "SMALLINT", + "INT", + "BIGINT", + "FLOAT", + "DOUBLE", + "DATE", + "TIMESTAMP", + "VARCHAR", + "VARIANT", + ], + result_type="VARIANT", +) +def all_types_snapshot( + flag, + tiny, + small, + integer, + big, + flt, + dbl, + date, + timestamp, + text, + variant, +): + if isinstance(variant, str): + try: + variant_value = json.loads(variant) + except json.JSONDecodeError: + variant_value = variant + else: + variant_value = variant + return { + "flag": flag, + "tiny": tiny, + "small": small, + "integer": integer, + "big": big, + "flt": flt, + "dbl": dbl, + "date": date.isoformat() if date else None, + "timestamp": timestamp.isoformat() if timestamp else None, + "text": text, + "variant": variant_value, + } + + +def create_server(port: int) -> UDFServer: + server = UDFServer(f"0.0.0.0:{port}") + for func in [ + add_four, + select_if, + split_and_join, + reverse_bytes, + decimal_div, + add_days, + array_access, + map_lookup, + wait, + tuple_slice, + all_types_snapshot, + ]: + server.add_function(func) + return server + + +if __name__ == "__main__": + import sys + + port = int(sys.argv[1]) if len(sys.argv) > 1 else 8815 + create_server(port).serve() diff --git a/python/tests/servers/stage_server.py b/python/tests/servers/stage_server.py new file mode 100644 index 0000000..f4b3df6 --- /dev/null +++ b/python/tests/servers/stage_server.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python3 +"""Stage-aware UDF server used in integration tests.""" + +import logging +import sys + +from databend_udf import StageLocation, UDFServer, udf + +logging.basicConfig(level=logging.INFO) + + +@udf(stage_refs=["data_stage"], input_types=["INT"], result_type="VARCHAR") +def stage_summary(stage: StageLocation, value: int) -> str: + assert stage.stage_type.lower() == "external" + assert stage.storage + bucket = stage.storage.get("bucket", stage.storage.get("container", "")) + return f"{stage.stage_name}:{bucket}:{stage.relative_path}:{value}" + + +@udf( + stage_refs=["input_stage", "output_stage"], + input_types=["INT"], + result_type="INT", +) +def multi_stage_process( # Align with documentation name + input_stage: StageLocation, output_stage: StageLocation, value: int +) -> int: + assert input_stage.storage and output_stage.storage + assert input_stage.stage_type.lower() == "external" + assert output_stage.stage_type.lower() == "external" + # Simple deterministic behaviour for testing + return ( + value + + len(input_stage.storage.get("bucket", "")) + + len(output_stage.storage.get("bucket", "")) + ) + + +def create_server(port: int): + server = UDFServer(f"0.0.0.0:{port}") + server.add_function(stage_summary) + server.add_function(multi_stage_process) + return server + + +if __name__ == "__main__": + port = int(sys.argv[1]) if len(sys.argv) > 1 else 8815 + server = create_server(port) + server.serve() diff --git a/python/tests/test_demo_server.py b/python/tests/test_demo_server.py new file mode 100644 index 0000000..1167b64 --- /dev/null +++ b/python/tests/test_demo_server.py @@ -0,0 +1,73 @@ +"""Integration tests against the self-contained demo server.""" + +import datetime +from decimal import Decimal + + +def test_simple_functions(demo_server): + client = demo_server.get_client() + assert client.call_function("add_four", 1, 2, 3, 4) == [10] + assert client.call_function("select_if", True, 5, 6) == [5] + + +def test_strings_and_binary(demo_server): + client = demo_server.get_client() + assert client.call_function("split_and_join", "a,b", ",", ":") == ["a:b"] + assert client.call_function("reverse_bytes", b"xyz") == [b"zyx"] + + +def test_decimal_and_dates(demo_server): + client = demo_server.get_client() + dec = client.call_function("decimal_div", Decimal("1"), Decimal("3"))[0] + assert dec == Decimal("0.3333333333333333333333333333") + assert client.call_function("add_days", datetime.date(2024, 1, 1), 2) == [ + datetime.date(2024, 1, 3) + ] + + +def test_collections(demo_server): + client = demo_server.get_client() + array_result = client.call_function_batch( + "array_access", values=[["foo", "bar"]], index=[2] + ) + assert array_result == ["bar"] + assert client.call_function("map_lookup", {"x": "y"}, "x") == ["y"] + + +def test_wait_and_tuple(demo_server): + client = demo_server.get_client() + assert client.call_function("wait", 9) == [9] + tuple_result = client.call_function_batch( + "tuple_slice", values=[["a", "b", "c"]], i=[0], j=[2] + ) + assert tuple_result == [["a", "c"]] + + +def test_return_all_types(demo_server): + client = demo_server.get_client() + values = [ + True, + -1, + 2, + 3, + 4, + 1.5, + 2.5, + datetime.date(2024, 2, 1), + datetime.datetime(2024, 2, 1, 12, 30, 0), + "hello", + '{"k":"v"}', + ] + import json + + snapshot_raw = client.call_function("all_types_snapshot", *values)[0] + if isinstance(snapshot_raw, (bytes, bytearray)): + snapshot = json.loads(snapshot_raw.decode("utf-8")) + elif isinstance(snapshot_raw, str): + snapshot = json.loads(snapshot_raw) + else: + snapshot = snapshot_raw + assert snapshot["flag"] is True + assert snapshot["tiny"] == -1 + assert snapshot["text"] == "hello" + assert snapshot["date"] == "2024-02-01" diff --git a/python/tests/test_stage_integration.py b/python/tests/test_stage_integration.py new file mode 100644 index 0000000..96e992c --- /dev/null +++ b/python/tests/test_stage_integration.py @@ -0,0 +1,164 @@ +"""End-to-end tests for stage-aware UDF functions.""" + +import json + +import pyarrow as pa +import pytest + + +def _s3_stage(param_name: str, bucket: str, path: str, stage_name: str = None) -> dict: + stage_name = stage_name or param_name + return { + "param_name": param_name, + "relative_path": path, + "stage_info": { + "stage_name": stage_name, + "stage_type": "External", + "stage_params": { + "storage": { + "type": "s3", + "bucket": bucket, + "access_key_id": f"ak-{bucket}", + "secret_access_key": f"sk-{bucket}", + } + }, + }, + } + + +def _gcs_stage(param_name: str, bucket: str, path: str, stage_name: str = None) -> dict: + stage_name = stage_name or param_name + return { + "param_name": param_name, + "relative_path": path, + "stage_info": { + "stage_name": stage_name, + "stage_type": "External", + "stage_params": { + "storage": { + "type": "gcs", + "bucket": bucket, + "credential": json.dumps( + { + "type": "service_account", + "client_email": "udf@databend.dev", + "private_key": "-----BEGIN PRIVATE KEY-----", + } + ), + } + }, + }, + } + + +def _azblob_stage( + param_name: str, container: str, path: str, stage_name: str = None +) -> dict: + stage_name = stage_name or param_name + return { + "param_name": param_name, + "relative_path": path, + "stage_info": { + "stage_name": stage_name, + "stage_type": "External", + "stage_params": { + "storage": { + "type": "azblob", + "container": container, + "account_name": "account", + "account_key": "key==", + } + }, + }, + } + + +def test_stage_integration_single_stage(stage_server): + client = stage_server.get_client() + + stage_locations = [_s3_stage("data_stage", "input-bucket", "data/input/")] + result = client.call_function("stage_summary", 5, stage_locations=stage_locations) + + assert result == ["data_stage:input-bucket:data/input/:5"] + + +def test_stage_integration_multiple_stages(stage_server): + client = stage_server.get_client() + + stage_locations = [ + _s3_stage("input_stage", "input-bucket", "data/input/"), + _s3_stage("output_stage", "output-bucket", "data/output/"), + ] + + result = client.call_function( + "multi_stage_process", 10, stage_locations=stage_locations + ) + + expected = 10 + len("input-bucket") + len("output-bucket") + assert result == [expected] + + +def test_stage_integration_gcs(stage_server): + client = stage_server.get_client() + + stage_locations = [ + _gcs_stage("data_stage", "gcs-bucket", "gcs/path/", stage_name="gcs_stage") + ] + result = client.call_function("stage_summary", 3, stage_locations=stage_locations) + + assert result == ["gcs_stage:gcs-bucket:gcs/path/:3"] + + +def test_stage_integration_azblob(stage_server): + client = stage_server.get_client() + + stage_locations = [ + _azblob_stage("data_stage", "container", "azure/path/", stage_name="az_stage") + ] + result = client.call_function("stage_summary", 4, stage_locations=stage_locations) + + assert result == ["az_stage:container:azure/path/:4"] + + +def test_stage_integration_rejects_fs(stage_server): + client = stage_server.get_client() + + stage_locations = [ + { + "param_name": "data_stage", + "relative_path": "data/", + "stage_info": { + "stage_name": "data_stage", + "stage_type": "External", + "stage_params": {"storage": {"type": "fs", "root": "/tmp"}}, + }, + } + ] + + with pytest.raises(pa.ArrowInvalid) as exc: + client.call_function("stage_summary", 1, stage_locations=stage_locations) + + assert "'fs' storage" in str(exc.value) + + +def test_stage_integration_rejects_internal(stage_server): + client = stage_server.get_client() + + stage_locations = [ + { + "param_name": "data_stage", + "relative_path": "data/", + "stage_info": { + "stage_name": "data_stage", + "stage_type": "Internal", + "stage_params": { + "storage": {"type": "s3", "bucket": "internal-bucket"} + }, + }, + } + ] + + with pytest.raises(pa.ArrowInvalid) as exc: + client.call_function("stage_summary", 1, stage_locations=stage_locations) + + assert "External stage" in str(exc.value) diff --git a/python/tests/test_stage_location.py b/python/tests/test_stage_location.py new file mode 100644 index 0000000..9ae93ef --- /dev/null +++ b/python/tests/test_stage_location.py @@ -0,0 +1,208 @@ +"""Unit tests for StageLocation parsing and injection helpers.""" + +import json + +import pyarrow as pa +import pytest + +from databend_udf import StageLocation, UDFClient, udf +from databend_udf.udf import Headers + + +def _make_batch(values): + schema = pa.schema([pa.field("value", pa.int32())]) + return pa.RecordBatch.from_arrays([pa.array(values, pa.int32())], schema=schema) + + +def _collect(func, batch, headers): + results = [] + for output in func.eval_batch(batch, headers): + results.extend(output.column(0).to_pylist()) + return results + + +@udf(stage_refs=["stage_loc"], input_types=["INT"], result_type="VARCHAR") +def describe_stage(stage: StageLocation, value: int) -> str: + assert stage.stage_type.lower() == "external" + assert stage.storage + return f"{stage.stage_name}:{stage.relative_path}:{value}" + + +@udf(stage_refs=["input_stage"], input_types=["INT"], result_type="INT") +def renamed_stage(input_stage: StageLocation, value: int) -> int: + assert input_stage.storage + return value + + +@udf(input_types=["INT"], result_type="INT") +def annotated_stage(stage: StageLocation, value: int) -> int: + assert stage.storage + return value + + +@udf(stage_refs=["input_stage", "output_stage"], input_types=["INT"], result_type="INT") +def multi_stage( + input_stage: StageLocation, output_stage: StageLocation, value: int +) -> int: + assert input_stage.storage and output_stage.storage + return value + + +def test_stage_mapping_basic_list(): + payload = UDFClient.format_stage_mapping( + [ + { + "param_name": "stage_loc", + "relative_path": "input/2024/", + "stage_info": { + "stage_name": "stage_loc", + "stage_type": "External", + "stage_params": {"storage": {"type": "s3", "bucket": "demo"}}, + }, + } + ] + ) + headers = Headers({"databend-stage-mapping": [payload]}) + result = _collect(describe_stage, _make_batch([1]), headers) + assert result == ["stage_loc:input/2024/:1"] + + +def test_stage_mapping_dict_payload(): + payload = json.dumps( + { + "stage_loc": { + "relative_path": "path/", + "stage_info": { + "stage_name": "dict_stage", + "stage_type": "External", + "stage_params": {"storage": {"type": "s3", "bucket": "demo"}}, + }, + } + } + ) + headers = Headers({"Databend-Stage-Mapping": [payload]}) + stage = headers.require_stage_locations(["stage_loc"])["stage_loc"] + assert stage.stage_name == "dict_stage" + assert stage.relative_path == "path/" + + +def test_stage_refs_rename(): + payload = UDFClient.format_stage_mapping( + [ + { + "param_name": "input_stage", + "relative_path": "input/", + "stage_info": { + "stage_name": "input_stage", + "stage_type": "External", + "stage_params": {"storage": {"type": "s3", "bucket": "alias"}}, + }, + } + ] + ) + headers = Headers({"Databend-Stage-Mapping": [payload]}) + assert _collect(renamed_stage, _make_batch([7]), headers) == [7] + + +def test_type_annotation_detection(): + payload = UDFClient.format_stage_mapping( + [ + { + "param_name": "stage", + "relative_path": "annotated/", + "stage_info": { + "stage_name": "annotated", + "stage_type": "External", + "stage_params": {"storage": {"type": "s3", "bucket": "anno"}}, + }, + } + ] + ) + headers = Headers({"databend-stage-mapping": [payload]}) + assert _collect(annotated_stage, _make_batch([5]), headers) == [5] + + +def test_multiple_stage_entries(): + payload = UDFClient.format_stage_mapping( + [ + { + "param_name": "input_stage", + "relative_path": "input/", + "stage_info": { + "stage_name": "input_stage", + "stage_type": "External", + "stage_params": {"storage": {"type": "s3", "bucket": "input"}}, + }, + }, + { + "param_name": "output_stage", + "relative_path": "output/", + "stage_info": { + "stage_name": "output_stage", + "stage_type": "External", + "stage_params": {"storage": {"type": "s3", "bucket": "output"}}, + }, + }, + ] + ) + headers = Headers({"Databend-Stage-Mapping": [payload]}) + assert _collect(multi_stage, _make_batch([2]), headers) == [2] + + +def test_missing_stage_mapping(): + with pytest.raises(ValueError, match="Missing stage mapping"): + _collect(describe_stage, _make_batch([1]), Headers()) + + +def test_missing_storage_rejected(): + payload = UDFClient.format_stage_mapping( + [ + { + "param_name": "stage_loc", + "stage_info": { + "stage_name": "no_storage", + "stage_type": "External", + "stage_params": {}, + }, + } + ] + ) + headers = Headers({"databend-stage-mapping": [payload]}) + with pytest.raises(ValueError, match="storage configuration"): + _collect(describe_stage, _make_batch([1]), headers) + + +def test_internal_stage_rejected(): + payload = UDFClient.format_stage_mapping( + [ + { + "param_name": "stage_loc", + "stage_info": { + "stage_name": "internal", + "stage_type": "Internal", + "stage_params": {"storage": {"type": "s3", "bucket": "demo"}}, + }, + } + ] + ) + headers = Headers({"databend-stage-mapping": [payload]}) + with pytest.raises(ValueError, match="External stage"): + _collect(describe_stage, _make_batch([1]), headers) + + +def test_fs_storage_rejected(): + payload = UDFClient.format_stage_mapping( + [ + { + "param_name": "stage_loc", + "stage_info": { + "stage_name": "bad", + "stage_type": "External", + "stage_params": {"storage": {"type": "fs", "root": "/tmp"}}, + }, + } + ] + ) + headers = Headers({"databend-stage-mapping": [payload]}) + with pytest.raises(ValueError, match="'fs' storage"): + _collect(describe_stage, _make_batch([1]), headers)