From a9bb62d4fff4ba54ffd41a74e11a2a68b36807bd Mon Sep 17 00:00:00 2001 From: Paul Timmins Date: Fri, 3 Oct 2025 13:35:41 +0000 Subject: [PATCH 1/4] feat: implement table valued functions / user defined table functions --- duckdb/functional/__init__.py | 10 +- scripts/connection_methods.json | 45 +++ src/duckdb_py/CMakeLists.txt | 1 + src/duckdb_py/functional/functional.cpp | 5 + .../conversions/python_tvf_type_enum.hpp | 72 ++++ .../pyconnection/pyconnection.hpp | 15 + src/duckdb_py/pyconnection.cpp | 68 +++- src/duckdb_py/python_tvf.cpp | 355 ++++++++++++++++++ 8 files changed, 567 insertions(+), 4 deletions(-) create mode 100644 src/duckdb_py/include/duckdb_python/pybind11/conversions/python_tvf_type_enum.hpp create mode 100644 src/duckdb_py/python_tvf.cpp diff --git a/duckdb/functional/__init__.py b/duckdb/functional/__init__.py index ac4a6495..7a209dd9 100644 --- a/duckdb/functional/__init__.py +++ b/duckdb/functional/__init__.py @@ -1,17 +1,23 @@ from _duckdb.functional import ( FunctionNullHandling, PythonUDFType, + PythonTVFType, SPECIAL, DEFAULT, NATIVE, - ARROW + ARROW, + TUPLES, + ARROW_TABLE ) __all__ = [ "FunctionNullHandling", "PythonUDFType", + "PythonTVFType", "SPECIAL", "DEFAULT", "NATIVE", - "ARROW" + "ARROW", + "TUPLES", + "ARROW_TABLE" ] diff --git a/scripts/connection_methods.json b/scripts/connection_methods.json index a87b992f..fcba443b 100644 --- a/scripts/connection_methods.json +++ b/scripts/connection_methods.json @@ -107,6 +107,51 @@ ], "return": "DuckDBPyConnection" }, + { + "name": "create_table_function", + "function": "RegisterTableFunction", + "docs": "Register a table valued function via Callable", + "args": [ + { + "name": "name", + "type": "str" + }, + { + "name": "callable", + "type": "Callable" + } + ], + "kwargs": [ + { + "name": "parameters", + "type": "Optional[Any]", + "default": "None" + }, + { + "name": "schema", + "type": "Optional[Any]", + "default": "None" + }, + { + "name": "type", + "type": "Optional[PythonTVFType]", + "default": "PythonTVFType.TUPLES" + } + ], + "return": "DuckDBPyConnection" + }, + { + "name": "unregister_table_function", + "function": "UnregisterTableFunction", + "docs": "Unregister a table valued function", + "args": [ + { + "name": "name", + "type": "str" + } + ], + "return": "DuckDBPyConnection" + }, { "name": [ "sqltype", diff --git a/src/duckdb_py/CMakeLists.txt b/src/duckdb_py/CMakeLists.txt index 2252ba29..78fdf5b9 100644 --- a/src/duckdb_py/CMakeLists.txt +++ b/src/duckdb_py/CMakeLists.txt @@ -28,6 +28,7 @@ add_library(python_src OBJECT python_import_cache.cpp python_replacement_scan.cpp python_udf.cpp + python_tvf.cpp ) target_link_libraries(python_src PRIVATE _duckdb_dependencies) diff --git a/src/duckdb_py/functional/functional.cpp b/src/duckdb_py/functional/functional.cpp index 6761a264..d63c2672 100644 --- a/src/duckdb_py/functional/functional.cpp +++ b/src/duckdb_py/functional/functional.cpp @@ -11,6 +11,11 @@ void DuckDBPyFunctional::Initialize(py::module_ &parent) { .value("ARROW", duckdb::PythonUDFType::ARROW) .export_values(); + py::enum_(m, "PythonTVFType") + .value("TUPLES", duckdb::PythonTVFType::TUPLES) + .value("ARROW_TABLE", duckdb::PythonTVFType::ARROW_TABLE) + .export_values(); + py::enum_(m, "FunctionNullHandling") .value("DEFAULT", duckdb::FunctionNullHandling::DEFAULT_NULL_HANDLING) .value("SPECIAL", duckdb::FunctionNullHandling::SPECIAL_HANDLING) diff --git a/src/duckdb_py/include/duckdb_python/pybind11/conversions/python_tvf_type_enum.hpp b/src/duckdb_py/include/duckdb_python/pybind11/conversions/python_tvf_type_enum.hpp new file mode 100644 index 00000000..729669c5 --- /dev/null +++ b/src/duckdb_py/include/duckdb_python/pybind11/conversions/python_tvf_type_enum.hpp @@ -0,0 +1,72 @@ +#pragma once + +#include "duckdb/common/common.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/common/string_util.hpp" + +using duckdb::InvalidInputException; +using duckdb::string; +using duckdb::StringUtil; + +namespace duckdb { + +enum class PythonTVFType : uint8_t { TUPLES, ARROW_TABLE }; + +} // namespace duckdb + +using duckdb::PythonTVFType; + +namespace py = pybind11; + +static PythonTVFType PythonTVFTypeFromString(const string &type) { + auto ltype = StringUtil::Lower(type); + if (ltype.empty() || ltype == "tuples") { + return PythonTVFType::TUPLES; + } else if (ltype == "arrow_table") { + return PythonTVFType::ARROW_TABLE; + } else { + throw InvalidInputException("'%s' is not a recognized type for 'tvf_type'", type); + } +} + +static PythonTVFType PythonTVFTypeFromInteger(int64_t value) { + if (value == 0) { + return PythonTVFType::TUPLES; + } else if (value == 1) { + return PythonTVFType::ARROW_TABLE; + } else { + throw InvalidInputException("'%d' is not a recognized type for 'tvf_type'", value); + } +} + +namespace PYBIND11_NAMESPACE { +namespace detail { + +template <> +struct type_caster : public type_caster_base { + using base = type_caster_base; + PythonTVFType tmp; + +public: + bool load(handle src, bool convert) { + if (base::load(src, convert)) { + return true; + } else if (py::isinstance(src)) { + tmp = PythonTVFTypeFromString(py::str(src)); + value = &tmp; + return true; + } else if (py::isinstance(src)) { + tmp = PythonTVFTypeFromInteger(src.cast()); + value = &tmp; + return true; + } + return false; + } + + static handle cast(PythonTVFType src, return_value_policy policy, handle parent) { + return base::cast(src, policy, parent); + } +}; + +} // namespace detail +} // namespace PYBIND11_NAMESPACE \ No newline at end of file diff --git a/src/duckdb_py/include/duckdb_python/pyconnection/pyconnection.hpp b/src/duckdb_py/include/duckdb_python/pyconnection/pyconnection.hpp index 48ee055e..c59a9a25 100644 --- a/src/duckdb_py/include/duckdb_python/pyconnection/pyconnection.hpp +++ b/src/duckdb_py/include/duckdb_python/pyconnection/pyconnection.hpp @@ -23,6 +23,7 @@ #include "duckdb/function/scalar_function.hpp" #include "duckdb_python/pybind11/conversions/exception_handling_enum.hpp" #include "duckdb_python/pybind11/conversions/python_udf_type_enum.hpp" +#include "duckdb_python/pybind11/conversions/python_tvf_type_enum.hpp" #include "duckdb_python/pybind11/conversions/python_csv_line_terminator_enum.hpp" #include "duckdb/common/shared_ptr.hpp" @@ -169,6 +170,8 @@ struct DuckDBPyConnection : public enable_shared_from_this { //! MemoryFileSystem used to temporarily store file-like objects for reading shared_ptr internal_object_filesystem; case_insensitive_map_t> registered_functions; + case_insensitive_map_t> registered_table_functions; + case_insensitive_set_t registered_objects; public: @@ -232,6 +235,13 @@ struct DuckDBPyConnection : public enable_shared_from_this { PythonExceptionHandling exception_handling = PythonExceptionHandling::FORWARD_ERROR, bool side_effects = false); + shared_ptr RegisterTableFunction(const string &name, const py::function &function, + const py::object ¶meters = py::none(), + const py::object &schema = py::none(), + PythonTVFType type = PythonTVFType::TUPLES); + + shared_ptr UnregisterTableFunction(const string &name); + shared_ptr UnregisterUDF(const string &name); shared_ptr ExecuteMany(const py::object &query, py::object params = py::list()); @@ -355,6 +365,11 @@ struct DuckDBPyConnection : public enable_shared_from_this { const shared_ptr &return_type, bool vectorized, FunctionNullHandling null_handling, PythonExceptionHandling exception_handling, bool side_effects); + + duckdb::TableFunction CreateTableFunctionFromCallable(const std::string &name, const py::function &callable, + const py::object ¶meters, const py::object &schema, + PythonTVFType type); + void RegisterArrowObject(const py::object &arrow_object, const string &name); vector> GetStatements(const py::object &query); diff --git a/src/duckdb_py/pyconnection.cpp b/src/duckdb_py/pyconnection.cpp index b88b88ed..4db77b46 100644 --- a/src/duckdb_py/pyconnection.cpp +++ b/src/duckdb_py/pyconnection.cpp @@ -393,7 +393,6 @@ DuckDBPyConnection::RegisterScalarUDF(const string &name, const py::function &ud auto scalar_function = CreateScalarUDF(name, udf, parameters_p, return_type_p, type == PythonUDFType::ARROW, null_handling, exception_handling, side_effects); CreateScalarFunctionInfo info(scalar_function); - context.RegisterFunction(info); auto dependency = make_uniq(); @@ -403,6 +402,57 @@ DuckDBPyConnection::RegisterScalarUDF(const string &name, const py::function &ud return shared_from_this(); } +shared_ptr DuckDBPyConnection::RegisterTableFunction(const string &name, + const py::function &function, + const py::object ¶meters, + const py::object &schema, + PythonTVFType type) { + + auto &connection = con.GetConnection(); + auto &context = *connection.context; + + if (context.transaction.HasActiveTransaction()) { + context.CancelTransaction(); + } + + if (registered_table_functions.find(name) != registered_table_functions.end()) { + throw NotImplementedException("A table function by the name of '%s' is already registered, " + "please unregister it first", + name); + } + + auto table_function = CreateTableFunctionFromCallable(name, function, parameters, schema, type); + CreateTableFunctionInfo info(table_function); + + // re-registration: changing the callable to another + info.on_conflict = OnCreateConflict::REPLACE_ON_CONFLICT; + + context.RegisterFunction(info); + + auto dependency = make_uniq(); + dependency->AddDependency("function", PythonDependencyItem::Create(function)); + registered_table_functions[name] = std::move(dependency); + + return shared_from_this(); +} + +shared_ptr DuckDBPyConnection::UnregisterTableFunction(const string &name) { + auto entry = registered_table_functions.find(name); + if (entry == registered_table_functions.end()) { + throw InvalidInputException( + "No table function by the name of '%s' was found in the list of registered table functions", name); + } + + auto &connection = con.GetConnection(); + auto &context = *connection.context; + + // Remove from our registry. + // TODO: Callable still exists in the function catalog, since duckdb doesn't (yet?) support removal + registered_table_functions.erase(entry); + + return shared_from_this(); +} + void DuckDBPyConnection::Initialize(py::handle &m) { auto connection_module = py::class_>(m, "DuckDBPyConnection", py::module_local()); @@ -411,6 +461,14 @@ void DuckDBPyConnection::Initialize(py::handle &m) { .def("__exit__", &DuckDBPyConnection::Exit, py::arg("exc_type"), py::arg("exc"), py::arg("traceback")); connection_module.def("__del__", &DuckDBPyConnection::Close); + connection_module.def("create_table_function", &DuckDBPyConnection::RegisterTableFunction, + "Register a table valued function via Callable", py::arg("name"), py::arg("callable"), + py::arg("parameters") = py::none(), py::arg("schema") = py::none(), + py::arg("type") = PythonTVFType::TUPLES); + + connection_module.def("unregister_table_function", &DuckDBPyConnection::UnregisterTableFunction, + "Unregister a table valued function", py::arg("name")); + InitializeConnectionMethods(connection_module); connection_module.def_property_readonly("description", &DuckDBPyConnection::GetDescription, "Get result set attributes, mainly column names"); @@ -1575,7 +1633,12 @@ unique_ptr DuckDBPyConnection::RunQuery(const py::object &quer } if (res->type == QueryResultType::STREAM_RESULT) { auto &stream_result = res->Cast(); - res = stream_result.Materialize(); + { + // Release the GIL, as Materialize *may* need the GIL (TVFs, for instance) + D_ASSERT(py::gil_check()); + py::gil_scoped_release release; + res = stream_result.Materialize(); + } } auto &materialized_result = res->Cast(); relation = make_shared_ptr(connection.context, materialized_result.TakeCollection(), @@ -1826,6 +1889,7 @@ void DuckDBPyConnection::Close() { // https://peps.python.org/pep-0249/#Connection.close cursors.ClearCursors(); registered_functions.clear(); + registered_table_functions.clear(); } void DuckDBPyConnection::Interrupt() { diff --git a/src/duckdb_py/python_tvf.cpp b/src/duckdb_py/python_tvf.cpp new file mode 100644 index 00000000..b1538251 --- /dev/null +++ b/src/duckdb_py/python_tvf.cpp @@ -0,0 +1,355 @@ +#include "duckdb_python/pybind11/pybind_wrapper.hpp" +#include "duckdb_python/pytype.hpp" +#include "duckdb_python/pyconnection/pyconnection.hpp" +#include "duckdb/common/arrow/arrow.hpp" +#include "duckdb/common/arrow/arrow_wrapper.hpp" +#include "duckdb_python/arrow/arrow_array_stream.hpp" +#include "duckdb/function/table/arrow.hpp" +#include "duckdb/function/function.hpp" +#include "duckdb/parser/tableref/table_function_ref.hpp" +#include "duckdb_python/python_conversion.hpp" +#include "duckdb_python/python_objects.hpp" +#include "duckdb_python/pybind11/python_object_container.hpp" + +namespace duckdb { + +struct PyTVFInfo : public TableFunctionInfo { + py::function callable; + vector return_types; + vector return_names; + PythonTVFType return_type; + + PyTVFInfo(py::function callable_p, vector types_p, vector names_p, PythonTVFType return_type_p) + : callable(std::move(callable_p)), return_types(std::move(types_p)), return_names(std::move(names_p)), + return_type(return_type_p) { + } + + ~PyTVFInfo() override { + py::gil_scoped_acquire acquire; + callable = py::function(); + } +}; + +struct PyTVFBindData : public TableFunctionData { + string func_name; + vector args; + named_parameter_map_t kwargs; + vector return_types; + vector return_names; + PythonObjectContainer python_objects; // Holds the callable + + PyTVFBindData(string func_name, vector args, named_parameter_map_t kwargs, vector return_types, + vector return_names, py::function callable) + : func_name(std::move(func_name)), args(std::move(args)), kwargs(std::move(kwargs)), + return_types(std::move(return_types)), return_names(std::move(return_names)) { + // gil acquired inside push + python_objects.Push(std::move(callable)); + } +}; + +struct PyTVFTuplesGlobalState : public GlobalTableFunctionState { + PythonObjectContainer python_objects; + bool iterator_exhausted = false; + + PyTVFTuplesGlobalState() : iterator_exhausted(false) { + } +}; + +struct PyTVFArrowGlobalState : public GlobalTableFunctionState { + unique_ptr arrow_factory; + unique_ptr arrow_bind_data; + unique_ptr arrow_global_state; + PythonObjectContainer python_objects; + idx_t num_columns; + + PyTVFArrowGlobalState() { + } +}; + +static void PyTVFTuplesScanFunction(ClientContext &context, TableFunctionInput &input, DataChunk &output) { + auto &gs = input.global_state->Cast(); + auto &bd = input.bind_data->Cast(); + + if (gs.iterator_exhausted) { + output.SetCardinality(0); + return; + } + + py::gil_scoped_acquire gil; + auto &it = gs.python_objects.LastAddedObject(); + + idx_t row_idx = 0; + for (idx_t i = 0; i < STANDARD_VECTOR_SIZE; i++) { + py::object next_item; + try { + next_item = it.attr("__next__")(); + } catch (py::error_already_set &e) { + if (e.matches(PyExc_StopIteration)) { + gs.iterator_exhausted = true; + PyErr_Clear(); + break; + } + throw; + } + + try { + // Extract each column from the tuple/list + for (idx_t col_idx = 0; col_idx < bd.return_types.size(); col_idx++) { + auto py_val = next_item[py::int_(col_idx)]; + Value duck_val = TransformPythonValue(py_val, bd.return_types[col_idx]); + output.SetValue(col_idx, row_idx, duck_val); + } + } catch (py::error_already_set &e) { + throw InvalidInputException("Table function '%s' returned invalid data: %s", bd.func_name, e.what()); + } + row_idx++; + } + output.SetCardinality(row_idx); +} + +struct PyTVFArrowLocalState : public LocalTableFunctionState { + unique_ptr arrow_local_state; + + explicit PyTVFArrowLocalState(unique_ptr arrow_local) + : arrow_local_state(std::move(arrow_local)) { + } +}; + +static void PyTVFArrowScanFunction(ClientContext &context, TableFunctionInput &input, DataChunk &output) { + // Delegates to ArrowScanFunction + auto &gs = input.global_state->Cast(); + auto &ls = input.local_state->Cast(); + + TableFunctionInput arrow_input(gs.arrow_bind_data.get(), ls.arrow_local_state.get(), gs.arrow_global_state.get()); + ArrowTableFunction::ArrowScanFunction(context, arrow_input, output); +} + +static unique_ptr PyTVFBindInternal(ClientContext &context, TableFunctionBindInput &in, + vector &return_types, vector &return_names) { + // Disable progress bar to prevent GIL deadlock with Jupyter + // TODO: Decide if this is still needed - was a problem when fully materializing, but switched to streaming + ClientConfig::GetConfig(context).enable_progress_bar = false; + ClientConfig::GetConfig(context).system_progress_bar_disable_reason = + "Table Valued Functions do not support the progress bar"; + + if (!in.info) { + throw InvalidInputException("Table function '%s' missing function info", in.table_function.name); + } + + auto &tvf_info = in.info->Cast(); + return_types = tvf_info.return_types; + return_names = tvf_info.return_names; + + // Acquire gil before copying py::function + py::gil_scoped_acquire gil; + return make_uniq(in.table_function.name, in.inputs, in.named_parameters, return_types, return_names, + tvf_info.callable); +} + +static unique_ptr PyTVFTuplesBindFunction(ClientContext &context, TableFunctionBindInput &in, + vector &return_types, + vector &return_names) { + auto bd = PyTVFBindInternal(context, in, return_types, return_names); + return std::move(bd); +} + +static unique_ptr PyTVFArrowBindFunction(ClientContext &context, TableFunctionBindInput &in, + vector &return_types, + vector &return_names) { + auto bd = PyTVFBindInternal(context, in, return_types, return_names); + return std::move(bd); +} + +static py::object CallPythonTVF(ClientContext &context, PyTVFBindData &bd) { + py::gil_scoped_acquire gil; + + // positional arguments + py::tuple args(bd.args.size()); + for (idx_t i = 0; i < bd.args.size(); i++) { + args[i] = PythonObject::FromValue(bd.args[i], bd.args[i].type(), context.GetClientProperties()); + } + + // keyword arguments + py::dict kwargs; + for (auto &kv : bd.kwargs) { + kwargs[py::str(kv.first)] = PythonObject::FromValue(kv.second, kv.second.type(), context.GetClientProperties()); + } + + // Call Python function + auto &callable = bd.python_objects.LastAddedObject(); + py::object result = callable(*args, **kwargs); + + if (result.is_none()) { + throw InvalidInputException("Table function '%s' returned None, expected iterable or Arrow table", + bd.func_name); + } + + return result; +} + +static unique_ptr PyTVFTuplesInitGlobal(ClientContext &context, TableFunctionInitInput &in) { + auto &bd = in.bind_data->Cast(); + auto gs = make_uniq(); + + { + py::gil_scoped_acquire gil; + // const_cast is safe here - we only read from python_objects, not modify bind_data structure + py::object result = CallPythonTVF(context, const_cast(bd)); + try { + py::iterator it = py::iter(result); + gs->python_objects.Push(std::move(it)); + gs->iterator_exhausted = false; + } catch (const py::error_already_set &e) { + throw InvalidInputException("Table function '%s' returned non-iterable result: %s", bd.func_name, e.what()); + } + } + + return std::move(gs); +} + +static unique_ptr PyTVFArrowInitGlobal(ClientContext &context, TableFunctionInitInput &in) { + auto &bd = in.bind_data->Cast(); + auto gs = make_uniq(); + + { + py::gil_scoped_acquire gil; + + py::object result = CallPythonTVF(context, const_cast(bd)); + PyObject *ptr = result.ptr(); + + // TODO: Should we verify this is an arrow table, or just fail later + gs->python_objects.Push(std::move(result)); + + gs->arrow_factory = make_uniq(ptr, context.GetClientProperties(), + DBConfig::GetConfig(context)); + } + + // Build bind input for Arrow scan + vector children; + children.push_back(Value::POINTER(CastPointerToValue(gs->arrow_factory.get()))); + children.push_back(Value::POINTER(CastPointerToValue(PythonTableArrowArrayStreamFactory::Produce))); + children.push_back(Value::POINTER(CastPointerToValue(PythonTableArrowArrayStreamFactory::GetSchema))); + + TableFunctionRef empty_ref; + duckdb::TableFunction dummy_tf; + dummy_tf.name = "PyTVFArrowWrapper"; + + named_parameter_map_t named_params; + vector input_types; + vector input_names; + + TableFunctionBindInput bind_input(children, named_params, input_types, input_names, nullptr, nullptr, dummy_tf, + empty_ref); + + vector return_types; + vector return_names; + gs->arrow_bind_data = ArrowTableFunction::ArrowScanBind(context, bind_input, return_types, return_names); + + // Validate Arrow schema matches declared + if (return_types.size() != bd.return_types.size()) { + throw InvalidInputException("Schema mismatch in table function '%s': " + "Arrow table has %lu columns but %lu were declared", + bd.func_name, return_types.size(), bd.return_types.size()); + } + + // Check column types match + for (idx_t i = 0; i < return_types.size(); i++) { + if (return_types[i] != bd.return_types[i]) { + throw InvalidInputException("Schema mismatch in table function '%s' at column %lu: " + "Arrow table has type %s but %s was declared", + bd.func_name, i, return_types[i].ToString().c_str(), + bd.return_types[i].ToString().c_str()); + } + } + + gs->num_columns = return_types.size(); + vector all_columns; + for (idx_t i = 0; i < gs->num_columns; i++) { + all_columns.push_back(i); + } + + TableFunctionInitInput init_input(gs->arrow_bind_data.get(), all_columns, all_columns, in.filters.get()); + gs->arrow_global_state = ArrowTableFunction::ArrowScanInitGlobal(context, init_input); + + return std::move(gs); +} + +static unique_ptr PyTVFArrowInitLocal(ExecutionContext &context, TableFunctionInitInput &in, + GlobalTableFunctionState *gstate) { + auto &gs = gstate->Cast(); + + vector all_columns; + for (idx_t i = 0; i < gs.num_columns; i++) { + all_columns.push_back(i); + } + + TableFunctionInitInput arrow_init(gs.arrow_bind_data.get(), all_columns, all_columns, in.filters.get()); + auto arrow_local_state = + ArrowTableFunction::ArrowScanInitLocalInternal(context.client, arrow_init, gs.arrow_global_state.get()); + + return make_uniq(std::move(arrow_local_state)); +} + +duckdb::TableFunction DuckDBPyConnection::CreateTableFunctionFromCallable(const std::string &name, + const py::function &callable, + const py::object ¶meters, + const py::object &schema, + PythonTVFType type) { + + // Schema + if (schema.is_none()) { + throw InvalidInputException("Table functions require a schema."); + } + + vector types; + vector names; + for (auto c : py::iter(schema)) { + auto item = py::cast(c); + if (py::isinstance(item)) { + throw InvalidInputException("Invalid schema format: expected [name, type] pairs, got string '%s'", + py::str(item).cast()); + } + if (!py::hasattr(item, "__getitem__") || py::len(item) < 2) { + throw InvalidInputException("Invalid schema format: each schema item must be a [name, type] pair"); + } + names.emplace_back(py::str(item[py::int_(0)])); + types.emplace_back(TransformStringToLogicalType(py::str(item[py::int_(1)]))); + } + + if (types.empty()) { + throw InvalidInputException("Table function '%s' schema cannot be empty", name); + } + + duckdb::TableFunction tf; + switch (type) { + case PythonTVFType::TUPLES: + tf = + duckdb::TableFunction(name, {}, +PyTVFTuplesScanFunction, +PyTVFTuplesBindFunction, +PyTVFTuplesInitGlobal); + break; + case PythonTVFType::ARROW_TABLE: + tf = duckdb::TableFunction(name, {}, +PyTVFArrowScanFunction, +PyTVFArrowBindFunction, +PyTVFArrowInitGlobal, + +PyTVFArrowInitLocal); + break; + default: + throw InvalidInputException("Unknown return type for table function '%s'", name); + } + + // Store the Python callable and schema + tf.function_info = make_shared_ptr(callable, types, names, type); + + // args + tf.varargs = LogicalType::ANY; + tf.named_parameters["args"] = LogicalType::ANY; + + // kwargs + if (!parameters.is_none()) { + for (auto ¶m : py::cast(parameters)) { + string param_name = py::str(param); + tf.named_parameters[param_name] = LogicalType::ANY; + } + } + + return tf; +} + +} // namespace duckdb From 013081f5c28022e44600dc6045e1e6eb283ca59b Mon Sep 17 00:00:00 2001 From: Paul Timmins Date: Fri, 3 Oct 2025 13:40:41 +0000 Subject: [PATCH 2/4] tests: add TVF test cases --- tests/fast/tvf/test_arrow.py | 204 +++++++++++++++ tests/fast/tvf/test_arrow_schema.py | 125 +++++++++ tests/fast/tvf/test_register.py | 328 ++++++++++++++++++++++++ tests/fast/tvf/test_tuples.py | 298 +++++++++++++++++++++ tests/fast/tvf/test_tuples_datatypes.py | 94 +++++++ 5 files changed, 1049 insertions(+) create mode 100644 tests/fast/tvf/test_arrow.py create mode 100644 tests/fast/tvf/test_arrow_schema.py create mode 100644 tests/fast/tvf/test_register.py create mode 100644 tests/fast/tvf/test_tuples.py create mode 100644 tests/fast/tvf/test_tuples_datatypes.py diff --git a/tests/fast/tvf/test_arrow.py b/tests/fast/tvf/test_arrow.py new file mode 100644 index 00000000..f55f5ee0 --- /dev/null +++ b/tests/fast/tvf/test_arrow.py @@ -0,0 +1,204 @@ +from typing import Iterator + +import pytest + +import duckdb +from duckdb.functional import PythonTVFType + + +def simple_generator(count: int = 10) -> Iterator[tuple[str, int]]: + for i in range(count): + yield (f"name_{i}", i) + + +def simple_arrow_table(count: int): + import pyarrow as pa + + data = { + "id": list(range(count)), + "value": [i * 2 for i in range(count)], + "name": [f"row_{i}" for i in range(count)], + } + return pa.table(data) + + +def test_arrow_small(tmp_path): + pa = pytest.importorskip("pyarrow") + + with duckdb.connect(tmp_path / "test.duckdb") as conn: + conn.create_table_function( + "simple_arrow", + simple_arrow_table, + schema=[("x", "BIGINT"), ("y", "VARCHAR")], # Wrong schema! + type=PythonTVFType.ARROW_TABLE, + ) + + with pytest.raises(Exception) as exc_info: + result = conn.execute("SELECT * FROM simple_arrow(5)").fetchall() + + assert ( + "Vector::Reference" in str(exc_info.value) + or "schema" in str(exc_info.value).lower() + ) + + +def test_arrow_large_1(tmp_path): + pa = pytest.importorskip("pyarrow") + + with duckdb.connect(tmp_path / "test.duckdb") as conn: + n = 2048 * 1000 + + conn.create_table_function( + "large_arrow", + simple_arrow_table, + schema=[("id", "BIGINT"), ("value", "BIGINT"), ("name", "VARCHAR")], + type="arrow_table", + ) + + result = conn.execute( + "SELECT COUNT(*) FROM large_arrow(?)", parameters=(n,) + ).fetchone() + assert result[0] == n + + df = conn.sql(f"SELECT * FROM large_arrow({n}) LIMIT 10").df() + assert len(df) == 10 + assert df["id"].tolist() == list(range(10)) + + arrow_result = conn.execute( + "SELECT * FROM large_arrow(?)", parameters=(n,) + ).fetch_arrow_table() + assert len(arrow_result) == n + + result = conn.sql( + "SELECT SUM(value) FROM large_arrow(?)", params=(n,) + ).fetchone() + expected_sum = sum(i * 2 for i in range(n)) + assert result[0] == expected_sum + + +def test_large_arrow_execute(tmp_path): + pytest.importorskip("pyarrow") + + count = 2048 * 1000 + with duckdb.connect(tmp_path / "test.duckdb") as conn: + schema = [["name", "VARCHAR"], ["id", "INT"]] + + conn.create_table_function( + name="gen_function", + callable=simple_generator, + parameters=None, + schema=schema, + type="tuples", + ) + + result = conn.execute( + "SELECT * FROM gen_function(?)", + parameters=(count,), + ).fetch_arrow_table() + + assert len(result) == count + + +def test_large_arrow_sql(tmp_path): + pytest.importorskip("pyarrow") + + count = 2048 * 1000 + with duckdb.connect(tmp_path / "test.duckdb") as conn: + schema = [["name", "VARCHAR"], ["id", "INT"]] + + conn.create_table_function( + name="gen_function", + callable=simple_generator, + parameters=None, + schema=schema, + type="tuples", + ) + + result = conn.sql( + "SELECT * FROM gen_function(?)", + params=(count,), + ).fetch_arrow_table() + + assert len(result) == count + + +def test_arrowbatched_execute(tmp_path): + pytest.importorskip("pyarrow") + + count = 2048 * 1000 + with duckdb.connect(tmp_path / "test.duckdb") as conn: + schema = [["name", "VARCHAR"], ["id", "INT"]] + + conn.create_table_function( + name="gen_function", + callable=simple_generator, + parameters=None, + schema=schema, + type="tuples", + ) + + result = conn.execute( + "SELECT * FROM gen_function(?)", + parameters=(count,), + ).fetch_record_batch() + + result = conn.execute( + f"SELECT * FROM gen_function({count})", + ).fetch_record_batch() + + c = 0 + for batch in result: + c += batch.num_rows + assert c == count + + +def test_arrowbatched_sql_relation(tmp_path): + pytest.importorskip("pyarrow") + + count = 2048 * 1000 + with duckdb.connect(tmp_path / "test.duckdb") as conn: + schema = [["name", "VARCHAR"], ["id", "INT"]] + + conn.create_table_function( + name="gen_function", + callable=simple_generator, + parameters=None, + schema=schema, + type="tuples", + ) + + result = conn.sql( + f"SELECT * FROM gen_function({count})", + ).fetch_arrow_reader() + + c = 0 + for batch in result: + c += batch.num_rows + assert c == count + + +def test_arrowbatched_sql_materialized(tmp_path): + pytest.importorskip("pyarrow") + + count = 2048 * 1000 + with duckdb.connect(tmp_path / "test.duckdb") as conn: + schema = [["name", "VARCHAR"], ["id", "INT"]] + + conn.create_table_function( + name="gen_function", + callable=simple_generator, + parameters=None, + schema=schema, + type="tuples", + ) + + # passing parameters makes it non-lazy /materialized + result = conn.sql( + "SELECT * FROM gen_function(?)", + params=(count,), + ).fetch_arrow_reader() + + c = 0 + for batch in result: + c += batch.num_rows + assert c == count diff --git a/tests/fast/tvf/test_arrow_schema.py b/tests/fast/tvf/test_arrow_schema.py new file mode 100644 index 00000000..2b089fc0 --- /dev/null +++ b/tests/fast/tvf/test_arrow_schema.py @@ -0,0 +1,125 @@ +"""Test Arrow TVF schema validation""" + +import pytest + +import duckdb +from duckdb.functional import PythonTVFType + + +def simple_arrow_table(count: int = 10): + import pyarrow as pa + + data = { + "id": list(range(count)), + "value": [i * 2 for i in range(count)], + "name": [f"row_{i}" for i in range(count)], + } + return pa.table(data) + + +def test_arrow_correct_schema(tmp_path): + pa = pytest.importorskip("pyarrow") + + with duckdb.connect(tmp_path / "test.duckdb") as conn: + conn.create_table_function( + "arrow_func", + simple_arrow_table, + schema=[("id", "BIGINT"), ("value", "BIGINT"), ("name", "VARCHAR")], + type=PythonTVFType.ARROW_TABLE, + ) + + result = conn.execute("SELECT * FROM arrow_func(5)").fetchall() + assert len(result) == 5 + assert result[0] == (0, 0, "row_0") + + +def test_arrow_more_columns(tmp_path): + pa = pytest.importorskip("pyarrow") + + with duckdb.connect(tmp_path / "test.duckdb") as conn: + # table has 3 cols, but declare only 2 + conn.create_table_function( + "arrow_func", + simple_arrow_table, + schema=[("x", "BIGINT"), ("y", "BIGINT")], # Missing third column + type=PythonTVFType.ARROW_TABLE, + ) + + with pytest.raises(duckdb.InvalidInputException) as exc_info: + conn.execute("SELECT * FROM arrow_func(5)").fetchall() + + error_msg = str(exc_info.value).lower() + assert ( + "schema mismatch" in error_msg + or "3 columns" in error_msg + or "2 were declared" in error_msg + ) + + +def test_arrow_fewer_columns(tmp_path): + pa = pytest.importorskip("pyarrow") + + with duckdb.connect(tmp_path / "test.duckdb") as conn: + # table has 3 columns, but declare 4 + conn.create_table_function( + "arrow_func", + simple_arrow_table, + schema=[ + ("id", "BIGINT"), + ("value", "BIGINT"), + ("name", "VARCHAR"), + ("extra", "INT"), # Extra column that doesn't exist + ], + type=PythonTVFType.ARROW_TABLE, + ) + + with pytest.raises(duckdb.InvalidInputException) as exc_info: + conn.execute("SELECT * FROM arrow_func(5)").fetchall() + + error_msg = str(exc_info.value).lower() + assert ( + "schema mismatch" in error_msg + or "3 columns" in error_msg + or "4 were declared" in error_msg + ) + + +def test_arrow_type_mismatch(tmp_path): + pa = pytest.importorskip("pyarrow") + + with duckdb.connect(tmp_path / "test.duckdb") as conn: + conn.create_table_function( + "arrow_func", + simple_arrow_table, + schema=[ + ("id", "VARCHAR"), # Wrong type - should be BIGINT + ("value", "BIGINT"), + ("name", "VARCHAR"), + ], + type=PythonTVFType.ARROW_TABLE, + ) + + with pytest.raises(duckdb.InvalidInputException) as exc_info: + conn.execute("SELECT * FROM arrow_func(5)").fetchall() + + error_msg = str(exc_info.value).lower() + assert "type" in error_msg or "mismatch" in error_msg + + +def test_arrow_name_mismatch_allowed(tmp_path): + pa = pytest.importorskip("pyarrow") + + with duckdb.connect(tmp_path / "test.duckdb") as conn: + conn.create_table_function( + "arrow_func", + simple_arrow_table, + schema=[ + ("a", "BIGINT"), # Arrow has 'id' + ("b", "BIGINT"), # Arrow has 'value' + ("c", "VARCHAR"), # Arrow has 'name' + ], + type=PythonTVFType.ARROW_TABLE, + ) + + result = conn.execute("SELECT * FROM arrow_func(3)").fetchall() + assert len(result) == 3 diff --git a/tests/fast/tvf/test_register.py b/tests/fast/tvf/test_register.py new file mode 100644 index 00000000..5160e929 --- /dev/null +++ b/tests/fast/tvf/test_register.py @@ -0,0 +1,328 @@ +import pytest + +import duckdb + + +def test_registry_collision(tmp_path): + """Two tvfs on different connections with same name""" "" + conn1 = duckdb.connect(tmp_path / "db1.db") + conn2 = duckdb.connect(tmp_path / "db2.db") + + def func_for_conn1(): + return [("conn1_data", 1)] + + def func_for_conn2(): + return [("conn2_data", 2)] + + schema = [("name", "VARCHAR"), ("id", "INT")] + + conn1.create_table_function( + name="same_name", + callable=func_for_conn1, + parameters=None, + schema=schema, + type="tuples", + ) + + conn2.create_table_function( + name="same_name", + callable=func_for_conn2, + parameters=None, + schema=schema, + type="tuples", + ) + + result1 = conn1.execute("SELECT * FROM same_name()").fetchall() + assert result1[0][0] == "conn1_data" + assert result1[0][1] == 1 + + result2 = conn2.execute("SELECT * FROM same_name()").fetchall() + assert result2[0][0] == "conn2_data" + assert result2[0][1] == 2 + + result1 = conn1.sql("SELECT * FROM same_name()").fetchall() + assert result1[0][0] == "conn1_data" + assert result1[0][1] == 1 + + conn1.close() + conn2.close() + + +def test_replace_without_unregister(tmp_path): + with duckdb.connect(tmp_path / "test.db") as conn: + + def func_v1(): + return [("version_1", 1)] + + def func_v2(): + return [("version_2", 2)] + + schema = [("name", "VARCHAR"), ("id", "INT")] + + conn.create_table_function("test_func", func_v1, schema=schema, type="tuples") + + result = conn.execute("SELECT * FROM test_func()").fetchall() + assert result[0][0] == "version_1" + assert result[0][1] == 1 + + with pytest.raises(duckdb.NotImplementedException) as exc_info: + conn.create_table_function( + "test_func", func_v2, schema=schema, type="tuples" + ) + assert "already registered" in str(exc_info.value) + + +def test_replace_after_unregister(tmp_path): + with duckdb.connect(tmp_path / "test.db") as conn: + + def func_v1(): + return [("version_1", 1)] + + def func_v2(): + return [("version_2", 2)] + + def func_v3(): + return [("version_3", 3)] + + schema = [("name", "VARCHAR"), ("id", "INT")] + + conn.create_table_function("test_func", func_v1, schema=schema, type="tuples") + + result = conn.execute("SELECT * FROM test_func()").fetchall() + assert result[0][0] == "version_1" + + conn.unregister_table_function("test_func") + conn.create_table_function("test_func", func_v2, schema=schema, type="tuples") + + result = conn.execute("SELECT * FROM test_func()").fetchall() + assert result[0][0] == "version_2" + + conn.unregister_table_function("test_func") + + result = conn.execute("SELECT * FROM test_func()").fetchall() + assert result[0][0] == "version_2" + + conn.create_table_function("test_func", func_v3, schema=schema, type="tuples") + + result = conn.execute("SELECT * FROM test_func()").fetchall() + assert result[0][0] == "version_3" + assert result[0][1] == 3 + + +def test_multiple_replacements(tmp_path): + """Replacing TVFs multiple times""" + with duckdb.connect(tmp_path / "test.db") as conn: + schema = [("value", "INT")] + + for i in range(1, 6): + + def make_func(val=i): + def func(): + return [(val,)] + + return func + + if i > 1: + conn.unregister_table_function("counter") + + conn.create_table_function( + "counter", make_func(), schema=schema, type="tuples" + ) + + result = conn.execute("SELECT * FROM counter()").fetchone() + assert result[0] == i + + +def test_replacement_with_different_schemas(tmp_path): + """Changing schema with replacements""" + with duckdb.connect(tmp_path / "test.db") as conn: + + def func_v1(): + return [("test", 1)] + + def func_v2(): + return [("modified", 2, 3.14)] + + schema_v1 = [("name", "VARCHAR"), ("id", "INT")] + conn.create_table_function( + "evolving_func", func_v1, schema=schema_v1, type="tuples" + ) + + result = conn.execute("SELECT * FROM evolving_func()").fetchall() + assert len(result[0]) == 2 + assert result[0][0] == "test" + + schema_v2 = [("name", "VARCHAR"), ("id", "INT"), ("value", "DOUBLE")] + conn.unregister_table_function("evolving_func") # Must unregister first + conn.create_table_function( + "evolving_func", func_v2, schema=schema_v2, type="tuples" + ) + + result = conn.execute("SELECT * FROM evolving_func()").fetchall() + assert len(result[0]) == 3 + assert result[0][0] == "modified" + assert result[0][2] == 3.14 + + +def test_replacement_2(tmp_path): + with duckdb.connect(tmp_path / "test.db") as conn: + + def func_v1(): + return [("v1",)] + + def func_v2(): + return [("v2",)] + + schema = [("version", "VARCHAR")] + + conn.create_table_function( + "tracked_func", func_v1, schema=schema, type="tuples" + ) + + conn.unregister_table_function("tracked_func") # Must unregister first + conn.create_table_function( + "tracked_func", func_v2, schema=schema, type="tuples" + ) + + conn.unregister_table_function("tracked_func") + + with pytest.raises(duckdb.InvalidInputException) as exc_info: + conn.unregister_table_function("tracked_func") + assert "No table function by the name of 'tracked_func'" in str(exc_info.value) + + result = conn.execute("SELECT * FROM tracked_func()").fetchone() + assert result[0] == "v2" + + +def test_sql_drop_table_function(tmp_path): + """Documents current behavior - that dropping functions has no effect on TVFs""" + with duckdb.connect(tmp_path / "test.db") as conn: + + def test_func(): + return [("test_value", 1)] + + schema = [("name", "VARCHAR"), ("id", "INT")] + conn.create_table_function("test_func", test_func, schema=schema, type="tuples") + + result = conn.execute("SELECT * FROM test_func()").fetchall() + assert result[0][0] == "test_value" + assert result[0][1] == 1 + + with pytest.raises(Exception): + conn.execute("DROP FUNCTION test_func") + + result = conn.execute("SELECT * FROM test_func()").fetchall() + assert result[0][0] == "test_value" + assert result[0][1] == 1 + + +def test_unregister_table_function(tmp_path): + with duckdb.connect(tmp_path / "test.db") as conn: + + def simple_function(): + return [("test_value", 1)] + + schema = [("name", "VARCHAR"), ("id", "INT")] + + conn.create_table_function( + name="test_func", + callable=simple_function, + parameters=None, + schema=schema, + type="tuples", + ) + + result = conn.execute("SELECT * FROM test_func()").fetchall() + assert len(result) == 1 + assert result[0][0] == "test_value" + assert result[0][1] == 1 + + conn.unregister_table_function("test_func") + + # TODO: Decide whether we want to fail or keep this behavior + result = conn.execute("SELECT * FROM test_func()").fetchall() + assert len(result) == 1 + assert result[0][0] == "test_value" + assert result[0][1] == 1 + + with pytest.raises(duckdb.InvalidInputException) as exc_info: + conn.unregister_table_function("test_func") + + assert "No table function by the name of 'test_func'" in str(exc_info.value) + + +def test_unregister_doesntexist(tmp_path): + with duckdb.connect(tmp_path / "test.db") as conn: + with pytest.raises(duckdb.InvalidInputException) as exc_info: + conn.unregister_table_function("nonexistent_func") + + assert "No table function by the name of 'nonexistent_func'" in str( + exc_info.value + ) + + +def test_reregister(tmp_path): + with duckdb.connect(tmp_path / "test.db") as conn: + + def func_v1(): + return [("version_1", 1)] + + def func_v2(): + return [("version_2", 2)] + + schema = [("name", "VARCHAR"), ("id", "INT")] + + conn.create_table_function( + name="versioned_func", + callable=func_v1, + schema=schema, + type="tuples", + ) + + result = conn.execute("SELECT * FROM versioned_func()").fetchall() + assert result[0][0] == "version_1" + + conn.unregister_table_function("versioned_func") + + conn.create_table_function( + name="versioned_func", + callable=func_v2, + schema=schema, + type="tuples", + ) + + result = conn.execute("SELECT * FROM versioned_func()").fetchall() + assert result[0][0] == "version_2" + + +def test_unregister_multi(tmp_path): + with duckdb.connect(tmp_path / "test.db") as conn: + cursor1 = conn.cursor() + cursor2 = conn.cursor() + + def test_func(): + return [("test_data", 1)] + + schema = [("name", "VARCHAR"), ("id", "INT")] + + cursor1.create_table_function( + name="shared_func", + callable=test_func, + schema=schema, + type="tuples", + ) + + result1 = cursor1.execute("SELECT * FROM shared_func()").fetchall() + assert result1[0][0] == "test_data" + + result2 = cursor2.execute("SELECT * FROM shared_func()").fetchall() + assert result2[0][0] == "test_data" + + cursor1.unregister_table_function("shared_func") + + # TODO: Decide whether to keep this unregister behavior + result1 = cursor1.execute("SELECT * FROM shared_func()").fetchall() + assert result1[0][0] == "test_data" + + result2 = cursor2.execute("SELECT * FROM shared_func()").fetchall() + assert result2[0][0] == "test_data" diff --git a/tests/fast/tvf/test_tuples.py b/tests/fast/tvf/test_tuples.py new file mode 100644 index 00000000..05506d7f --- /dev/null +++ b/tests/fast/tvf/test_tuples.py @@ -0,0 +1,298 @@ +from typing import Iterator + +import pytest + +import duckdb +from duckdb.functional import PythonTVFType + + +def simple_generator(count: int = 10) -> Iterator[tuple[str, int]]: + for i in range(count): + yield (f"name_{i}", i) + + +def simple_pylist(count: int = 10) -> list[tuple[str, int]]: + return [(f"name_{i}", i) for i in range(count)] + + +def simple_pylistlist(count: int = 10) -> list[list[str, int]]: + return [[f"name_{i}", i] for i in range(count)] + + +@pytest.mark.parametrize( + "gen_function", [simple_generator, simple_pylist, simple_pylistlist] +) +def test_simple(tmp_path, gen_function): + with duckdb.connect(tmp_path / "test.duckdb") as conn: + schema = [["name", "VARCHAR"], ["id", "INT"]] + + conn.create_table_function( + name="gen_function", + callable=gen_function, + parameters=None, + schema=schema, + type=PythonTVFType.TUPLES, + ) + + result = conn.sql("SELECT * FROM gen_function(5)").fetchall() + + assert len(result) == 5 + assert result[0][0] == "name_0" + assert result[-1][-1] == 4 + + result = conn.sql("SELECT * FROM gen_function()").fetchall() + + assert len(result) == 10 + assert result[-1][0] == "name_9" + assert result[-1][1] == 9 + + +@pytest.mark.parametrize("gen_function", [simple_generator]) +def test_simple_large_fetchall(tmp_path, gen_function): + count = 2048 * 1000 + with duckdb.connect(tmp_path / "test.duckdb") as conn: + schema = [["name", "VARCHAR"], ["id", "INT"]] + + conn.create_table_function( + name="gen_function", + callable=gen_function, + parameters=None, + schema=schema, + type="tuples", + ) + + result = conn.sql( + "SELECT * FROM gen_function(?)", + params=(count,), + ).fetchall() + + assert len(result) == count + assert result[0][0] == "name_0" + assert result[-1][-1] == count - 1 + + +@pytest.mark.parametrize("gen_function", [simple_generator]) +def test_simple_large_df(tmp_path, gen_function): + count = 2048 * 1000 + with duckdb.connect(tmp_path / "test.duckdb") as conn: + schema = [["name", "VARCHAR"], ["id", "INT"]] + + conn.create_table_function( + name="gen_function", + callable=gen_function, + parameters=None, + schema=schema, + type="tuples", + ) + + result = conn.sql( + "SELECT * FROM gen_function(?)", + params=(count,), + ).df() + + assert len(result) == count + + +def test_no_schema(tmp_path): + def gen_function(n): + return n + + with duckdb.connect(tmp_path / "test.duckdb") as conn: + with pytest.raises(duckdb.InvalidInputException): + conn.create_table_function( + name="gen_function", + callable=gen_function, + type="tuples", + ) + + +def test_returns_scalar(tmp_path): + def gen_function(n): + return n + + with duckdb.connect(tmp_path / "test.duckdb") as conn: + with pytest.raises(duckdb.InvalidInputException): + conn.create_table_function( + name="gen_function", + callable=gen_function, + parameters=["n"], + schema=["value"], + type="tuples", + ) + + +def test_returns_list_scalar(tmp_path): + def gen_function_2(n): + return [n] + + with duckdb.connect(tmp_path / "test.duckdb") as conn: + with pytest.raises(duckdb.InvalidInputException): + conn.create_table_function( + name="gen_function_2", + callable=gen_function_2, + schema=["value"], + type="tuples", + ) + + +def test_returns_wrong_schema(tmp_path): + def gen_function(n): + return list[range(n)] + + with duckdb.connect(tmp_path / "test.duckdb") as conn: + schema = [["name", "VARCHAR"], ["id", "INT"]] + + conn.create_table_function( + name="gen_function", + callable=gen_function, + schema=schema, + type="tuples", + ) + with pytest.raises(duckdb.InvalidInputException): + conn.sql("SELECT * FROM gen_function(5)").fetchall() + + +def test_kwargs(tmp_path): + def simple_pylist(count, foo=10): + return [(f"name_{i}_{foo}", i) for i in range(count)] + + with duckdb.connect(tmp_path / "test.duckdb") as conn: + conn.create_table_function( + name="simple_pylist", + callable=simple_pylist, + parameters=["count"], + schema=[["name", "VARCHAR"], ["id", "INT"]], + type="tuples", + ) + result = conn.sql("SELECT * FROM simple_pylist(3)").fetchall() + assert result[-1][0] == "name_2_10" + + result = conn.sql("SELECT * FROM simple_pylist(count:=3)").fetchall() + assert result[-1][0] == "name_2_10" + + with pytest.raises(duckdb.BinderException): + result = conn.sql( + "SELECT * FROM simple_pylist(count:=3, foo:=2)" + ).fetchall() + + +def test_large_2(tmp_path): + """aggregtes and filtering""" + with duckdb.connect(tmp_path / "test.db") as conn: + count = 500000 + + def large_generator(): + return [(f"item_{i}", i) for i in range(count)] + + schema = [("name", "VARCHAR"), ("id", "INT")] + + conn.create_table_function( + name="large_tvf", + callable=large_generator, + parameters=None, + schema=schema, + type="tuples", + ) + + result = conn.execute("SELECT COUNT(*) FROM large_tvf()").fetchone() + assert result[0] == count + + result = conn.sql("SELECT MAX(id) FROM large_tvf()").fetchone() + assert result[0] == count - 1 + + result = conn.execute( + "SELECT COUNT(*) FROM large_tvf() WHERE id < 100" + ).fetchone() + assert result[0] == 100 + + +def test__parameters(tmp_path): + with duckdb.connect(tmp_path / "test.db") as conn: + + def parametrized_function(count=10, prefix="item"): + return [(f"{prefix}_{i}", i) for i in range(count)] + + schema = [("name", "VARCHAR"), ("id", "INT")] + + conn.create_table_function( + name="param_tvf", + callable=parametrized_function, + parameters=["count", "prefix"], + schema=schema, + type="tuples", + ) + + result1 = conn.execute("SELECT COUNT(*) FROM param_tvf(5, 'test')").fetchone() + assert result1[0] == 5 + + result2 = conn.execute( + "SELECT COUNT(*) FROM param_tvf(20, prefix:='data')" + ).fetchone() + assert result2[0] == 20 + + # Test parameter order + result3 = conn.execute( + "SELECT name FROM param_tvf(3, 'xyz') ORDER BY id LIMIT 1" + ).fetchone() + assert result3[0] == "xyz_0" + + +def test_error(tmp_path): + with duckdb.connect(tmp_path / "test.db") as conn: + + def error_function(): + raise ValueError("Intentional") + + schema = [("name", "VARCHAR"), ("id", "INT")] + + conn.create_table_function( + name="error_tvf", + callable=error_function, + parameters=None, + schema=schema, + type="tuples", + ) + + with pytest.raises(Exception): + conn.execute("SELECT * FROM error_tvf()").fetchall() + + +def test_callable_refcount(tmp_path): + import sys + + def gen_function(n): + return [(f"name_{i}", i) for i in range(n)] + + initial_refcount = sys.getrefcount(gen_function) + + with duckdb.connect(tmp_path / "test.duckdb") as conn: + schema = [["name", "VARCHAR"], ["id", "INT"]] + + conn.create_table_function( + name="gen_function", + callable=gen_function, + schema=schema, + type="tuples", + ) + + after_register_refcount = sys.getrefcount(gen_function) + assert after_register_refcount > initial_refcount, ( + f"Expected refcount to increase after registration, " + f"but got {after_register_refcount} (initial: {initial_refcount})" + ) + + for _ in range(3): + result = conn.sql("SELECT * FROM gen_function(5)").fetchall() + assert len(result) == 5 + + after_execution_refcount = sys.getrefcount(gen_function) + assert after_execution_refcount == after_register_refcount, ( + f"Expected refcount to remain stable after execution, " + f"but got {after_execution_refcount} (after register: {after_register_refcount})" + ) + + final_refcount = sys.getrefcount(gen_function) + assert final_refcount == initial_refcount, ( + f"Expected refcount to return to initial after unregistration, " + f"but got {final_refcount} (initial: {initial_refcount})" + ) diff --git a/tests/fast/tvf/test_tuples_datatypes.py b/tests/fast/tvf/test_tuples_datatypes.py new file mode 100644 index 00000000..91bbb0bd --- /dev/null +++ b/tests/fast/tvf/test_tuples_datatypes.py @@ -0,0 +1,94 @@ +import pytest + +import duckdb + + +def test_bigint_params(tmp_path): + def bigint_func(big_value): + return [(big_value, big_value + 1, big_value * 2)] + + with duckdb.connect(tmp_path / "test.duckdb") as conn: + conn.create_table_function( + name="bigint_func", + callable=bigint_func, + schema=[["orig", "BIGINT"], ["plus_one", "BIGINT"], ["doubled", "BIGINT"]], + type="tuples", + ) + + large_val = 4611686018427387900 # Half of max int64 + result = conn.sql( + f"SELECT * FROM bigint_func(?)", params=(large_val,) + ).fetchall() + assert result[0][0] == large_val + assert result[0][1] == large_val + 1 + assert result[0][2] == large_val * 2 + + +def test_hugeint_params(tmp_path): + def hugeint_func(huge_value): + return [(huge_value, huge_value + 1)] + + with duckdb.connect(tmp_path / "test.duckdb") as conn: + conn.create_table_function( + name="hugeint_func", + callable=hugeint_func, + schema=[["orig", "HUGEINT"], ["plus_one", "HUGEINT"]], + type="tuples", + ) + + huge_val = 9223372036854775808 + result = conn.sql( + f"SELECT * FROM hugeint_func(?)", params=(huge_val,) + ).fetchall() + assert result[0][0] == huge_val + assert result[0][1] == huge_val + 1 + + +def test_decimal_params(tmp_path): + from decimal import Decimal + + def decimal_func(dec_value): + if isinstance(dec_value, float): + result = dec_value * 2 + else: + result = Decimal(str(dec_value)) * 2 + return [(dec_value, result)] + + with duckdb.connect(tmp_path / "test.duckdb") as conn: + conn.create_table_function( + name="decimal_func", + callable=decimal_func, + schema=[["orig", "DECIMAL(10,2)"], ["doubled", "DECIMAL(10,2)"]], + type="tuples", + ) + + result = conn.sql( + "SELECT * FROM decimal_func(?::decimal)", params=(123.45,) + ).fetchall() + assert float(result[0][0]) == 123.45 + assert float(result[0][1]) == 246.90 + + +def test_uuid_params(tmp_path): + import uuid + + def uuid_func(uuid_value): + if isinstance(uuid_value, str): + parsed = uuid.UUID(uuid_value) + else: + parsed = uuid_value + return [(str(uuid_value),)] + + with duckdb.connect(tmp_path / "test.duckdb") as conn: + conn.create_table_function( + name="uuid_func", + callable=uuid_func, + schema=[("orig", "UUID")], + type="tuples", + ) + + test_uuid = "550e8400-e29b-41d4-a716-446655440000" + result = conn.sql( + f"SELECT * FROM uuid_func(?::uuid)", params=(test_uuid,) + ).fetchall() + assert str(result[0][0]) == test_uuid From 8384dc3e41dc31e59491f992adeff87dc346ea89 Mon Sep 17 00:00:00 2001 From: Paul Timmins Date: Sat, 11 Oct 2025 13:45:12 +0000 Subject: [PATCH 3/4] Cleanup from first round of PR review: Rename from TVF to Table UDF, use a dict[str,duckdb.sqltype] for schema, kwargs for create_table_function, clean up tests --- duckdb/func/__init__.py | 4 +- duckdb/functional/__init__.py | 6 +- scripts/connection_methods.json | 7 +- src/duckdb_py/CMakeLists.txt | 2 +- src/duckdb_py/functional/functional.cpp | 6 +- ...num.hpp => python_table_udf_type_enum.hpp} | 28 +- .../pyconnection/pyconnection.hpp | 10 +- src/duckdb_py/pyconnection.cpp | 16 +- .../{python_tvf.cpp => python_table_udf.cpp} | 163 ++++--- tests/fast/table_udf/test_arrow.py | 247 ++++++++++ .../{tvf => table_udf}/test_arrow_schema.py | 63 ++- .../fast/{tvf => table_udf}/test_register.py | 58 +-- tests/fast/table_udf/test_schema.py | 425 ++++++++++++++++++ tests/fast/{tvf => table_udf}/test_tuples.py | 140 +++--- .../test_tuples_datatypes.py | 36 +- tests/fast/tvf/test_arrow.py | 204 --------- 16 files changed, 943 insertions(+), 472 deletions(-) rename src/duckdb_py/include/duckdb_python/pybind11/conversions/{python_tvf_type_enum.hpp => python_table_udf_type_enum.hpp} (58%) rename src/duckdb_py/{python_tvf.cpp => python_table_udf.cpp} (58%) create mode 100644 tests/fast/table_udf/test_arrow.py rename tests/fast/{tvf => table_udf}/test_arrow_schema.py (63%) rename tests/fast/{tvf => table_udf}/test_register.py (84%) create mode 100644 tests/fast/table_udf/test_schema.py rename tests/fast/{tvf => table_udf}/test_tuples.py (63%) rename tests/fast/{tvf => table_udf}/test_tuples_datatypes.py (64%) delete mode 100644 tests/fast/tvf/test_arrow.py diff --git a/duckdb/func/__init__.py b/duckdb/func/__init__.py index 5d73f490..0518f667 100644 --- a/duckdb/func/__init__.py +++ b/duckdb/func/__init__.py @@ -1,3 +1,3 @@ -from _duckdb._func import ARROW, DEFAULT, NATIVE, SPECIAL, FunctionNullHandling, PythonUDFType # noqa: D104 +from _duckdb._func import ARROW, DEFAULT, NATIVE, SPECIAL, FunctionNullHandling, PythonTableUDFType, PythonUDFType # noqa: D104 -__all__ = ["ARROW", "DEFAULT", "NATIVE", "SPECIAL", "FunctionNullHandling", "PythonUDFType"] +__all__ = ["ARROW", "DEFAULT", "NATIVE", "SPECIAL", "FunctionNullHandling", "PythonTableUDFType", "PythonUDFType"] diff --git a/duckdb/functional/__init__.py b/duckdb/functional/__init__.py index e9f2f4db..9830c94f 100644 --- a/duckdb/functional/__init__.py +++ b/duckdb/functional/__init__.py @@ -2,12 +2,12 @@ import warnings -from duckdb.func import ARROW, DEFAULT, NATIVE, SPECIAL, FunctionNullHandling, PythonTVFType, PythonUDFType +from duckdb.func import ARROW, DEFAULT, NATIVE, SPECIAL, FunctionNullHandling, PythonTableUDFType, PythonUDFType -__all__ = ["ARROW", "DEFAULT", "NATIVE", "SPECIAL", "FunctionNullHandling", "PythonTVFType", "PythonUDFType"] +__all__ = ["ARROW", "DEFAULT", "NATIVE", "SPECIAL", "FunctionNullHandling", "PythonTableUDFType", "PythonUDFType"] warnings.warn( "`duckdb.functional` is deprecated and will be removed in a future version. Please use `duckdb.func` instead.", DeprecationWarning, stacklevel=2, -) \ No newline at end of file +) diff --git a/scripts/connection_methods.json b/scripts/connection_methods.json index fcba443b..b11cce49 100644 --- a/scripts/connection_methods.json +++ b/scripts/connection_methods.json @@ -134,8 +134,8 @@ }, { "name": "type", - "type": "Optional[PythonTVFType]", - "default": "PythonTVFType.TUPLES" + "type": "Optional[PythonTableUDFType]", + "default": "PythonTableUDFType.TUPLES" } ], "return": "DuckDBPyConnection" @@ -457,7 +457,6 @@ "fetch_record_batch", "arrow" ], - "function": "FetchRecordBatchReader", "docs": "Fetch an Arrow RecordBatchReader following execute()", "args": [ @@ -1139,4 +1138,4 @@ ], "return": "None" } -] +] \ No newline at end of file diff --git a/src/duckdb_py/CMakeLists.txt b/src/duckdb_py/CMakeLists.txt index 9e2083cc..1fcede3e 100644 --- a/src/duckdb_py/CMakeLists.txt +++ b/src/duckdb_py/CMakeLists.txt @@ -28,7 +28,7 @@ add_library( python_dependency.cpp python_import_cache.cpp python_replacement_scan.cpp - python_tvf.cpp + python_table_udf.cpp python_udf.cpp) target_link_libraries(python_src PRIVATE _duckdb_dependencies) diff --git a/src/duckdb_py/functional/functional.cpp b/src/duckdb_py/functional/functional.cpp index c1e58eb4..32e64145 100644 --- a/src/duckdb_py/functional/functional.cpp +++ b/src/duckdb_py/functional/functional.cpp @@ -10,9 +10,9 @@ void DuckDBPyFunctional::Initialize(py::module_ &parent) { .value("ARROW", duckdb::PythonUDFType::ARROW) .export_values(); - py::enum_(m, "PythonTVFType") - .value("TUPLES", duckdb::PythonTVFType::TUPLES) - .value("ARROW_TABLE", duckdb::PythonTVFType::ARROW_TABLE) + py::enum_(m, "PythonTableUDFType") + .value("TUPLES", duckdb::PythonTableUDFType::TUPLES) + .value("ARROW_TABLE", duckdb::PythonTableUDFType::ARROW_TABLE) .export_values(); py::enum_(m, "FunctionNullHandling") diff --git a/src/duckdb_py/include/duckdb_python/pybind11/conversions/python_tvf_type_enum.hpp b/src/duckdb_py/include/duckdb_python/pybind11/conversions/python_table_udf_type_enum.hpp similarity index 58% rename from src/duckdb_py/include/duckdb_python/pybind11/conversions/python_tvf_type_enum.hpp rename to src/duckdb_py/include/duckdb_python/pybind11/conversions/python_table_udf_type_enum.hpp index 729669c5..d61deeca 100644 --- a/src/duckdb_py/include/duckdb_python/pybind11/conversions/python_tvf_type_enum.hpp +++ b/src/duckdb_py/include/duckdb_python/pybind11/conversions/python_table_udf_type_enum.hpp @@ -10,30 +10,30 @@ using duckdb::StringUtil; namespace duckdb { -enum class PythonTVFType : uint8_t { TUPLES, ARROW_TABLE }; +enum class PythonTableUDFType : uint8_t { TUPLES, ARROW_TABLE }; } // namespace duckdb -using duckdb::PythonTVFType; +using duckdb::PythonTableUDFType; namespace py = pybind11; -static PythonTVFType PythonTVFTypeFromString(const string &type) { +static PythonTableUDFType PythonTableUDFTypeFromString(const string &type) { auto ltype = StringUtil::Lower(type); if (ltype.empty() || ltype == "tuples") { - return PythonTVFType::TUPLES; + return PythonTableUDFType::TUPLES; } else if (ltype == "arrow_table") { - return PythonTVFType::ARROW_TABLE; + return PythonTableUDFType::ARROW_TABLE; } else { throw InvalidInputException("'%s' is not a recognized type for 'tvf_type'", type); } } -static PythonTVFType PythonTVFTypeFromInteger(int64_t value) { +static PythonTableUDFType PythonTableUDFTypeFromInteger(int64_t value) { if (value == 0) { - return PythonTVFType::TUPLES; + return PythonTableUDFType::TUPLES; } else if (value == 1) { - return PythonTVFType::ARROW_TABLE; + return PythonTableUDFType::ARROW_TABLE; } else { throw InvalidInputException("'%d' is not a recognized type for 'tvf_type'", value); } @@ -43,27 +43,27 @@ namespace PYBIND11_NAMESPACE { namespace detail { template <> -struct type_caster : public type_caster_base { - using base = type_caster_base; - PythonTVFType tmp; +struct type_caster : public type_caster_base { + using base = type_caster_base; + PythonTableUDFType tmp; public: bool load(handle src, bool convert) { if (base::load(src, convert)) { return true; } else if (py::isinstance(src)) { - tmp = PythonTVFTypeFromString(py::str(src)); + tmp = PythonTableUDFTypeFromString(py::str(src)); value = &tmp; return true; } else if (py::isinstance(src)) { - tmp = PythonTVFTypeFromInteger(src.cast()); + tmp = PythonTableUDFTypeFromInteger(src.cast()); value = &tmp; return true; } return false; } - static handle cast(PythonTVFType src, return_value_policy policy, handle parent) { + static handle cast(PythonTableUDFType src, return_value_policy policy, handle parent) { return base::cast(src, policy, parent); } }; diff --git a/src/duckdb_py/include/duckdb_python/pyconnection/pyconnection.hpp b/src/duckdb_py/include/duckdb_python/pyconnection/pyconnection.hpp index c59a9a25..9109d5c5 100644 --- a/src/duckdb_py/include/duckdb_python/pyconnection/pyconnection.hpp +++ b/src/duckdb_py/include/duckdb_python/pyconnection/pyconnection.hpp @@ -23,7 +23,7 @@ #include "duckdb/function/scalar_function.hpp" #include "duckdb_python/pybind11/conversions/exception_handling_enum.hpp" #include "duckdb_python/pybind11/conversions/python_udf_type_enum.hpp" -#include "duckdb_python/pybind11/conversions/python_tvf_type_enum.hpp" +#include "duckdb_python/pybind11/conversions/python_table_udf_type_enum.hpp" #include "duckdb_python/pybind11/conversions/python_csv_line_terminator_enum.hpp" #include "duckdb/common/shared_ptr.hpp" @@ -236,9 +236,9 @@ struct DuckDBPyConnection : public enable_shared_from_this { bool side_effects = false); shared_ptr RegisterTableFunction(const string &name, const py::function &function, - const py::object ¶meters = py::none(), - const py::object &schema = py::none(), - PythonTVFType type = PythonTVFType::TUPLES); + const py::object &schema, + PythonTableUDFType type = PythonTableUDFType::TUPLES, + const py::object ¶meters = py::none()); shared_ptr UnregisterTableFunction(const string &name); @@ -368,7 +368,7 @@ struct DuckDBPyConnection : public enable_shared_from_this { duckdb::TableFunction CreateTableFunctionFromCallable(const std::string &name, const py::function &callable, const py::object ¶meters, const py::object &schema, - PythonTVFType type); + PythonTableUDFType type); void RegisterArrowObject(const py::object &arrow_object, const string &name); vector> GetStatements(const py::object &query); diff --git a/src/duckdb_py/pyconnection.cpp b/src/duckdb_py/pyconnection.cpp index 4db77b46..32645690 100644 --- a/src/duckdb_py/pyconnection.cpp +++ b/src/duckdb_py/pyconnection.cpp @@ -402,11 +402,9 @@ DuckDBPyConnection::RegisterScalarUDF(const string &name, const py::function &ud return shared_from_this(); } -shared_ptr DuckDBPyConnection::RegisterTableFunction(const string &name, - const py::function &function, - const py::object ¶meters, - const py::object &schema, - PythonTVFType type) { +shared_ptr +DuckDBPyConnection::RegisterTableFunction(const string &name, const py::function &function, const py::object &schema, + PythonTableUDFType type, const py::object ¶meters) { auto &connection = con.GetConnection(); auto &context = *connection.context; @@ -462,12 +460,12 @@ void DuckDBPyConnection::Initialize(py::handle &m) { connection_module.def("__del__", &DuckDBPyConnection::Close); connection_module.def("create_table_function", &DuckDBPyConnection::RegisterTableFunction, - "Register a table valued function via Callable", py::arg("name"), py::arg("callable"), - py::arg("parameters") = py::none(), py::arg("schema") = py::none(), - py::arg("type") = PythonTVFType::TUPLES); + "Register a table user defined function via Callable", py::arg("name"), py::arg("callable"), + py::arg("schema"), py::kw_only(), py::arg("type") = PythonTableUDFType::TUPLES, + py::arg("parameters") = py::none()); connection_module.def("unregister_table_function", &DuckDBPyConnection::UnregisterTableFunction, - "Unregister a table valued function", py::arg("name")); + "Unregister a table user defined function", py::arg("name")); InitializeConnectionMethods(connection_module); connection_module.def_property_readonly("description", &DuckDBPyConnection::GetDescription, diff --git a/src/duckdb_py/python_tvf.cpp b/src/duckdb_py/python_table_udf.cpp similarity index 58% rename from src/duckdb_py/python_tvf.cpp rename to src/duckdb_py/python_table_udf.cpp index b1538251..ed45f105 100644 --- a/src/duckdb_py/python_tvf.cpp +++ b/src/duckdb_py/python_table_udf.cpp @@ -13,24 +13,25 @@ namespace duckdb { -struct PyTVFInfo : public TableFunctionInfo { +struct PyTableUDFInfo : public TableFunctionInfo { py::function callable; vector return_types; vector return_names; - PythonTVFType return_type; + PythonTableUDFType return_type; - PyTVFInfo(py::function callable_p, vector types_p, vector names_p, PythonTVFType return_type_p) + PyTableUDFInfo(py::function callable_p, vector types_p, vector names_p, + PythonTableUDFType return_type_p) : callable(std::move(callable_p)), return_types(std::move(types_p)), return_names(std::move(names_p)), return_type(return_type_p) { } - ~PyTVFInfo() override { + ~PyTableUDFInfo() override { py::gil_scoped_acquire acquire; callable = py::function(); } }; -struct PyTVFBindData : public TableFunctionData { +struct PyTableUDFBindData : public TableFunctionData { string func_name; vector args; named_parameter_map_t kwargs; @@ -38,8 +39,8 @@ struct PyTVFBindData : public TableFunctionData { vector return_names; PythonObjectContainer python_objects; // Holds the callable - PyTVFBindData(string func_name, vector args, named_parameter_map_t kwargs, vector return_types, - vector return_names, py::function callable) + PyTableUDFBindData(string func_name, vector args, named_parameter_map_t kwargs, + vector return_types, vector return_names, py::function callable) : func_name(std::move(func_name)), args(std::move(args)), kwargs(std::move(kwargs)), return_types(std::move(return_types)), return_names(std::move(return_names)) { // gil acquired inside push @@ -47,28 +48,28 @@ struct PyTVFBindData : public TableFunctionData { } }; -struct PyTVFTuplesGlobalState : public GlobalTableFunctionState { +struct PyTableUDFTuplesGlobalState : public GlobalTableFunctionState { PythonObjectContainer python_objects; bool iterator_exhausted = false; - PyTVFTuplesGlobalState() : iterator_exhausted(false) { + PyTableUDFTuplesGlobalState() : iterator_exhausted(false) { } }; -struct PyTVFArrowGlobalState : public GlobalTableFunctionState { +struct PyTableUDFArrowGlobalState : public GlobalTableFunctionState { unique_ptr arrow_factory; unique_ptr arrow_bind_data; unique_ptr arrow_global_state; PythonObjectContainer python_objects; idx_t num_columns; - PyTVFArrowGlobalState() { + PyTableUDFArrowGlobalState() { } }; -static void PyTVFTuplesScanFunction(ClientContext &context, TableFunctionInput &input, DataChunk &output) { - auto &gs = input.global_state->Cast(); - auto &bd = input.bind_data->Cast(); +static void PyTableUDFTuplesScanFunction(ClientContext &context, TableFunctionInput &input, DataChunk &output) { + auto &gs = input.global_state->Cast(); + auto &bd = input.bind_data->Cast(); if (gs.iterator_exhausted) { output.SetCardinality(0); @@ -107,60 +108,55 @@ static void PyTVFTuplesScanFunction(ClientContext &context, TableFunctionInput & output.SetCardinality(row_idx); } -struct PyTVFArrowLocalState : public LocalTableFunctionState { +struct PyTableUDFArrowLocalState : public LocalTableFunctionState { unique_ptr arrow_local_state; - explicit PyTVFArrowLocalState(unique_ptr arrow_local) + explicit PyTableUDFArrowLocalState(unique_ptr arrow_local) : arrow_local_state(std::move(arrow_local)) { } }; -static void PyTVFArrowScanFunction(ClientContext &context, TableFunctionInput &input, DataChunk &output) { +static void PyTableUDFArrowScanFunction(ClientContext &context, TableFunctionInput &input, DataChunk &output) { // Delegates to ArrowScanFunction - auto &gs = input.global_state->Cast(); - auto &ls = input.local_state->Cast(); + auto &gs = input.global_state->Cast(); + auto &ls = input.local_state->Cast(); TableFunctionInput arrow_input(gs.arrow_bind_data.get(), ls.arrow_local_state.get(), gs.arrow_global_state.get()); ArrowTableFunction::ArrowScanFunction(context, arrow_input, output); } -static unique_ptr PyTVFBindInternal(ClientContext &context, TableFunctionBindInput &in, - vector &return_types, vector &return_names) { - // Disable progress bar to prevent GIL deadlock with Jupyter - // TODO: Decide if this is still needed - was a problem when fully materializing, but switched to streaming - ClientConfig::GetConfig(context).enable_progress_bar = false; - ClientConfig::GetConfig(context).system_progress_bar_disable_reason = - "Table Valued Functions do not support the progress bar"; - +static unique_ptr PyTableUDFBindInternal(ClientContext &context, TableFunctionBindInput &in, + vector &return_types, + vector &return_names) { if (!in.info) { throw InvalidInputException("Table function '%s' missing function info", in.table_function.name); } - auto &tvf_info = in.info->Cast(); - return_types = tvf_info.return_types; - return_names = tvf_info.return_names; + auto &tableudf_info = in.info->Cast(); + return_types = tableudf_info.return_types; + return_names = tableudf_info.return_names; // Acquire gil before copying py::function py::gil_scoped_acquire gil; - return make_uniq(in.table_function.name, in.inputs, in.named_parameters, return_types, return_names, - tvf_info.callable); + return make_uniq(in.table_function.name, in.inputs, in.named_parameters, return_types, + return_names, tableudf_info.callable); } -static unique_ptr PyTVFTuplesBindFunction(ClientContext &context, TableFunctionBindInput &in, - vector &return_types, - vector &return_names) { - auto bd = PyTVFBindInternal(context, in, return_types, return_names); +static unique_ptr PyTableUDFTuplesBindFunction(ClientContext &context, TableFunctionBindInput &in, + vector &return_types, + vector &return_names) { + auto bd = PyTableUDFBindInternal(context, in, return_types, return_names); return std::move(bd); } -static unique_ptr PyTVFArrowBindFunction(ClientContext &context, TableFunctionBindInput &in, - vector &return_types, - vector &return_names) { - auto bd = PyTVFBindInternal(context, in, return_types, return_names); +static unique_ptr PyTableUDFArrowBindFunction(ClientContext &context, TableFunctionBindInput &in, + vector &return_types, + vector &return_names) { + auto bd = PyTableUDFBindInternal(context, in, return_types, return_names); return std::move(bd); } -static py::object CallPythonTVF(ClientContext &context, PyTVFBindData &bd) { +static py::object CallPythonTableUDF(ClientContext &context, PyTableUDFBindData &bd) { py::gil_scoped_acquire gil; // positional arguments @@ -187,14 +183,15 @@ static py::object CallPythonTVF(ClientContext &context, PyTVFBindData &bd) { return result; } -static unique_ptr PyTVFTuplesInitGlobal(ClientContext &context, TableFunctionInitInput &in) { - auto &bd = in.bind_data->Cast(); - auto gs = make_uniq(); +static unique_ptr PyTableUDFTuplesInitGlobal(ClientContext &context, + TableFunctionInitInput &in) { + auto &bd = in.bind_data->Cast(); + auto gs = make_uniq(); { py::gil_scoped_acquire gil; // const_cast is safe here - we only read from python_objects, not modify bind_data structure - py::object result = CallPythonTVF(context, const_cast(bd)); + py::object result = CallPythonTableUDF(context, const_cast(bd)); try { py::iterator it = py::iter(result); gs->python_objects.Push(std::move(it)); @@ -207,17 +204,17 @@ static unique_ptr PyTVFTuplesInitGlobal(ClientContext return std::move(gs); } -static unique_ptr PyTVFArrowInitGlobal(ClientContext &context, TableFunctionInitInput &in) { - auto &bd = in.bind_data->Cast(); - auto gs = make_uniq(); +static unique_ptr PyTableUDFArrowInitGlobal(ClientContext &context, + TableFunctionInitInput &in) { + auto &bd = in.bind_data->Cast(); + auto gs = make_uniq(); { py::gil_scoped_acquire gil; - py::object result = CallPythonTVF(context, const_cast(bd)); + py::object result = CallPythonTableUDF(context, const_cast(bd)); PyObject *ptr = result.ptr(); - // TODO: Should we verify this is an arrow table, or just fail later gs->python_objects.Push(std::move(result)); gs->arrow_factory = make_uniq(ptr, context.GetClientProperties(), @@ -232,7 +229,7 @@ static unique_ptr PyTVFArrowInitGlobal(ClientContext & TableFunctionRef empty_ref; duckdb::TableFunction dummy_tf; - dummy_tf.name = "PyTVFArrowWrapper"; + dummy_tf.name = "PyTableUDFArrowWrapper"; named_parameter_map_t named_params; vector input_types; @@ -274,9 +271,9 @@ static unique_ptr PyTVFArrowInitGlobal(ClientContext & return std::move(gs); } -static unique_ptr PyTVFArrowInitLocal(ExecutionContext &context, TableFunctionInitInput &in, - GlobalTableFunctionState *gstate) { - auto &gs = gstate->Cast(); +static unique_ptr +PyTableUDFArrowInitLocal(ExecutionContext &context, TableFunctionInitInput &in, GlobalTableFunctionState *gstate) { + auto &gs = gstate->Cast(); vector all_columns; for (idx_t i = 0; i < gs.num_columns; i++) { @@ -287,14 +284,14 @@ static unique_ptr PyTVFArrowInitLocal(ExecutionContext auto arrow_local_state = ArrowTableFunction::ArrowScanInitLocalInternal(context.client, arrow_init, gs.arrow_global_state.get()); - return make_uniq(std::move(arrow_local_state)); + return make_uniq(std::move(arrow_local_state)); } duckdb::TableFunction DuckDBPyConnection::CreateTableFunctionFromCallable(const std::string &name, const py::function &callable, const py::object ¶meters, const py::object &schema, - PythonTVFType type) { + PythonTableUDFType type) { // Schema if (schema.is_none()) { @@ -303,17 +300,39 @@ duckdb::TableFunction DuckDBPyConnection::CreateTableFunctionFromCallable(const vector types; vector names; - for (auto c : py::iter(schema)) { - auto item = py::cast(c); - if (py::isinstance(item)) { - throw InvalidInputException("Invalid schema format: expected [name, type] pairs, got string '%s'", - py::str(item).cast()); + + // Schema must be dict format: {"col1": DuckDBPyType, "col2": DuckDBPyType} + if (!py::isinstance(schema)) { + throw InvalidInputException("Table function '%s' schema must be a dict mapping column names to duckdb.sqltypes " + "(e.g., {\"col1\": INTEGER, \"col2\": VARCHAR})", + name); + } + + auto schema_dict = py::cast(schema); + for (auto item : schema_dict) { + // schema is a dict of str => DuckDBPyType + + string col_name = py::str(item.first); + names.emplace_back(col_name); + + auto type_obj = py::cast(item.second); + + // Check for string BEFORE DuckDBPyType because pybind11 has implicit conversion from str to DuckDBPyType + if (py::isinstance(type_obj)) { + throw InvalidInputException("Invalid schema format: type for column '%s' must be a duckdb.sqltype (e.g., " + "INTEGER, VARCHAR), not a string. " + "Use sqltypes.%s instead of \"%s\"", + col_name, py::str(type_obj).cast().c_str(), + py::str(type_obj).cast().c_str()); } - if (!py::hasattr(item, "__getitem__") || py::len(item) < 2) { - throw InvalidInputException("Invalid schema format: each schema item must be a [name, type] pair"); + + if (!py::isinstance(type_obj)) { + throw InvalidInputException( + "Invalid schema format: type for column '%s' must be a duckdb.sqltype (e.g., INTEGER, VARCHAR), got %s", + col_name, py::str(type_obj.get_type()).cast()); } - names.emplace_back(py::str(item[py::int_(0)])); - types.emplace_back(TransformStringToLogicalType(py::str(item[py::int_(1)]))); + auto pytype = py::cast>(type_obj); + types.emplace_back(pytype->Type()); } if (types.empty()) { @@ -322,20 +341,20 @@ duckdb::TableFunction DuckDBPyConnection::CreateTableFunctionFromCallable(const duckdb::TableFunction tf; switch (type) { - case PythonTVFType::TUPLES: - tf = - duckdb::TableFunction(name, {}, +PyTVFTuplesScanFunction, +PyTVFTuplesBindFunction, +PyTVFTuplesInitGlobal); + case PythonTableUDFType::TUPLES: + tf = duckdb::TableFunction(name, {}, PyTableUDFTuplesScanFunction, PyTableUDFTuplesBindFunction, + PyTableUDFTuplesInitGlobal); break; - case PythonTVFType::ARROW_TABLE: - tf = duckdb::TableFunction(name, {}, +PyTVFArrowScanFunction, +PyTVFArrowBindFunction, +PyTVFArrowInitGlobal, - +PyTVFArrowInitLocal); + case PythonTableUDFType::ARROW_TABLE: + tf = duckdb::TableFunction(name, {}, PyTableUDFArrowScanFunction, PyTableUDFArrowBindFunction, + PyTableUDFArrowInitGlobal, PyTableUDFArrowInitLocal); break; default: throw InvalidInputException("Unknown return type for table function '%s'", name); } // Store the Python callable and schema - tf.function_info = make_shared_ptr(callable, types, names, type); + tf.function_info = make_shared_ptr(callable, types, names, type); // args tf.varargs = LogicalType::ANY; diff --git a/tests/fast/table_udf/test_arrow.py b/tests/fast/table_udf/test_arrow.py new file mode 100644 index 00000000..86bc956d --- /dev/null +++ b/tests/fast/table_udf/test_arrow.py @@ -0,0 +1,247 @@ +from typing import Iterator + +import pytest + +import duckdb +import duckdb.sqltypes as sqltypes +from duckdb.functional import PythonTableUDFType + + +def tuple_generator(count: int = 10) -> Iterator[tuple[str, int]]: + for i in range(count): + yield (f"name_{i}", i) + + +def simple_arrow_table(count: int): + pa = pytest.importorskip("pyarrow") + + data = { + "id": list(range(count)), + "value": [i * 2 for i in range(count)], + "name": [f"row_{i}" for i in range(count)], + } + return pa.table(data) + + +def arrow_all_types(count: int): + pa = pytest.importorskip("pyarrow") + from decimal import Decimal + from datetime import datetime, timedelta, timezone + + now = datetime.now(timezone.utc) + data = { + "col_tinyint": pa.array(range(count), type=pa.int8()), + "col_smallint": pa.array(range(count), type=pa.int16()), + "col_int": pa.array(range(count), type=pa.int32()), + "col_bigint": pa.array(range(count), type=pa.int64()), + "col_utinyint": pa.array(range(count), type=pa.uint8()), + "col_usmallint": pa.array(range(count), type=pa.uint16()), + "col_uint": pa.array(range(count), type=pa.uint32()), + "col_ubigint": pa.array(range(count), type=pa.uint64()), + "col_float": pa.array((i * 1.5 for i in range(count)), type=pa.float32()), + "col_double": pa.array((i * 2.5 for i in range(count)), type=pa.float64()), + "col_varchar": pa.array((f"row_{i}" for i in range(count)), type=pa.string()), + "col_bool": pa.array((i % 2 == 0 for i in range(count)), type=pa.bool_()), + "col_timestamp": pa.array( + (now + timedelta(seconds=i) for i in range(count)), type=pa.timestamp("us", tz="UTC") + ), + "col_date": pa.array((now.date() + timedelta(days=i) for i in range(count)), type=pa.date32()), + "col_time": pa.array(((now + timedelta(microseconds=i)).time() for i in range(count)), type=pa.time64("ns")), + "col_decimal": pa.array((Decimal(i) / 10 for i in range(count)), type=pa.decimal128(10, 2)), + "col_blob": pa.array((f"bin_{i}".encode() for i in range(count)), type=pa.binary()), + "col_list": pa.array(([i, i + 1] for i in range(count)), type=pa.list_(pa.int32())), + "col_struct": pa.array( + ({"x": i, "y": float(i)} for i in range(count)), type=pa.struct([("x", pa.int32()), ("y", pa.float32())]) + ), + } + return pa.table(data) + + +ALL_TYPES_SCHEMA = { + "col_tinyint": sqltypes.TINYINT, + "col_smallint": sqltypes.SMALLINT, + "col_int": sqltypes.INTEGER, + "col_bigint": sqltypes.BIGINT, + "col_utinyint": sqltypes.UTINYINT, + "col_usmallint": sqltypes.USMALLINT, + "col_uint": sqltypes.UINTEGER, + "col_ubigint": sqltypes.UBIGINT, + "col_float": sqltypes.FLOAT, + "col_double": sqltypes.DOUBLE, + "col_varchar": sqltypes.VARCHAR, + "col_bool": sqltypes.BOOLEAN, + "col_timestamp": sqltypes.TIMESTAMP_TZ, + "col_date": sqltypes.DATE, + "col_time": sqltypes.TIME, + "col_decimal": duckdb.decimal_type(10, 2), + "col_blob": sqltypes.BLOB, + "col_list": duckdb.list_type(sqltypes.INTEGER), + "col_struct": duckdb.struct_type({"x": sqltypes.INTEGER, "y": sqltypes.FLOAT}), +} + + +def test_arrow_small(tmp_path): + """Defines and creates a Table UDF with only positional parameters, verifies that it works + and verifies it fails from another connection scope. + """ + pa = pytest.importorskip("pyarrow") + + with duckdb.connect(tmp_path / "test.duckdb") as conn: + conn.create_table_function( + "simple_arrow", + simple_arrow_table, + schema={"x": sqltypes.BIGINT, "y": sqltypes.BIGINT, "name": sqltypes.VARCHAR}, + type=PythonTableUDFType.ARROW_TABLE, + ) + + result = conn.execute("SELECT * FROM simple_arrow(5)").fetchall() + + assert len(result) == 5 + + # Should fail because it's not defined in this conn + with duckdb.connect(tmp_path / "test2.duckdb") as conn, pytest.raises(duckdb.CatalogException): + result = conn.execute("SELECT * FROM simple_arrow(5)").fetchall() + + with duckdb.connect(tmp_path / "test.duckdb") as conn: + # Should fail because schema is missing a col + conn.create_table_function( + "simple_arrow", + simple_arrow_table, + schema={"x": sqltypes.BIGINT, "y": sqltypes.BIGINT}, + type=PythonTableUDFType.ARROW_TABLE, + ) + with pytest.raises(duckdb.InvalidInputException) as exc_info: + result = conn.execute("SELECT * FROM simple_arrow(5)").fetchall() + assert "Vector::Reference" in str(exc_info.value) or "schema" in str(exc_info.value).lower() + + +def test_arrow_large_1(tmp_path): + """tests: more rows, aggregation, limits, named parameters, parameters.""" + pa = pytest.importorskip("pyarrow") + + with duckdb.connect(tmp_path / "test.duckdb") as conn: + n = 2048 * 1000 + + conn.create_table_function( + "large_arrow", + simple_arrow_table, + schema={"id": sqltypes.BIGINT, "value": sqltypes.BIGINT, "name": sqltypes.VARCHAR}, + type="arrow_table", + parameters=["count"], + ) + + result = conn.execute("SELECT COUNT(*) FROM large_arrow(count:=?)", parameters=(n,)).fetchone() + assert result[0] == n + + df = conn.sql(f"SELECT * FROM large_arrow({n}) LIMIT 10").df() + assert len(df) == 10 + assert df["id"].tolist() == list(range(10)) + + arrow_result = conn.execute("SELECT * FROM large_arrow(?)", parameters=(n,)).fetch_arrow_table() + assert len(arrow_result) == n + + result = conn.sql("SELECT SUM(value) FROM large_arrow(count:=$count)", params={"count": n}).fetchone() + expected_sum = sum(i * 2 for i in range(n)) + assert result[0] == expected_sum + + +def test_arrowbatched_execute(tmp_path): + pytest.importorskip("pyarrow") + + count = 2048 * 1000 + with duckdb.connect(tmp_path / "test.duckdb") as conn: + schema = {"name": sqltypes.VARCHAR, "id": sqltypes.INTEGER} + + conn.create_table_function( + name="gen_function", + callable=tuple_generator, + parameters=None, + schema=schema, + type="tuples", + ) + + result = conn.execute( + "SELECT * FROM gen_function(?)", + parameters=(count,), + ).fetch_record_batch() + + result = conn.execute( + f"SELECT * FROM gen_function({count})", + ).fetch_record_batch() + + c = 0 + for batch in result: + c += batch.num_rows + assert c == count + + +def test_arrowbatched_sql_relation(tmp_path): + pytest.importorskip("pyarrow") + + count = 2048 * 1000 + with duckdb.connect(tmp_path / "test.duckdb") as conn: + schema = {"name": sqltypes.VARCHAR, "id": sqltypes.INTEGER} + + conn.create_table_function( + name="gen_function", + callable=tuple_generator, + parameters=None, + schema=schema, + type="tuples", + ) + + result = conn.sql( + f"SELECT * FROM gen_function({count})", + ).fetch_arrow_reader() + + c = 0 + for batch in result: + c += batch.num_rows + assert c == count + + +def test_arrow_types(tmp_path): + """Return many types from an arrow table UDF, and verify the results are correct""" + pa = pytest.importorskip("pyarrow") + + with duckdb.connect(tmp_path / "test.duckdb") as conn: + conn.create_table_function( + "all_types_arrow", + arrow_all_types, + schema=ALL_TYPES_SCHEMA, + type=PythonTableUDFType.ARROW_TABLE, + ) + + result = conn.execute("SELECT * FROM all_types_arrow(3)").fetchall() + assert len(result) == 3 + + first_row = result[0] + assert first_row[0] == 0 # col_tinyint + assert first_row[1] == 0 # col_smallint + assert first_row[2] == 0 # col_int + assert first_row[3] == 0 # col_bigint + assert first_row[10] == "row_0" # col_varchar + assert first_row[11] is True # col_bool + + result = conn.execute("SELECT SUM(col_int), AVG(col_float) FROM all_types_arrow(100)").fetchone() + expected_sum = sum(range(100)) + assert result[0] == expected_sum + + result = conn.execute("SELECT COUNT(*) FROM all_types_arrow(50) WHERE col_bool = true").fetchone() + assert result[0] == 25 # Half should be true (even numbers) + + result = conn.execute("SELECT col_varchar, col_int FROM all_types_arrow(5)").fetchall() + assert len(result) == 5 + assert result[2] == ("row_2", 2) + + result = conn.execute("SELECT col_list FROM all_types_arrow(2)").fetchall() + assert result[0][0] == [0, 1] + assert result[1][0] == [1, 2] + + result = conn.sql("SELECT col_struct FROM all_types_arrow(2)").fetchall() + assert result[0][0] == {"x": 0, "y": 0.0} + assert result[1][0] == {"x": 1, "y": 1.0} + + schema_result = conn.sql("DESCRIBE SELECT * FROM all_types_arrow(1)").fetchall() + column_names = [row[0] for row in schema_result] + assert list(ALL_TYPES_SCHEMA.keys()) == column_names diff --git a/tests/fast/tvf/test_arrow_schema.py b/tests/fast/table_udf/test_arrow_schema.py similarity index 63% rename from tests/fast/tvf/test_arrow_schema.py rename to tests/fast/table_udf/test_arrow_schema.py index 2b089fc0..61db1105 100644 --- a/tests/fast/tvf/test_arrow_schema.py +++ b/tests/fast/table_udf/test_arrow_schema.py @@ -1,9 +1,10 @@ -"""Test Arrow TVF schema validation""" +"""Test Arrow Table UDF schema validation""" import pytest import duckdb -from duckdb.functional import PythonTVFType +import duckdb.sqltypes as sqltypes +from duckdb.functional import PythonTableUDFType def simple_arrow_table(count: int = 10): @@ -24,8 +25,8 @@ def test_arrow_correct_schema(tmp_path): conn.create_table_function( "arrow_func", simple_arrow_table, - schema=[("id", "BIGINT"), ("value", "BIGINT"), ("name", "VARCHAR")], - type=PythonTVFType.ARROW_TABLE, + schema={"id": sqltypes.BIGINT, "value": sqltypes.BIGINT, "name": sqltypes.VARCHAR}, + type=PythonTableUDFType.ARROW_TABLE, ) result = conn.execute("SELECT * FROM arrow_func(5)").fetchall() @@ -41,19 +42,15 @@ def test_arrow_more_columns(tmp_path): conn.create_table_function( "arrow_func", simple_arrow_table, - schema=[("x", "BIGINT"), ("y", "BIGINT")], # Missing third column - type=PythonTVFType.ARROW_TABLE, + schema={"x": sqltypes.BIGINT, "y": sqltypes.BIGINT}, # Missing third column + type=PythonTableUDFType.ARROW_TABLE, ) with pytest.raises(duckdb.InvalidInputException) as exc_info: conn.execute("SELECT * FROM arrow_func(5)").fetchall() error_msg = str(exc_info.value).lower() - assert ( - "schema mismatch" in error_msg - or "3 columns" in error_msg - or "2 were declared" in error_msg - ) + assert "schema mismatch" in error_msg or "3 columns" in error_msg or "2 were declared" in error_msg def test_arrow_fewer_columns(tmp_path): @@ -64,24 +61,20 @@ def test_arrow_fewer_columns(tmp_path): conn.create_table_function( "arrow_func", simple_arrow_table, - schema=[ - ("id", "BIGINT"), - ("value", "BIGINT"), - ("name", "VARCHAR"), - ("extra", "INT"), # Extra column that doesn't exist - ], - type=PythonTVFType.ARROW_TABLE, + schema={ + "id": sqltypes.BIGINT, + "value": sqltypes.BIGINT, + "name": sqltypes.VARCHAR, + "extra": sqltypes.INTEGER, # Extra column that doesn't exist + }, + type=PythonTableUDFType.ARROW_TABLE, ) with pytest.raises(duckdb.InvalidInputException) as exc_info: conn.execute("SELECT * FROM arrow_func(5)").fetchall() error_msg = str(exc_info.value).lower() - assert ( - "schema mismatch" in error_msg - or "3 columns" in error_msg - or "4 were declared" in error_msg - ) + assert "schema mismatch" in error_msg or "3 columns" in error_msg or "4 were declared" in error_msg def test_arrow_type_mismatch(tmp_path): @@ -91,12 +84,12 @@ def test_arrow_type_mismatch(tmp_path): conn.create_table_function( "arrow_func", simple_arrow_table, - schema=[ - ("id", "VARCHAR"), # Wrong type - should be BIGINT - ("value", "BIGINT"), - ("name", "VARCHAR"), - ], - type=PythonTVFType.ARROW_TABLE, + schema={ + "id": sqltypes.VARCHAR, # Wrong type - should be BIGINT + "value": sqltypes.BIGINT, + "name": sqltypes.VARCHAR, + }, + type=PythonTableUDFType.ARROW_TABLE, ) with pytest.raises(duckdb.InvalidInputException) as exc_info: @@ -113,12 +106,12 @@ def test_arrow_name_mismatch_allowed(tmp_path): conn.create_table_function( "arrow_func", simple_arrow_table, - schema=[ - ("a", "BIGINT"), # Arrow has 'id' - ("b", "BIGINT"), # Arrow has 'value' - ("c", "VARCHAR"), # Arrow has 'name' - ], - type=PythonTVFType.ARROW_TABLE, + schema={ + "a": sqltypes.BIGINT, # Arrow has 'id' + "b": sqltypes.BIGINT, # Arrow has 'value' + "c": sqltypes.VARCHAR, # Arrow has 'name' + }, + type=PythonTableUDFType.ARROW_TABLE, ) result = conn.execute("SELECT * FROM arrow_func(3)").fetchall() diff --git a/tests/fast/tvf/test_register.py b/tests/fast/table_udf/test_register.py similarity index 84% rename from tests/fast/tvf/test_register.py rename to tests/fast/table_udf/test_register.py index 5160e929..3d7d6d2d 100644 --- a/tests/fast/tvf/test_register.py +++ b/tests/fast/table_udf/test_register.py @@ -1,10 +1,11 @@ import pytest import duckdb +import duckdb.sqltypes as sqltypes def test_registry_collision(tmp_path): - """Two tvfs on different connections with same name""" "" + """Two table_udfs on different connections with same name""" "" conn1 = duckdb.connect(tmp_path / "db1.db") conn2 = duckdb.connect(tmp_path / "db2.db") @@ -14,7 +15,7 @@ def func_for_conn1(): def func_for_conn2(): return [("conn2_data", 2)] - schema = [("name", "VARCHAR"), ("id", "INT")] + schema = {"name": sqltypes.VARCHAR, "id": sqltypes.INTEGER} conn1.create_table_function( name="same_name", @@ -57,7 +58,7 @@ def func_v1(): def func_v2(): return [("version_2", 2)] - schema = [("name", "VARCHAR"), ("id", "INT")] + schema = {"name": sqltypes.VARCHAR, "id": sqltypes.INTEGER} conn.create_table_function("test_func", func_v1, schema=schema, type="tuples") @@ -66,9 +67,7 @@ def func_v2(): assert result[0][1] == 1 with pytest.raises(duckdb.NotImplementedException) as exc_info: - conn.create_table_function( - "test_func", func_v2, schema=schema, type="tuples" - ) + conn.create_table_function("test_func", func_v2, schema=schema, type="tuples") assert "already registered" in str(exc_info.value) @@ -84,7 +83,7 @@ def func_v2(): def func_v3(): return [("version_3", 3)] - schema = [("name", "VARCHAR"), ("id", "INT")] + schema = {"name": sqltypes.VARCHAR, "id": sqltypes.INTEGER} conn.create_table_function("test_func", func_v1, schema=schema, type="tuples") @@ -110,9 +109,9 @@ def func_v3(): def test_multiple_replacements(tmp_path): - """Replacing TVFs multiple times""" + """Replacing Table UDFs multiple times""" with duckdb.connect(tmp_path / "test.db") as conn: - schema = [("value", "INT")] + schema = {"value": sqltypes.INTEGER} for i in range(1, 6): @@ -125,9 +124,7 @@ def func(): if i > 1: conn.unregister_table_function("counter") - conn.create_table_function( - "counter", make_func(), schema=schema, type="tuples" - ) + conn.create_table_function("counter", make_func(), schema=schema, type="tuples") result = conn.execute("SELECT * FROM counter()").fetchone() assert result[0] == i @@ -143,20 +140,16 @@ def func_v1(): def func_v2(): return [("modified", 2, 3.14)] - schema_v1 = [("name", "VARCHAR"), ("id", "INT")] - conn.create_table_function( - "evolving_func", func_v1, schema=schema_v1, type="tuples" - ) + schema_v1 = {"name": sqltypes.VARCHAR, "id": sqltypes.INTEGER} + conn.create_table_function("evolving_func", func_v1, schema=schema_v1, type="tuples") result = conn.execute("SELECT * FROM evolving_func()").fetchall() assert len(result[0]) == 2 assert result[0][0] == "test" - schema_v2 = [("name", "VARCHAR"), ("id", "INT"), ("value", "DOUBLE")] + schema_v2 = {"name": sqltypes.VARCHAR, "id": sqltypes.INTEGER, "value": sqltypes.DOUBLE} conn.unregister_table_function("evolving_func") # Must unregister first - conn.create_table_function( - "evolving_func", func_v2, schema=schema_v2, type="tuples" - ) + conn.create_table_function("evolving_func", func_v2, schema=schema_v2, type="tuples") result = conn.execute("SELECT * FROM evolving_func()").fetchall() assert len(result[0]) == 3 @@ -173,16 +166,12 @@ def func_v1(): def func_v2(): return [("v2",)] - schema = [("version", "VARCHAR")] + schema = {"version": sqltypes.VARCHAR} - conn.create_table_function( - "tracked_func", func_v1, schema=schema, type="tuples" - ) + conn.create_table_function("tracked_func", func_v1, schema=schema, type="tuples") conn.unregister_table_function("tracked_func") # Must unregister first - conn.create_table_function( - "tracked_func", func_v2, schema=schema, type="tuples" - ) + conn.create_table_function("tracked_func", func_v2, schema=schema, type="tuples") conn.unregister_table_function("tracked_func") @@ -195,13 +184,13 @@ def func_v2(): def test_sql_drop_table_function(tmp_path): - """Documents current behavior - that dropping functions has no effect on TVFs""" + """Documents current behavior - that dropping functions has no effect on Table UDFs""" with duckdb.connect(tmp_path / "test.db") as conn: def test_func(): return [("test_value", 1)] - schema = [("name", "VARCHAR"), ("id", "INT")] + schema = {"name": sqltypes.VARCHAR, "id": sqltypes.INTEGER} conn.create_table_function("test_func", test_func, schema=schema, type="tuples") result = conn.execute("SELECT * FROM test_func()").fetchall() @@ -222,7 +211,7 @@ def test_unregister_table_function(tmp_path): def simple_function(): return [("test_value", 1)] - schema = [("name", "VARCHAR"), ("id", "INT")] + schema = {"name": sqltypes.VARCHAR, "id": sqltypes.INTEGER} conn.create_table_function( name="test_func", @@ -239,7 +228,6 @@ def simple_function(): conn.unregister_table_function("test_func") - # TODO: Decide whether we want to fail or keep this behavior result = conn.execute("SELECT * FROM test_func()").fetchall() assert len(result) == 1 assert result[0][0] == "test_value" @@ -256,9 +244,7 @@ def test_unregister_doesntexist(tmp_path): with pytest.raises(duckdb.InvalidInputException) as exc_info: conn.unregister_table_function("nonexistent_func") - assert "No table function by the name of 'nonexistent_func'" in str( - exc_info.value - ) + assert "No table function by the name of 'nonexistent_func'" in str(exc_info.value) def test_reregister(tmp_path): @@ -270,7 +256,7 @@ def func_v1(): def func_v2(): return [("version_2", 2)] - schema = [("name", "VARCHAR"), ("id", "INT")] + schema = {"name": sqltypes.VARCHAR, "id": sqltypes.INTEGER} conn.create_table_function( name="versioned_func", @@ -303,7 +289,7 @@ def test_unregister_multi(tmp_path): def test_func(): return [("test_data", 1)] - schema = [("name", "VARCHAR"), ("id", "INT")] + schema = {"name": sqltypes.VARCHAR, "id": sqltypes.INTEGER} cursor1.create_table_function( name="shared_func", diff --git a/tests/fast/table_udf/test_schema.py b/tests/fast/table_udf/test_schema.py new file mode 100644 index 00000000..f416ad94 --- /dev/null +++ b/tests/fast/table_udf/test_schema.py @@ -0,0 +1,425 @@ +"""Test schema validation for table-valued functions.""" + +from typing import Iterator + +import pytest + +import duckdb +import duckdb.sqltypes as sqltypes + + +def test_valid_schema_basic_types(tmp_path): + def gen_function(): + return [("test", 42)] + + with duckdb.connect(tmp_path / "test.duckdb") as conn: + schema = {"name": sqltypes.VARCHAR, "id": sqltypes.INTEGER} + + conn.create_table_function( + name="gen_function", + callable=gen_function, + schema=schema, + type="tuples", + ) + + result = conn.sql("SELECT * FROM gen_function()").fetchall() + assert len(result) == 1 + assert result[0] == ("test", 42) + + +def test_valid_schema_numeric_types(tmp_path): + def gen_function(): + return [(1, 2, 3, 4, 5, 6.5, 7.25)] + + with duckdb.connect(tmp_path / "test.duckdb") as conn: + schema = { + "tiny": sqltypes.TINYINT, + "small": sqltypes.SMALLINT, + "int": sqltypes.INTEGER, + "big": sqltypes.BIGINT, + "huge": sqltypes.HUGEINT, + "float": sqltypes.FLOAT, + "double": sqltypes.DOUBLE, + } + + conn.create_table_function( + name="gen_function", + callable=gen_function, + schema=schema, + type="tuples", + ) + + result = conn.sql("SELECT * FROM gen_function()").fetchall() + assert len(result) == 1 + + +def test_valid_schema_temporal_types(tmp_path): + from datetime import date, datetime, time + + def gen_function(): + return [(date(2024, 1, 1), time(12, 30, 45), datetime(2024, 1, 1, 12, 30, 45))] + + with duckdb.connect(tmp_path / "test.duckdb") as conn: + schema = { + "d": sqltypes.DATE, + "t": sqltypes.TIME, + "ts": sqltypes.TIMESTAMP, + } + + conn.create_table_function( + name="gen_function", + callable=gen_function, + schema=schema, + type="tuples", + ) + + result = conn.sql("SELECT * FROM gen_function()").fetchall() + assert len(result) == 1 + + +def test_valid_schema_boolean_and_blob(tmp_path): + def gen_function(): + return [(True, b"binary_data")] + + with duckdb.connect(tmp_path / "test.duckdb") as conn: + schema = { + "flag": sqltypes.BOOLEAN, + "data": sqltypes.BLOB, + } + + conn.create_table_function( + name="gen_function", + callable=gen_function, + schema=schema, + type="tuples", + ) + + result = conn.sql("SELECT * FROM gen_function()").fetchall() + assert len(result) == 1 + assert result[0][0] is True + assert result[0][1] == b"binary_data" + + +def test_valid_schema_single_column(tmp_path): + def gen_function(): + return [(42,), (43,), (44,)] + + with duckdb.connect(tmp_path / "test.duckdb") as conn: + schema = {"value": sqltypes.INTEGER} + + conn.create_table_function( + name="gen_function", + callable=gen_function, + schema=schema, + type="tuples", + ) + + result = conn.sql("SELECT * FROM gen_function()").fetchall() + assert len(result) == 3 + + +def test_valid_schema_many_columns(tmp_path): + def gen_function(): + return [tuple(range(20))] + + with duckdb.connect(tmp_path / "test.duckdb") as conn: + schema = {f"col{i}": sqltypes.INTEGER for i in range(20)} + + conn.create_table_function( + name="gen_function", + callable=gen_function, + schema=schema, + type="tuples", + ) + + result = conn.sql("SELECT * FROM gen_function()").fetchall() + assert len(result) == 1 + assert len(result[0]) == 20 + + +def test_invalid_schema_none(tmp_path): + def gen_function(): + return [("test", 1)] + + with duckdb.connect(tmp_path / "test.duckdb") as conn: + with pytest.raises(duckdb.InvalidInputException, match="Table functions require a schema"): + conn.create_table_function( + name="gen_function", + callable=gen_function, + schema=None, + type="tuples", + ) + + +def test_invalid_schema_empty_dict(tmp_path): + def gen_function(): + return [("test", 1)] + + with duckdb.connect(tmp_path / "test.duckdb") as conn: + with pytest.raises(duckdb.InvalidInputException, match="schema cannot be empty"): + conn.create_table_function( + name="gen_function", + callable=gen_function, + schema={}, + type="tuples", + ) + + +def test_invalid_schema_list_format(tmp_path): + def gen_function(): + return [("test", 1)] + + with duckdb.connect(tmp_path / "test.duckdb") as conn: + schema = [["name", "VARCHAR"], ["id", "INT"]] + + with pytest.raises(duckdb.InvalidInputException, match="schema must be a dict"): + conn.create_table_function( + name="gen_function", + callable=gen_function, + schema=schema, + type="tuples", + ) + + +def test_invalid_schema_tuple_format(tmp_path): + def gen_function(): + return [("test", 1)] + + with duckdb.connect(tmp_path / "test.duckdb") as conn: + schema = [("name", "VARCHAR"), ("id", "INT")] + + with pytest.raises(duckdb.InvalidInputException, match="schema must be a dict"): + conn.create_table_function( + name="gen_function", + callable=gen_function, + schema=schema, + type="tuples", + ) + + +def test_invalid_schema_string_value(tmp_path): + """Test that string type values are rejected.""" + + def gen_function(): + return [("test", 1)] + + with duckdb.connect(tmp_path / "test.duckdb") as conn: + # String types should be rejected + schema = {"name": "VARCHAR", "id": "INT"} + + with pytest.raises(duckdb.InvalidInputException, match="must be a duckdb.sqltype"): + conn.create_table_function( + name="gen_function", + callable=gen_function, + schema=schema, + type="tuples", + ) + + +def test_invalid_schema_integer_value(tmp_path): + def gen_function(): + return [("test", 1)] + + with duckdb.connect(tmp_path / "test.duckdb") as conn: + schema = {"name": sqltypes.VARCHAR, "id": 123} + + with pytest.raises(duckdb.InvalidInputException, match="must be a duckdb.sqltype"): + conn.create_table_function( + name="gen_function", + callable=gen_function, + schema=schema, + type="tuples", + ) + + +def test_invalid_schema_none_value(tmp_path): + def gen_function(): + return [("test", 1)] + + with duckdb.connect(tmp_path / "test.duckdb") as conn: + schema = {"name": sqltypes.VARCHAR, "id": None} + + with pytest.raises(duckdb.InvalidInputException, match="must be a duckdb.sqltype"): + conn.create_table_function( + name="gen_function", + callable=gen_function, + schema=schema, + type="tuples", + ) + + +def test_invalid_schema_mixed_types(tmp_path): + """Test that schema with mix of DuckDBPyType and strings is rejected.""" + + def gen_function(): + return [("test", 1, 2.5)] + + with duckdb.connect(tmp_path / "test.duckdb") as conn: + # Mix of DuckDBPyType and string - should reject strings + schema = {"name": sqltypes.VARCHAR, "id": "INT", "value": sqltypes.DOUBLE} + + with pytest.raises(duckdb.InvalidInputException, match="must be a duckdb.sqltype"): + conn.create_table_function( + name="gen_function", + callable=gen_function, + schema=schema, + type="tuples", + ) + + +def test_invalid_schema_python_type(tmp_path): + def gen_function(): + return [("test", 1)] + + with duckdb.connect(tmp_path / "test.duckdb") as conn: + schema = {"name": str, "id": int} + + with pytest.raises(duckdb.InvalidInputException, match="must be a duckdb.sqltype"): + conn.create_table_function( + name="gen_function", + callable=gen_function, + schema=schema, + type="tuples", + ) + + +def test_invalid_schema_column_name_not_string(tmp_path): + def gen_function(): + return [(1, 2)] + + with duckdb.connect(tmp_path / "test.duckdb") as conn: + schema = {1: sqltypes.INTEGER, 2: sqltypes.INTEGER} + + conn.create_table_function( + name="gen_function", + callable=gen_function, + schema=schema, + type="tuples", + ) + + result = conn.sql("SELECT * FROM gen_function()").fetchall() + assert len(result) == 1 + + +def test_schema_column_name_special_characters(tmp_path): + def gen_function(): + return [("test", 42, 3.14)] + + with duckdb.connect(tmp_path / "test.duckdb") as conn: + schema = { + "my-column": sqltypes.VARCHAR, + "another_column": sqltypes.INTEGER, + "column.with.dots": sqltypes.DOUBLE, + } + + conn.create_table_function( + name="gen_function", + callable=gen_function, + schema=schema, + type="tuples", + ) + + result = conn.sql('SELECT "my-column", another_column FROM gen_function()').fetchall() + assert len(result) == 1 + + +def test_schema_preserved_order(tmp_path): + def gen_function(): + return [(1, 2, 3, 4, 5)] + + with duckdb.connect(tmp_path / "test.duckdb") as conn: + schema = { + "first": sqltypes.INTEGER, + "second": sqltypes.INTEGER, + "third": sqltypes.INTEGER, + "fourth": sqltypes.INTEGER, + "fifth": sqltypes.INTEGER, + } + + conn.create_table_function( + name="gen_function", + callable=gen_function, + schema=schema, + type="tuples", + ) + + result = conn.sql("DESCRIBE select * from gen_function()").fetchall() + column_names = [row[0] for row in result] + assert column_names == ["first", "second", "third", "fourth", "fifth"] + + +def test_schema(tmp_path): + def gen_function(count: int = 10) -> Iterator[tuple[str, int]]: + for i in range(count): + yield (f"name_{i}", i) + + with duckdb.connect(tmp_path / "test.duckdb") as conn: + schema = {"name": sqltypes.VARCHAR, "id": sqltypes.INTEGER} + + conn.create_table_function( + name="gen_function", + callable=gen_function, + schema=schema, + type="tuples", + ) + + result = conn.sql("SELECT * FROM gen_function(5)").fetchall() + assert len(result) == 5 + assert result[0][0] == "name_0" + assert result[-1][-1] == 4 + + +def test_schema_2(tmp_path): + """Test various types.""" + + def gen_function(count: int = 10) -> Iterator[tuple[str, int, float]]: + for i in range(count): + yield (f"name_{i}", i, i * 1.5) + + with duckdb.connect(tmp_path / "test.duckdb") as conn: + schema = {"name": sqltypes.VARCHAR, "id": sqltypes.INTEGER, "value": sqltypes.DOUBLE} + + conn.create_table_function( + name="gen_function", + callable=gen_function, + schema=schema, + type="tuples", + ) + + result = conn.sql("SELECT * FROM gen_function(3)").fetchall() + assert len(result) == 3 + assert result[0] == ("name_0", 0, 0.0) + assert result[2] == ("name_2", 2, 3.0) + + +def test_schema_invalid_type(tmp_path): + def gen_function(): + return [("test", 1)] + + with duckdb.connect(tmp_path / "test.duckdb") as conn: + schema = {"name": sqltypes.VARCHAR, "id": 123} # int is not valid + + with pytest.raises(duckdb.InvalidInputException, match="must be a duckdb.sqltype"): + conn.create_table_function( + name="gen_function", + callable=gen_function, + schema=schema, + type="tuples", + ) + + +def test_schema_not_dict(tmp_path): + def gen_function(): + return [("test", 1)] + + with duckdb.connect(tmp_path / "test.duckdb") as conn: + # Schema must be dict, not list (this is the old list format) + schema = [["name", "VARCHAR"], ["id", "INT"]] + + with pytest.raises(duckdb.InvalidInputException, match="schema must be a dict"): + conn.create_table_function( + name="gen_function", + callable=gen_function, + schema=schema, + type="tuples", + ) diff --git a/tests/fast/tvf/test_tuples.py b/tests/fast/table_udf/test_tuples.py similarity index 63% rename from tests/fast/tvf/test_tuples.py rename to tests/fast/table_udf/test_tuples.py index 05506d7f..aee26cd7 100644 --- a/tests/fast/tvf/test_tuples.py +++ b/tests/fast/table_udf/test_tuples.py @@ -1,9 +1,10 @@ -from typing import Iterator +from collections.abc import Iterator import pytest import duckdb -from duckdb.functional import PythonTVFType +import duckdb.sqltypes as sqltypes +from duckdb.functional import PythonTableUDFType def simple_generator(count: int = 10) -> Iterator[tuple[str, int]]: @@ -19,19 +20,17 @@ def simple_pylistlist(count: int = 10) -> list[list[str, int]]: return [[f"name_{i}", i] for i in range(count)] -@pytest.mark.parametrize( - "gen_function", [simple_generator, simple_pylist, simple_pylistlist] -) +@pytest.mark.parametrize("gen_function", [simple_generator, simple_pylist, simple_pylistlist]) def test_simple(tmp_path, gen_function): with duckdb.connect(tmp_path / "test.duckdb") as conn: - schema = [["name", "VARCHAR"], ["id", "INT"]] + schema = {"name": sqltypes.VARCHAR, "id": sqltypes.INTEGER} conn.create_table_function( name="gen_function", callable=gen_function, parameters=None, schema=schema, - type=PythonTVFType.TUPLES, + type=PythonTableUDFType.TUPLES, ) result = conn.sql("SELECT * FROM gen_function(5)").fetchall() @@ -48,17 +47,17 @@ def test_simple(tmp_path, gen_function): @pytest.mark.parametrize("gen_function", [simple_generator]) -def test_simple_large_fetchall(tmp_path, gen_function): +def test_simple_large_fetchall_default_type(tmp_path, gen_function): count = 2048 * 1000 with duckdb.connect(tmp_path / "test.duckdb") as conn: - schema = [["name", "VARCHAR"], ["id", "INT"]] + schema = {"name": sqltypes.VARCHAR, "id": sqltypes.INTEGER} + # don't pass type="tuples" to verify default is tuples conn.create_table_function( name="gen_function", callable=gen_function, parameters=None, schema=schema, - type="tuples", ) result = conn.sql( @@ -75,7 +74,7 @@ def test_simple_large_fetchall(tmp_path, gen_function): def test_simple_large_df(tmp_path, gen_function): count = 2048 * 1000 with duckdb.connect(tmp_path / "test.duckdb") as conn: - schema = [["name", "VARCHAR"], ["id", "INT"]] + schema = {"name": sqltypes.VARCHAR, "id": sqltypes.INTEGER} conn.create_table_function( name="gen_function", @@ -97,13 +96,12 @@ def test_no_schema(tmp_path): def gen_function(n): return n - with duckdb.connect(tmp_path / "test.duckdb") as conn: - with pytest.raises(duckdb.InvalidInputException): - conn.create_table_function( - name="gen_function", - callable=gen_function, - type="tuples", - ) + with duckdb.connect(tmp_path / "test.duckdb") as conn, pytest.raises((duckdb.InvalidInputException, TypeError)): + conn.create_table_function( + name="gen_function", + callable=gen_function, + type="tuples", + ) def test_returns_scalar(tmp_path): @@ -111,14 +109,16 @@ def gen_function(n): return n with duckdb.connect(tmp_path / "test.duckdb") as conn: + conn.create_table_function( + name="gen_function", + callable=gen_function, + parameters=["n"], + schema={"value": sqltypes.INTEGER}, + type="tuples", + ) + # Error happens at execution time, not registration with pytest.raises(duckdb.InvalidInputException): - conn.create_table_function( - name="gen_function", - callable=gen_function, - parameters=["n"], - schema=["value"], - type="tuples", - ) + conn.sql("SELECT * FROM gen_function(5)").fetchall() def test_returns_list_scalar(tmp_path): @@ -126,13 +126,15 @@ def gen_function_2(n): return [n] with duckdb.connect(tmp_path / "test.duckdb") as conn: + conn.create_table_function( + name="gen_function_2", + callable=gen_function_2, + schema={"value": sqltypes.INTEGER}, + type="tuples", + ) + # Error happens at execution time, not registration with pytest.raises(duckdb.InvalidInputException): - conn.create_table_function( - name="gen_function_2", - callable=gen_function_2, - schema=["value"], - type="tuples", - ) + conn.sql("SELECT * FROM gen_function_2(5)").fetchall() def test_returns_wrong_schema(tmp_path): @@ -140,7 +142,7 @@ def gen_function(n): return list[range(n)] with duckdb.connect(tmp_path / "test.duckdb") as conn: - schema = [["name", "VARCHAR"], ["id", "INT"]] + schema = {"name": sqltypes.VARCHAR, "id": sqltypes.INTEGER} conn.create_table_function( name="gen_function", @@ -161,7 +163,7 @@ def simple_pylist(count, foo=10): name="simple_pylist", callable=simple_pylist, parameters=["count"], - schema=[["name", "VARCHAR"], ["id", "INT"]], + schema={"name": sqltypes.VARCHAR, "id": sqltypes.INTEGER}, type="tuples", ) result = conn.sql("SELECT * FROM simple_pylist(3)").fetchall() @@ -171,38 +173,34 @@ def simple_pylist(count, foo=10): assert result[-1][0] == "name_2_10" with pytest.raises(duckdb.BinderException): - result = conn.sql( - "SELECT * FROM simple_pylist(count:=3, foo:=2)" - ).fetchall() + result = conn.sql("SELECT * FROM simple_pylist(count:=3, foo:=2)").fetchall() def test_large_2(tmp_path): - """aggregtes and filtering""" + """Aggregation and filtering""" with duckdb.connect(tmp_path / "test.db") as conn: count = 500000 def large_generator(): return [(f"item_{i}", i) for i in range(count)] - schema = [("name", "VARCHAR"), ("id", "INT")] + schema = {"name": sqltypes.VARCHAR, "id": sqltypes.INTEGER} conn.create_table_function( - name="large_tvf", + name="large_table_udf", callable=large_generator, parameters=None, schema=schema, type="tuples", ) - result = conn.execute("SELECT COUNT(*) FROM large_tvf()").fetchone() + result = conn.execute("SELECT COUNT(*) FROM large_table_udf()").fetchone() assert result[0] == count - result = conn.sql("SELECT MAX(id) FROM large_tvf()").fetchone() + result = conn.sql("SELECT MAX(id) FROM large_table_udf()").fetchone() assert result[0] == count - 1 - result = conn.execute( - "SELECT COUNT(*) FROM large_tvf() WHERE id < 100" - ).fetchone() + result = conn.execute("SELECT COUNT(*) FROM large_table_udf() WHERE id < 100").fetchone() assert result[0] == 100 @@ -212,28 +210,24 @@ def test__parameters(tmp_path): def parametrized_function(count=10, prefix="item"): return [(f"{prefix}_{i}", i) for i in range(count)] - schema = [("name", "VARCHAR"), ("id", "INT")] + schema = {"name": sqltypes.VARCHAR, "id": sqltypes.INTEGER} conn.create_table_function( - name="param_tvf", + name="param_table_udf", callable=parametrized_function, parameters=["count", "prefix"], schema=schema, type="tuples", ) - result1 = conn.execute("SELECT COUNT(*) FROM param_tvf(5, 'test')").fetchone() + result1 = conn.execute("SELECT COUNT(*) FROM param_table_udf(5, 'test')").fetchone() assert result1[0] == 5 - result2 = conn.execute( - "SELECT COUNT(*) FROM param_tvf(20, prefix:='data')" - ).fetchone() + result2 = conn.execute("SELECT COUNT(*) FROM param_table_udf(20, prefix:='data')").fetchone() assert result2[0] == 20 # Test parameter order - result3 = conn.execute( - "SELECT name FROM param_tvf(3, 'xyz') ORDER BY id LIMIT 1" - ).fetchone() + result3 = conn.execute("SELECT name FROM param_table_udf(3, 'xyz') ORDER BY id LIMIT 1").fetchone() assert result3[0] == "xyz_0" @@ -243,18 +237,18 @@ def test_error(tmp_path): def error_function(): raise ValueError("Intentional") - schema = [("name", "VARCHAR"), ("id", "INT")] + schema = {"name": sqltypes.VARCHAR, "id": sqltypes.INTEGER} conn.create_table_function( - name="error_tvf", + name="error_table_udf", callable=error_function, parameters=None, schema=schema, type="tuples", ) - with pytest.raises(Exception): - conn.execute("SELECT * FROM error_tvf()").fetchall() + with pytest.raises(duckdb.Error): + conn.execute("SELECT * FROM error_table_udf()").fetchall() def test_callable_refcount(tmp_path): @@ -266,7 +260,7 @@ def gen_function(n): initial_refcount = sys.getrefcount(gen_function) with duckdb.connect(tmp_path / "test.duckdb") as conn: - schema = [["name", "VARCHAR"], ["id", "INT"]] + schema = {"name": sqltypes.VARCHAR, "id": sqltypes.INTEGER} conn.create_table_function( name="gen_function", @@ -296,3 +290,33 @@ def gen_function(n): f"Expected refcount to return to initial after unregistration, " f"but got {final_refcount} (initial: {initial_refcount})" ) + + +def test_callable_lifetime_in_view(tmp_path): + # registers a table UDF within a function scope + # and make sure it's still accessible from another scope (not GC'd) + with duckdb.connect(tmp_path / "test.duckdb") as conn: + + def create_and_register(): + def gen_data(count=5): + return [(f"item_{i}", i * 10) for i in range(count)] + + conn.create_table_function( + name="temp_function", + callable=gen_data, + schema={"name": sqltypes.VARCHAR, "value": sqltypes.INTEGER}, + type="tuples", + ) + + create_and_register() + + conn.execute("CREATE VIEW my_view AS SELECT * FROM temp_function(3)") + + # Unregister only allows the function to be reused - it'll still be accessible in the connection + conn.unregister_table_function("temp_function") + + result = conn.execute("SELECT * FROM my_view").fetchall() + assert len(result) == 3 + assert result[0] == ("item_0", 0) + assert result[1] == ("item_1", 10) + assert result[2] == ("item_2", 20) diff --git a/tests/fast/tvf/test_tuples_datatypes.py b/tests/fast/table_udf/test_tuples_datatypes.py similarity index 64% rename from tests/fast/tvf/test_tuples_datatypes.py rename to tests/fast/table_udf/test_tuples_datatypes.py index 91bbb0bd..e7856e1b 100644 --- a/tests/fast/tvf/test_tuples_datatypes.py +++ b/tests/fast/table_udf/test_tuples_datatypes.py @@ -1,6 +1,7 @@ import pytest import duckdb +import duckdb.sqltypes as sqltypes def test_bigint_params(tmp_path): @@ -11,14 +12,12 @@ def bigint_func(big_value): conn.create_table_function( name="bigint_func", callable=bigint_func, - schema=[["orig", "BIGINT"], ["plus_one", "BIGINT"], ["doubled", "BIGINT"]], + schema={"orig": sqltypes.BIGINT, "plus_one": sqltypes.BIGINT, "doubled": sqltypes.BIGINT}, type="tuples", ) large_val = 4611686018427387900 # Half of max int64 - result = conn.sql( - f"SELECT * FROM bigint_func(?)", params=(large_val,) - ).fetchall() + result = conn.sql("SELECT * FROM bigint_func(?)", params=(large_val,)).fetchall() assert result[0][0] == large_val assert result[0][1] == large_val + 1 assert result[0][2] == large_val * 2 @@ -32,14 +31,12 @@ def hugeint_func(huge_value): conn.create_table_function( name="hugeint_func", callable=hugeint_func, - schema=[["orig", "HUGEINT"], ["plus_one", "HUGEINT"]], + schema={"orig": sqltypes.HUGEINT, "plus_one": sqltypes.HUGEINT}, type="tuples", ) huge_val = 9223372036854775808 - result = conn.sql( - f"SELECT * FROM hugeint_func(?)", params=(huge_val,) - ).fetchall() + result = conn.sql(f"SELECT * FROM hugeint_func(?)", params=(huge_val,)).fetchall() assert result[0][0] == huge_val assert result[0][1] == huge_val + 1 @@ -48,47 +45,34 @@ def test_decimal_params(tmp_path): from decimal import Decimal def decimal_func(dec_value): - if isinstance(dec_value, float): - result = dec_value * 2 - else: - result = Decimal(str(dec_value)) * 2 + result = dec_value * 2 if isinstance(dec_value, float) else Decimal(str(dec_value)) * 2 return [(dec_value, result)] with duckdb.connect(tmp_path / "test.duckdb") as conn: conn.create_table_function( name="decimal_func", callable=decimal_func, - schema=[["orig", "DECIMAL(10,2)"], ["doubled", "DECIMAL(10,2)"]], + schema={"orig": duckdb.decimal_type(10, 2), "doubled": duckdb.decimal_type(10, 2)}, type="tuples", ) - result = conn.sql( - "SELECT * FROM decimal_func(?::decimal)", params=(123.45,) - ).fetchall() + result = conn.sql("SELECT * FROM decimal_func(?::decimal)", params=(123.45,)).fetchall() assert float(result[0][0]) == 123.45 assert float(result[0][1]) == 246.90 def test_uuid_params(tmp_path): - import uuid - def uuid_func(uuid_value): - if isinstance(uuid_value, str): - parsed = uuid.UUID(uuid_value) - else: - parsed = uuid_value return [(str(uuid_value),)] with duckdb.connect(tmp_path / "test.duckdb") as conn: conn.create_table_function( name="uuid_func", callable=uuid_func, - schema=[("orig", "UUID")], + schema={"orig": sqltypes.UUID}, type="tuples", ) test_uuid = "550e8400-e29b-41d4-a716-446655440000" - result = conn.sql( - f"SELECT * FROM uuid_func(?::uuid)", params=(test_uuid,) - ).fetchall() + result = conn.sql(f"SELECT * FROM uuid_func(?::uuid)", params=(test_uuid,)).fetchall() assert str(result[0][0]) == test_uuid diff --git a/tests/fast/tvf/test_arrow.py b/tests/fast/tvf/test_arrow.py deleted file mode 100644 index f55f5ee0..00000000 --- a/tests/fast/tvf/test_arrow.py +++ /dev/null @@ -1,204 +0,0 @@ -from typing import Iterator - -import pytest - -import duckdb -from duckdb.functional import PythonTVFType - - -def simple_generator(count: int = 10) -> Iterator[tuple[str, int]]: - for i in range(count): - yield (f"name_{i}", i) - - -def simple_arrow_table(count: int): - import pyarrow as pa - - data = { - "id": list(range(count)), - "value": [i * 2 for i in range(count)], - "name": [f"row_{i}" for i in range(count)], - } - return pa.table(data) - - -def test_arrow_small(tmp_path): - pa = pytest.importorskip("pyarrow") - - with duckdb.connect(tmp_path / "test.duckdb") as conn: - conn.create_table_function( - "simple_arrow", - simple_arrow_table, - schema=[("x", "BIGINT"), ("y", "VARCHAR")], # Wrong schema! - type=PythonTVFType.ARROW_TABLE, - ) - - with pytest.raises(Exception) as exc_info: - result = conn.execute("SELECT * FROM simple_arrow(5)").fetchall() - - assert ( - "Vector::Reference" in str(exc_info.value) - or "schema" in str(exc_info.value).lower() - ) - - -def test_arrow_large_1(tmp_path): - pa = pytest.importorskip("pyarrow") - - with duckdb.connect(tmp_path / "test.duckdb") as conn: - n = 2048 * 1000 - - conn.create_table_function( - "large_arrow", - simple_arrow_table, - schema=[("id", "BIGINT"), ("value", "BIGINT"), ("name", "VARCHAR")], - type="arrow_table", - ) - - result = conn.execute( - "SELECT COUNT(*) FROM large_arrow(?)", parameters=(n,) - ).fetchone() - assert result[0] == n - - df = conn.sql(f"SELECT * FROM large_arrow({n}) LIMIT 10").df() - assert len(df) == 10 - assert df["id"].tolist() == list(range(10)) - - arrow_result = conn.execute( - "SELECT * FROM large_arrow(?)", parameters=(n,) - ).fetch_arrow_table() - assert len(arrow_result) == n - - result = conn.sql( - "SELECT SUM(value) FROM large_arrow(?)", params=(n,) - ).fetchone() - expected_sum = sum(i * 2 for i in range(n)) - assert result[0] == expected_sum - - -def test_large_arrow_execute(tmp_path): - pytest.importorskip("pyarrow") - - count = 2048 * 1000 - with duckdb.connect(tmp_path / "test.duckdb") as conn: - schema = [["name", "VARCHAR"], ["id", "INT"]] - - conn.create_table_function( - name="gen_function", - callable=simple_generator, - parameters=None, - schema=schema, - type="tuples", - ) - - result = conn.execute( - "SELECT * FROM gen_function(?)", - parameters=(count,), - ).fetch_arrow_table() - - assert len(result) == count - - -def test_large_arrow_sql(tmp_path): - pytest.importorskip("pyarrow") - - count = 2048 * 1000 - with duckdb.connect(tmp_path / "test.duckdb") as conn: - schema = [["name", "VARCHAR"], ["id", "INT"]] - - conn.create_table_function( - name="gen_function", - callable=simple_generator, - parameters=None, - schema=schema, - type="tuples", - ) - - result = conn.sql( - "SELECT * FROM gen_function(?)", - params=(count,), - ).fetch_arrow_table() - - assert len(result) == count - - -def test_arrowbatched_execute(tmp_path): - pytest.importorskip("pyarrow") - - count = 2048 * 1000 - with duckdb.connect(tmp_path / "test.duckdb") as conn: - schema = [["name", "VARCHAR"], ["id", "INT"]] - - conn.create_table_function( - name="gen_function", - callable=simple_generator, - parameters=None, - schema=schema, - type="tuples", - ) - - result = conn.execute( - "SELECT * FROM gen_function(?)", - parameters=(count,), - ).fetch_record_batch() - - result = conn.execute( - f"SELECT * FROM gen_function({count})", - ).fetch_record_batch() - - c = 0 - for batch in result: - c += batch.num_rows - assert c == count - - -def test_arrowbatched_sql_relation(tmp_path): - pytest.importorskip("pyarrow") - - count = 2048 * 1000 - with duckdb.connect(tmp_path / "test.duckdb") as conn: - schema = [["name", "VARCHAR"], ["id", "INT"]] - - conn.create_table_function( - name="gen_function", - callable=simple_generator, - parameters=None, - schema=schema, - type="tuples", - ) - - result = conn.sql( - f"SELECT * FROM gen_function({count})", - ).fetch_arrow_reader() - - c = 0 - for batch in result: - c += batch.num_rows - assert c == count - - -def test_arrowbatched_sql_materialized(tmp_path): - pytest.importorskip("pyarrow") - - count = 2048 * 1000 - with duckdb.connect(tmp_path / "test.duckdb") as conn: - schema = [["name", "VARCHAR"], ["id", "INT"]] - - conn.create_table_function( - name="gen_function", - callable=simple_generator, - parameters=None, - schema=schema, - type="tuples", - ) - - # passing parameters makes it non-lazy /materialized - result = conn.sql( - "SELECT * FROM gen_function(?)", - params=(count,), - ).fetch_arrow_reader() - - c = 0 - for batch in result: - c += batch.num_rows - assert c == count From e2f465e13bb29a707123f776ffbafe790648fd78 Mon Sep 17 00:00:00 2001 From: Paul Timmins Date: Sat, 11 Oct 2025 14:17:14 +0000 Subject: [PATCH 4/4] chore: linting and formatting (triggered by new pre-commits). --- _duckdb-stubs/_func.pyi | 29 ++++++++++- duckdb/func/__init__.py | 10 +++- tests/fast/table_udf/test_arrow.py | 15 +++--- tests/fast/table_udf/test_arrow_schema.py | 12 ++--- tests/fast/table_udf/test_register.py | 9 ++-- tests/fast/table_udf/test_schema.py | 50 ++++++++++--------- tests/fast/table_udf/test_tuples.py | 5 +- tests/fast/table_udf/test_tuples_datatypes.py | 6 +-- 8 files changed, 84 insertions(+), 52 deletions(-) diff --git a/_duckdb-stubs/_func.pyi b/_duckdb-stubs/_func.pyi index 68484499..8bf0480b 100644 --- a/_duckdb-stubs/_func.pyi +++ b/_duckdb-stubs/_func.pyi @@ -1,6 +1,14 @@ import typing as pytyping -__all__: list[str] = ["ARROW", "DEFAULT", "NATIVE", "SPECIAL", "FunctionNullHandling", "PythonUDFType"] +__all__: list[str] = [ + "ARROW", + "DEFAULT", + "NATIVE", + "SPECIAL", + "FunctionNullHandling", + "PythonTableUDFType", + "PythonUDFType", +] class FunctionNullHandling: DEFAULT: pytyping.ClassVar[FunctionNullHandling] # value = @@ -21,6 +29,25 @@ class FunctionNullHandling: @property def value(self) -> int: ... +class PythonTableUDFType: + ARROW_TABLE: pytyping.ClassVar[PythonTableUDFType] # value = + TUPLES: pytyping.ClassVar[PythonTableUDFType] # value = + __members__: pytyping.ClassVar[ + dict[str, PythonTableUDFType] + ] # value = {'TUPLES': , 'ARROW_TABLE': } + def __eq__(self, other: object) -> bool: ... + def __getstate__(self) -> int: ... + def __hash__(self) -> int: ... + def __index__(self) -> int: ... + def __init__(self, value: pytyping.SupportsInt) -> None: ... + def __int__(self) -> int: ... + def __ne__(self, other: object) -> bool: ... + def __setstate__(self, state: pytyping.SupportsInt) -> None: ... + @property + def name(self) -> str: ... + @property + def value(self) -> int: ... + class PythonUDFType: ARROW: pytyping.ClassVar[PythonUDFType] # value = NATIVE: pytyping.ClassVar[PythonUDFType] # value = diff --git a/duckdb/func/__init__.py b/duckdb/func/__init__.py index 0518f667..fdb1adfd 100644 --- a/duckdb/func/__init__.py +++ b/duckdb/func/__init__.py @@ -1,3 +1,11 @@ -from _duckdb._func import ARROW, DEFAULT, NATIVE, SPECIAL, FunctionNullHandling, PythonTableUDFType, PythonUDFType # noqa: D104 +from _duckdb._func import ( # noqa: D104 + ARROW, + DEFAULT, + NATIVE, + SPECIAL, + FunctionNullHandling, + PythonTableUDFType, + PythonUDFType, +) __all__ = ["ARROW", "DEFAULT", "NATIVE", "SPECIAL", "FunctionNullHandling", "PythonTableUDFType", "PythonUDFType"] diff --git a/tests/fast/table_udf/test_arrow.py b/tests/fast/table_udf/test_arrow.py index 86bc956d..c709f4bf 100644 --- a/tests/fast/table_udf/test_arrow.py +++ b/tests/fast/table_udf/test_arrow.py @@ -1,4 +1,4 @@ -from typing import Iterator +from collections.abc import Iterator import pytest @@ -25,8 +25,8 @@ def simple_arrow_table(count: int): def arrow_all_types(count: int): pa = pytest.importorskip("pyarrow") - from decimal import Decimal from datetime import datetime, timedelta, timezone + from decimal import Decimal now = datetime.now(timezone.utc) data = { @@ -81,10 +81,7 @@ def arrow_all_types(count: int): def test_arrow_small(tmp_path): - """Defines and creates a Table UDF with only positional parameters, verifies that it works - and verifies it fails from another connection scope. - """ - pa = pytest.importorskip("pyarrow") + pytest.importorskip("pyarrow") with duckdb.connect(tmp_path / "test.duckdb") as conn: conn.create_table_function( @@ -117,7 +114,7 @@ def test_arrow_small(tmp_path): def test_arrow_large_1(tmp_path): """tests: more rows, aggregation, limits, named parameters, parameters.""" - pa = pytest.importorskip("pyarrow") + pytest.importorskip("pyarrow") with duckdb.connect(tmp_path / "test.duckdb") as conn: n = 2048 * 1000 @@ -201,8 +198,8 @@ def test_arrowbatched_sql_relation(tmp_path): def test_arrow_types(tmp_path): - """Return many types from an arrow table UDF, and verify the results are correct""" - pa = pytest.importorskip("pyarrow") + """Return many types from an arrow table UDF, and verify the results are correct.""" + pytest.importorskip("pyarrow") with duckdb.connect(tmp_path / "test.duckdb") as conn: conn.create_table_function( diff --git a/tests/fast/table_udf/test_arrow_schema.py b/tests/fast/table_udf/test_arrow_schema.py index 61db1105..a11b77c9 100644 --- a/tests/fast/table_udf/test_arrow_schema.py +++ b/tests/fast/table_udf/test_arrow_schema.py @@ -1,5 +1,3 @@ -"""Test Arrow Table UDF schema validation""" - import pytest import duckdb @@ -19,7 +17,7 @@ def simple_arrow_table(count: int = 10): def test_arrow_correct_schema(tmp_path): - pa = pytest.importorskip("pyarrow") + pytest.importorskip("pyarrow") with duckdb.connect(tmp_path / "test.duckdb") as conn: conn.create_table_function( @@ -35,7 +33,7 @@ def test_arrow_correct_schema(tmp_path): def test_arrow_more_columns(tmp_path): - pa = pytest.importorskip("pyarrow") + pytest.importorskip("pyarrow") with duckdb.connect(tmp_path / "test.duckdb") as conn: # table has 3 cols, but declare only 2 @@ -54,7 +52,7 @@ def test_arrow_more_columns(tmp_path): def test_arrow_fewer_columns(tmp_path): - pa = pytest.importorskip("pyarrow") + pytest.importorskip("pyarrow") with duckdb.connect(tmp_path / "test.duckdb") as conn: # table has 3 columns, but declare 4 @@ -78,7 +76,7 @@ def test_arrow_fewer_columns(tmp_path): def test_arrow_type_mismatch(tmp_path): - pa = pytest.importorskip("pyarrow") + pytest.importorskip("pyarrow") with duckdb.connect(tmp_path / "test.duckdb") as conn: conn.create_table_function( @@ -100,7 +98,7 @@ def test_arrow_type_mismatch(tmp_path): def test_arrow_name_mismatch_allowed(tmp_path): - pa = pytest.importorskip("pyarrow") + pytest.importorskip("pyarrow") with duckdb.connect(tmp_path / "test.duckdb") as conn: conn.create_table_function( diff --git a/tests/fast/table_udf/test_register.py b/tests/fast/table_udf/test_register.py index 3d7d6d2d..44ef91ae 100644 --- a/tests/fast/table_udf/test_register.py +++ b/tests/fast/table_udf/test_register.py @@ -109,7 +109,7 @@ def func_v3(): def test_multiple_replacements(tmp_path): - """Replacing Table UDFs multiple times""" + """Replacing Table UDFs multiple times.""" with duckdb.connect(tmp_path / "test.db") as conn: schema = {"value": sqltypes.INTEGER} @@ -131,7 +131,7 @@ def func(): def test_replacement_with_different_schemas(tmp_path): - """Changing schema with replacements""" + """Changing schema with replacements.""" with duckdb.connect(tmp_path / "test.db") as conn: def func_v1(): @@ -184,7 +184,7 @@ def func_v2(): def test_sql_drop_table_function(tmp_path): - """Documents current behavior - that dropping functions has no effect on Table UDFs""" + """Documents current behavior - that dropping functions has no effect on Table UDFs.""" with duckdb.connect(tmp_path / "test.db") as conn: def test_func(): @@ -197,7 +197,7 @@ def test_func(): assert result[0][0] == "test_value" assert result[0][1] == 1 - with pytest.raises(Exception): + with pytest.raises(duckdb.CatalogException): conn.execute("DROP FUNCTION test_func") result = conn.execute("SELECT * FROM test_func()").fetchall() @@ -306,7 +306,6 @@ def test_func(): cursor1.unregister_table_function("shared_func") - # TODO: Decide whether to keep this unregister behavior result1 = cursor1.execute("SELECT * FROM shared_func()").fetchall() assert result1[0][0] == "test_data" diff --git a/tests/fast/table_udf/test_schema.py b/tests/fast/table_udf/test_schema.py index f416ad94..7490162d 100644 --- a/tests/fast/table_udf/test_schema.py +++ b/tests/fast/table_udf/test_schema.py @@ -1,6 +1,6 @@ """Test schema validation for table-valued functions.""" -from typing import Iterator +from collections.abc import Iterator import pytest @@ -141,28 +141,32 @@ def test_invalid_schema_none(tmp_path): def gen_function(): return [("test", 1)] - with duckdb.connect(tmp_path / "test.duckdb") as conn: - with pytest.raises(duckdb.InvalidInputException, match="Table functions require a schema"): - conn.create_table_function( - name="gen_function", - callable=gen_function, - schema=None, - type="tuples", - ) + with ( + duckdb.connect(tmp_path / "test.duckdb") as conn, + pytest.raises(duckdb.InvalidInputException, match="Table functions require a schema"), + ): + conn.create_table_function( + name="gen_function", + callable=gen_function, + schema=None, + type="tuples", + ) def test_invalid_schema_empty_dict(tmp_path): def gen_function(): return [("test", 1)] - with duckdb.connect(tmp_path / "test.duckdb") as conn: - with pytest.raises(duckdb.InvalidInputException, match="schema cannot be empty"): - conn.create_table_function( - name="gen_function", - callable=gen_function, - schema={}, - type="tuples", - ) + with ( + duckdb.connect(tmp_path / "test.duckdb") as conn, + pytest.raises(duckdb.InvalidInputException, match="schema cannot be empty"), + ): + conn.create_table_function( + name="gen_function", + callable=gen_function, + schema={}, + type="tuples", + ) def test_invalid_schema_list_format(tmp_path): @@ -207,7 +211,7 @@ def gen_function(): # String types should be rejected schema = {"name": "VARCHAR", "id": "INT"} - with pytest.raises(duckdb.InvalidInputException, match="must be a duckdb.sqltype"): + with pytest.raises(duckdb.InvalidInputException, match="must be a duckdb\\.sqltype"): conn.create_table_function( name="gen_function", callable=gen_function, @@ -223,7 +227,7 @@ def gen_function(): with duckdb.connect(tmp_path / "test.duckdb") as conn: schema = {"name": sqltypes.VARCHAR, "id": 123} - with pytest.raises(duckdb.InvalidInputException, match="must be a duckdb.sqltype"): + with pytest.raises(duckdb.InvalidInputException, match="must be a duckdb\\.sqltype"): conn.create_table_function( name="gen_function", callable=gen_function, @@ -239,7 +243,7 @@ def gen_function(): with duckdb.connect(tmp_path / "test.duckdb") as conn: schema = {"name": sqltypes.VARCHAR, "id": None} - with pytest.raises(duckdb.InvalidInputException, match="must be a duckdb.sqltype"): + with pytest.raises(duckdb.InvalidInputException, match="must be a duckdb\\.sqltype"): conn.create_table_function( name="gen_function", callable=gen_function, @@ -258,7 +262,7 @@ def gen_function(): # Mix of DuckDBPyType and string - should reject strings schema = {"name": sqltypes.VARCHAR, "id": "INT", "value": sqltypes.DOUBLE} - with pytest.raises(duckdb.InvalidInputException, match="must be a duckdb.sqltype"): + with pytest.raises(duckdb.InvalidInputException, match="must be a duckdb\\.sqltype"): conn.create_table_function( name="gen_function", callable=gen_function, @@ -274,7 +278,7 @@ def gen_function(): with duckdb.connect(tmp_path / "test.duckdb") as conn: schema = {"name": str, "id": int} - with pytest.raises(duckdb.InvalidInputException, match="must be a duckdb.sqltype"): + with pytest.raises(duckdb.InvalidInputException, match="must be a duckdb\\.sqltype"): conn.create_table_function( name="gen_function", callable=gen_function, @@ -399,7 +403,7 @@ def gen_function(): with duckdb.connect(tmp_path / "test.duckdb") as conn: schema = {"name": sqltypes.VARCHAR, "id": 123} # int is not valid - with pytest.raises(duckdb.InvalidInputException, match="must be a duckdb.sqltype"): + with pytest.raises(duckdb.InvalidInputException, match="must be a duckdb\\.sqltype"): conn.create_table_function( name="gen_function", callable=gen_function, diff --git a/tests/fast/table_udf/test_tuples.py b/tests/fast/table_udf/test_tuples.py index aee26cd7..a91db688 100644 --- a/tests/fast/table_udf/test_tuples.py +++ b/tests/fast/table_udf/test_tuples.py @@ -177,7 +177,7 @@ def simple_pylist(count, foo=10): def test_large_2(tmp_path): - """Aggregation and filtering""" + """Aggregation and filtering.""" with duckdb.connect(tmp_path / "test.db") as conn: count = 500000 @@ -235,7 +235,8 @@ def test_error(tmp_path): with duckdb.connect(tmp_path / "test.db") as conn: def error_function(): - raise ValueError("Intentional") + error_message = "Intentional Error" + raise ValueError(error_message) schema = {"name": sqltypes.VARCHAR, "id": sqltypes.INTEGER} diff --git a/tests/fast/table_udf/test_tuples_datatypes.py b/tests/fast/table_udf/test_tuples_datatypes.py index e7856e1b..5e31711f 100644 --- a/tests/fast/table_udf/test_tuples_datatypes.py +++ b/tests/fast/table_udf/test_tuples_datatypes.py @@ -1,5 +1,3 @@ -import pytest - import duckdb import duckdb.sqltypes as sqltypes @@ -36,7 +34,7 @@ def hugeint_func(huge_value): ) huge_val = 9223372036854775808 - result = conn.sql(f"SELECT * FROM hugeint_func(?)", params=(huge_val,)).fetchall() + result = conn.sql("SELECT * FROM hugeint_func(?)", params=(huge_val,)).fetchall() assert result[0][0] == huge_val assert result[0][1] == huge_val + 1 @@ -74,5 +72,5 @@ def uuid_func(uuid_value): ) test_uuid = "550e8400-e29b-41d4-a716-446655440000" - result = conn.sql(f"SELECT * FROM uuid_func(?::uuid)", params=(test_uuid,)).fetchall() + result = conn.sql("SELECT * FROM uuid_func(?::uuid)", params=(test_uuid,)).fetchall() assert str(result[0][0]) == test_uuid