diff --git a/datafusion/catalog/src/memory/schema.rs b/datafusion/catalog/src/memory/schema.rs index f1b3628f7aff..71ea304108db 100644 --- a/datafusion/catalog/src/memory/schema.rs +++ b/datafusion/catalog/src/memory/schema.rs @@ -17,7 +17,7 @@ //! [`MemorySchemaProvider`]: In-memory implementations of [`SchemaProvider`]. -use crate::{SchemaProvider, TableProvider}; +use crate::{SchemaProvider, TableFunction, TableProvider}; use async_trait::async_trait; use dashmap::DashMap; use datafusion_common::{exec_err, DataFusionError}; @@ -28,6 +28,7 @@ use std::sync::Arc; #[derive(Debug)] pub struct MemorySchemaProvider { tables: DashMap>, + table_functions: DashMap>, } impl MemorySchemaProvider { @@ -35,6 +36,7 @@ impl MemorySchemaProvider { pub fn new() -> Self { Self { tables: DashMap::new(), + table_functions: DashMap::new(), } } } @@ -86,4 +88,184 @@ impl SchemaProvider for MemorySchemaProvider { fn table_exist(&self, name: &str) -> bool { self.tables.contains_key(name) } + + fn udtf_names(&self) -> Vec { + self.table_functions + .iter() + .map(|f| f.key().clone()) + .collect() + } + + fn udtf( + &self, + name: &str, + ) -> datafusion_common::Result>, DataFusionError> { + Ok(self + .table_functions + .get(name) + .map(|f| Arc::clone(f.value()))) + } + + fn register_udtf( + &self, + name: String, + function: Arc, + ) -> datafusion_common::Result>> { + if self.udtf_exist(name.as_str()) { + return exec_err!("The table function {name} already exists"); + } + Ok(self.table_functions.insert(name, function)) + } + + fn deregister_udtf( + &self, + name: &str, + ) -> datafusion_common::Result>> { + Ok(self.table_functions.remove(name).map(|(_, f)| f)) + } + + fn udtf_exist(&self, name: &str) -> bool { + self.table_functions.contains_key(name) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::table::TableFunctionImpl; + use crate::Session; + use arrow::datatypes::Schema; + use datafusion_common::Result; + use datafusion_expr::{Expr, TableType}; + use datafusion_physical_plan::ExecutionPlan; + + #[derive(Debug)] + struct DummyTableFunc; + + #[derive(Debug)] + struct DummyTable { + schema: arrow::datatypes::SchemaRef, + } + + #[async_trait::async_trait] + impl TableProvider for DummyTable { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> arrow::datatypes::SchemaRef { + self.schema.clone() + } + + fn table_type(&self) -> TableType { + TableType::Base + } + + async fn scan( + &self, + _state: &dyn Session, + _projection: Option<&Vec>, + _filters: &[Expr], + _limit: Option, + ) -> Result> { + datafusion_common::plan_err!("DummyTable does not support scanning") + } + } + + impl TableFunctionImpl for DummyTableFunc { + fn call(&self, _args: &[Expr]) -> Result> { + Ok(Arc::new(DummyTable { + schema: Arc::new(Schema::empty()), + })) + } + } + + #[test] + fn test_register_and_retrieve_udtf() { + let schema = MemorySchemaProvider::new(); + let func = Arc::new(TableFunction::new( + "my_func".to_string(), + Arc::new(DummyTableFunc), + )); + + let result = schema.register_udtf("my_func".to_string(), func.clone()); + assert!(result.is_ok()); + assert!(result.unwrap().is_none()); + + assert!(schema.udtf_exist("my_func")); + assert_eq!(schema.udtf_names(), vec!["my_func"]); + + let retrieved = schema.udtf("my_func").unwrap(); + assert!(retrieved.is_some()); + assert_eq!(retrieved.unwrap().name(), "my_func"); + } + + #[test] + fn test_duplicate_udtf_registration_fails() { + let schema = MemorySchemaProvider::new(); + let func = Arc::new(TableFunction::new( + "my_func".to_string(), + Arc::new(DummyTableFunc), + )); + + schema + .register_udtf("my_func".to_string(), func.clone()) + .unwrap(); + + let result = schema.register_udtf("my_func".to_string(), func.clone()); + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("already exists")); + } + + #[test] + fn test_deregister_udtf() { + let schema = MemorySchemaProvider::new(); + let func = Arc::new(TableFunction::new( + "my_func".to_string(), + Arc::new(DummyTableFunc), + )); + + schema.register_udtf("my_func".to_string(), func).unwrap(); + assert!(schema.udtf_exist("my_func")); + + let removed = schema.deregister_udtf("my_func").unwrap(); + assert!(removed.is_some()); + assert!(!schema.udtf_exist("my_func")); + assert_eq!(schema.udtf_names(), Vec::::new()); + + let removed = schema.deregister_udtf("my_func").unwrap(); + assert!(removed.is_none()); + } + + #[test] + fn test_udtf_not_found() { + let schema = MemorySchemaProvider::new(); + + assert!(!schema.udtf_exist("nonexistent")); + let result = schema.udtf("nonexistent").unwrap(); + assert!(result.is_none()); + } + + #[test] + fn test_multiple_udtfs() { + let schema = MemorySchemaProvider::new(); + let func1 = Arc::new(TableFunction::new( + "func1".to_string(), + Arc::new(DummyTableFunc), + )); + let func2 = Arc::new(TableFunction::new( + "func2".to_string(), + Arc::new(DummyTableFunc), + )); + + schema.register_udtf("func1".to_string(), func1).unwrap(); + schema.register_udtf("func2".to_string(), func2).unwrap(); + + let mut names = schema.udtf_names(); + names.sort(); + assert_eq!(names, vec!["func1", "func2"]); + + assert!(schema.udtf_exist("func1")); + assert!(schema.udtf_exist("func2")); + } } diff --git a/datafusion/catalog/src/schema.rs b/datafusion/catalog/src/schema.rs index 9ba55256f182..2e64006354cf 100644 --- a/datafusion/catalog/src/schema.rs +++ b/datafusion/catalog/src/schema.rs @@ -24,7 +24,7 @@ use std::any::Any; use std::fmt::Debug; use std::sync::Arc; -use crate::table::TableProvider; +use crate::table::{TableFunction, TableProvider}; use datafusion_common::Result; use datafusion_expr::TableType; @@ -88,4 +88,41 @@ pub trait SchemaProvider: Debug + Sync + Send { /// Returns true if table exist in the schema provider, false otherwise. fn table_exist(&self, name: &str) -> bool; + + /// Retrieves the list of available table function names in this schema. + fn udtf_names(&self) -> Vec { + vec![] + } + + /// Retrieves a specific table function from the schema by name, if it exists, + /// otherwise returns `None`. + fn udtf(&self, _name: &str) -> Result>> { + Ok(None) + } + + /// If supported by the implementation, adds a new table function named `name` to + /// this schema. + /// + /// If a table function of the same name was already registered, returns "Table + /// function already exists" error. + fn register_udtf( + &self, + _name: String, + _function: Arc, + ) -> Result>> { + exec_err!("schema provider does not support registering table functions") + } + + /// If supported by the implementation, removes the `name` table function from this + /// schema and returns the previously registered [`TableFunction`], if any. + /// + /// If no `name` table function exists, returns Ok(None). + fn deregister_udtf(&self, _name: &str) -> Result>> { + exec_err!("schema provider does not support deregistering table functions") + } + + /// Returns true if table function exists in the schema provider, false otherwise. + fn udtf_exist(&self, _name: &str) -> bool { + false + } } diff --git a/datafusion/core/src/execution/session_state.rs b/datafusion/core/src/execution/session_state.rs index b04004dd495c..bdc91f8a3165 100644 --- a/datafusion/core/src/execution/session_state.rs +++ b/datafusion/core/src/execution/session_state.rs @@ -1699,12 +1699,8 @@ impl ContextProvider for SessionContextProvider<'_> { name: &str, args: Vec, ) -> datafusion_common::Result> { - let tbl_func = self - .state - .table_functions - .get(name) - .cloned() - .ok_or_else(|| plan_datafusion_err!("table function '{name}' not found"))?; + let table_ref = TableReference::parse_str(name); + let dummy_schema = DFSchema::empty(); let simplifier = ExprSimplifier::new(SessionSimplifyProvider::new(self.state, &dummy_schema)); @@ -1712,8 +1708,25 @@ impl ContextProvider for SessionContextProvider<'_> { .into_iter() .map(|arg| simplifier.simplify(arg)) .collect::>>()?; - let provider = tbl_func.create_table_provider(&args)?; + let tbl_func = if table_ref.schema().is_some() { + let func_name = table_ref.table().to_string(); + let schema = self.state.schema_for_ref(table_ref)?; + + schema.udtf(&func_name)?.ok_or_else(|| { + plan_datafusion_err!("Table function '{}' not found in schema", name) + })? + } else { + self.state + .table_functions + .get(name) + .cloned() + .ok_or_else(|| { + plan_datafusion_err!("table function '{name}' not found") + })? + }; + + let provider = tbl_func.create_table_provider(&args)?; Ok(provider_as_source(provider)) } diff --git a/datafusion/core/tests/user_defined/user_defined_table_functions.rs b/datafusion/core/tests/user_defined/user_defined_table_functions.rs index 2c6611f382ce..c2252800fa5d 100644 --- a/datafusion/core/tests/user_defined/user_defined_table_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_table_functions.rs @@ -34,7 +34,7 @@ use datafusion::execution::TaskContext; use datafusion::physical_plan::{collect, ExecutionPlan}; use datafusion::prelude::SessionContext; use datafusion_catalog::Session; -use datafusion_catalog::TableFunctionImpl; +use datafusion_catalog::{SchemaProvider, TableFunctionImpl}; use datafusion_common::{DFSchema, ScalarValue}; use datafusion_expr::{EmptyRelation, Expr, LogicalPlan, Projection, TableType}; @@ -109,6 +109,76 @@ async fn test_deregister_udtf() -> Result<()> { Ok(()) } +#[tokio::test] +async fn test_schema_qualified_udtf() -> Result<()> { + let ctx = SessionContext::new(); + + let catalog = ctx.catalog("datafusion").unwrap(); + let schema = catalog.schema("public").unwrap(); + let memory_schema = schema + .as_any() + .downcast_ref::() + .unwrap(); + + let func = Arc::new(datafusion_catalog::TableFunction::new( + "schema_func".to_string(), + Arc::new(SimpleCsvTableFunc {}), + )); + memory_schema + .register_udtf("schema_func".to_string(), func) + .unwrap(); + + let csv_file = "tests/tpch-csv/nation.csv"; + let rbs = ctx + .sql(format!("SELECT * FROM public.schema_func('{csv_file}', 3);").as_str()) + .await? + .collect() + .await?; + + assert_eq!(rbs[0].num_rows(), 3); + + Ok(()) +} + +/// Test that unqualified names still use global registry (backward compatibility) +#[tokio::test] +async fn test_unqualified_uses_global_registry() -> Result<()> { + let ctx = SessionContext::new(); + + ctx.register_udtf("global_func", Arc::new(SimpleCsvTableFunc {})); + + let csv_file = "tests/tpch-csv/nation.csv"; + let rbs = ctx + .sql(format!("SELECT * FROM global_func('{csv_file}', 2);").as_str()) + .await? + .collect() + .await?; + + assert_eq!(rbs[0].num_rows(), 2); + + Ok(()) +} + +#[tokio::test] +async fn test_schema_qualified_not_in_global() -> Result<()> { + let ctx = SessionContext::new(); + + ctx.register_udtf("global_only", Arc::new(SimpleCsvTableFunc {})); + + let csv_file = "tests/tpch-csv/nation.csv"; + let result = ctx + .sql(format!("SELECT * FROM public.global_only('{csv_file}');").as_str()) + .await; + + assert!(result.is_err()); + assert!(result + .unwrap_err() + .to_string() + .contains("not found in schema")); + + Ok(()) +} + #[derive(Debug)] struct SimpleCsvTable { schema: SchemaRef, diff --git a/datafusion/sql/src/relation/mod.rs b/datafusion/sql/src/relation/mod.rs index 9dfa078701d3..d9ad4f48c5a2 100644 --- a/datafusion/sql/src/relation/mod.rs +++ b/datafusion/sql/src/relation/mod.rs @@ -43,8 +43,12 @@ impl SqlToRel<'_, S> { name, alias, args, .. } => { if let Some(func_args) = args { - let tbl_func_name = - name.0.first().unwrap().as_ident().unwrap().to_string(); + let tbl_func_name = name + .0 + .iter() + .map(|ident| ident.as_ident().unwrap().to_string()) + .collect::>() + .join("."); let args = func_args .args .into_iter() @@ -175,9 +179,10 @@ impl SqlToRel<'_, S> { _ => plan_err!("Unsupported function argument: {arg:?}"), }) .collect::>>()?; + let qualified_name = tbl_func_ref.to_string(); let provider = self .context_provider - .get_table_function_source(tbl_func_ref.table(), func_args)?; + .get_table_function_source(&qualified_name, func_args)?; let plan = LogicalPlanBuilder::scan(tbl_func_ref.table(), provider, None)? .build()?; diff --git a/datafusion/sqllogictest/test_files/table_functions.slt b/datafusion/sqllogictest/test_files/table_functions.slt index 0159abe8d06b..1e5c388a46ee 100644 --- a/datafusion/sqllogictest/test_files/table_functions.slt +++ b/datafusion/sqllogictest/test_files/table_functions.slt @@ -494,3 +494,32 @@ SELECT c, f.* FROM json_table, LATERAL generate_series(1,2) f; 1 2 2 1 2 2 + +# +# Test schema-qualified table function names +# Global table functions are not accessible via qualified names +# + +# Test: unqualified table function name resolves from global registry (backward compatibility) +query I +SELECT * FROM generate_series(1, 3) +---- +1 +2 +3 + +# Test: qualified name for global table function should fail +# (global table functions are not in schemas, they're in the global registry) +statement error DataFusion error: Error during planning: Table function 'public.generate_series' not found in schema +SELECT * FROM public.generate_series(1, 3) + +statement error DataFusion error: Error during planning: Table function 'datafusion.public.generate_series' not found in schema +SELECT * FROM datafusion.public.generate_series(1, 3) + +# Test: non-existent function with qualified name +statement error DataFusion error: Error during planning: Table function 'public.nonexistent_func' not found in schema +SELECT * FROM public.nonexistent_func(1, 2) + +# Test: non-existent function with unqualified name +statement error DataFusion error: Error during planning: table function 'nonexistent_func' not found +SELECT * FROM nonexistent_func(1, 2)