Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions python/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -147,3 +147,12 @@ python3 examples/server.py

### Acknowledgement
Databend Python UDF Server API is inspired by [RisingWave Python API](https://pypi.org/project/risingwave/).

### Code Formatting

Use Ruff to keep the Python sources consistent:

```bash
python -m pip install ruff # once
python -m ruff format python/databend_udf python/tests
```
121 changes: 0 additions & 121 deletions python/README_CLIENT.md

This file was deleted.

92 changes: 84 additions & 8 deletions python/databend_udf/client.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
"""
Simple client library for testing Databend UDF servers.
"""
"""Simple client library for testing Databend UDF servers."""

import json
from typing import Any, Dict, Iterable, List, Sequence, Tuple

import pyarrow as pa
import pyarrow.flight as fl
from typing import List, Any


class UDFClient:
Expand Down Expand Up @@ -135,7 +135,69 @@ def get_function_info(self, function_name: str) -> fl.FlightInfo:
descriptor = fl.FlightDescriptor.for_path(function_name)
return self.client.get_flight_info(descriptor)

def call_function(self, function_name: str, *args) -> List[Any]:
@staticmethod
def format_stage_mapping(stage_locations: Iterable[Dict[str, Any]]) -> str:
"""Serialize stage mapping entries to the Databend header payload."""

serialized_entries: List[Dict[str, Any]] = []
for entry in stage_locations:
if not isinstance(entry, dict):
raise ValueError("stage_locations entries must be dictionaries")
if "param_name" not in entry:
raise ValueError("stage_locations entry requires 'param_name'")
serialized_entries.append(entry)

return json.dumps(serialized_entries)

@staticmethod
def _build_flight_headers(
headers: Dict[str, Any] = None,
stage_locations: Iterable[Dict[str, Any]] = None,
) -> Sequence[Tuple[str, str]]:
"""Construct Flight headers for a UDF call.

``stage_locations`` becomes a single header named ``databend-stage-mapping``
whose value is a JSON array. This mirrors what Databend Query sends to
external UDF servers. Example HTTP-style representation::

databend-stage-mapping: [
{
"param_name": "stage_loc",
"relative_path": "input/2024/",
"stage_info": { ... StageInfo JSON ... }
}
]

Multiple stage parameters simply append more objects to the array.
Additional custom headers can be supplied through ``headers``.
"""
headers = headers or {}
flight_headers: List[Tuple[bytes, bytes]] = []

for key, value in headers.items():
if isinstance(value, (list, tuple)):
for item in value:
flight_headers.append(
(str(key).encode("utf-8"), str(item).encode("utf-8"))
)
else:
flight_headers.append(
(str(key).encode("utf-8"), str(value).encode("utf-8"))
)

if stage_locations:
payload = UDFClient.format_stage_mapping(stage_locations)
flight_headers.append((b"databend-stage-mapping", payload.encode("utf-8")))

return flight_headers

def call_function(
self,
function_name: str,
*args,
headers: Dict[str, Any] = None,
stage_locations: Iterable[Dict[str, Any]] = None,
) -> List[Any]:
"""
Call a UDF function with given arguments.

Expand All @@ -150,7 +212,11 @@ def call_function(self, function_name: str, *args) -> List[Any]:

# Call function
descriptor = fl.FlightDescriptor.for_path(function_name)
writer, reader = self.client.do_exchange(descriptor=descriptor)
flight_headers = self._build_flight_headers(headers, stage_locations)
options = (
fl.FlightCallOptions(headers=flight_headers) if flight_headers else None
)
writer, reader = self.client.do_exchange(descriptor=descriptor, options=options)

with writer:
writer.begin(input_schema)
Expand All @@ -166,7 +232,13 @@ def call_function(self, function_name: str, *args) -> List[Any]:

return results

def call_function_batch(self, function_name: str, **kwargs) -> List[Any]:
def call_function_batch(
self,
function_name: str,
headers: Dict[str, Any] = None,
stage_locations: Iterable[Dict[str, Any]] = None,
**kwargs,
) -> List[Any]:
"""
Call a UDF function with batch data.

Expand All @@ -181,7 +253,11 @@ def call_function_batch(self, function_name: str, **kwargs) -> List[Any]:

# Call function
descriptor = fl.FlightDescriptor.for_path(function_name)
writer, reader = self.client.do_exchange(descriptor=descriptor)
flight_headers = self._build_flight_headers(headers, stage_locations)
options = (
fl.FlightCallOptions(headers=flight_headers) if flight_headers else None
)
writer, reader = self.client.do_exchange(descriptor=descriptor, options=options)

with writer:
writer.begin(input_schema)
Expand Down
Loading