diff --git a/_duckdb-stubs/_func.pyi b/_duckdb-stubs/_func.pyi index 68484499..8bf0480b 100644 --- a/_duckdb-stubs/_func.pyi +++ b/_duckdb-stubs/_func.pyi @@ -1,6 +1,14 @@ import typing as pytyping -__all__: list[str] = ["ARROW", "DEFAULT", "NATIVE", "SPECIAL", "FunctionNullHandling", "PythonUDFType"] +__all__: list[str] = [ + "ARROW", + "DEFAULT", + "NATIVE", + "SPECIAL", + "FunctionNullHandling", + "PythonTableUDFType", + "PythonUDFType", +] class FunctionNullHandling: DEFAULT: pytyping.ClassVar[FunctionNullHandling] # value = @@ -21,6 +29,25 @@ class FunctionNullHandling: @property def value(self) -> int: ... +class PythonTableUDFType: + ARROW_TABLE: pytyping.ClassVar[PythonTableUDFType] # value = + TUPLES: pytyping.ClassVar[PythonTableUDFType] # value = + __members__: pytyping.ClassVar[ + dict[str, PythonTableUDFType] + ] # value = {'TUPLES': , 'ARROW_TABLE': } + def __eq__(self, other: object) -> bool: ... + def __getstate__(self) -> int: ... + def __hash__(self) -> int: ... + def __index__(self) -> int: ... + def __init__(self, value: pytyping.SupportsInt) -> None: ... + def __int__(self) -> int: ... + def __ne__(self, other: object) -> bool: ... + def __setstate__(self, state: pytyping.SupportsInt) -> None: ... + @property + def name(self) -> str: ... + @property + def value(self) -> int: ... + class PythonUDFType: ARROW: pytyping.ClassVar[PythonUDFType] # value = NATIVE: pytyping.ClassVar[PythonUDFType] # value = diff --git a/duckdb/func/__init__.py b/duckdb/func/__init__.py index 5d73f490..fdb1adfd 100644 --- a/duckdb/func/__init__.py +++ b/duckdb/func/__init__.py @@ -1,3 +1,11 @@ -from _duckdb._func import ARROW, DEFAULT, NATIVE, SPECIAL, FunctionNullHandling, PythonUDFType # noqa: D104 +from _duckdb._func import ( # noqa: D104 + ARROW, + DEFAULT, + NATIVE, + SPECIAL, + FunctionNullHandling, + PythonTableUDFType, + PythonUDFType, +) -__all__ = ["ARROW", "DEFAULT", "NATIVE", "SPECIAL", "FunctionNullHandling", "PythonUDFType"] +__all__ = ["ARROW", "DEFAULT", "NATIVE", "SPECIAL", "FunctionNullHandling", "PythonTableUDFType", "PythonUDFType"] diff --git a/duckdb/functional/__init__.py b/duckdb/functional/__init__.py index 5114629b..9830c94f 100644 --- a/duckdb/functional/__init__.py +++ b/duckdb/functional/__init__.py @@ -2,9 +2,9 @@ import warnings -from duckdb.func import ARROW, DEFAULT, NATIVE, SPECIAL, FunctionNullHandling, PythonUDFType +from duckdb.func import ARROW, DEFAULT, NATIVE, SPECIAL, FunctionNullHandling, PythonTableUDFType, PythonUDFType -__all__ = ["ARROW", "DEFAULT", "NATIVE", "SPECIAL", "FunctionNullHandling", "PythonUDFType"] +__all__ = ["ARROW", "DEFAULT", "NATIVE", "SPECIAL", "FunctionNullHandling", "PythonTableUDFType", "PythonUDFType"] warnings.warn( "`duckdb.functional` is deprecated and will be removed in a future version. Please use `duckdb.func` instead.", diff --git a/scripts/connection_methods.json b/scripts/connection_methods.json index a87b992f..b11cce49 100644 --- a/scripts/connection_methods.json +++ b/scripts/connection_methods.json @@ -107,6 +107,51 @@ ], "return": "DuckDBPyConnection" }, + { + "name": "create_table_function", + "function": "RegisterTableFunction", + "docs": "Register a table valued function via Callable", + "args": [ + { + "name": "name", + "type": "str" + }, + { + "name": "callable", + "type": "Callable" + } + ], + "kwargs": [ + { + "name": "parameters", + "type": "Optional[Any]", + "default": "None" + }, + { + "name": "schema", + "type": "Optional[Any]", + "default": "None" + }, + { + "name": "type", + "type": "Optional[PythonTableUDFType]", + "default": "PythonTableUDFType.TUPLES" + } + ], + "return": "DuckDBPyConnection" + }, + { + "name": "unregister_table_function", + "function": "UnregisterTableFunction", + "docs": "Unregister a table valued function", + "args": [ + { + "name": "name", + "type": "str" + } + ], + "return": "DuckDBPyConnection" + }, { "name": [ "sqltype", @@ -412,7 +457,6 @@ "fetch_record_batch", "arrow" ], - "function": "FetchRecordBatchReader", "docs": "Fetch an Arrow RecordBatchReader following execute()", "args": [ @@ -1094,4 +1138,4 @@ ], "return": "None" } -] +] \ No newline at end of file diff --git a/src/duckdb_py/CMakeLists.txt b/src/duckdb_py/CMakeLists.txt index 3d06b062..1fcede3e 100644 --- a/src/duckdb_py/CMakeLists.txt +++ b/src/duckdb_py/CMakeLists.txt @@ -28,6 +28,7 @@ add_library( python_dependency.cpp python_import_cache.cpp python_replacement_scan.cpp + python_table_udf.cpp python_udf.cpp) target_link_libraries(python_src PRIVATE _duckdb_dependencies) diff --git a/src/duckdb_py/functional/functional.cpp b/src/duckdb_py/functional/functional.cpp index 252634b1..32e64145 100644 --- a/src/duckdb_py/functional/functional.cpp +++ b/src/duckdb_py/functional/functional.cpp @@ -10,6 +10,11 @@ void DuckDBPyFunctional::Initialize(py::module_ &parent) { .value("ARROW", duckdb::PythonUDFType::ARROW) .export_values(); + py::enum_(m, "PythonTableUDFType") + .value("TUPLES", duckdb::PythonTableUDFType::TUPLES) + .value("ARROW_TABLE", duckdb::PythonTableUDFType::ARROW_TABLE) + .export_values(); + py::enum_(m, "FunctionNullHandling") .value("DEFAULT", duckdb::FunctionNullHandling::DEFAULT_NULL_HANDLING) .value("SPECIAL", duckdb::FunctionNullHandling::SPECIAL_HANDLING) diff --git a/src/duckdb_py/include/duckdb_python/pybind11/conversions/python_table_udf_type_enum.hpp b/src/duckdb_py/include/duckdb_python/pybind11/conversions/python_table_udf_type_enum.hpp new file mode 100644 index 00000000..d61deeca --- /dev/null +++ b/src/duckdb_py/include/duckdb_python/pybind11/conversions/python_table_udf_type_enum.hpp @@ -0,0 +1,72 @@ +#pragma once + +#include "duckdb/common/common.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/common/string_util.hpp" + +using duckdb::InvalidInputException; +using duckdb::string; +using duckdb::StringUtil; + +namespace duckdb { + +enum class PythonTableUDFType : uint8_t { TUPLES, ARROW_TABLE }; + +} // namespace duckdb + +using duckdb::PythonTableUDFType; + +namespace py = pybind11; + +static PythonTableUDFType PythonTableUDFTypeFromString(const string &type) { + auto ltype = StringUtil::Lower(type); + if (ltype.empty() || ltype == "tuples") { + return PythonTableUDFType::TUPLES; + } else if (ltype == "arrow_table") { + return PythonTableUDFType::ARROW_TABLE; + } else { + throw InvalidInputException("'%s' is not a recognized type for 'tvf_type'", type); + } +} + +static PythonTableUDFType PythonTableUDFTypeFromInteger(int64_t value) { + if (value == 0) { + return PythonTableUDFType::TUPLES; + } else if (value == 1) { + return PythonTableUDFType::ARROW_TABLE; + } else { + throw InvalidInputException("'%d' is not a recognized type for 'tvf_type'", value); + } +} + +namespace PYBIND11_NAMESPACE { +namespace detail { + +template <> +struct type_caster : public type_caster_base { + using base = type_caster_base; + PythonTableUDFType tmp; + +public: + bool load(handle src, bool convert) { + if (base::load(src, convert)) { + return true; + } else if (py::isinstance(src)) { + tmp = PythonTableUDFTypeFromString(py::str(src)); + value = &tmp; + return true; + } else if (py::isinstance(src)) { + tmp = PythonTableUDFTypeFromInteger(src.cast()); + value = &tmp; + return true; + } + return false; + } + + static handle cast(PythonTableUDFType src, return_value_policy policy, handle parent) { + return base::cast(src, policy, parent); + } +}; + +} // namespace detail +} // namespace PYBIND11_NAMESPACE \ No newline at end of file diff --git a/src/duckdb_py/include/duckdb_python/pyconnection/pyconnection.hpp b/src/duckdb_py/include/duckdb_python/pyconnection/pyconnection.hpp index 48ee055e..9109d5c5 100644 --- a/src/duckdb_py/include/duckdb_python/pyconnection/pyconnection.hpp +++ b/src/duckdb_py/include/duckdb_python/pyconnection/pyconnection.hpp @@ -23,6 +23,7 @@ #include "duckdb/function/scalar_function.hpp" #include "duckdb_python/pybind11/conversions/exception_handling_enum.hpp" #include "duckdb_python/pybind11/conversions/python_udf_type_enum.hpp" +#include "duckdb_python/pybind11/conversions/python_table_udf_type_enum.hpp" #include "duckdb_python/pybind11/conversions/python_csv_line_terminator_enum.hpp" #include "duckdb/common/shared_ptr.hpp" @@ -169,6 +170,8 @@ struct DuckDBPyConnection : public enable_shared_from_this { //! MemoryFileSystem used to temporarily store file-like objects for reading shared_ptr internal_object_filesystem; case_insensitive_map_t> registered_functions; + case_insensitive_map_t> registered_table_functions; + case_insensitive_set_t registered_objects; public: @@ -232,6 +235,13 @@ struct DuckDBPyConnection : public enable_shared_from_this { PythonExceptionHandling exception_handling = PythonExceptionHandling::FORWARD_ERROR, bool side_effects = false); + shared_ptr RegisterTableFunction(const string &name, const py::function &function, + const py::object &schema, + PythonTableUDFType type = PythonTableUDFType::TUPLES, + const py::object ¶meters = py::none()); + + shared_ptr UnregisterTableFunction(const string &name); + shared_ptr UnregisterUDF(const string &name); shared_ptr ExecuteMany(const py::object &query, py::object params = py::list()); @@ -355,6 +365,11 @@ struct DuckDBPyConnection : public enable_shared_from_this { const shared_ptr &return_type, bool vectorized, FunctionNullHandling null_handling, PythonExceptionHandling exception_handling, bool side_effects); + + duckdb::TableFunction CreateTableFunctionFromCallable(const std::string &name, const py::function &callable, + const py::object ¶meters, const py::object &schema, + PythonTableUDFType type); + void RegisterArrowObject(const py::object &arrow_object, const string &name); vector> GetStatements(const py::object &query); diff --git a/src/duckdb_py/pyconnection.cpp b/src/duckdb_py/pyconnection.cpp index b88b88ed..32645690 100644 --- a/src/duckdb_py/pyconnection.cpp +++ b/src/duckdb_py/pyconnection.cpp @@ -393,7 +393,6 @@ DuckDBPyConnection::RegisterScalarUDF(const string &name, const py::function &ud auto scalar_function = CreateScalarUDF(name, udf, parameters_p, return_type_p, type == PythonUDFType::ARROW, null_handling, exception_handling, side_effects); CreateScalarFunctionInfo info(scalar_function); - context.RegisterFunction(info); auto dependency = make_uniq(); @@ -403,6 +402,55 @@ DuckDBPyConnection::RegisterScalarUDF(const string &name, const py::function &ud return shared_from_this(); } +shared_ptr +DuckDBPyConnection::RegisterTableFunction(const string &name, const py::function &function, const py::object &schema, + PythonTableUDFType type, const py::object ¶meters) { + + auto &connection = con.GetConnection(); + auto &context = *connection.context; + + if (context.transaction.HasActiveTransaction()) { + context.CancelTransaction(); + } + + if (registered_table_functions.find(name) != registered_table_functions.end()) { + throw NotImplementedException("A table function by the name of '%s' is already registered, " + "please unregister it first", + name); + } + + auto table_function = CreateTableFunctionFromCallable(name, function, parameters, schema, type); + CreateTableFunctionInfo info(table_function); + + // re-registration: changing the callable to another + info.on_conflict = OnCreateConflict::REPLACE_ON_CONFLICT; + + context.RegisterFunction(info); + + auto dependency = make_uniq(); + dependency->AddDependency("function", PythonDependencyItem::Create(function)); + registered_table_functions[name] = std::move(dependency); + + return shared_from_this(); +} + +shared_ptr DuckDBPyConnection::UnregisterTableFunction(const string &name) { + auto entry = registered_table_functions.find(name); + if (entry == registered_table_functions.end()) { + throw InvalidInputException( + "No table function by the name of '%s' was found in the list of registered table functions", name); + } + + auto &connection = con.GetConnection(); + auto &context = *connection.context; + + // Remove from our registry. + // TODO: Callable still exists in the function catalog, since duckdb doesn't (yet?) support removal + registered_table_functions.erase(entry); + + return shared_from_this(); +} + void DuckDBPyConnection::Initialize(py::handle &m) { auto connection_module = py::class_>(m, "DuckDBPyConnection", py::module_local()); @@ -411,6 +459,14 @@ void DuckDBPyConnection::Initialize(py::handle &m) { .def("__exit__", &DuckDBPyConnection::Exit, py::arg("exc_type"), py::arg("exc"), py::arg("traceback")); connection_module.def("__del__", &DuckDBPyConnection::Close); + connection_module.def("create_table_function", &DuckDBPyConnection::RegisterTableFunction, + "Register a table user defined function via Callable", py::arg("name"), py::arg("callable"), + py::arg("schema"), py::kw_only(), py::arg("type") = PythonTableUDFType::TUPLES, + py::arg("parameters") = py::none()); + + connection_module.def("unregister_table_function", &DuckDBPyConnection::UnregisterTableFunction, + "Unregister a table user defined function", py::arg("name")); + InitializeConnectionMethods(connection_module); connection_module.def_property_readonly("description", &DuckDBPyConnection::GetDescription, "Get result set attributes, mainly column names"); @@ -1575,7 +1631,12 @@ unique_ptr DuckDBPyConnection::RunQuery(const py::object &quer } if (res->type == QueryResultType::STREAM_RESULT) { auto &stream_result = res->Cast(); - res = stream_result.Materialize(); + { + // Release the GIL, as Materialize *may* need the GIL (TVFs, for instance) + D_ASSERT(py::gil_check()); + py::gil_scoped_release release; + res = stream_result.Materialize(); + } } auto &materialized_result = res->Cast(); relation = make_shared_ptr(connection.context, materialized_result.TakeCollection(), @@ -1826,6 +1887,7 @@ void DuckDBPyConnection::Close() { // https://peps.python.org/pep-0249/#Connection.close cursors.ClearCursors(); registered_functions.clear(); + registered_table_functions.clear(); } void DuckDBPyConnection::Interrupt() { diff --git a/src/duckdb_py/python_table_udf.cpp b/src/duckdb_py/python_table_udf.cpp new file mode 100644 index 00000000..ed45f105 --- /dev/null +++ b/src/duckdb_py/python_table_udf.cpp @@ -0,0 +1,374 @@ +#include "duckdb_python/pybind11/pybind_wrapper.hpp" +#include "duckdb_python/pytype.hpp" +#include "duckdb_python/pyconnection/pyconnection.hpp" +#include "duckdb/common/arrow/arrow.hpp" +#include "duckdb/common/arrow/arrow_wrapper.hpp" +#include "duckdb_python/arrow/arrow_array_stream.hpp" +#include "duckdb/function/table/arrow.hpp" +#include "duckdb/function/function.hpp" +#include "duckdb/parser/tableref/table_function_ref.hpp" +#include "duckdb_python/python_conversion.hpp" +#include "duckdb_python/python_objects.hpp" +#include "duckdb_python/pybind11/python_object_container.hpp" + +namespace duckdb { + +struct PyTableUDFInfo : public TableFunctionInfo { + py::function callable; + vector return_types; + vector return_names; + PythonTableUDFType return_type; + + PyTableUDFInfo(py::function callable_p, vector types_p, vector names_p, + PythonTableUDFType return_type_p) + : callable(std::move(callable_p)), return_types(std::move(types_p)), return_names(std::move(names_p)), + return_type(return_type_p) { + } + + ~PyTableUDFInfo() override { + py::gil_scoped_acquire acquire; + callable = py::function(); + } +}; + +struct PyTableUDFBindData : public TableFunctionData { + string func_name; + vector args; + named_parameter_map_t kwargs; + vector return_types; + vector return_names; + PythonObjectContainer python_objects; // Holds the callable + + PyTableUDFBindData(string func_name, vector args, named_parameter_map_t kwargs, + vector return_types, vector return_names, py::function callable) + : func_name(std::move(func_name)), args(std::move(args)), kwargs(std::move(kwargs)), + return_types(std::move(return_types)), return_names(std::move(return_names)) { + // gil acquired inside push + python_objects.Push(std::move(callable)); + } +}; + +struct PyTableUDFTuplesGlobalState : public GlobalTableFunctionState { + PythonObjectContainer python_objects; + bool iterator_exhausted = false; + + PyTableUDFTuplesGlobalState() : iterator_exhausted(false) { + } +}; + +struct PyTableUDFArrowGlobalState : public GlobalTableFunctionState { + unique_ptr arrow_factory; + unique_ptr arrow_bind_data; + unique_ptr arrow_global_state; + PythonObjectContainer python_objects; + idx_t num_columns; + + PyTableUDFArrowGlobalState() { + } +}; + +static void PyTableUDFTuplesScanFunction(ClientContext &context, TableFunctionInput &input, DataChunk &output) { + auto &gs = input.global_state->Cast(); + auto &bd = input.bind_data->Cast(); + + if (gs.iterator_exhausted) { + output.SetCardinality(0); + return; + } + + py::gil_scoped_acquire gil; + auto &it = gs.python_objects.LastAddedObject(); + + idx_t row_idx = 0; + for (idx_t i = 0; i < STANDARD_VECTOR_SIZE; i++) { + py::object next_item; + try { + next_item = it.attr("__next__")(); + } catch (py::error_already_set &e) { + if (e.matches(PyExc_StopIteration)) { + gs.iterator_exhausted = true; + PyErr_Clear(); + break; + } + throw; + } + + try { + // Extract each column from the tuple/list + for (idx_t col_idx = 0; col_idx < bd.return_types.size(); col_idx++) { + auto py_val = next_item[py::int_(col_idx)]; + Value duck_val = TransformPythonValue(py_val, bd.return_types[col_idx]); + output.SetValue(col_idx, row_idx, duck_val); + } + } catch (py::error_already_set &e) { + throw InvalidInputException("Table function '%s' returned invalid data: %s", bd.func_name, e.what()); + } + row_idx++; + } + output.SetCardinality(row_idx); +} + +struct PyTableUDFArrowLocalState : public LocalTableFunctionState { + unique_ptr arrow_local_state; + + explicit PyTableUDFArrowLocalState(unique_ptr arrow_local) + : arrow_local_state(std::move(arrow_local)) { + } +}; + +static void PyTableUDFArrowScanFunction(ClientContext &context, TableFunctionInput &input, DataChunk &output) { + // Delegates to ArrowScanFunction + auto &gs = input.global_state->Cast(); + auto &ls = input.local_state->Cast(); + + TableFunctionInput arrow_input(gs.arrow_bind_data.get(), ls.arrow_local_state.get(), gs.arrow_global_state.get()); + ArrowTableFunction::ArrowScanFunction(context, arrow_input, output); +} + +static unique_ptr PyTableUDFBindInternal(ClientContext &context, TableFunctionBindInput &in, + vector &return_types, + vector &return_names) { + if (!in.info) { + throw InvalidInputException("Table function '%s' missing function info", in.table_function.name); + } + + auto &tableudf_info = in.info->Cast(); + return_types = tableudf_info.return_types; + return_names = tableudf_info.return_names; + + // Acquire gil before copying py::function + py::gil_scoped_acquire gil; + return make_uniq(in.table_function.name, in.inputs, in.named_parameters, return_types, + return_names, tableudf_info.callable); +} + +static unique_ptr PyTableUDFTuplesBindFunction(ClientContext &context, TableFunctionBindInput &in, + vector &return_types, + vector &return_names) { + auto bd = PyTableUDFBindInternal(context, in, return_types, return_names); + return std::move(bd); +} + +static unique_ptr PyTableUDFArrowBindFunction(ClientContext &context, TableFunctionBindInput &in, + vector &return_types, + vector &return_names) { + auto bd = PyTableUDFBindInternal(context, in, return_types, return_names); + return std::move(bd); +} + +static py::object CallPythonTableUDF(ClientContext &context, PyTableUDFBindData &bd) { + py::gil_scoped_acquire gil; + + // positional arguments + py::tuple args(bd.args.size()); + for (idx_t i = 0; i < bd.args.size(); i++) { + args[i] = PythonObject::FromValue(bd.args[i], bd.args[i].type(), context.GetClientProperties()); + } + + // keyword arguments + py::dict kwargs; + for (auto &kv : bd.kwargs) { + kwargs[py::str(kv.first)] = PythonObject::FromValue(kv.second, kv.second.type(), context.GetClientProperties()); + } + + // Call Python function + auto &callable = bd.python_objects.LastAddedObject(); + py::object result = callable(*args, **kwargs); + + if (result.is_none()) { + throw InvalidInputException("Table function '%s' returned None, expected iterable or Arrow table", + bd.func_name); + } + + return result; +} + +static unique_ptr PyTableUDFTuplesInitGlobal(ClientContext &context, + TableFunctionInitInput &in) { + auto &bd = in.bind_data->Cast(); + auto gs = make_uniq(); + + { + py::gil_scoped_acquire gil; + // const_cast is safe here - we only read from python_objects, not modify bind_data structure + py::object result = CallPythonTableUDF(context, const_cast(bd)); + try { + py::iterator it = py::iter(result); + gs->python_objects.Push(std::move(it)); + gs->iterator_exhausted = false; + } catch (const py::error_already_set &e) { + throw InvalidInputException("Table function '%s' returned non-iterable result: %s", bd.func_name, e.what()); + } + } + + return std::move(gs); +} + +static unique_ptr PyTableUDFArrowInitGlobal(ClientContext &context, + TableFunctionInitInput &in) { + auto &bd = in.bind_data->Cast(); + auto gs = make_uniq(); + + { + py::gil_scoped_acquire gil; + + py::object result = CallPythonTableUDF(context, const_cast(bd)); + PyObject *ptr = result.ptr(); + + gs->python_objects.Push(std::move(result)); + + gs->arrow_factory = make_uniq(ptr, context.GetClientProperties(), + DBConfig::GetConfig(context)); + } + + // Build bind input for Arrow scan + vector children; + children.push_back(Value::POINTER(CastPointerToValue(gs->arrow_factory.get()))); + children.push_back(Value::POINTER(CastPointerToValue(PythonTableArrowArrayStreamFactory::Produce))); + children.push_back(Value::POINTER(CastPointerToValue(PythonTableArrowArrayStreamFactory::GetSchema))); + + TableFunctionRef empty_ref; + duckdb::TableFunction dummy_tf; + dummy_tf.name = "PyTableUDFArrowWrapper"; + + named_parameter_map_t named_params; + vector input_types; + vector input_names; + + TableFunctionBindInput bind_input(children, named_params, input_types, input_names, nullptr, nullptr, dummy_tf, + empty_ref); + + vector return_types; + vector return_names; + gs->arrow_bind_data = ArrowTableFunction::ArrowScanBind(context, bind_input, return_types, return_names); + + // Validate Arrow schema matches declared + if (return_types.size() != bd.return_types.size()) { + throw InvalidInputException("Schema mismatch in table function '%s': " + "Arrow table has %lu columns but %lu were declared", + bd.func_name, return_types.size(), bd.return_types.size()); + } + + // Check column types match + for (idx_t i = 0; i < return_types.size(); i++) { + if (return_types[i] != bd.return_types[i]) { + throw InvalidInputException("Schema mismatch in table function '%s' at column %lu: " + "Arrow table has type %s but %s was declared", + bd.func_name, i, return_types[i].ToString().c_str(), + bd.return_types[i].ToString().c_str()); + } + } + + gs->num_columns = return_types.size(); + vector all_columns; + for (idx_t i = 0; i < gs->num_columns; i++) { + all_columns.push_back(i); + } + + TableFunctionInitInput init_input(gs->arrow_bind_data.get(), all_columns, all_columns, in.filters.get()); + gs->arrow_global_state = ArrowTableFunction::ArrowScanInitGlobal(context, init_input); + + return std::move(gs); +} + +static unique_ptr +PyTableUDFArrowInitLocal(ExecutionContext &context, TableFunctionInitInput &in, GlobalTableFunctionState *gstate) { + auto &gs = gstate->Cast(); + + vector all_columns; + for (idx_t i = 0; i < gs.num_columns; i++) { + all_columns.push_back(i); + } + + TableFunctionInitInput arrow_init(gs.arrow_bind_data.get(), all_columns, all_columns, in.filters.get()); + auto arrow_local_state = + ArrowTableFunction::ArrowScanInitLocalInternal(context.client, arrow_init, gs.arrow_global_state.get()); + + return make_uniq(std::move(arrow_local_state)); +} + +duckdb::TableFunction DuckDBPyConnection::CreateTableFunctionFromCallable(const std::string &name, + const py::function &callable, + const py::object ¶meters, + const py::object &schema, + PythonTableUDFType type) { + + // Schema + if (schema.is_none()) { + throw InvalidInputException("Table functions require a schema."); + } + + vector types; + vector names; + + // Schema must be dict format: {"col1": DuckDBPyType, "col2": DuckDBPyType} + if (!py::isinstance(schema)) { + throw InvalidInputException("Table function '%s' schema must be a dict mapping column names to duckdb.sqltypes " + "(e.g., {\"col1\": INTEGER, \"col2\": VARCHAR})", + name); + } + + auto schema_dict = py::cast(schema); + for (auto item : schema_dict) { + // schema is a dict of str => DuckDBPyType + + string col_name = py::str(item.first); + names.emplace_back(col_name); + + auto type_obj = py::cast(item.second); + + // Check for string BEFORE DuckDBPyType because pybind11 has implicit conversion from str to DuckDBPyType + if (py::isinstance(type_obj)) { + throw InvalidInputException("Invalid schema format: type for column '%s' must be a duckdb.sqltype (e.g., " + "INTEGER, VARCHAR), not a string. " + "Use sqltypes.%s instead of \"%s\"", + col_name, py::str(type_obj).cast().c_str(), + py::str(type_obj).cast().c_str()); + } + + if (!py::isinstance(type_obj)) { + throw InvalidInputException( + "Invalid schema format: type for column '%s' must be a duckdb.sqltype (e.g., INTEGER, VARCHAR), got %s", + col_name, py::str(type_obj.get_type()).cast()); + } + auto pytype = py::cast>(type_obj); + types.emplace_back(pytype->Type()); + } + + if (types.empty()) { + throw InvalidInputException("Table function '%s' schema cannot be empty", name); + } + + duckdb::TableFunction tf; + switch (type) { + case PythonTableUDFType::TUPLES: + tf = duckdb::TableFunction(name, {}, PyTableUDFTuplesScanFunction, PyTableUDFTuplesBindFunction, + PyTableUDFTuplesInitGlobal); + break; + case PythonTableUDFType::ARROW_TABLE: + tf = duckdb::TableFunction(name, {}, PyTableUDFArrowScanFunction, PyTableUDFArrowBindFunction, + PyTableUDFArrowInitGlobal, PyTableUDFArrowInitLocal); + break; + default: + throw InvalidInputException("Unknown return type for table function '%s'", name); + } + + // Store the Python callable and schema + tf.function_info = make_shared_ptr(callable, types, names, type); + + // args + tf.varargs = LogicalType::ANY; + tf.named_parameters["args"] = LogicalType::ANY; + + // kwargs + if (!parameters.is_none()) { + for (auto ¶m : py::cast(parameters)) { + string param_name = py::str(param); + tf.named_parameters[param_name] = LogicalType::ANY; + } + } + + return tf; +} + +} // namespace duckdb diff --git a/tests/fast/table_udf/test_arrow.py b/tests/fast/table_udf/test_arrow.py new file mode 100644 index 00000000..c709f4bf --- /dev/null +++ b/tests/fast/table_udf/test_arrow.py @@ -0,0 +1,244 @@ +from collections.abc import Iterator + +import pytest + +import duckdb +import duckdb.sqltypes as sqltypes +from duckdb.functional import PythonTableUDFType + + +def tuple_generator(count: int = 10) -> Iterator[tuple[str, int]]: + for i in range(count): + yield (f"name_{i}", i) + + +def simple_arrow_table(count: int): + pa = pytest.importorskip("pyarrow") + + data = { + "id": list(range(count)), + "value": [i * 2 for i in range(count)], + "name": [f"row_{i}" for i in range(count)], + } + return pa.table(data) + + +def arrow_all_types(count: int): + pa = pytest.importorskip("pyarrow") + from datetime import datetime, timedelta, timezone + from decimal import Decimal + + now = datetime.now(timezone.utc) + data = { + "col_tinyint": pa.array(range(count), type=pa.int8()), + "col_smallint": pa.array(range(count), type=pa.int16()), + "col_int": pa.array(range(count), type=pa.int32()), + "col_bigint": pa.array(range(count), type=pa.int64()), + "col_utinyint": pa.array(range(count), type=pa.uint8()), + "col_usmallint": pa.array(range(count), type=pa.uint16()), + "col_uint": pa.array(range(count), type=pa.uint32()), + "col_ubigint": pa.array(range(count), type=pa.uint64()), + "col_float": pa.array((i * 1.5 for i in range(count)), type=pa.float32()), + "col_double": pa.array((i * 2.5 for i in range(count)), type=pa.float64()), + "col_varchar": pa.array((f"row_{i}" for i in range(count)), type=pa.string()), + "col_bool": pa.array((i % 2 == 0 for i in range(count)), type=pa.bool_()), + "col_timestamp": pa.array( + (now + timedelta(seconds=i) for i in range(count)), type=pa.timestamp("us", tz="UTC") + ), + "col_date": pa.array((now.date() + timedelta(days=i) for i in range(count)), type=pa.date32()), + "col_time": pa.array(((now + timedelta(microseconds=i)).time() for i in range(count)), type=pa.time64("ns")), + "col_decimal": pa.array((Decimal(i) / 10 for i in range(count)), type=pa.decimal128(10, 2)), + "col_blob": pa.array((f"bin_{i}".encode() for i in range(count)), type=pa.binary()), + "col_list": pa.array(([i, i + 1] for i in range(count)), type=pa.list_(pa.int32())), + "col_struct": pa.array( + ({"x": i, "y": float(i)} for i in range(count)), type=pa.struct([("x", pa.int32()), ("y", pa.float32())]) + ), + } + return pa.table(data) + + +ALL_TYPES_SCHEMA = { + "col_tinyint": sqltypes.TINYINT, + "col_smallint": sqltypes.SMALLINT, + "col_int": sqltypes.INTEGER, + "col_bigint": sqltypes.BIGINT, + "col_utinyint": sqltypes.UTINYINT, + "col_usmallint": sqltypes.USMALLINT, + "col_uint": sqltypes.UINTEGER, + "col_ubigint": sqltypes.UBIGINT, + "col_float": sqltypes.FLOAT, + "col_double": sqltypes.DOUBLE, + "col_varchar": sqltypes.VARCHAR, + "col_bool": sqltypes.BOOLEAN, + "col_timestamp": sqltypes.TIMESTAMP_TZ, + "col_date": sqltypes.DATE, + "col_time": sqltypes.TIME, + "col_decimal": duckdb.decimal_type(10, 2), + "col_blob": sqltypes.BLOB, + "col_list": duckdb.list_type(sqltypes.INTEGER), + "col_struct": duckdb.struct_type({"x": sqltypes.INTEGER, "y": sqltypes.FLOAT}), +} + + +def test_arrow_small(tmp_path): + pytest.importorskip("pyarrow") + + with duckdb.connect(tmp_path / "test.duckdb") as conn: + conn.create_table_function( + "simple_arrow", + simple_arrow_table, + schema={"x": sqltypes.BIGINT, "y": sqltypes.BIGINT, "name": sqltypes.VARCHAR}, + type=PythonTableUDFType.ARROW_TABLE, + ) + + result = conn.execute("SELECT * FROM simple_arrow(5)").fetchall() + + assert len(result) == 5 + + # Should fail because it's not defined in this conn + with duckdb.connect(tmp_path / "test2.duckdb") as conn, pytest.raises(duckdb.CatalogException): + result = conn.execute("SELECT * FROM simple_arrow(5)").fetchall() + + with duckdb.connect(tmp_path / "test.duckdb") as conn: + # Should fail because schema is missing a col + conn.create_table_function( + "simple_arrow", + simple_arrow_table, + schema={"x": sqltypes.BIGINT, "y": sqltypes.BIGINT}, + type=PythonTableUDFType.ARROW_TABLE, + ) + with pytest.raises(duckdb.InvalidInputException) as exc_info: + result = conn.execute("SELECT * FROM simple_arrow(5)").fetchall() + assert "Vector::Reference" in str(exc_info.value) or "schema" in str(exc_info.value).lower() + + +def test_arrow_large_1(tmp_path): + """tests: more rows, aggregation, limits, named parameters, parameters.""" + pytest.importorskip("pyarrow") + + with duckdb.connect(tmp_path / "test.duckdb") as conn: + n = 2048 * 1000 + + conn.create_table_function( + "large_arrow", + simple_arrow_table, + schema={"id": sqltypes.BIGINT, "value": sqltypes.BIGINT, "name": sqltypes.VARCHAR}, + type="arrow_table", + parameters=["count"], + ) + + result = conn.execute("SELECT COUNT(*) FROM large_arrow(count:=?)", parameters=(n,)).fetchone() + assert result[0] == n + + df = conn.sql(f"SELECT * FROM large_arrow({n}) LIMIT 10").df() + assert len(df) == 10 + assert df["id"].tolist() == list(range(10)) + + arrow_result = conn.execute("SELECT * FROM large_arrow(?)", parameters=(n,)).fetch_arrow_table() + assert len(arrow_result) == n + + result = conn.sql("SELECT SUM(value) FROM large_arrow(count:=$count)", params={"count": n}).fetchone() + expected_sum = sum(i * 2 for i in range(n)) + assert result[0] == expected_sum + + +def test_arrowbatched_execute(tmp_path): + pytest.importorskip("pyarrow") + + count = 2048 * 1000 + with duckdb.connect(tmp_path / "test.duckdb") as conn: + schema = {"name": sqltypes.VARCHAR, "id": sqltypes.INTEGER} + + conn.create_table_function( + name="gen_function", + callable=tuple_generator, + parameters=None, + schema=schema, + type="tuples", + ) + + result = conn.execute( + "SELECT * FROM gen_function(?)", + parameters=(count,), + ).fetch_record_batch() + + result = conn.execute( + f"SELECT * FROM gen_function({count})", + ).fetch_record_batch() + + c = 0 + for batch in result: + c += batch.num_rows + assert c == count + + +def test_arrowbatched_sql_relation(tmp_path): + pytest.importorskip("pyarrow") + + count = 2048 * 1000 + with duckdb.connect(tmp_path / "test.duckdb") as conn: + schema = {"name": sqltypes.VARCHAR, "id": sqltypes.INTEGER} + + conn.create_table_function( + name="gen_function", + callable=tuple_generator, + parameters=None, + schema=schema, + type="tuples", + ) + + result = conn.sql( + f"SELECT * FROM gen_function({count})", + ).fetch_arrow_reader() + + c = 0 + for batch in result: + c += batch.num_rows + assert c == count + + +def test_arrow_types(tmp_path): + """Return many types from an arrow table UDF, and verify the results are correct.""" + pytest.importorskip("pyarrow") + + with duckdb.connect(tmp_path / "test.duckdb") as conn: + conn.create_table_function( + "all_types_arrow", + arrow_all_types, + schema=ALL_TYPES_SCHEMA, + type=PythonTableUDFType.ARROW_TABLE, + ) + + result = conn.execute("SELECT * FROM all_types_arrow(3)").fetchall() + assert len(result) == 3 + + first_row = result[0] + assert first_row[0] == 0 # col_tinyint + assert first_row[1] == 0 # col_smallint + assert first_row[2] == 0 # col_int + assert first_row[3] == 0 # col_bigint + assert first_row[10] == "row_0" # col_varchar + assert first_row[11] is True # col_bool + + result = conn.execute("SELECT SUM(col_int), AVG(col_float) FROM all_types_arrow(100)").fetchone() + expected_sum = sum(range(100)) + assert result[0] == expected_sum + + result = conn.execute("SELECT COUNT(*) FROM all_types_arrow(50) WHERE col_bool = true").fetchone() + assert result[0] == 25 # Half should be true (even numbers) + + result = conn.execute("SELECT col_varchar, col_int FROM all_types_arrow(5)").fetchall() + assert len(result) == 5 + assert result[2] == ("row_2", 2) + + result = conn.execute("SELECT col_list FROM all_types_arrow(2)").fetchall() + assert result[0][0] == [0, 1] + assert result[1][0] == [1, 2] + + result = conn.sql("SELECT col_struct FROM all_types_arrow(2)").fetchall() + assert result[0][0] == {"x": 0, "y": 0.0} + assert result[1][0] == {"x": 1, "y": 1.0} + + schema_result = conn.sql("DESCRIBE SELECT * FROM all_types_arrow(1)").fetchall() + column_names = [row[0] for row in schema_result] + assert list(ALL_TYPES_SCHEMA.keys()) == column_names diff --git a/tests/fast/table_udf/test_arrow_schema.py b/tests/fast/table_udf/test_arrow_schema.py new file mode 100644 index 00000000..a11b77c9 --- /dev/null +++ b/tests/fast/table_udf/test_arrow_schema.py @@ -0,0 +1,116 @@ +import pytest + +import duckdb +import duckdb.sqltypes as sqltypes +from duckdb.functional import PythonTableUDFType + + +def simple_arrow_table(count: int = 10): + import pyarrow as pa + + data = { + "id": list(range(count)), + "value": [i * 2 for i in range(count)], + "name": [f"row_{i}" for i in range(count)], + } + return pa.table(data) + + +def test_arrow_correct_schema(tmp_path): + pytest.importorskip("pyarrow") + + with duckdb.connect(tmp_path / "test.duckdb") as conn: + conn.create_table_function( + "arrow_func", + simple_arrow_table, + schema={"id": sqltypes.BIGINT, "value": sqltypes.BIGINT, "name": sqltypes.VARCHAR}, + type=PythonTableUDFType.ARROW_TABLE, + ) + + result = conn.execute("SELECT * FROM arrow_func(5)").fetchall() + assert len(result) == 5 + assert result[0] == (0, 0, "row_0") + + +def test_arrow_more_columns(tmp_path): + pytest.importorskip("pyarrow") + + with duckdb.connect(tmp_path / "test.duckdb") as conn: + # table has 3 cols, but declare only 2 + conn.create_table_function( + "arrow_func", + simple_arrow_table, + schema={"x": sqltypes.BIGINT, "y": sqltypes.BIGINT}, # Missing third column + type=PythonTableUDFType.ARROW_TABLE, + ) + + with pytest.raises(duckdb.InvalidInputException) as exc_info: + conn.execute("SELECT * FROM arrow_func(5)").fetchall() + + error_msg = str(exc_info.value).lower() + assert "schema mismatch" in error_msg or "3 columns" in error_msg or "2 were declared" in error_msg + + +def test_arrow_fewer_columns(tmp_path): + pytest.importorskip("pyarrow") + + with duckdb.connect(tmp_path / "test.duckdb") as conn: + # table has 3 columns, but declare 4 + conn.create_table_function( + "arrow_func", + simple_arrow_table, + schema={ + "id": sqltypes.BIGINT, + "value": sqltypes.BIGINT, + "name": sqltypes.VARCHAR, + "extra": sqltypes.INTEGER, # Extra column that doesn't exist + }, + type=PythonTableUDFType.ARROW_TABLE, + ) + + with pytest.raises(duckdb.InvalidInputException) as exc_info: + conn.execute("SELECT * FROM arrow_func(5)").fetchall() + + error_msg = str(exc_info.value).lower() + assert "schema mismatch" in error_msg or "3 columns" in error_msg or "4 were declared" in error_msg + + +def test_arrow_type_mismatch(tmp_path): + pytest.importorskip("pyarrow") + + with duckdb.connect(tmp_path / "test.duckdb") as conn: + conn.create_table_function( + "arrow_func", + simple_arrow_table, + schema={ + "id": sqltypes.VARCHAR, # Wrong type - should be BIGINT + "value": sqltypes.BIGINT, + "name": sqltypes.VARCHAR, + }, + type=PythonTableUDFType.ARROW_TABLE, + ) + + with pytest.raises(duckdb.InvalidInputException) as exc_info: + conn.execute("SELECT * FROM arrow_func(5)").fetchall() + + error_msg = str(exc_info.value).lower() + assert "type" in error_msg or "mismatch" in error_msg + + +def test_arrow_name_mismatch_allowed(tmp_path): + pytest.importorskip("pyarrow") + + with duckdb.connect(tmp_path / "test.duckdb") as conn: + conn.create_table_function( + "arrow_func", + simple_arrow_table, + schema={ + "a": sqltypes.BIGINT, # Arrow has 'id' + "b": sqltypes.BIGINT, # Arrow has 'value' + "c": sqltypes.VARCHAR, # Arrow has 'name' + }, + type=PythonTableUDFType.ARROW_TABLE, + ) + + result = conn.execute("SELECT * FROM arrow_func(3)").fetchall() + assert len(result) == 3 diff --git a/tests/fast/table_udf/test_register.py b/tests/fast/table_udf/test_register.py new file mode 100644 index 00000000..44ef91ae --- /dev/null +++ b/tests/fast/table_udf/test_register.py @@ -0,0 +1,313 @@ +import pytest + +import duckdb +import duckdb.sqltypes as sqltypes + + +def test_registry_collision(tmp_path): + """Two table_udfs on different connections with same name""" "" + conn1 = duckdb.connect(tmp_path / "db1.db") + conn2 = duckdb.connect(tmp_path / "db2.db") + + def func_for_conn1(): + return [("conn1_data", 1)] + + def func_for_conn2(): + return [("conn2_data", 2)] + + schema = {"name": sqltypes.VARCHAR, "id": sqltypes.INTEGER} + + conn1.create_table_function( + name="same_name", + callable=func_for_conn1, + parameters=None, + schema=schema, + type="tuples", + ) + + conn2.create_table_function( + name="same_name", + callable=func_for_conn2, + parameters=None, + schema=schema, + type="tuples", + ) + + result1 = conn1.execute("SELECT * FROM same_name()").fetchall() + assert result1[0][0] == "conn1_data" + assert result1[0][1] == 1 + + result2 = conn2.execute("SELECT * FROM same_name()").fetchall() + assert result2[0][0] == "conn2_data" + assert result2[0][1] == 2 + + result1 = conn1.sql("SELECT * FROM same_name()").fetchall() + assert result1[0][0] == "conn1_data" + assert result1[0][1] == 1 + + conn1.close() + conn2.close() + + +def test_replace_without_unregister(tmp_path): + with duckdb.connect(tmp_path / "test.db") as conn: + + def func_v1(): + return [("version_1", 1)] + + def func_v2(): + return [("version_2", 2)] + + schema = {"name": sqltypes.VARCHAR, "id": sqltypes.INTEGER} + + conn.create_table_function("test_func", func_v1, schema=schema, type="tuples") + + result = conn.execute("SELECT * FROM test_func()").fetchall() + assert result[0][0] == "version_1" + assert result[0][1] == 1 + + with pytest.raises(duckdb.NotImplementedException) as exc_info: + conn.create_table_function("test_func", func_v2, schema=schema, type="tuples") + assert "already registered" in str(exc_info.value) + + +def test_replace_after_unregister(tmp_path): + with duckdb.connect(tmp_path / "test.db") as conn: + + def func_v1(): + return [("version_1", 1)] + + def func_v2(): + return [("version_2", 2)] + + def func_v3(): + return [("version_3", 3)] + + schema = {"name": sqltypes.VARCHAR, "id": sqltypes.INTEGER} + + conn.create_table_function("test_func", func_v1, schema=schema, type="tuples") + + result = conn.execute("SELECT * FROM test_func()").fetchall() + assert result[0][0] == "version_1" + + conn.unregister_table_function("test_func") + conn.create_table_function("test_func", func_v2, schema=schema, type="tuples") + + result = conn.execute("SELECT * FROM test_func()").fetchall() + assert result[0][0] == "version_2" + + conn.unregister_table_function("test_func") + + result = conn.execute("SELECT * FROM test_func()").fetchall() + assert result[0][0] == "version_2" + + conn.create_table_function("test_func", func_v3, schema=schema, type="tuples") + + result = conn.execute("SELECT * FROM test_func()").fetchall() + assert result[0][0] == "version_3" + assert result[0][1] == 3 + + +def test_multiple_replacements(tmp_path): + """Replacing Table UDFs multiple times.""" + with duckdb.connect(tmp_path / "test.db") as conn: + schema = {"value": sqltypes.INTEGER} + + for i in range(1, 6): + + def make_func(val=i): + def func(): + return [(val,)] + + return func + + if i > 1: + conn.unregister_table_function("counter") + + conn.create_table_function("counter", make_func(), schema=schema, type="tuples") + + result = conn.execute("SELECT * FROM counter()").fetchone() + assert result[0] == i + + +def test_replacement_with_different_schemas(tmp_path): + """Changing schema with replacements.""" + with duckdb.connect(tmp_path / "test.db") as conn: + + def func_v1(): + return [("test", 1)] + + def func_v2(): + return [("modified", 2, 3.14)] + + schema_v1 = {"name": sqltypes.VARCHAR, "id": sqltypes.INTEGER} + conn.create_table_function("evolving_func", func_v1, schema=schema_v1, type="tuples") + + result = conn.execute("SELECT * FROM evolving_func()").fetchall() + assert len(result[0]) == 2 + assert result[0][0] == "test" + + schema_v2 = {"name": sqltypes.VARCHAR, "id": sqltypes.INTEGER, "value": sqltypes.DOUBLE} + conn.unregister_table_function("evolving_func") # Must unregister first + conn.create_table_function("evolving_func", func_v2, schema=schema_v2, type="tuples") + + result = conn.execute("SELECT * FROM evolving_func()").fetchall() + assert len(result[0]) == 3 + assert result[0][0] == "modified" + assert result[0][2] == 3.14 + + +def test_replacement_2(tmp_path): + with duckdb.connect(tmp_path / "test.db") as conn: + + def func_v1(): + return [("v1",)] + + def func_v2(): + return [("v2",)] + + schema = {"version": sqltypes.VARCHAR} + + conn.create_table_function("tracked_func", func_v1, schema=schema, type="tuples") + + conn.unregister_table_function("tracked_func") # Must unregister first + conn.create_table_function("tracked_func", func_v2, schema=schema, type="tuples") + + conn.unregister_table_function("tracked_func") + + with pytest.raises(duckdb.InvalidInputException) as exc_info: + conn.unregister_table_function("tracked_func") + assert "No table function by the name of 'tracked_func'" in str(exc_info.value) + + result = conn.execute("SELECT * FROM tracked_func()").fetchone() + assert result[0] == "v2" + + +def test_sql_drop_table_function(tmp_path): + """Documents current behavior - that dropping functions has no effect on Table UDFs.""" + with duckdb.connect(tmp_path / "test.db") as conn: + + def test_func(): + return [("test_value", 1)] + + schema = {"name": sqltypes.VARCHAR, "id": sqltypes.INTEGER} + conn.create_table_function("test_func", test_func, schema=schema, type="tuples") + + result = conn.execute("SELECT * FROM test_func()").fetchall() + assert result[0][0] == "test_value" + assert result[0][1] == 1 + + with pytest.raises(duckdb.CatalogException): + conn.execute("DROP FUNCTION test_func") + + result = conn.execute("SELECT * FROM test_func()").fetchall() + assert result[0][0] == "test_value" + assert result[0][1] == 1 + + +def test_unregister_table_function(tmp_path): + with duckdb.connect(tmp_path / "test.db") as conn: + + def simple_function(): + return [("test_value", 1)] + + schema = {"name": sqltypes.VARCHAR, "id": sqltypes.INTEGER} + + conn.create_table_function( + name="test_func", + callable=simple_function, + parameters=None, + schema=schema, + type="tuples", + ) + + result = conn.execute("SELECT * FROM test_func()").fetchall() + assert len(result) == 1 + assert result[0][0] == "test_value" + assert result[0][1] == 1 + + conn.unregister_table_function("test_func") + + result = conn.execute("SELECT * FROM test_func()").fetchall() + assert len(result) == 1 + assert result[0][0] == "test_value" + assert result[0][1] == 1 + + with pytest.raises(duckdb.InvalidInputException) as exc_info: + conn.unregister_table_function("test_func") + + assert "No table function by the name of 'test_func'" in str(exc_info.value) + + +def test_unregister_doesntexist(tmp_path): + with duckdb.connect(tmp_path / "test.db") as conn: + with pytest.raises(duckdb.InvalidInputException) as exc_info: + conn.unregister_table_function("nonexistent_func") + + assert "No table function by the name of 'nonexistent_func'" in str(exc_info.value) + + +def test_reregister(tmp_path): + with duckdb.connect(tmp_path / "test.db") as conn: + + def func_v1(): + return [("version_1", 1)] + + def func_v2(): + return [("version_2", 2)] + + schema = {"name": sqltypes.VARCHAR, "id": sqltypes.INTEGER} + + conn.create_table_function( + name="versioned_func", + callable=func_v1, + schema=schema, + type="tuples", + ) + + result = conn.execute("SELECT * FROM versioned_func()").fetchall() + assert result[0][0] == "version_1" + + conn.unregister_table_function("versioned_func") + + conn.create_table_function( + name="versioned_func", + callable=func_v2, + schema=schema, + type="tuples", + ) + + result = conn.execute("SELECT * FROM versioned_func()").fetchall() + assert result[0][0] == "version_2" + + +def test_unregister_multi(tmp_path): + with duckdb.connect(tmp_path / "test.db") as conn: + cursor1 = conn.cursor() + cursor2 = conn.cursor() + + def test_func(): + return [("test_data", 1)] + + schema = {"name": sqltypes.VARCHAR, "id": sqltypes.INTEGER} + + cursor1.create_table_function( + name="shared_func", + callable=test_func, + schema=schema, + type="tuples", + ) + + result1 = cursor1.execute("SELECT * FROM shared_func()").fetchall() + assert result1[0][0] == "test_data" + + result2 = cursor2.execute("SELECT * FROM shared_func()").fetchall() + assert result2[0][0] == "test_data" + + cursor1.unregister_table_function("shared_func") + + result1 = cursor1.execute("SELECT * FROM shared_func()").fetchall() + assert result1[0][0] == "test_data" + + result2 = cursor2.execute("SELECT * FROM shared_func()").fetchall() + assert result2[0][0] == "test_data" diff --git a/tests/fast/table_udf/test_schema.py b/tests/fast/table_udf/test_schema.py new file mode 100644 index 00000000..7490162d --- /dev/null +++ b/tests/fast/table_udf/test_schema.py @@ -0,0 +1,429 @@ +"""Test schema validation for table-valued functions.""" + +from collections.abc import Iterator + +import pytest + +import duckdb +import duckdb.sqltypes as sqltypes + + +def test_valid_schema_basic_types(tmp_path): + def gen_function(): + return [("test", 42)] + + with duckdb.connect(tmp_path / "test.duckdb") as conn: + schema = {"name": sqltypes.VARCHAR, "id": sqltypes.INTEGER} + + conn.create_table_function( + name="gen_function", + callable=gen_function, + schema=schema, + type="tuples", + ) + + result = conn.sql("SELECT * FROM gen_function()").fetchall() + assert len(result) == 1 + assert result[0] == ("test", 42) + + +def test_valid_schema_numeric_types(tmp_path): + def gen_function(): + return [(1, 2, 3, 4, 5, 6.5, 7.25)] + + with duckdb.connect(tmp_path / "test.duckdb") as conn: + schema = { + "tiny": sqltypes.TINYINT, + "small": sqltypes.SMALLINT, + "int": sqltypes.INTEGER, + "big": sqltypes.BIGINT, + "huge": sqltypes.HUGEINT, + "float": sqltypes.FLOAT, + "double": sqltypes.DOUBLE, + } + + conn.create_table_function( + name="gen_function", + callable=gen_function, + schema=schema, + type="tuples", + ) + + result = conn.sql("SELECT * FROM gen_function()").fetchall() + assert len(result) == 1 + + +def test_valid_schema_temporal_types(tmp_path): + from datetime import date, datetime, time + + def gen_function(): + return [(date(2024, 1, 1), time(12, 30, 45), datetime(2024, 1, 1, 12, 30, 45))] + + with duckdb.connect(tmp_path / "test.duckdb") as conn: + schema = { + "d": sqltypes.DATE, + "t": sqltypes.TIME, + "ts": sqltypes.TIMESTAMP, + } + + conn.create_table_function( + name="gen_function", + callable=gen_function, + schema=schema, + type="tuples", + ) + + result = conn.sql("SELECT * FROM gen_function()").fetchall() + assert len(result) == 1 + + +def test_valid_schema_boolean_and_blob(tmp_path): + def gen_function(): + return [(True, b"binary_data")] + + with duckdb.connect(tmp_path / "test.duckdb") as conn: + schema = { + "flag": sqltypes.BOOLEAN, + "data": sqltypes.BLOB, + } + + conn.create_table_function( + name="gen_function", + callable=gen_function, + schema=schema, + type="tuples", + ) + + result = conn.sql("SELECT * FROM gen_function()").fetchall() + assert len(result) == 1 + assert result[0][0] is True + assert result[0][1] == b"binary_data" + + +def test_valid_schema_single_column(tmp_path): + def gen_function(): + return [(42,), (43,), (44,)] + + with duckdb.connect(tmp_path / "test.duckdb") as conn: + schema = {"value": sqltypes.INTEGER} + + conn.create_table_function( + name="gen_function", + callable=gen_function, + schema=schema, + type="tuples", + ) + + result = conn.sql("SELECT * FROM gen_function()").fetchall() + assert len(result) == 3 + + +def test_valid_schema_many_columns(tmp_path): + def gen_function(): + return [tuple(range(20))] + + with duckdb.connect(tmp_path / "test.duckdb") as conn: + schema = {f"col{i}": sqltypes.INTEGER for i in range(20)} + + conn.create_table_function( + name="gen_function", + callable=gen_function, + schema=schema, + type="tuples", + ) + + result = conn.sql("SELECT * FROM gen_function()").fetchall() + assert len(result) == 1 + assert len(result[0]) == 20 + + +def test_invalid_schema_none(tmp_path): + def gen_function(): + return [("test", 1)] + + with ( + duckdb.connect(tmp_path / "test.duckdb") as conn, + pytest.raises(duckdb.InvalidInputException, match="Table functions require a schema"), + ): + conn.create_table_function( + name="gen_function", + callable=gen_function, + schema=None, + type="tuples", + ) + + +def test_invalid_schema_empty_dict(tmp_path): + def gen_function(): + return [("test", 1)] + + with ( + duckdb.connect(tmp_path / "test.duckdb") as conn, + pytest.raises(duckdb.InvalidInputException, match="schema cannot be empty"), + ): + conn.create_table_function( + name="gen_function", + callable=gen_function, + schema={}, + type="tuples", + ) + + +def test_invalid_schema_list_format(tmp_path): + def gen_function(): + return [("test", 1)] + + with duckdb.connect(tmp_path / "test.duckdb") as conn: + schema = [["name", "VARCHAR"], ["id", "INT"]] + + with pytest.raises(duckdb.InvalidInputException, match="schema must be a dict"): + conn.create_table_function( + name="gen_function", + callable=gen_function, + schema=schema, + type="tuples", + ) + + +def test_invalid_schema_tuple_format(tmp_path): + def gen_function(): + return [("test", 1)] + + with duckdb.connect(tmp_path / "test.duckdb") as conn: + schema = [("name", "VARCHAR"), ("id", "INT")] + + with pytest.raises(duckdb.InvalidInputException, match="schema must be a dict"): + conn.create_table_function( + name="gen_function", + callable=gen_function, + schema=schema, + type="tuples", + ) + + +def test_invalid_schema_string_value(tmp_path): + """Test that string type values are rejected.""" + + def gen_function(): + return [("test", 1)] + + with duckdb.connect(tmp_path / "test.duckdb") as conn: + # String types should be rejected + schema = {"name": "VARCHAR", "id": "INT"} + + with pytest.raises(duckdb.InvalidInputException, match="must be a duckdb\\.sqltype"): + conn.create_table_function( + name="gen_function", + callable=gen_function, + schema=schema, + type="tuples", + ) + + +def test_invalid_schema_integer_value(tmp_path): + def gen_function(): + return [("test", 1)] + + with duckdb.connect(tmp_path / "test.duckdb") as conn: + schema = {"name": sqltypes.VARCHAR, "id": 123} + + with pytest.raises(duckdb.InvalidInputException, match="must be a duckdb\\.sqltype"): + conn.create_table_function( + name="gen_function", + callable=gen_function, + schema=schema, + type="tuples", + ) + + +def test_invalid_schema_none_value(tmp_path): + def gen_function(): + return [("test", 1)] + + with duckdb.connect(tmp_path / "test.duckdb") as conn: + schema = {"name": sqltypes.VARCHAR, "id": None} + + with pytest.raises(duckdb.InvalidInputException, match="must be a duckdb\\.sqltype"): + conn.create_table_function( + name="gen_function", + callable=gen_function, + schema=schema, + type="tuples", + ) + + +def test_invalid_schema_mixed_types(tmp_path): + """Test that schema with mix of DuckDBPyType and strings is rejected.""" + + def gen_function(): + return [("test", 1, 2.5)] + + with duckdb.connect(tmp_path / "test.duckdb") as conn: + # Mix of DuckDBPyType and string - should reject strings + schema = {"name": sqltypes.VARCHAR, "id": "INT", "value": sqltypes.DOUBLE} + + with pytest.raises(duckdb.InvalidInputException, match="must be a duckdb\\.sqltype"): + conn.create_table_function( + name="gen_function", + callable=gen_function, + schema=schema, + type="tuples", + ) + + +def test_invalid_schema_python_type(tmp_path): + def gen_function(): + return [("test", 1)] + + with duckdb.connect(tmp_path / "test.duckdb") as conn: + schema = {"name": str, "id": int} + + with pytest.raises(duckdb.InvalidInputException, match="must be a duckdb\\.sqltype"): + conn.create_table_function( + name="gen_function", + callable=gen_function, + schema=schema, + type="tuples", + ) + + +def test_invalid_schema_column_name_not_string(tmp_path): + def gen_function(): + return [(1, 2)] + + with duckdb.connect(tmp_path / "test.duckdb") as conn: + schema = {1: sqltypes.INTEGER, 2: sqltypes.INTEGER} + + conn.create_table_function( + name="gen_function", + callable=gen_function, + schema=schema, + type="tuples", + ) + + result = conn.sql("SELECT * FROM gen_function()").fetchall() + assert len(result) == 1 + + +def test_schema_column_name_special_characters(tmp_path): + def gen_function(): + return [("test", 42, 3.14)] + + with duckdb.connect(tmp_path / "test.duckdb") as conn: + schema = { + "my-column": sqltypes.VARCHAR, + "another_column": sqltypes.INTEGER, + "column.with.dots": sqltypes.DOUBLE, + } + + conn.create_table_function( + name="gen_function", + callable=gen_function, + schema=schema, + type="tuples", + ) + + result = conn.sql('SELECT "my-column", another_column FROM gen_function()').fetchall() + assert len(result) == 1 + + +def test_schema_preserved_order(tmp_path): + def gen_function(): + return [(1, 2, 3, 4, 5)] + + with duckdb.connect(tmp_path / "test.duckdb") as conn: + schema = { + "first": sqltypes.INTEGER, + "second": sqltypes.INTEGER, + "third": sqltypes.INTEGER, + "fourth": sqltypes.INTEGER, + "fifth": sqltypes.INTEGER, + } + + conn.create_table_function( + name="gen_function", + callable=gen_function, + schema=schema, + type="tuples", + ) + + result = conn.sql("DESCRIBE select * from gen_function()").fetchall() + column_names = [row[0] for row in result] + assert column_names == ["first", "second", "third", "fourth", "fifth"] + + +def test_schema(tmp_path): + def gen_function(count: int = 10) -> Iterator[tuple[str, int]]: + for i in range(count): + yield (f"name_{i}", i) + + with duckdb.connect(tmp_path / "test.duckdb") as conn: + schema = {"name": sqltypes.VARCHAR, "id": sqltypes.INTEGER} + + conn.create_table_function( + name="gen_function", + callable=gen_function, + schema=schema, + type="tuples", + ) + + result = conn.sql("SELECT * FROM gen_function(5)").fetchall() + assert len(result) == 5 + assert result[0][0] == "name_0" + assert result[-1][-1] == 4 + + +def test_schema_2(tmp_path): + """Test various types.""" + + def gen_function(count: int = 10) -> Iterator[tuple[str, int, float]]: + for i in range(count): + yield (f"name_{i}", i, i * 1.5) + + with duckdb.connect(tmp_path / "test.duckdb") as conn: + schema = {"name": sqltypes.VARCHAR, "id": sqltypes.INTEGER, "value": sqltypes.DOUBLE} + + conn.create_table_function( + name="gen_function", + callable=gen_function, + schema=schema, + type="tuples", + ) + + result = conn.sql("SELECT * FROM gen_function(3)").fetchall() + assert len(result) == 3 + assert result[0] == ("name_0", 0, 0.0) + assert result[2] == ("name_2", 2, 3.0) + + +def test_schema_invalid_type(tmp_path): + def gen_function(): + return [("test", 1)] + + with duckdb.connect(tmp_path / "test.duckdb") as conn: + schema = {"name": sqltypes.VARCHAR, "id": 123} # int is not valid + + with pytest.raises(duckdb.InvalidInputException, match="must be a duckdb\\.sqltype"): + conn.create_table_function( + name="gen_function", + callable=gen_function, + schema=schema, + type="tuples", + ) + + +def test_schema_not_dict(tmp_path): + def gen_function(): + return [("test", 1)] + + with duckdb.connect(tmp_path / "test.duckdb") as conn: + # Schema must be dict, not list (this is the old list format) + schema = [["name", "VARCHAR"], ["id", "INT"]] + + with pytest.raises(duckdb.InvalidInputException, match="schema must be a dict"): + conn.create_table_function( + name="gen_function", + callable=gen_function, + schema=schema, + type="tuples", + ) diff --git a/tests/fast/table_udf/test_tuples.py b/tests/fast/table_udf/test_tuples.py new file mode 100644 index 00000000..a91db688 --- /dev/null +++ b/tests/fast/table_udf/test_tuples.py @@ -0,0 +1,323 @@ +from collections.abc import Iterator + +import pytest + +import duckdb +import duckdb.sqltypes as sqltypes +from duckdb.functional import PythonTableUDFType + + +def simple_generator(count: int = 10) -> Iterator[tuple[str, int]]: + for i in range(count): + yield (f"name_{i}", i) + + +def simple_pylist(count: int = 10) -> list[tuple[str, int]]: + return [(f"name_{i}", i) for i in range(count)] + + +def simple_pylistlist(count: int = 10) -> list[list[str, int]]: + return [[f"name_{i}", i] for i in range(count)] + + +@pytest.mark.parametrize("gen_function", [simple_generator, simple_pylist, simple_pylistlist]) +def test_simple(tmp_path, gen_function): + with duckdb.connect(tmp_path / "test.duckdb") as conn: + schema = {"name": sqltypes.VARCHAR, "id": sqltypes.INTEGER} + + conn.create_table_function( + name="gen_function", + callable=gen_function, + parameters=None, + schema=schema, + type=PythonTableUDFType.TUPLES, + ) + + result = conn.sql("SELECT * FROM gen_function(5)").fetchall() + + assert len(result) == 5 + assert result[0][0] == "name_0" + assert result[-1][-1] == 4 + + result = conn.sql("SELECT * FROM gen_function()").fetchall() + + assert len(result) == 10 + assert result[-1][0] == "name_9" + assert result[-1][1] == 9 + + +@pytest.mark.parametrize("gen_function", [simple_generator]) +def test_simple_large_fetchall_default_type(tmp_path, gen_function): + count = 2048 * 1000 + with duckdb.connect(tmp_path / "test.duckdb") as conn: + schema = {"name": sqltypes.VARCHAR, "id": sqltypes.INTEGER} + + # don't pass type="tuples" to verify default is tuples + conn.create_table_function( + name="gen_function", + callable=gen_function, + parameters=None, + schema=schema, + ) + + result = conn.sql( + "SELECT * FROM gen_function(?)", + params=(count,), + ).fetchall() + + assert len(result) == count + assert result[0][0] == "name_0" + assert result[-1][-1] == count - 1 + + +@pytest.mark.parametrize("gen_function", [simple_generator]) +def test_simple_large_df(tmp_path, gen_function): + count = 2048 * 1000 + with duckdb.connect(tmp_path / "test.duckdb") as conn: + schema = {"name": sqltypes.VARCHAR, "id": sqltypes.INTEGER} + + conn.create_table_function( + name="gen_function", + callable=gen_function, + parameters=None, + schema=schema, + type="tuples", + ) + + result = conn.sql( + "SELECT * FROM gen_function(?)", + params=(count,), + ).df() + + assert len(result) == count + + +def test_no_schema(tmp_path): + def gen_function(n): + return n + + with duckdb.connect(tmp_path / "test.duckdb") as conn, pytest.raises((duckdb.InvalidInputException, TypeError)): + conn.create_table_function( + name="gen_function", + callable=gen_function, + type="tuples", + ) + + +def test_returns_scalar(tmp_path): + def gen_function(n): + return n + + with duckdb.connect(tmp_path / "test.duckdb") as conn: + conn.create_table_function( + name="gen_function", + callable=gen_function, + parameters=["n"], + schema={"value": sqltypes.INTEGER}, + type="tuples", + ) + # Error happens at execution time, not registration + with pytest.raises(duckdb.InvalidInputException): + conn.sql("SELECT * FROM gen_function(5)").fetchall() + + +def test_returns_list_scalar(tmp_path): + def gen_function_2(n): + return [n] + + with duckdb.connect(tmp_path / "test.duckdb") as conn: + conn.create_table_function( + name="gen_function_2", + callable=gen_function_2, + schema={"value": sqltypes.INTEGER}, + type="tuples", + ) + # Error happens at execution time, not registration + with pytest.raises(duckdb.InvalidInputException): + conn.sql("SELECT * FROM gen_function_2(5)").fetchall() + + +def test_returns_wrong_schema(tmp_path): + def gen_function(n): + return list[range(n)] + + with duckdb.connect(tmp_path / "test.duckdb") as conn: + schema = {"name": sqltypes.VARCHAR, "id": sqltypes.INTEGER} + + conn.create_table_function( + name="gen_function", + callable=gen_function, + schema=schema, + type="tuples", + ) + with pytest.raises(duckdb.InvalidInputException): + conn.sql("SELECT * FROM gen_function(5)").fetchall() + + +def test_kwargs(tmp_path): + def simple_pylist(count, foo=10): + return [(f"name_{i}_{foo}", i) for i in range(count)] + + with duckdb.connect(tmp_path / "test.duckdb") as conn: + conn.create_table_function( + name="simple_pylist", + callable=simple_pylist, + parameters=["count"], + schema={"name": sqltypes.VARCHAR, "id": sqltypes.INTEGER}, + type="tuples", + ) + result = conn.sql("SELECT * FROM simple_pylist(3)").fetchall() + assert result[-1][0] == "name_2_10" + + result = conn.sql("SELECT * FROM simple_pylist(count:=3)").fetchall() + assert result[-1][0] == "name_2_10" + + with pytest.raises(duckdb.BinderException): + result = conn.sql("SELECT * FROM simple_pylist(count:=3, foo:=2)").fetchall() + + +def test_large_2(tmp_path): + """Aggregation and filtering.""" + with duckdb.connect(tmp_path / "test.db") as conn: + count = 500000 + + def large_generator(): + return [(f"item_{i}", i) for i in range(count)] + + schema = {"name": sqltypes.VARCHAR, "id": sqltypes.INTEGER} + + conn.create_table_function( + name="large_table_udf", + callable=large_generator, + parameters=None, + schema=schema, + type="tuples", + ) + + result = conn.execute("SELECT COUNT(*) FROM large_table_udf()").fetchone() + assert result[0] == count + + result = conn.sql("SELECT MAX(id) FROM large_table_udf()").fetchone() + assert result[0] == count - 1 + + result = conn.execute("SELECT COUNT(*) FROM large_table_udf() WHERE id < 100").fetchone() + assert result[0] == 100 + + +def test__parameters(tmp_path): + with duckdb.connect(tmp_path / "test.db") as conn: + + def parametrized_function(count=10, prefix="item"): + return [(f"{prefix}_{i}", i) for i in range(count)] + + schema = {"name": sqltypes.VARCHAR, "id": sqltypes.INTEGER} + + conn.create_table_function( + name="param_table_udf", + callable=parametrized_function, + parameters=["count", "prefix"], + schema=schema, + type="tuples", + ) + + result1 = conn.execute("SELECT COUNT(*) FROM param_table_udf(5, 'test')").fetchone() + assert result1[0] == 5 + + result2 = conn.execute("SELECT COUNT(*) FROM param_table_udf(20, prefix:='data')").fetchone() + assert result2[0] == 20 + + # Test parameter order + result3 = conn.execute("SELECT name FROM param_table_udf(3, 'xyz') ORDER BY id LIMIT 1").fetchone() + assert result3[0] == "xyz_0" + + +def test_error(tmp_path): + with duckdb.connect(tmp_path / "test.db") as conn: + + def error_function(): + error_message = "Intentional Error" + raise ValueError(error_message) + + schema = {"name": sqltypes.VARCHAR, "id": sqltypes.INTEGER} + + conn.create_table_function( + name="error_table_udf", + callable=error_function, + parameters=None, + schema=schema, + type="tuples", + ) + + with pytest.raises(duckdb.Error): + conn.execute("SELECT * FROM error_table_udf()").fetchall() + + +def test_callable_refcount(tmp_path): + import sys + + def gen_function(n): + return [(f"name_{i}", i) for i in range(n)] + + initial_refcount = sys.getrefcount(gen_function) + + with duckdb.connect(tmp_path / "test.duckdb") as conn: + schema = {"name": sqltypes.VARCHAR, "id": sqltypes.INTEGER} + + conn.create_table_function( + name="gen_function", + callable=gen_function, + schema=schema, + type="tuples", + ) + + after_register_refcount = sys.getrefcount(gen_function) + assert after_register_refcount > initial_refcount, ( + f"Expected refcount to increase after registration, " + f"but got {after_register_refcount} (initial: {initial_refcount})" + ) + + for _ in range(3): + result = conn.sql("SELECT * FROM gen_function(5)").fetchall() + assert len(result) == 5 + + after_execution_refcount = sys.getrefcount(gen_function) + assert after_execution_refcount == after_register_refcount, ( + f"Expected refcount to remain stable after execution, " + f"but got {after_execution_refcount} (after register: {after_register_refcount})" + ) + + final_refcount = sys.getrefcount(gen_function) + assert final_refcount == initial_refcount, ( + f"Expected refcount to return to initial after unregistration, " + f"but got {final_refcount} (initial: {initial_refcount})" + ) + + +def test_callable_lifetime_in_view(tmp_path): + # registers a table UDF within a function scope + # and make sure it's still accessible from another scope (not GC'd) + with duckdb.connect(tmp_path / "test.duckdb") as conn: + + def create_and_register(): + def gen_data(count=5): + return [(f"item_{i}", i * 10) for i in range(count)] + + conn.create_table_function( + name="temp_function", + callable=gen_data, + schema={"name": sqltypes.VARCHAR, "value": sqltypes.INTEGER}, + type="tuples", + ) + + create_and_register() + + conn.execute("CREATE VIEW my_view AS SELECT * FROM temp_function(3)") + + # Unregister only allows the function to be reused - it'll still be accessible in the connection + conn.unregister_table_function("temp_function") + + result = conn.execute("SELECT * FROM my_view").fetchall() + assert len(result) == 3 + assert result[0] == ("item_0", 0) + assert result[1] == ("item_1", 10) + assert result[2] == ("item_2", 20) diff --git a/tests/fast/table_udf/test_tuples_datatypes.py b/tests/fast/table_udf/test_tuples_datatypes.py new file mode 100644 index 00000000..5e31711f --- /dev/null +++ b/tests/fast/table_udf/test_tuples_datatypes.py @@ -0,0 +1,76 @@ +import duckdb +import duckdb.sqltypes as sqltypes + + +def test_bigint_params(tmp_path): + def bigint_func(big_value): + return [(big_value, big_value + 1, big_value * 2)] + + with duckdb.connect(tmp_path / "test.duckdb") as conn: + conn.create_table_function( + name="bigint_func", + callable=bigint_func, + schema={"orig": sqltypes.BIGINT, "plus_one": sqltypes.BIGINT, "doubled": sqltypes.BIGINT}, + type="tuples", + ) + + large_val = 4611686018427387900 # Half of max int64 + result = conn.sql("SELECT * FROM bigint_func(?)", params=(large_val,)).fetchall() + assert result[0][0] == large_val + assert result[0][1] == large_val + 1 + assert result[0][2] == large_val * 2 + + +def test_hugeint_params(tmp_path): + def hugeint_func(huge_value): + return [(huge_value, huge_value + 1)] + + with duckdb.connect(tmp_path / "test.duckdb") as conn: + conn.create_table_function( + name="hugeint_func", + callable=hugeint_func, + schema={"orig": sqltypes.HUGEINT, "plus_one": sqltypes.HUGEINT}, + type="tuples", + ) + + huge_val = 9223372036854775808 + result = conn.sql("SELECT * FROM hugeint_func(?)", params=(huge_val,)).fetchall() + assert result[0][0] == huge_val + assert result[0][1] == huge_val + 1 + + +def test_decimal_params(tmp_path): + from decimal import Decimal + + def decimal_func(dec_value): + result = dec_value * 2 if isinstance(dec_value, float) else Decimal(str(dec_value)) * 2 + return [(dec_value, result)] + + with duckdb.connect(tmp_path / "test.duckdb") as conn: + conn.create_table_function( + name="decimal_func", + callable=decimal_func, + schema={"orig": duckdb.decimal_type(10, 2), "doubled": duckdb.decimal_type(10, 2)}, + type="tuples", + ) + + result = conn.sql("SELECT * FROM decimal_func(?::decimal)", params=(123.45,)).fetchall() + assert float(result[0][0]) == 123.45 + assert float(result[0][1]) == 246.90 + + +def test_uuid_params(tmp_path): + def uuid_func(uuid_value): + return [(str(uuid_value),)] + + with duckdb.connect(tmp_path / "test.duckdb") as conn: + conn.create_table_function( + name="uuid_func", + callable=uuid_func, + schema={"orig": sqltypes.UUID}, + type="tuples", + ) + + test_uuid = "550e8400-e29b-41d4-a716-446655440000" + result = conn.sql("SELECT * FROM uuid_func(?::uuid)", params=(test_uuid,)).fetchall() + assert str(result[0][0]) == test_uuid