diff --git a/.github/actions/test_stateless_cluster_linux/action.yml b/.github/actions/test_stateless_cluster_linux/action.yml index 98bd2147ad77f..d30d4a69d4a4f 100644 --- a/.github/actions/test_stateless_cluster_linux/action.yml +++ b/.github/actions/test_stateless_cluster_linux/action.yml @@ -13,7 +13,7 @@ runs: - name: Start UDF Server shell: bash run: | - pip install databend-udf>=0.2.6 + pip install databend-udf>=0.2.7 python3 tests/udf/udf_server.py & sleep 2 diff --git a/.github/actions/test_stateless_cluster_macos/action.yml b/.github/actions/test_stateless_cluster_macos/action.yml index 024106d5299c1..2c1eb45a769b4 100644 --- a/.github/actions/test_stateless_cluster_macos/action.yml +++ b/.github/actions/test_stateless_cluster_macos/action.yml @@ -13,7 +13,7 @@ runs: - name: Start UDF Server shell: bash run: | - pip install databend-udf>=0.2.6 + pip install databend-udf>=0.2.7 python3 tests/udf/udf_server.py & sleep 2 diff --git a/.github/actions/test_stateless_standalone_linux/action.yml b/.github/actions/test_stateless_standalone_linux/action.yml index 09033cb945bf0..bd0206cbf5c96 100644 --- a/.github/actions/test_stateless_standalone_linux/action.yml +++ b/.github/actions/test_stateless_standalone_linux/action.yml @@ -13,7 +13,7 @@ runs: - name: Start UDF Server shell: bash run: | - pip install databend-udf>=0.2.6 + pip install databend-udf>=0.2.7 python3 tests/udf/udf_server.py & sleep 2 diff --git a/.github/actions/test_stateless_standalone_macos/action.yml b/.github/actions/test_stateless_standalone_macos/action.yml index 017ac3f6773d5..b164379cb700f 100644 --- a/.github/actions/test_stateless_standalone_macos/action.yml +++ b/.github/actions/test_stateless_standalone_macos/action.yml @@ -13,7 +13,7 @@ runs: - name: Start UDF Server shell: bash run: | - pip install databend-udf>=0.2.6 + pip install databend-udf>=0.2.7 python3 tests/udf/udf_server.py & sleep 2 diff --git a/.github/workflows/reuse.sqllogic.yml b/.github/workflows/reuse.sqllogic.yml index 2bb94f2eaf6f7..e1952c4788cac 100644 --- a/.github/workflows/reuse.sqllogic.yml +++ b/.github/workflows/reuse.sqllogic.yml @@ -92,7 +92,8 @@ jobs: - uses: actions/checkout@v4 - name: Start UDF Server run: | - pip install databend-udf>=0.2.6 + docker run -d --name minio -p 9000:9000 -p 9001:9001 -e MINIO_ROOT_USER=minioadmin -e MINIO_ROOT_PASSWORD=minioadmin quay.io/minio/minio server /data --console-address ":9001" + pip install databend-udf>=0.2.7 python3 tests/udf/udf_server.py & sleep 2 - uses: ./.github/actions/test_sqllogic_standalone_linux diff --git a/src/meta/app/src/principal/user_defined_function.rs b/src/meta/app/src/principal/user_defined_function.rs index f6d1210dba115..37a9cf30ea35b 100644 --- a/src/meta/app/src/principal/user_defined_function.rs +++ b/src/meta/app/src/principal/user_defined_function.rs @@ -33,6 +33,7 @@ pub struct UDFServer { pub handler: String, pub headers: BTreeMap, pub language: String, + pub arg_names: Vec, pub arg_types: Vec, pub return_type: DataType, pub immutable: Option, @@ -168,6 +169,7 @@ impl UserDefinedFunction { handler: &str, headers: &BTreeMap, language: &str, + arg_names: Vec, arg_types: Vec, return_type: DataType, description: &str, @@ -181,6 +183,7 @@ impl UserDefinedFunction { handler: handler.to_string(), headers: headers.clone(), language: language.to_string(), + arg_names, arg_types, return_type, immutable, @@ -237,6 +240,7 @@ impl Display for UDFDefinition { } UDFDefinition::UDFServer(UDFServer { address, + arg_names, arg_types, return_type, handler, @@ -249,6 +253,9 @@ impl Display for UDFDefinition { write!(f, ", ")?; } write!(f, "{item}")?; + if !arg_names.is_empty() { + write!(f, " {}", arg_names[i])?; + } } write!(f, ") RETURNS {return_type} LANGUAGE {language}")?; if let Some(immutable) = immutable { diff --git a/src/meta/proto-conv/src/schema_from_to_protobuf_impl.rs b/src/meta/proto-conv/src/schema_from_to_protobuf_impl.rs index 6d320b90b0179..871ca5af47620 100644 --- a/src/meta/proto-conv/src/schema_from_to_protobuf_impl.rs +++ b/src/meta/proto-conv/src/schema_from_to_protobuf_impl.rs @@ -312,6 +312,7 @@ impl FromToProto for ex::TableDataType { Dt24::VectorT(v) => { ex::TableDataType::Vector(ex::types::VectorDataType::from_pb(v)?) } + Dt24::StageLocationT(_) => ex::TableDataType::StageLocation, }; Ok(x) } @@ -380,6 +381,7 @@ impl FromToProto for ex::TableDataType { let x = v.to_pb()?; new_pb_dt24(Dt24::VectorT(x)) } + TableDataType::StageLocation => new_pb_dt24(Dt24::StageLocationT(pb::Empty {})), }; Ok(x) } diff --git a/src/meta/proto-conv/src/udf_from_to_protobuf_impl.rs b/src/meta/proto-conv/src/udf_from_to_protobuf_impl.rs index fd850e5f4ff34..5ce96fc9254b0 100644 --- a/src/meta/proto-conv/src/udf_from_to_protobuf_impl.rs +++ b/src/meta/proto-conv/src/udf_from_to_protobuf_impl.rs @@ -77,6 +77,7 @@ impl FromToProto for mt::UDFServer { headers: p.headers, language: p.language, immutable: p.immutable, + arg_names: p.arg_names, }) } @@ -112,6 +113,7 @@ impl FromToProto for mt::UDFServer { arg_types, return_type: Some(return_type), immutable: self.immutable, + arg_names: self.arg_names.clone(), }) } } diff --git a/src/meta/proto-conv/src/util.rs b/src/meta/proto-conv/src/util.rs index efe1068fa9b4b..06a0feef99122 100644 --- a/src/meta/proto-conv/src/util.rs +++ b/src/meta/proto-conv/src/util.rs @@ -181,6 +181,7 @@ const META_CHANGE_LOG: &[(u64, &str)] = &[ (149, "2025-09-24: Add: add AutoIncrement name and display on TableField"), (150, "2025-09-26: Add: RoleInfo::comment"), (151, "2025-09-28: Add: TableMeta::RowAccessPolicyColumnMap store policy name and column id"), + (152, "2025-10-14: Add: TableDataType::StageLocation and UDFServer add arg_names"), // Dear developer: // If you're gonna add a new metadata version, you'll have to add a test for it. // You could just copy an existing test file(e.g., `../tests/it/v024_table_meta.rs`) diff --git a/src/meta/proto-conv/tests/it/main.rs b/src/meta/proto-conv/tests/it/main.rs index e05b848a217ae..9189941994579 100644 --- a/src/meta/proto-conv/tests/it/main.rs +++ b/src/meta/proto-conv/tests/it/main.rs @@ -143,3 +143,4 @@ mod v148_virtual_schema; mod v149_field_auto_increment; mod v150_role_comment; mod v151_row_access_column_map; +mod v152_external_udf; diff --git a/src/meta/proto-conv/tests/it/proto_conv.rs b/src/meta/proto-conv/tests/it/proto_conv.rs index 3ce5e43913108..8cf3e2b067cc4 100644 --- a/src/meta/proto-conv/tests/it/proto_conv.rs +++ b/src/meta/proto-conv/tests/it/proto_conv.rs @@ -311,6 +311,7 @@ fn new_udf_server() -> databend_common_meta_app::principal::UDFServer { ("X-Api-Version".to_string(), "11".to_string()), ]), language: "python".to_string(), + arg_names: vec![], arg_types: vec![DataType::String], return_type: DataType::Boolean, immutable: None, diff --git a/src/meta/proto-conv/tests/it/v058_udf.rs b/src/meta/proto-conv/tests/it/v058_udf.rs index e32b984a6ca81..e52bfe3dbd31f 100644 --- a/src/meta/proto-conv/tests/it/v058_udf.rs +++ b/src/meta/proto-conv/tests/it/v058_udf.rs @@ -55,6 +55,7 @@ fn test_decode_v57_udf() -> anyhow::Result<()> { handler: "plus_int_py".to_string(), headers: BTreeMap::default(), language: "python".to_string(), + arg_names: vec![], arg_types: vec![ DataType::Number(NumberDataType::Int32), DataType::Number(NumberDataType::Int32), diff --git a/src/meta/proto-conv/tests/it/v079_udf_created_on.rs b/src/meta/proto-conv/tests/it/v079_udf_created_on.rs index 4e860860232c1..1740164f12b9b 100644 --- a/src/meta/proto-conv/tests/it/v079_udf_created_on.rs +++ b/src/meta/proto-conv/tests/it/v079_udf_created_on.rs @@ -58,6 +58,7 @@ fn test_decode_v79_udf_python() -> anyhow::Result<()> { handler: "plus_int_py".to_string(), headers: BTreeMap::default(), language: "python".to_string(), + arg_names: vec![], arg_types: vec![ DataType::Number(NumberDataType::Int32), DataType::Number(NumberDataType::Int32), diff --git a/src/meta/proto-conv/tests/it/v081_udf_script.rs b/src/meta/proto-conv/tests/it/v081_udf_script.rs index caa270a87f38d..143f906b9fa4c 100644 --- a/src/meta/proto-conv/tests/it/v081_udf_script.rs +++ b/src/meta/proto-conv/tests/it/v081_udf_script.rs @@ -59,6 +59,7 @@ fn test_decode_v81_udf_python() -> anyhow::Result<()> { handler: "plus_int_py".to_string(), headers: BTreeMap::default(), language: "python".to_string(), + arg_names: vec![], arg_types: vec![ DataType::Number(NumberDataType::Int32), DataType::Number(NumberDataType::Int32), diff --git a/src/meta/proto-conv/tests/it/v124_udf_server_headers.rs b/src/meta/proto-conv/tests/it/v124_udf_server_headers.rs index f994e45b5fb50..b4ee993f55965 100644 --- a/src/meta/proto-conv/tests/it/v124_udf_server_headers.rs +++ b/src/meta/proto-conv/tests/it/v124_udf_server_headers.rs @@ -47,6 +47,7 @@ fn test_decode_v124_udf_server_headers() -> anyhow::Result<()> { ("X-Api-Version".to_string(), "11".to_string()), ]), language: "python".to_string(), + arg_names: vec![], arg_types: vec![DataType::String], return_type: DataType::Boolean, immutable: None, diff --git a/src/meta/proto-conv/tests/it/v135_udf_immutable.rs b/src/meta/proto-conv/tests/it/v135_udf_immutable.rs index c326fa06ab5a4..3309357ff8dd9 100644 --- a/src/meta/proto-conv/tests/it/v135_udf_immutable.rs +++ b/src/meta/proto-conv/tests/it/v135_udf_immutable.rs @@ -60,6 +60,7 @@ fn test_decode_v135_udf_server() -> anyhow::Result<()> { ("X-Api-Version".to_string(), "11".to_string()), ]), language: "python".to_string(), + arg_names: vec![], arg_types: vec![DataType::String], return_type: DataType::Boolean, immutable: Some(true), diff --git a/src/meta/proto-conv/tests/it/v152_external_udf.rs b/src/meta/proto-conv/tests/it/v152_external_udf.rs new file mode 100644 index 0000000000000..304c239bb8624 --- /dev/null +++ b/src/meta/proto-conv/tests/it/v152_external_udf.rs @@ -0,0 +1,80 @@ +// Copyright 2023 Datafuse Labs. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::collections::BTreeMap; + +use chrono::DateTime; +use chrono::Utc; +use databend_common_expression::types::DataType; +use databend_common_expression::TableDataType; +use databend_common_meta_app::principal::UDFDefinition; +use databend_common_meta_app::principal::UDFServer; +use databend_common_meta_app::principal::UserDefinedFunction; +use fastrace::func_name; + +use crate::common; + +#[test] +fn test_decode_v152_external_udf() -> anyhow::Result<()> { + let user_defined_function_v152 = vec![ + 10, 8, 116, 101, 115, 116, 95, 117, 100, 102, 18, 17, 105, 115, 32, 97, 32, 101, 120, 116, + 101, 114, 110, 97, 108, 32, 117, 100, 102, 34, 123, 10, 21, 104, 116, 116, 112, 58, 47, 47, + 49, 50, 55, 46, 48, 46, 48, 46, 49, 58, 56, 56, 56, 56, 18, 7, 105, 115, 101, 109, 112, + 116, 121, 26, 6, 112, 121, 116, 104, 111, 110, 34, 10, 162, 3, 0, 160, 6, 152, 1, 168, 6, + 24, 34, 10, 146, 2, 0, 160, 6, 152, 1, 168, 6, 24, 42, 10, 138, 2, 0, 160, 6, 152, 1, 168, + 6, 24, 50, 19, 10, 13, 88, 45, 65, 112, 105, 45, 86, 101, 114, 115, 105, 111, 110, 18, 2, + 49, 49, 50, 17, 10, 7, 88, 45, 84, 111, 107, 101, 110, 18, 6, 97, 98, 99, 49, 50, 51, 160, + 6, 152, 1, 168, 6, 24, 42, 23, 49, 57, 55, 48, 45, 48, 49, 45, 48, 49, 32, 48, 48, 58, 48, + 48, 58, 48, 48, 32, 85, 84, 67, 160, 6, 152, 1, 168, 6, 24, + ]; + + let want = || UserDefinedFunction { + name: "test_udf".to_string(), + description: "is a external udf".to_string(), + definition: UDFDefinition::UDFServer(UDFServer { + address: "http://127.0.0.1:8888".to_string(), + handler: "isempty".to_string(), + headers: BTreeMap::from([ + ("X-Token".to_string(), "abc123".to_string()), + ("X-Api-Version".to_string(), "11".to_string()), + ]), + language: "python".to_string(), + arg_names: vec![], + arg_types: vec![DataType::StageLocation, DataType::String], + return_type: DataType::Boolean, + immutable: None, + }), + created_on: DateTime::::default(), + }; + common::test_pb_from_to(func_name!(), want())?; + common::test_load_old( + func_name!(), + user_defined_function_v152.as_slice(), + 152, + want(), + )?; + + Ok(()) +} + +#[test] +fn test_decode_v152_data_type_stage_location() -> anyhow::Result<()> { + let table_data_type_v152 = vec![162, 3, 0, 160, 6, 152, 1, 168, 6, 24]; + + let want = || TableDataType::StageLocation; + common::test_pb_from_to(func_name!(), want())?; + common::test_load_old(func_name!(), table_data_type_v152.as_slice(), 152, want())?; + + Ok(()) +} diff --git a/src/meta/protos/proto/datatype.proto b/src/meta/protos/proto/datatype.proto index 556b68c775347..4f4bafa45fd46 100644 --- a/src/meta/protos/proto/datatype.proto +++ b/src/meta/protos/proto/datatype.proto @@ -69,6 +69,7 @@ message DataType { Empty interval_t = 49; Vector vector_t = 50; uint64 opaque_t = 51; + Empty stage_location_t = 52; } } diff --git a/src/meta/protos/proto/udf.proto b/src/meta/protos/proto/udf.proto index 7ae97266617d7..a0efd94134051 100644 --- a/src/meta/protos/proto/udf.proto +++ b/src/meta/protos/proto/udf.proto @@ -38,6 +38,7 @@ message UDFServer { DataType return_type = 5; map headers = 6; optional bool immutable = 7; + repeated string arg_names = 8; } message UDFScript { diff --git a/src/query/ast/src/ast/expr.rs b/src/query/ast/src/ast/expr.rs index 8aaf8b5be0a20..199dcad97a5c0 100644 --- a/src/query/ast/src/ast/expr.rs +++ b/src/query/ast/src/ast/expr.rs @@ -296,6 +296,10 @@ pub enum Expr { Placeholder { span: Span, }, + StageLocation { + span: Span, + location: String, + }, } impl Expr { @@ -341,7 +345,8 @@ impl Expr { | Expr::PreviousDay { span, .. } | Expr::NextDay { span, .. } | Expr::Hole { span, .. } - | Expr::Placeholder { span } => *span, + | Expr::Placeholder { span } + | Expr::StageLocation { span, .. } => *span, } } @@ -510,6 +515,7 @@ impl Expr { Expr::NextDay { span, date, .. } => merge_span(*span, date.whole_span()), Expr::Hole { span, .. } => *span, Expr::Placeholder { span } => *span, + Expr::StageLocation { span, .. } => *span, } } @@ -906,6 +912,9 @@ impl Display for Expr { Expr::Placeholder { .. } => { write!(f, "?")?; } + Expr::StageLocation { location, .. } => { + write!(f, "@{location}")?; + } } if need_paren { @@ -1201,6 +1210,7 @@ pub enum TypeName { Vector(u64), Nullable(Box), NotNull(Box), + StageLocation, } impl TypeName { @@ -1333,6 +1343,9 @@ impl Display for TypeName { TypeName::Vector(dimension) => { write!(f, "VECTOR({dimension})")?; } + TypeName::StageLocation => { + write!(f, "STAGE_LOCATION")?; + } } Ok(()) } diff --git a/src/query/ast/src/ast/statements/udf.rs b/src/query/ast/src/ast/statements/udf.rs index 9215594afc227..c4d2fa4640a26 100644 --- a/src/query/ast/src/ast/statements/udf.rs +++ b/src/query/ast/src/ast/statements/udf.rs @@ -27,6 +27,12 @@ use crate::ast::Expr; use crate::ast::Identifier; use crate::ast::TypeName; +#[derive(Debug, Clone, PartialEq, Drive, DriveMut)] +pub enum UDFArgs { + Types(Vec), + NameWithTypes(Vec<(Identifier, TypeName)>), +} + #[derive(Debug, Clone, PartialEq, Drive, DriveMut)] pub enum UDFDefinition { LambdaUDF { @@ -34,7 +40,7 @@ pub enum UDFDefinition { definition: Box, }, UDFServer { - arg_types: Vec, + arg_types: UDFArgs, return_type: TypeName, address: String, handler: String, @@ -54,7 +60,7 @@ pub enum UDFDefinition { immutable: Option, }, UDAFServer { - arg_types: Vec, + arg_types: UDFArgs, state_fields: Vec, return_type: TypeName, address: String, @@ -83,6 +89,38 @@ pub enum UDFDefinition { }, } +impl UDFArgs { + pub fn len(&self) -> usize { + match self { + UDFArgs::Types(types) => types.len(), + UDFArgs::NameWithTypes(name_with_types) => name_with_types.len(), + } + } + + pub fn is_empty(&self) -> bool { + self.len() == 0 + } +} + +impl Display for UDFArgs { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + UDFArgs::Types(types) => { + write_comma_separated_list(f, types)?; + } + UDFArgs::NameWithTypes(name_with_types) => { + write_comma_separated_list( + f, + name_with_types + .iter() + .map(|(name, ty)| format!("{name} {ty}")), + )?; + } + } + Ok(()) + } +} + impl Display for UDFDefinition { fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { match self { @@ -103,8 +141,7 @@ impl Display for UDFDefinition { language, immutable, } => { - write!(f, "( ")?; - write_comma_separated_list(f, arg_types)?; + write!(f, "( {arg_types}")?; write!(f, " ) RETURNS {return_type} LANGUAGE {language}")?; if let Some(immutable) = immutable { if *immutable { @@ -169,8 +206,7 @@ impl Display for UDFDefinition { headers, language, } => { - write!(f, "( ")?; - write_comma_separated_list(f, arg_types)?; + write!(f, "( {arg_types}")?; write!(f, " ) STATE {{ ")?; write_comma_separated_list(f, state_types)?; write!(f, " }} RETURNS {return_type} LANGUAGE {language}")?; diff --git a/src/query/ast/src/parser/expr.rs b/src/query/ast/src/parser/expr.rs index 83c50b96c731e..2e581c22d4dfe 100644 --- a/src/query/ast/src/parser/expr.rs +++ b/src/query/ast/src/parser/expr.rs @@ -351,6 +351,9 @@ pub enum ExprElement { name: String, }, Placeholder, + StageLocation { + location: String, + }, } pub const BETWEEN_PREC: u32 = 20; @@ -471,6 +474,7 @@ impl ExprElement { ExprElement::Hole { .. } => Affix::Nilfix, ExprElement::Placeholder => Affix::Nilfix, ExprElement::VariableAccess { .. } => Affix::Nilfix, + ExprElement::StageLocation { .. } => Affix::Nilfix, } } } @@ -522,6 +526,7 @@ impl Expr { Expr::NextDay { .. } => Affix::Nilfix, Expr::Hole { .. } => Affix::Nilfix, Expr::Placeholder { .. } => Affix::Nilfix, + Expr::StageLocation { .. } => Affix::Nilfix, } } } @@ -770,6 +775,10 @@ impl<'a, I: Iterator>> PrattParser for ExprP let span = transform_span(elem.span.tokens); make_func_get_variable(span, name) } + ExprElement::StageLocation { location } => Expr::StageLocation { + span: transform_span(elem.span.tokens), + location, + }, _ => unreachable!(), }; Ok(expr) @@ -1490,6 +1499,10 @@ pub fn expr_element(i: Input) -> IResult> { } }); + let stage_location = map(rule! { #at_string }, |location| { + ExprElement::StageLocation { location } + }); + map( consumed(alt(( // Note: each `alt` call supports maximum of 21 parsers @@ -1540,6 +1553,7 @@ pub fn expr_element(i: Input) -> IResult> { #case : "`CASE ... END`" | #tuple : "`( [, ...])`" | #subquery : "`(SELECT ...)`" + | #stage_location: "@" | #column_ref : "" | #dot_access : "" | #map_access : "[] | . | :" @@ -1911,6 +1925,7 @@ pub fn type_name(i: Input) -> IResult { rule! { VECTOR ~ ^"(" ~ ^#literal_u64 ~ ^")" }, |(_, _, dimension, _)| TypeName::Vector(dimension), ); + let ty_stage_location = value(TypeName::StageLocation, rule! { STAGE_LOCATION }); map_res( alt(( rule! { @@ -1945,6 +1960,7 @@ pub fn type_name(i: Input) -> IResult { | #ty_geography | #ty_nullable | #ty_vector + | #ty_stage_location ) ~ #nullable? : "type name" }, )), |(ty, opt_nullable)| match opt_nullable { diff --git a/src/query/ast/src/parser/statement.rs b/src/query/ast/src/parser/statement.rs index b45514bf1a45b..b0f1f89dc90ca 100644 --- a/src/query/ast/src/parser/statement.rs +++ b/src/query/ast/src/parser/statement.rs @@ -5326,9 +5326,9 @@ pub fn udf_definition(i: Input) -> IResult { }, ); - let udf = map( + let udf = map_res( rule! { - "(" ~ #comma_separated_list0(type_name) ~ ")" + #udf_args ~ RETURNS ~ #type_name ~ LANGUAGE ~ #ident ~ (#udf_immutable)? @@ -5339,10 +5339,8 @@ pub fn udf_definition(i: Input) -> IResult { ~ #udf_script_or_address }, |( - _, arg_types, _, - _, return_type, _, language, @@ -5356,7 +5354,12 @@ pub fn udf_definition(i: Input) -> IResult { address_or_code, )| { if address_or_code.1 { - UDFDefinition::UDFScript { + let UDFArgs::Types(arg_types) = arg_types else { + return Err(nom::Err::Failure(ErrorKind::Other( + "UDFScript parameters can only be of type", + ))); + }; + Ok(UDFDefinition::UDFScript { arg_types, return_type, code: address_or_code.0, @@ -5372,9 +5375,9 @@ pub fn udf_definition(i: Input) -> IResult { // Now we use fixed runtime version runtime_version: "".to_string(), immutable, - } + }) } else { - UDFDefinition::UDFServer { + Ok(UDFDefinition::UDFServer { arg_types, return_type, address: address_or_code.0, @@ -5384,7 +5387,7 @@ pub fn udf_definition(i: Input) -> IResult { .map(|(_, _, _, headers, _)| BTreeMap::from_iter(headers)) .unwrap_or_default(), immutable, - } + }) } }, ); @@ -5409,9 +5412,9 @@ pub fn udf_definition(i: Input) -> IResult { }, ); - let udaf = map( + let udaf = map_res( rule! { - "(" ~ #comma_separated_list0(type_name) ~ ")" + #udf_args ~ STATE ~ "{" ~ #comma_separated_list0(udaf_state_field) ~ "}" ~ RETURNS ~ #type_name ~ LANGUAGE ~ #ident @@ -5421,11 +5424,9 @@ pub fn udf_definition(i: Input) -> IResult { ~ #udf_script_or_address }, |( - _, arg_types, _, _, - _, state_types, _, _, @@ -5438,7 +5439,12 @@ pub fn udf_definition(i: Input) -> IResult { address_or_code, )| { if address_or_code.1 { - UDFDefinition::UDAFScript { + let UDFArgs::Types(arg_types) = arg_types else { + return Err(nom::Err::Failure(ErrorKind::Other( + "UDAFScript parameters can only be of type", + ))); + }; + Ok(UDFDefinition::UDAFScript { arg_types, state_fields: state_types, return_type, @@ -5453,9 +5459,9 @@ pub fn udf_definition(i: Input) -> IResult { // TODO inject runtime_version by user // Now we use fixed runtime version runtime_version: "".to_string(), - } + }) } else { - UDFDefinition::UDAFServer { + Ok(UDFDefinition::UDAFServer { arg_types, state_fields: state_types, return_type, @@ -5464,7 +5470,7 @@ pub fn udf_definition(i: Input) -> IResult { .map(|(_, _, _, headers, _)| BTreeMap::from_iter(headers)) .unwrap_or_default(), language: language.to_string(), - } + }) } }, ); @@ -5477,6 +5483,26 @@ pub fn udf_definition(i: Input) -> IResult { )(i) } +fn udf_args(i: Input) -> IResult { + let types = map( + rule! { + "(" ~ #comma_separated_list0(type_name) ~ ")" + }, + |(_, types, _)| UDFArgs::Types(types), + ); + let name_with_types = map( + rule! { + "(" ~ #comma_separated_list0(udtf_arg) ~ ")" + }, + |(_, name_with_types, _)| UDFArgs::NameWithTypes(name_with_types), + ); + + rule!( + #types: "(, ...)" + | #name_with_types: "(, ...)" + )(i) +} + fn udtf_arg(i: Input) -> IResult<(Identifier, TypeName)> { map(rule! { #ident ~ ^#type_name }, |(name, ty)| (name, ty))(i) } diff --git a/src/query/ast/src/parser/token.rs b/src/query/ast/src/parser/token.rs index 8d0e22d8eb869..c02cfc468fa77 100644 --- a/src/query/ast/src/parser/token.rs +++ b/src/query/ast/src/parser/token.rs @@ -170,7 +170,7 @@ pub enum TokenKind { #[regex(r#"\$\$([^\$]|(\$[^\$]))*\$\$"#)] LiteralCodeString, - #[regex(r#"@([^\s`;'"()]|\\\s|\\'|\\"|\\\\)+"#)] + #[regex(r#"@([^\s,`;'"()]|\\\s|\\'|\\"|\\\\)+"#)] LiteralAtString, #[regex(r"[xX]'[a-fA-F0-9]*'")] @@ -1150,6 +1150,8 @@ pub enum TokenKind { SPLIT_SIZE, #[token("STAGE", ignore(ascii_case))] STAGE, + #[token("STAGE_LOCATION", ignore(ascii_case))] + STAGE_LOCATION, #[token("SYNTAX", ignore(ascii_case))] SYNTAX, #[token("USAGE", ignore(ascii_case))] diff --git a/src/query/ast/tests/it/parser.rs b/src/query/ast/tests/it/parser.rs index 65aedab8e6763..9567c4144026e 100644 --- a/src/query/ast/tests/it/parser.rs +++ b/src/query/ast/tests/it/parser.rs @@ -841,6 +841,7 @@ SELECT * from s;"#, r#"CREATE OR REPLACE FUNCTION isnotempty_test_replace AS(p) -> not(is_null(p)) DESC = 'This is a description';"#, r#"CREATE OR REPLACE FUNCTION isnotempty_test_replace (p STRING) RETURNS BOOL AS $$ not(is_null(p)) $$;"#, r#"CREATE FUNCTION binary_reverse (BINARY) RETURNS BINARY LANGUAGE python HANDLER = 'binary_reverse' ADDRESS = 'http://0.0.0.0:8815';"#, + r#"CREATE FUNCTION binary_reverse (arg0 BINARY) RETURNS BINARY LANGUAGE python HANDLER = 'binary_reverse' ADDRESS = 'http://0.0.0.0:8815';"#, r#"CREATE FUNCTION binary_reverse (BINARY) RETURNS BINARY LANGUAGE python HANDLER = 'binary_reverse' HEADERS = ('X-Authorization' = '123') ADDRESS = 'http://0.0.0.0:8815';"#, r#"CREATE FUNCTION binary_reverse_table () RETURNS TABLE (c1 int) AS $$ select * from binary_reverse $$;"#, r#"ALTER FUNCTION binary_reverse (BINARY) RETURNS BINARY LANGUAGE python HANDLER = 'binary_reverse' ADDRESS = 'http://0.0.0.0:8815';"#, @@ -1402,6 +1403,7 @@ fn test_expr() { r#"MAP_TRANSFORM_VALUES({1:10,2:20,3:30}, (k, v) -> v + 1)"#, r#"INTERVAL '1 YEAR'"#, r#"(?, ?)"#, + r#"@test_stage/input/34"#, ]; for case in cases { diff --git a/src/query/ast/tests/it/testdata/expr-error.txt b/src/query/ast/tests/it/testdata/expr-error.txt index bb5cda9599713..da8722f56ed5b 100644 --- a/src/query/ast/tests/it/testdata/expr-error.txt +++ b/src/query/ast/tests/it/testdata/expr-error.txt @@ -29,7 +29,7 @@ error: --> SQL:1:14 | 1 | CAST(col1 AS foo) - | ---- ^^^ unexpected `foo`, expecting `BOOL`, `FLOAT`, `BOOLEAN`, `FLOAT32`, `FLOAT64`, `BLOB`, `JSON`, `DOUBLE`, `VECTOR`, `LONGBLOB`, `GEOMETRY`, `GEOGRAPHY`, `UINT8`, `TINYINT`, `UINT16`, `SMALLINT`, `UINT32`, `INT`, `INTEGER`, `UINT64`, `UNSIGNED`, `BIGINT`, `INT8`, `INT16`, `INT32`, `INT64`, `SIGNED`, `REAL`, `DECIMAL`, `ARRAY`, `MAP`, `BITMAP`, `TUPLE`, `DATE`, `DATETIME`, `TIMESTAMP`, `INTERVAL`, `NUMERIC`, `BINARY`, `VARBINARY`, `MEDIUMBLOB`, `TINYBLOB`, `STRING`, `VARCHAR`, `CHAR`, `CHARACTER`, `TEXT`, `VARIANT`, or `NULLABLE` + | ---- ^^^ unexpected `foo`, expecting `BOOL`, `FLOAT`, `BOOLEAN`, `FLOAT32`, `FLOAT64`, `BLOB`, `JSON`, `DOUBLE`, `VECTOR`, `LONGBLOB`, `GEOMETRY`, `GEOGRAPHY`, `STAGE_LOCATION`, `UINT8`, `TINYINT`, `UINT16`, `SMALLINT`, `UINT32`, `INT`, `INTEGER`, `UINT64`, `UNSIGNED`, `BIGINT`, `INT8`, `INT16`, `INT32`, `INT64`, `SIGNED`, `REAL`, `DECIMAL`, `ARRAY`, `MAP`, `BITMAP`, `TUPLE`, `DATE`, `DATETIME`, `TIMESTAMP`, `INTERVAL`, `NUMERIC`, `BINARY`, `VARBINARY`, `MEDIUMBLOB`, `TINYBLOB`, `STRING`, `VARCHAR`, `CHAR`, `CHARACTER`, `TEXT`, `VARIANT`, or `NULLABLE` | | | while parsing `CAST(... AS ...)` | while parsing expression @@ -52,7 +52,7 @@ error: --> SQL:1:10 | 1 | CAST(col1) - | ---- ^ unexpected `)`, expecting `AS`, `,`, `(`, `IS`, `NOT`, `IN`, `LIKE`, `EXISTS`, `BETWEEN`, `+`, `-`, `*`, `/`, `//`, `DIV`, `%`, `||`, `<=>`, `<+>`, `<->`, `>`, `<`, `>=`, `<=`, `=`, `<>`, `!=`, `^`, `AND`, `OR`, `XOR`, `REGEXP`, `RLIKE`, `SOUNDS`, , , , , , `->`, `->>`, `#>`, `#>>`, `?`, `?|`, `?&`, `@>`, `<@`, `@?`, `@@`, `#-`, , , , , , `CAST`, `TRY_CAST`, `::`, `POSITION`, or 48 more ... + | ---- ^ unexpected `)`, expecting `AS`, `,`, `(`, `IS`, `NOT`, `IN`, `LIKE`, `EXISTS`, `BETWEEN`, `+`, `-`, `*`, `/`, `//`, `DIV`, `%`, `||`, `<=>`, `<+>`, `<->`, `>`, `<`, `>=`, `<=`, `=`, `<>`, `!=`, `^`, `AND`, `OR`, `XOR`, `REGEXP`, `RLIKE`, `SOUNDS`, , , , , , `->`, `->>`, `#>`, `#>>`, `?`, `?|`, `?&`, `@>`, `<@`, `@?`, `@@`, `#-`, , , , , , `CAST`, `TRY_CAST`, `::`, `POSITION`, or 49 more ... | | | while parsing `CAST(... AS ...)` | while parsing expression @@ -81,7 +81,7 @@ error: 1 | $ abc + 3 | ^ | | - | unexpected `$`, expecting `IS`, `IN`, `LIKE`, `EXISTS`, `BETWEEN`, `+`, `-`, `*`, `/`, `//`, `DIV`, `%`, `||`, `<=>`, `<+>`, `<->`, `>`, `<`, `>=`, `<=`, `=`, `<>`, `!=`, `^`, `AND`, `OR`, `XOR`, `NOT`, `REGEXP`, `RLIKE`, `SOUNDS`, , , , , , `->`, `->>`, `#>`, `#>>`, `?`, `?|`, `?&`, `@>`, `<@`, `@?`, `@@`, `#-`, , , , , , `CAST`, `TRY_CAST`, `::`, `POSITION`, `IdentVariable`, `DATEADD`, `DATE_ADD`, or 46 more ... + | unexpected `$`, expecting `IS`, `IN`, `LIKE`, `EXISTS`, `BETWEEN`, `+`, `-`, `*`, `/`, `//`, `DIV`, `%`, `||`, `<=>`, `<+>`, `<->`, `>`, `<`, `>=`, `<=`, `=`, `<>`, `!=`, `^`, `AND`, `OR`, `XOR`, `NOT`, `REGEXP`, `RLIKE`, `SOUNDS`, , , , , , `->`, `->>`, `#>`, `#>>`, `?`, `?|`, `?&`, `@>`, `<@`, `@?`, `@@`, `#-`, , , , , , `CAST`, `TRY_CAST`, `::`, `POSITION`, `IdentVariable`, `DATEADD`, `DATE_ADD`, or 47 more ... | while parsing expression diff --git a/src/query/ast/tests/it/testdata/expr.txt b/src/query/ast/tests/it/testdata/expr.txt index c8198ae5bac5e..ccbe7f08ffa44 100644 --- a/src/query/ast/tests/it/testdata/expr.txt +++ b/src/query/ast/tests/it/testdata/expr.txt @@ -6651,3 +6651,16 @@ Tuple { } +---------- Input ---------- +@test_stage/input/34 +---------- Output --------- +@test_stage/input/34 +---------- AST ------------ +StageLocation { + span: Some( + 0..20, + ), + location: "test_stage/input/34", +} + + diff --git a/src/query/ast/tests/it/testdata/stmt-error.txt b/src/query/ast/tests/it/testdata/stmt-error.txt index 136da7162ffaf..cdce19d4656de 100644 --- a/src/query/ast/tests/it/testdata/stmt-error.txt +++ b/src/query/ast/tests/it/testdata/stmt-error.txt @@ -39,7 +39,7 @@ error: --> SQL:1:19 | 1 | create table a (c varch) - | ------ - ^^^^^ unexpected `varch`, expecting `VARCHAR`, `CHAR`, `VARIANT`, `CHARACTER`, `VARBINARY`, `ARRAY`, `BINARY`, `VECTOR`, `GEOGRAPHY`, `MAP`, `DATE`, `STRING`, `FLOAT32`, `FLOAT64`, `DECIMAL`, `NUMERIC`, `SMALLINT`, `DATETIME`, `INTERVAL`, `NULLABLE`, `REAL`, `BOOLEAN`, `BOOL`, `UINT8`, `TINYINT`, `UINT16`, `UINT32`, `INT`, `INTEGER`, `UINT64`, `UNSIGNED`, `BIGINT`, `INT8`, `INT16`, `INT32`, `INT64`, `SIGNED`, `FLOAT`, `DOUBLE`, `BITMAP`, `TUPLE`, `TIMESTAMP`, `LONGBLOB`, `MEDIUMBLOB`, `TINYBLOB`, `BLOB`, `TEXT`, `JSON`, or `GEOMETRY` + | ------ - ^^^^^ unexpected `varch`, expecting `VARCHAR`, `CHAR`, `VARIANT`, `CHARACTER`, `VARBINARY`, `ARRAY`, `BINARY`, `VECTOR`, `GEOGRAPHY`, `STAGE_LOCATION`, `MAP`, `DATE`, `STRING`, `FLOAT32`, `FLOAT64`, `DECIMAL`, `NUMERIC`, `SMALLINT`, `DATETIME`, `INTERVAL`, `NULLABLE`, `REAL`, `BOOLEAN`, `BOOL`, `UINT8`, `TINYINT`, `UINT16`, `UINT32`, `INT`, `INTEGER`, `UINT64`, `UNSIGNED`, `BIGINT`, `INT8`, `INT16`, `INT32`, `INT64`, `SIGNED`, `FLOAT`, `DOUBLE`, `BITMAP`, `TUPLE`, `TIMESTAMP`, `LONGBLOB`, `MEDIUMBLOB`, `TINYBLOB`, `BLOB`, `TEXT`, `JSON`, or `GEOMETRY` | | | | | while parsing ` [DEFAULT ] [AS () VIRTUAL] [AS () STORED] [CHECK ()] [COMMENT '']` | while parsing `CREATE [OR REPLACE] TABLE [IF NOT EXISTS] [.] [] []` @@ -52,7 +52,7 @@ error: --> SQL:1:25 | 1 | create table a (c tuple()) - | ------ - ----- ^ unexpected `)`, expecting `BOOLEAN`, `BOOL`, `UINT8`, `TINYINT`, `UINT16`, `SMALLINT`, `UINT32`, `INT`, `INTEGER`, `UINT64`, `UNSIGNED`, `BIGINT`, `INT8`, `INT16`, `INT32`, `INT64`, `SIGNED`, `FLOAT32`, `FLOAT`, `REAL`, `FLOAT64`, `DOUBLE`, `DECIMAL`, `ARRAY`, `MAP`, `BITMAP`, `TUPLE`, `DATE`, `DATETIME`, `TIMESTAMP`, `INTERVAL`, `NUMERIC`, `BINARY`, `VARBINARY`, `LONGBLOB`, `MEDIUMBLOB`, `TINYBLOB`, `BLOB`, `STRING`, `VARCHAR`, `CHAR`, `CHARACTER`, `TEXT`, `VARIANT`, `JSON`, `GEOMETRY`, `GEOGRAPHY`, `NULLABLE`, `VECTOR`, , , or `IDENTIFIER` + | ------ - ----- ^ unexpected `)`, expecting `BOOLEAN`, `BOOL`, `UINT8`, `TINYINT`, `UINT16`, `SMALLINT`, `UINT32`, `INT`, `INTEGER`, `UINT64`, `UNSIGNED`, `BIGINT`, `INT8`, `INT16`, `INT32`, `INT64`, `SIGNED`, `FLOAT32`, `FLOAT`, `REAL`, `FLOAT64`, `DOUBLE`, `DECIMAL`, `ARRAY`, `MAP`, `BITMAP`, `TUPLE`, `DATE`, `DATETIME`, `TIMESTAMP`, `INTERVAL`, `NUMERIC`, `BINARY`, `VARBINARY`, `LONGBLOB`, `MEDIUMBLOB`, `TINYBLOB`, `BLOB`, `STRING`, `VARCHAR`, `CHAR`, `CHARACTER`, `TEXT`, `VARIANT`, `JSON`, `GEOMETRY`, `GEOGRAPHY`, `NULLABLE`, `VECTOR`, `STAGE_LOCATION`, , , or `IDENTIFIER` | | | | | | | while parsing type name | | while parsing ` [DEFAULT ] [AS () VIRTUAL] [AS () STORED] [CHECK ()] [COMMENT '']` @@ -66,7 +66,7 @@ error: --> SQL:1:38 | 1 | create table a (b tuple(c int, uint64)); - | ------ - ----- ^ unexpected `)`, expecting `BOOLEAN`, `BOOL`, `UINT8`, `TINYINT`, `UINT16`, `SMALLINT`, `UINT32`, `INT`, `INTEGER`, `UINT64`, `UNSIGNED`, `BIGINT`, `INT8`, `INT16`, `INT32`, `INT64`, `SIGNED`, `FLOAT32`, `FLOAT`, `REAL`, `FLOAT64`, `DOUBLE`, `DECIMAL`, `ARRAY`, `MAP`, `BITMAP`, `TUPLE`, `DATE`, `DATETIME`, `TIMESTAMP`, `INTERVAL`, `NUMERIC`, `BINARY`, `VARBINARY`, `LONGBLOB`, `MEDIUMBLOB`, `TINYBLOB`, `BLOB`, `STRING`, `VARCHAR`, `CHAR`, `CHARACTER`, `TEXT`, `VARIANT`, `JSON`, `GEOMETRY`, `GEOGRAPHY`, `NULLABLE`, or `VECTOR` + | ------ - ----- ^ unexpected `)`, expecting `BOOLEAN`, `BOOL`, `UINT8`, `TINYINT`, `UINT16`, `SMALLINT`, `UINT32`, `INT`, `INTEGER`, `UINT64`, `UNSIGNED`, `BIGINT`, `INT8`, `INT16`, `INT32`, `INT64`, `SIGNED`, `FLOAT32`, `FLOAT`, `REAL`, `FLOAT64`, `DOUBLE`, `DECIMAL`, `ARRAY`, `MAP`, `BITMAP`, `TUPLE`, `DATE`, `DATETIME`, `TIMESTAMP`, `INTERVAL`, `NUMERIC`, `BINARY`, `VARBINARY`, `LONGBLOB`, `MEDIUMBLOB`, `TINYBLOB`, `BLOB`, `STRING`, `VARCHAR`, `CHAR`, `CHARACTER`, `TEXT`, `VARIANT`, `JSON`, `GEOMETRY`, `GEOGRAPHY`, `NULLABLE`, `VECTOR`, or `STAGE_LOCATION` | | | | | | | while parsing TUPLE( , ...) | | | while parsing type name @@ -608,7 +608,7 @@ error: --> SQL:1:41 | 1 | SELECT * FROM t GROUP BY GROUPING SETS () - | ------ ^ unexpected `)`, expecting `(`, `IS`, `IN`, `LIKE`, `EXISTS`, `BETWEEN`, `+`, `-`, `*`, `/`, `//`, `DIV`, `%`, `||`, `<=>`, `<+>`, `<->`, `>`, `<`, `>=`, `<=`, `=`, `<>`, `!=`, `^`, `AND`, `OR`, `XOR`, `NOT`, `REGEXP`, `RLIKE`, `SOUNDS`, , , , , , `->`, `->>`, `#>`, `#>>`, `?`, `?|`, `?&`, `@>`, `<@`, `@?`, `@@`, `#-`, , , , , , `CAST`, `TRY_CAST`, `::`, `POSITION`, `IdentVariable`, `DATEADD`, or 46 more ... + | ------ ^ unexpected `)`, expecting `(`, `IS`, `IN`, `LIKE`, `EXISTS`, `BETWEEN`, `+`, `-`, `*`, `/`, `//`, `DIV`, `%`, `||`, `<=>`, `<+>`, `<->`, `>`, `<`, `>=`, `<=`, `=`, `<>`, `!=`, `^`, `AND`, `OR`, `XOR`, `NOT`, `REGEXP`, `RLIKE`, `SOUNDS`, , , , , , `->`, `->>`, `#>`, `#>>`, `?`, `?|`, `?&`, `@>`, `<@`, `@?`, `@@`, `#-`, , , , , , `CAST`, `TRY_CAST`, `::`, `POSITION`, `IdentVariable`, `DATEADD`, or 47 more ... | | | while parsing `SELECT ...` @@ -1039,7 +1039,7 @@ error: --> SQL:1:65 | 1 | CREATE FUNCTION IF NOT EXISTS isnotempty AS(p) -> not(is_null(p) - | ------ -- ---- ^ unexpected end of input, expecting `)`, `(`, `WITHIN`, `IGNORE`, `RESPECT`, `OVER`, `IS`, `NOT`, `IN`, `LIKE`, `EXISTS`, `BETWEEN`, `+`, `-`, `*`, `/`, `//`, `DIV`, `%`, `||`, `<=>`, `<+>`, `<->`, `>`, `<`, `>=`, `<=`, `=`, `<>`, `!=`, `^`, `AND`, `OR`, `XOR`, `REGEXP`, `RLIKE`, `SOUNDS`, , , , , , `->`, `->>`, `#>`, `#>>`, `?`, `?|`, `?&`, `@>`, `<@`, `@?`, `@@`, `#-`, , , , , , `CAST`, or 52 more ... + | ------ -- ---- ^ unexpected end of input, expecting `)`, `(`, `WITHIN`, `IGNORE`, `RESPECT`, `OVER`, `IS`, `NOT`, `IN`, `LIKE`, `EXISTS`, `BETWEEN`, `+`, `-`, `*`, `/`, `//`, `DIV`, `%`, `||`, `<=>`, `<+>`, `<->`, `>`, `<`, `>=`, `<=`, `=`, `<>`, `!=`, `^`, `AND`, `OR`, `XOR`, `REGEXP`, `RLIKE`, `SOUNDS`, , , , , , `->`, `->>`, `#>`, `#>>`, `?`, `?|`, `?&`, `@>`, `<@`, `@?`, `@@`, `#-`, , , , , , `CAST`, or 53 more ... | | | | | | | | | while parsing `( [, ...])` | | | while parsing expression @@ -1067,7 +1067,7 @@ error: --> SQL:1:40 | 1 | CREATE FUNCTION my_agg (INT) STATE { s STRIN } RETURNS BOOLEAN LANGUAGE javascript ADDRESS = 'http://0.0.0.0:8815'; - | ------ - ^^^^^ unexpected `STRIN`, expecting `STRING`, `SIGNED`, `INTERVAL`, `TINYINT`, `VARIANT`, `SMALLINT`, `TINYBLOB`, `VARBINARY`, `INT8`, `JSON`, `INT16`, `INT32`, `INT64`, `UINT8`, `BIGINT`, `UINT16`, `UINT32`, `UINT64`, `BINARY`, `INTEGER`, `DATETIME`, `NUMERIC`, `TIMESTAMP`, `UNSIGNED`, `REAL`, `DATE`, `CHAR`, `TEXT`, `ARRAY`, `TUPLE`, `VECTOR`, `BOOLEAN`, `DECIMAL`, `VARCHAR`, `LONGBLOB`, `NULLABLE`, `CHARACTER`, `GEOGRAPHY`, `MEDIUMBLOB`, `BITMAP`, `}`, `BOOL`, `INT`, `FLOAT32`, `FLOAT`, `FLOAT64`, `DOUBLE`, `MAP`, `BLOB`, or `GEOMETRY` + | ------ - ^^^^^ unexpected `STRIN`, expecting `STRING`, `SIGNED`, `INTERVAL`, `TINYINT`, `VARIANT`, `SMALLINT`, `TINYBLOB`, `VARBINARY`, `INT8`, `JSON`, `INT16`, `INT32`, `INT64`, `UINT8`, `BIGINT`, `UINT16`, `UINT32`, `UINT64`, `BINARY`, `INTEGER`, `DATETIME`, `NUMERIC`, `TIMESTAMP`, `UNSIGNED`, `STAGE_LOCATION`, `REAL`, `DATE`, `CHAR`, `TEXT`, `ARRAY`, `TUPLE`, `VECTOR`, `BOOLEAN`, `DECIMAL`, `VARCHAR`, `LONGBLOB`, `NULLABLE`, `CHARACTER`, `GEOGRAPHY`, `MEDIUMBLOB`, `BITMAP`, `}`, `BOOL`, `INT`, `FLOAT32`, `FLOAT`, `FLOAT64`, `DOUBLE`, `MAP`, `BLOB`, or `GEOMETRY` | | | | | while parsing (, ...) STATE {, ...} RETURNS LANGUAGE { ADDRESS= | AS } | while parsing `CREATE [OR REPLACE] FUNCTION [IF NOT EXISTS] [DESC = ]` @@ -1155,7 +1155,7 @@ error: | ------ while parsing `CREATE [OR REPLACE] DICTIONARY [IF NOT EXISTS] [(, ...)] PRIMARY KEY [, ...] SOURCE ( ([])) [COMMENT ] ` 2 | ( 3 | user_name tuple(), - | --------- ----- ^ unexpected `)`, expecting `BOOLEAN`, `BOOL`, `UINT8`, `TINYINT`, `UINT16`, `SMALLINT`, `UINT32`, `INT`, `INTEGER`, `UINT64`, `UNSIGNED`, `BIGINT`, `INT8`, `INT16`, `INT32`, `INT64`, `SIGNED`, `FLOAT32`, `FLOAT`, `REAL`, `FLOAT64`, `DOUBLE`, `DECIMAL`, `ARRAY`, `MAP`, `BITMAP`, `TUPLE`, `DATE`, `DATETIME`, `TIMESTAMP`, `INTERVAL`, `NUMERIC`, `BINARY`, `VARBINARY`, `LONGBLOB`, `MEDIUMBLOB`, `TINYBLOB`, `BLOB`, `STRING`, `VARCHAR`, `CHAR`, `CHARACTER`, `TEXT`, `VARIANT`, `JSON`, `GEOMETRY`, `GEOGRAPHY`, `NULLABLE`, `VECTOR`, , , or `IDENTIFIER` + | --------- ----- ^ unexpected `)`, expecting `BOOLEAN`, `BOOL`, `UINT8`, `TINYINT`, `UINT16`, `SMALLINT`, `UINT32`, `INT`, `INTEGER`, `UINT64`, `UNSIGNED`, `BIGINT`, `INT8`, `INT16`, `INT32`, `INT64`, `SIGNED`, `FLOAT32`, `FLOAT`, `REAL`, `FLOAT64`, `DOUBLE`, `DECIMAL`, `ARRAY`, `MAP`, `BITMAP`, `TUPLE`, `DATE`, `DATETIME`, `TIMESTAMP`, `INTERVAL`, `NUMERIC`, `BINARY`, `VARBINARY`, `LONGBLOB`, `MEDIUMBLOB`, `TINYBLOB`, `BLOB`, `STRING`, `VARCHAR`, `CHAR`, `CHARACTER`, `TEXT`, `VARIANT`, `JSON`, `GEOMETRY`, `GEOGRAPHY`, `NULLABLE`, `VECTOR`, `STAGE_LOCATION`, , , or `IDENTIFIER` | | | | | while parsing type name | while parsing ` [DEFAULT ] [AS () VIRTUAL] [AS () STORED] [CHECK ()] [COMMENT '']` @@ -1206,7 +1206,7 @@ error: --> SQL:1:19 | 1 | drop procedure p1(a int) - | ---- ^ unexpected `a`, expecting `DATE`, `ARRAY`, `VARCHAR`, `VARIANT`, `SMALLINT`, `DATETIME`, `VARBINARY`, `CHARACTER`, `)`, `BOOLEAN`, `BOOL`, `UINT8`, `TINYINT`, `UINT16`, `UINT32`, `INT`, `INTEGER`, `UINT64`, `UNSIGNED`, `BIGINT`, `INT8`, `INT16`, `INT32`, `INT64`, `SIGNED`, `FLOAT32`, `FLOAT`, `REAL`, `FLOAT64`, `DOUBLE`, `DECIMAL`, `MAP`, `BITMAP`, `TUPLE`, `TIMESTAMP`, `INTERVAL`, `NUMERIC`, `BINARY`, `LONGBLOB`, `MEDIUMBLOB`, `TINYBLOB`, `BLOB`, `STRING`, `CHAR`, `TEXT`, `JSON`, `GEOMETRY`, `GEOGRAPHY`, `NULLABLE`, or `VECTOR` + | ---- ^ unexpected `a`, expecting `DATE`, `ARRAY`, `VARCHAR`, `VARIANT`, `SMALLINT`, `DATETIME`, `VARBINARY`, `CHARACTER`, `STAGE_LOCATION`, `)`, `BOOLEAN`, `BOOL`, `UINT8`, `TINYINT`, `UINT16`, `UINT32`, `INT`, `INTEGER`, `UINT64`, `UNSIGNED`, `BIGINT`, `INT8`, `INT16`, `INT32`, `INT64`, `SIGNED`, `FLOAT32`, `FLOAT`, `REAL`, `FLOAT64`, `DOUBLE`, `DECIMAL`, `MAP`, `BITMAP`, `TUPLE`, `TIMESTAMP`, `INTERVAL`, `NUMERIC`, `BINARY`, `LONGBLOB`, `MEDIUMBLOB`, `TINYBLOB`, `BLOB`, `STRING`, `CHAR`, `TEXT`, `JSON`, `GEOMETRY`, `GEOGRAPHY`, `NULLABLE`, or `VECTOR` | | | while parsing `DROP PROCEDURE ()` @@ -1238,7 +1238,7 @@ error: --> SQL:1:44 | 1 | create PROCEDURE p1() returns table(string not null, int null) language sql comment = 'test' as $$ - | ------ ----- ^^^ unexpected `not`, expecting `INT8`, `INT16`, `INT32`, `INT64`, `UINT16`, `UINT32`, `UINT64`, `INTEGER`, `FLOAT32`, `FLOAT64`, `INTERVAL`, `GEOMETRY`, `INT`, `BOOL`, `DATE`, `BLOB`, `TEXT`, `JSON`, `UINT8`, `FLOAT`, `TUPLE`, `DOUBLE`, `BITMAP`, `BINARY`, `STRING`, `VECTOR`, `BOOLEAN`, `NUMERIC`, `UNSIGNED`, `DATETIME`, `NULLABLE`, `TIMESTAMP`, `GEOGRAPHY`, `TINYINT`, `LONGBLOB`, `TINYBLOB`, `SMALLINT`, `BIGINT`, `SIGNED`, `REAL`, `DECIMAL`, `ARRAY`, `MAP`, `VARBINARY`, `MEDIUMBLOB`, `VARCHAR`, `CHAR`, `CHARACTER`, or `VARIANT` + | ------ ----- ^^^ unexpected `not`, expecting `INT8`, `INT16`, `INT32`, `INT64`, `UINT16`, `UINT32`, `UINT64`, `INTEGER`, `FLOAT32`, `FLOAT64`, `INTERVAL`, `GEOMETRY`, `INT`, `BOOL`, `DATE`, `BLOB`, `TEXT`, `JSON`, `UINT8`, `FLOAT`, `TUPLE`, `DOUBLE`, `BITMAP`, `BINARY`, `STRING`, `VECTOR`, `BOOLEAN`, `NUMERIC`, `UNSIGNED`, `DATETIME`, `NULLABLE`, `TIMESTAMP`, `GEOGRAPHY`, `TINYINT`, `LONGBLOB`, `TINYBLOB`, `STAGE_LOCATION`, `SMALLINT`, `BIGINT`, `SIGNED`, `REAL`, `DECIMAL`, `ARRAY`, `MAP`, `VARBINARY`, `MEDIUMBLOB`, `VARCHAR`, `CHAR`, `CHARACTER`, or `VARIANT` | | | | | while parsing TABLE( , ...) | while parsing `CREATE [ OR REPLACE ] PROCEDURE () RETURNS { [ NOT NULL ] | TABLE( , ...)} LANGUAGE SQL [ COMMENT = '' ] AS ` @@ -1259,7 +1259,7 @@ error: --> SQL:1:24 | 1 | create PROCEDURE p1(int, string) returns table(string not null, int null) language sql comment = 'test' as $$ - | ------ - ^ unexpected `,`, expecting `BOOLEAN`, `BOOL`, `UINT8`, `TINYINT`, `UINT16`, `SMALLINT`, `UINT32`, `INT`, `INTEGER`, `UINT64`, `UNSIGNED`, `BIGINT`, `INT8`, `INT16`, `INT32`, `INT64`, `SIGNED`, `FLOAT32`, `FLOAT`, `REAL`, `FLOAT64`, `DOUBLE`, `DECIMAL`, `ARRAY`, `MAP`, `BITMAP`, `TUPLE`, `DATE`, `DATETIME`, `TIMESTAMP`, `INTERVAL`, `NUMERIC`, `BINARY`, `VARBINARY`, `LONGBLOB`, `MEDIUMBLOB`, `TINYBLOB`, `BLOB`, `STRING`, `VARCHAR`, `CHAR`, `CHARACTER`, `TEXT`, `VARIANT`, `JSON`, `GEOMETRY`, `GEOGRAPHY`, `NULLABLE`, or `VECTOR` + | ------ - ^ unexpected `,`, expecting `BOOLEAN`, `BOOL`, `UINT8`, `TINYINT`, `UINT16`, `SMALLINT`, `UINT32`, `INT`, `INTEGER`, `UINT64`, `UNSIGNED`, `BIGINT`, `INT8`, `INT16`, `INT32`, `INT64`, `SIGNED`, `FLOAT32`, `FLOAT`, `REAL`, `FLOAT64`, `DOUBLE`, `DECIMAL`, `ARRAY`, `MAP`, `BITMAP`, `TUPLE`, `DATE`, `DATETIME`, `TIMESTAMP`, `INTERVAL`, `NUMERIC`, `BINARY`, `VARBINARY`, `LONGBLOB`, `MEDIUMBLOB`, `TINYBLOB`, `BLOB`, `STRING`, `VARCHAR`, `CHAR`, `CHARACTER`, `TEXT`, `VARIANT`, `JSON`, `GEOMETRY`, `GEOGRAPHY`, `NULLABLE`, `VECTOR`, or `STAGE_LOCATION` | | | | | while parsing ( , ...) | while parsing `CREATE [ OR REPLACE ] PROCEDURE () RETURNS { [ NOT NULL ] | TABLE( , ...)} LANGUAGE SQL [ COMMENT = '' ] AS ` diff --git a/src/query/ast/tests/it/testdata/stmt.txt b/src/query/ast/tests/it/testdata/stmt.txt index d705e4d75364b..485295cc1449e 100644 --- a/src/query/ast/tests/it/testdata/stmt.txt +++ b/src/query/ast/tests/it/testdata/stmt.txt @@ -26453,9 +26453,55 @@ CreateUDF( }, description: None, definition: UDFServer { - arg_types: [ - Binary, - ], + arg_types: Types( + [ + Binary, + ], + ), + return_type: Binary, + address: "http://0.0.0.0:8815", + handler: "binary_reverse", + headers: {}, + language: "python", + immutable: None, + }, + }, +) + + +---------- Input ---------- +CREATE FUNCTION binary_reverse (arg0 BINARY) RETURNS BINARY LANGUAGE python HANDLER = 'binary_reverse' ADDRESS = 'http://0.0.0.0:8815'; +---------- Output --------- +CREATE FUNCTION binary_reverse ( arg0 BINARY ) RETURNS BINARY LANGUAGE python HANDLER = 'binary_reverse' ADDRESS = 'http://0.0.0.0:8815' +---------- AST ------------ +CreateUDF( + CreateUDFStmt { + create_option: Create, + udf_name: Identifier { + span: Some( + 16..30, + ), + name: "binary_reverse", + quote: None, + ident_type: None, + }, + description: None, + definition: UDFServer { + arg_types: NameWithTypes( + [ + ( + Identifier { + span: Some( + 32..36, + ), + name: "arg0", + quote: None, + ident_type: None, + }, + Binary, + ), + ], + ), return_type: Binary, address: "http://0.0.0.0:8815", handler: "binary_reverse", @@ -26485,9 +26531,11 @@ CreateUDF( }, description: None, definition: UDFServer { - arg_types: [ - Binary, - ], + arg_types: Types( + [ + Binary, + ], + ), return_type: Binary, address: "http://0.0.0.0:8815", handler: "binary_reverse", @@ -26558,9 +26606,11 @@ AlterUDF( }, description: None, definition: UDFServer { - arg_types: [ - Binary, - ], + arg_types: Types( + [ + Binary, + ], + ), return_type: Binary, address: "http://0.0.0.0:8815", handler: "binary_reverse", @@ -26590,9 +26640,11 @@ CreateUDF( }, description: None, definition: UDFServer { - arg_types: [ - Binary, - ], + arg_types: Types( + [ + Binary, + ], + ), return_type: Binary, address: "http://0.0.0.0:8815", handler: "binary_reverse", @@ -26835,9 +26887,11 @@ CreateUDF( }, description: None, definition: UDAFServer { - arg_types: [ - Int32, - ], + arg_types: Types( + [ + Int32, + ], + ), state_fields: [ UDAFStateField { name: Identifier { diff --git a/src/query/expression/src/aggregate/payload_row.rs b/src/query/expression/src/aggregate/payload_row.rs index db6ffb179de19..971710168b1fd 100644 --- a/src/query/expression/src/aggregate/payload_row.rs +++ b/src/query/expression/src/aggregate/payload_row.rs @@ -68,7 +68,7 @@ pub(super) fn rowformat_size(data_type: &DataType) -> usize { | DataType::Geography => 4 + 8, // u32 len + address DataType::Nullable(x) => rowformat_size(x), DataType::Array(_) | DataType::Map(_) | DataType::Tuple(_) | DataType::Vector(_) => 4 + 8, - DataType::Generic(_) => unreachable!(), + DataType::Generic(_) | DataType::StageLocation => unreachable!(), DataType::Opaque(size) => size * 8, } } diff --git a/src/query/expression/src/converts/arrow/from.rs b/src/query/expression/src/converts/arrow/from.rs index 77ddf8e4968af..a7ce0f3e0e0b5 100644 --- a/src/query/expression/src/converts/arrow/from.rs +++ b/src/query/expression/src/converts/arrow/from.rs @@ -437,6 +437,7 @@ impl Column { } } DataType::Generic(_) => unreachable!("Generic type is not supported"), + DataType::StageLocation => unreachable!("StageLocation type is not supported"), }; Ok(column) diff --git a/src/query/expression/src/converts/arrow/to.rs b/src/query/expression/src/converts/arrow/to.rs index 44295947e38f7..d754179395d13 100644 --- a/src/query/expression/src/converts/arrow/to.rs +++ b/src/query/expression/src/converts/arrow/to.rs @@ -221,6 +221,9 @@ impl From<&TableField> for Field { let inner_field = Arc::new(Field::new_list_field(inner_ty, false)); ArrowDataType::FixedSizeList(inner_field, dimension) } + TableDataType::StageLocation => { + unreachable!("TableDataType::StageLocation only for UDFServer/UDAFServer") + } }; Field::new(f.name(), ty, f.is_nullable()).with_metadata(metadata) diff --git a/src/query/expression/src/property.rs b/src/query/expression/src/property.rs index 3fc391dbb8601..348bd65b77332 100644 --- a/src/query/expression/src/property.rs +++ b/src/query/expression/src/property.rs @@ -229,7 +229,7 @@ impl Domain { | DataType::Geography | DataType::Vector(_) | DataType::Opaque(_) => Domain::Undefined, - DataType::Generic(_) => unreachable!(), + DataType::Generic(_) | DataType::StageLocation => unreachable!(), } } diff --git a/src/query/expression/src/schema.rs b/src/query/expression/src/schema.rs index 8731f16940fe8..d0fe2dadee38b 100644 --- a/src/query/expression/src/schema.rs +++ b/src/query/expression/src/schema.rs @@ -342,6 +342,8 @@ pub enum TableDataType { Interval, Vector(VectorDataType), Opaque(usize), + // Only used to persist DataType in meta + StageLocation, } impl DataSchema { @@ -1363,6 +1365,7 @@ impl From<&TableDataType> for DataType { TableDataType::Geometry => DataType::Geometry, TableDataType::Geography => DataType::Geography, TableDataType::Vector(ty) => DataType::Vector(*ty), + TableDataType::StageLocation => DataType::StageLocation, } } } @@ -1469,6 +1472,7 @@ impl TableDataType { TableDataType::String => "VARCHAR".to_string(), TableDataType::Vector(ty) => format!("VECTOR({})", ty.dimension()), TableDataType::Nullable(inner_ty) => format!("{} NULL", inner_ty.sql_name()), + TableDataType::StageLocation => "STAGE_LOCATION".to_string(), _ => self.to_string().to_uppercase(), } } @@ -1519,7 +1523,8 @@ impl TableDataType { | TableDataType::Geometry | TableDataType::Geography | TableDataType::Interval - | TableDataType::Vector(_) => ty.sql_name(), + | TableDataType::Vector(_) + | TableDataType::StageLocation => ty.sql_name(), }; if is_null { format!("{} NULL", s) @@ -1700,6 +1705,7 @@ pub fn infer_schema_type(data_type: &DataType) -> Result { }) } DataType::Vector(ty) => Ok(TableDataType::Vector(*ty)), + DataType::StageLocation => Ok(TableDataType::StageLocation), DataType::Generic(_) => Err(ErrorCode::SemanticError(format!( "Cannot create table with type: {data_type}", ))), diff --git a/src/query/expression/src/type_check.rs b/src/query/expression/src/type_check.rs index 01dfa83b9fcc5..e4815f31a1c2b 100755 --- a/src/query/expression/src/type_check.rs +++ b/src/query/expression/src/type_check.rs @@ -164,6 +164,11 @@ pub fn check_cast( dest_type.clone() }; + if expr.data_type() == &DataType::String + && dest_type.remove_nullable() == DataType::StageLocation + { + return Ok(expr); + } if expr.data_type() == &wrapped_dest_type { Ok(expr) } else if expr.data_type().wrap_nullable() == wrapped_dest_type { diff --git a/src/query/expression/src/types.rs b/src/query/expression/src/types.rs index 74edc6e922240..ac0c32a8d2865 100755 --- a/src/query/expression/src/types.rs +++ b/src/query/expression/src/types.rs @@ -134,6 +134,8 @@ pub enum DataType { // Used internally for generic types Generic(usize), + + StageLocation, } impl DataType { @@ -214,7 +216,7 @@ impl DataType { DataType::Map(ty) => ty.has_generic(), DataType::Tuple(tys) => tys.iter().any(|ty| ty.has_generic()), DataType::Generic(_) => true, - DataType::Opaque(_) => false, + DataType::Opaque(_) | DataType::StageLocation => false, } } @@ -243,6 +245,7 @@ impl DataType { DataType::Map(ty) => ty.has_nested_nullable(), DataType::Tuple(tys) => tys.iter().any(|ty| ty.has_nested_nullable()), DataType::Opaque(_) => false, + DataType::StageLocation => false, } } @@ -470,7 +473,8 @@ impl DataType { | DataType::EmptyArray | DataType::EmptyMap | DataType::Opaque(_) - | DataType::Generic(_) => Err(ErrorCode::BadArguments(format!( + | DataType::Generic(_) + | DataType::StageLocation => Err(ErrorCode::BadArguments(format!( "Unsupported data type {} to sql type", self ))), diff --git a/src/query/expression/src/utils/display.rs b/src/query/expression/src/utils/display.rs index 9881a5d699316..9e501fcbdfc2b 100755 --- a/src/query/expression/src/utils/display.rs +++ b/src/query/expression/src/utils/display.rs @@ -640,6 +640,7 @@ impl Display for DataType { DataType::Vector(vector) => write!(f, "{vector}"), DataType::Generic(index) => write!(f, "T{index}"), DataType::Opaque(size) => write!(f, "Opaque({size})"), + DataType::StageLocation => write!(f, "StageLocation"), } } } @@ -691,6 +692,7 @@ impl Display for TableDataType { TableDataType::Geometry => write!(f, "Geometry"), TableDataType::Geography => write!(f, "Geography"), TableDataType::Vector(vector) => write!(f, "{vector}"), + TableDataType::StageLocation => write!(f, "StageLocation"), } } } diff --git a/src/query/expression/src/utils/udf_client.rs b/src/query/expression/src/utils/udf_client.rs index f88870774d25e..953314672b703 100644 --- a/src/query/expression/src/utils/udf_client.rs +++ b/src/query/expression/src/utils/udf_client.rs @@ -201,6 +201,20 @@ impl UDFFlightClient { arg_types: &[DataType], return_type: &DataType, ) -> Result<()> { + // DataType::StageLocation is only used to pass the stage location parameter to the external UDF. + // It will be passed to the Python server in the UDF headers and skipped when passing this parameter to the UDF server. + // That is why it is skipped. + fn eq_skip_stage(remote_args: &[DataType], input_args: &[DataType]) -> bool { + remote_args + .iter() + .zip( + input_args + .iter() + .filter(|ty| ty.remove_nullable() != DataType::StageLocation), + ) + .all(|(x, y)| x == y) + } + let descriptor = FlightDescriptor::new_path(vec![func_name.to_string()]); let request = self.make_request(descriptor); let flight_info = self.inner.get_flight_info(request).await?.into_inner(); @@ -229,7 +243,7 @@ impl UDFFlightClient { .iter() .map(|f| f.data_type()) .collect::>(); - if remote_arg_types != arg_types { + if !eq_skip_stage(&remote_arg_types, arg_types) { return Err(ErrorCode::UDFSchemaMismatch(format!( "UDF arg types mismatch on UDF function {}, remote arg types: ({:?}), defined arg types: ({:?})", func_name, diff --git a/src/query/expression/src/utils/variant_transform.rs b/src/query/expression/src/utils/variant_transform.rs index aadda60813f2d..d08a7e35a9e55 100644 --- a/src/query/expression/src/utils/variant_transform.rs +++ b/src/query/expression/src/utils/variant_transform.rs @@ -49,6 +49,7 @@ pub fn contains_variant(data_type: &DataType) -> bool { DataType::Map(ty) => contains_variant(ty.as_ref()), DataType::Tuple(types) => types.iter().any(contains_variant), DataType::Opaque(_) => false, + DataType::StageLocation => false, } } diff --git a/src/query/expression/src/values.rs b/src/query/expression/src/values.rs index 246f522082d2f..ed68ae01948af 100755 --- a/src/query/expression/src/values.rs +++ b/src/query/expression/src/values.rs @@ -1665,7 +1665,7 @@ impl Column { } Column::Vector(builder.build()) } - DataType::Generic(_) => unreachable!(), + DataType::Generic(_) | DataType::StageLocation => unreachable!(), DataType::Opaque(size) => { with_opaque_size!(|N| match *size { N => { @@ -2179,6 +2179,9 @@ impl ColumnBuilder { DataType::Generic(_) => { unreachable!("unable to initialize column builder for generic type") } + DataType::StageLocation => { + unreachable!("unable to initialize column builder for stage location type") + } } } @@ -2253,6 +2256,9 @@ impl ColumnBuilder { DataType::Generic(_) => { unreachable!("unable to initialize column builder for generic type") } + DataType::StageLocation => { + unreachable!("unable to initialize column builder for stage location type") + } } } diff --git a/src/query/functions/src/test_utils.rs b/src/query/functions/src/test_utils.rs index c801d7317da06..5e5ff1569d66d 100644 --- a/src/query/functions/src/test_utils.rs +++ b/src/query/functions/src/test_utils.rs @@ -690,6 +690,7 @@ fn transform_data_type(target_type: databend_common_ast::ast::TypeName) -> DataT DataType::Vector(VectorDataType::Float32(d)) } databend_common_ast::ast::TypeName::NotNull(inner_type) => transform_data_type(*inner_type), + databend_common_ast::ast::TypeName::StageLocation => DataType::StageLocation, } } diff --git a/src/query/management/tests/it/udf.rs b/src/query/management/tests/it/udf.rs index 2a8d16b398bcf..8dd72e91db31c 100644 --- a/src/query/management/tests/it/udf.rs +++ b/src/query/management/tests/it/udf.rs @@ -216,6 +216,7 @@ fn create_test_udf_server() -> UserDefinedFunction { "strlen_py", &BTreeMap::default(), "python", + vec![], vec![DataType::String], DataType::Number(NumberDataType::Int64), "This is a description", diff --git a/src/query/service/src/builtin/builtin_udfs.rs b/src/query/service/src/builtin/builtin_udfs.rs index 053d4b16f94a8..823ac7fbe5b12 100644 --- a/src/query/service/src/builtin/builtin_udfs.rs +++ b/src/query/service/src/builtin/builtin_udfs.rs @@ -15,6 +15,8 @@ use std::collections::HashMap; use databend_common_ast::ast::Statement; +use databend_common_ast::ast::TypeName; +use databend_common_ast::ast::UDFArgs; use databend_common_ast::ast::UDFDefinition; use databend_common_ast::parser::parse_sql; use databend_common_ast::parser::tokenize_sql; @@ -24,7 +26,9 @@ use databend_common_exception::ErrorCode; use databend_common_exception::Result; use databend_common_expression::types::DataType; use databend_common_meta_app::principal::UserDefinedFunction; +use databend_common_sql::normalize_identifier; use databend_common_sql::resolve_type_name_udf; +use databend_common_sql::NameResolutionContext; use log::error; pub struct BuiltinUDFs { @@ -52,9 +56,33 @@ impl BuiltinUDFs { language, immutable, } => { + let mut arg_names = Vec::with_capacity(arg_types.len()); let mut arg_datatypes = Vec::with_capacity(arg_types.len()); - for arg_type in arg_types { - arg_datatypes.push(DataType::from(&resolve_type_name_udf(&arg_type)?)); + match &arg_types { + UDFArgs::Types(types) => { + for arg_type in types { + if matches!(arg_type, TypeName::StageLocation) { + return Err(ErrorCode::InvalidArgument( + "StageLocation must have a corresponding variable name", + )); + } + arg_datatypes + .push(DataType::from(&resolve_type_name_udf(arg_type)?)); + } + } + UDFArgs::NameWithTypes(name_with_types) => { + for (arg_name, arg_type) in name_with_types { + arg_names.push( + normalize_identifier( + arg_name, + &NameResolutionContext::default(), + ) + .name, + ); + arg_datatypes + .push(DataType::from(&resolve_type_name_udf(arg_type)?)); + } + } } let return_type = DataType::from(&resolve_type_name_udf(&return_type)?); let udf = UserDefinedFunction::create_udf_server( @@ -63,6 +91,7 @@ impl BuiltinUDFs { &handler, &headers, &language, + arg_names, arg_datatypes, return_type, "Built-in UDF", diff --git a/src/query/sql/src/planner/binder/udf.rs b/src/query/sql/src/planner/binder/udf.rs index 8107a02842f02..72e243a4d2cbd 100644 --- a/src/query/sql/src/planner/binder/udf.rs +++ b/src/query/sql/src/planner/binder/udf.rs @@ -20,6 +20,7 @@ use databend_common_ast::ast::CreateUDFStmt; use databend_common_ast::ast::Identifier; use databend_common_ast::ast::TypeName; use databend_common_ast::ast::UDAFStateField; +use databend_common_ast::ast::UDFArgs; use databend_common_ast::ast::UDFDefinition; use databend_common_exception::ErrorCode; use databend_common_exception::Result; @@ -90,8 +91,26 @@ impl Binder { UDFValidator::is_udf_server_allowed(address.as_str())?; let mut arg_datatypes = Vec::with_capacity(arg_types.len()); - for arg_type in arg_types { - arg_datatypes.push(DataType::from(&resolve_type_name_udf(arg_type)?)); + let mut arg_names = Vec::with_capacity(arg_types.len()); + match arg_types { + UDFArgs::Types(types) => { + for arg_type in types { + if matches!(arg_type, TypeName::StageLocation) { + return Err(ErrorCode::InvalidArgument( + "StageLocation must have a corresponding variable name", + )); + } + arg_datatypes.push(DataType::from(&resolve_type_name_udf(arg_type)?)); + } + } + UDFArgs::NameWithTypes(name_with_types) => { + for (arg_name, arg_type) in name_with_types { + arg_names.push( + normalize_identifier(arg_name, &self.name_resolution_ctx).name, + ); + arg_datatypes.push(DataType::from(&resolve_type_name_udf(arg_type)?)); + } + } } let return_type = DataType::from(&resolve_type_name_udf(return_type)?); @@ -132,6 +151,7 @@ impl Binder { description, definition: PlanUDFDefinition::UDFServer(UDFServer { address: address.clone(), + arg_names, arg_types: arg_datatypes, return_type, handler: handler.clone(), diff --git a/src/query/sql/src/planner/semantic/type_check.rs b/src/query/sql/src/planner/semantic/type_check.rs index 76d75b026de49..26df172cb357b 100644 --- a/src/query/sql/src/planner/semantic/type_check.rs +++ b/src/query/sql/src/planner/semantic/type_check.rs @@ -108,6 +108,8 @@ use databend_common_license::license::Feature; use databend_common_license::license_manager::LicenseManagerSwitch; use databend_common_meta_app::principal::LambdaUDF; use databend_common_meta_app::principal::ScalarUDF; +use databend_common_meta_app::principal::StageInfo; +use databend_common_meta_app::principal::StageType; use databend_common_meta_app::principal::UDAFScript; use databend_common_meta_app::principal::UDFDefinition; use databend_common_meta_app::principal::UDFScript; @@ -117,6 +119,7 @@ use databend_common_meta_app::schema::DictionaryIdentity; use databend_common_meta_app::schema::GetSequenceReq; use databend_common_meta_app::schema::SequenceIdent; use databend_common_meta_app::schema::TableIndexType; +use databend_common_meta_app::storage::StorageParams; use databend_common_storage::init_stage_operator; use databend_common_users::Object; use databend_common_users::UserApiProvider; @@ -137,6 +140,7 @@ use super::name_resolution::NameResolutionContext; use super::normalize_identifier; use crate::binder::bind_values; use crate::binder::resolve_file_location; +use crate::binder::resolve_stage_location; use crate::binder::resolve_stage_locations; use crate::binder::wrap_cast; use crate::binder::Binder; @@ -1208,6 +1212,9 @@ impl<'a> TypeChecker<'a> { ) .set_span(*span)) } + Expr::StageLocation { span, location } => { + self.resolve_stage_location(*span, location)? + } }; Ok(Box::new((scalar, data_type))) } @@ -4872,8 +4879,15 @@ impl<'a> TypeChecker<'a> { span: Span, name: String, arguments: &[Expr], - udf_definition: UDFServer, + mut udf_definition: UDFServer, ) -> Result> { + #[derive(serde::Serialize, serde::Deserialize)] + struct StageLocationParam { + param_name: String, + relative_path: String, + stage_info: StageInfo, + } + UDFValidator::is_udf_server_allowed(&udf_definition.address)?; if arguments.len() != udf_definition.arg_types.len() { return Err(ErrorCode::InvalidArgument(format!( @@ -4886,18 +4900,75 @@ impl<'a> TypeChecker<'a> { let mut all_args_const = true; let mut args = Vec::with_capacity(arguments.len()); - for (argument, dest_type) in arguments.iter().zip(udf_definition.arg_types.iter()) { + let mut stage_locations = Vec::new(); + for (i, (argument, dest_type)) in arguments + .iter() + .zip(udf_definition.arg_types.iter()) + .enumerate() + { let box (arg, ty) = self.resolve(argument)?; // TODO: support cast constant - if !matches!(arg, ScalarExpr::ConstantExpr(_)) || ty != dest_type.remove_nullable() { + if !matches!(arg, ScalarExpr::ConstantExpr(_)) + || (ty != dest_type.remove_nullable() + && dest_type.remove_nullable() != DataType::StageLocation) + { all_args_const = false; } + if dest_type.remove_nullable() == DataType::StageLocation { + if udf_definition.arg_names.is_empty() { + return Err(ErrorCode::InvalidArgument( + "StageLocation must have a corresponding variable name", + )); + } + let expr = arg.as_expr()?; + let (expr, _) = ConstantFolder::fold(&expr, &self.func_ctx, &BUILTIN_FUNCTIONS); + let Ok(Some(location)) = + expr.into_constant().map(|c| c.scalar.as_string().cloned()) + else { + return Err(ErrorCode::SemanticError(format!( + "invalid parameter {argument} for udf function, expected constant string", + )) + .set_span(span)); + }; + let (stage_info, relative_path) = databend_common_base::runtime::block_on( + resolve_stage_location(self.ctx.as_ref(), &location), + )?; + + if !matches!(stage_info.stage_type, StageType::External) { + return Err(ErrorCode::SemanticError(format!( + "stage {} type is {}, UDF only support External Stage", + stage_info.stage_name, stage_info.stage_type, + )) + .set_span(span)); + } + if let StorageParams::S3(config) = &stage_info.stage_params.storage { + if !config.security_token.is_empty() || !config.role_arn.is_empty() { + return Err(ErrorCode::SemanticError(format!( + "StageLocation: @{} must use a separate credential", + location + ))); + } + } + + stage_locations.push(StageLocationParam { + param_name: udf_definition.arg_names[i].clone(), + relative_path, + stage_info, + }); + continue; + } if ty != *dest_type { args.push(wrap_cast(&arg, dest_type)); } else { args.push(arg); } } + if !stage_locations.is_empty() { + let stage_location_value = serde_json::to_string(&stage_locations)?; + udf_definition + .headers + .insert("databend-stage-mapping".to_string(), stage_location_value); + } let immutable = udf_definition.immutable.unwrap_or_default(); if immutable && all_args_const { let mut arg_scalars = Vec::with_capacity(args.len()); @@ -4953,7 +5024,15 @@ impl<'a> TypeChecker<'a> { udf_definition: UDFServer, ) -> Result { let mut block_entries = Vec::with_capacity(args.len()); - for (arg, dest_type) in args.into_iter().zip(udf_definition.arg_types.iter()) { + for (arg, dest_type) in args.into_iter().zip( + udf_definition + .arg_types + .iter() + .filter(|ty| ty.remove_nullable() != DataType::StageLocation), + ) { + if matches!(dest_type, DataType::StageLocation) { + continue; + } let entry = BlockEntry::new_const_column(dest_type.clone(), arg, 1); block_entries.push(entry); } @@ -6055,6 +6134,20 @@ impl<'a> TypeChecker<'a> { ))) } + fn resolve_stage_location( + &mut self, + span: Span, + location: &str, + ) -> Result> { + Ok(Box::new(( + ScalarExpr::ConstantExpr(ConstantExpr { + span, + value: Scalar::String(location.to_string()), + }), + DataType::String, + ))) + } + #[allow(clippy::only_used_in_recursion)] pub fn clone_expr_with_replacement(original_expr: &Expr, replacement_fn: F) -> Result where F: Fn(&Expr) -> Result> { @@ -6212,6 +6305,7 @@ pub fn resolve_type_name(type_name: &TypeName, not_null: bool) -> Result TableDataType::StageLocation, }; if !matches!(type_name, TypeName::Nullable(_) | TypeName::NotNull(_)) && !not_null { return Ok(data_type.wrap_nullable()); diff --git a/src/query/storages/common/stage/src/read/cast.rs b/src/query/storages/common/stage/src/read/cast.rs index c111860c0be96..2ed37c85e87b6 100644 --- a/src/query/storages/common/stage/src/read/cast.rs +++ b/src/query/storages/common/stage/src/read/cast.rs @@ -55,7 +55,7 @@ pub fn load_can_auto_cast_to(from_type: &DataType, to_type: &DataType) -> bool { // we mainly care about which types can/cannot cast to to_type. // the match branches is grouped in a way to make it easier to read this info. match (from_type, to_type) { - (_, Null | EmptyArray | EmptyMap | Generic(_)) => unreachable!(), + (_, Null | EmptyArray | EmptyMap | Generic(_) | StageLocation) => unreachable!(), // ==== remove null first, all trivial (Null, Nullable(_)) => true, diff --git a/src/query/users/tests/it/user_udf.rs b/src/query/users/tests/it/user_udf.rs index 488682dc91d1e..7d3d794b3cdce 100644 --- a/src/query/users/tests/it/user_udf.rs +++ b/src/query/users/tests/it/user_udf.rs @@ -132,6 +132,7 @@ async fn test_user_udf_server() -> Result<()> { isempty, &BTreeMap::default(), "python", + vec![], arg_types.clone(), return_type.clone(), description, @@ -148,6 +149,7 @@ async fn test_user_udf_server() -> Result<()> { isnotempty, &BTreeMap::default(), "python", + vec![], arg_types.clone(), return_type.clone(), description, diff --git a/src/tests/sqlsmith/src/sql_gen/ddl.rs b/src/tests/sqlsmith/src/sql_gen/ddl.rs index 7cf282a37a568..040f2aa9c3b46 100644 --- a/src/tests/sqlsmith/src/sql_gen/ddl.rs +++ b/src/tests/sqlsmith/src/sql_gen/ddl.rs @@ -379,5 +379,6 @@ fn gen_default_expr(type_name: &TypeName) -> Expr { value: Literal::Null, }, TypeName::NotNull(box ty) => gen_default_expr(ty), + TypeName::StageLocation => unreachable!(), } } diff --git a/tests/sqllogictests/suites/udf_server/udf_server_test.test b/tests/sqllogictests/suites/udf_server/udf_server_test.test index 46b584051cce0..02b03ffe00922 100644 --- a/tests/sqllogictests/suites/udf_server/udf_server_test.test +++ b/tests/sqllogictests/suites/udf_server/udf_server_test.test @@ -680,4 +680,29 @@ select sum((value::Int)::Int) from system.metrics where metric = 'external_runni ---- 0 +statement ok +CREATE OR REPLACE FUNCTION stage_summary(data_stage STAGE_LOCATION, value INT) RETURNS STRING LANGUAGE python HANDLER = 'stage_summary' ADDRESS = 'http://0.0.0.0:8815'; + +statement ok +CREATE OR REPLACE FUNCTION multi_stage_process(input_stage STAGE_LOCATION, output_stage STAGE_LOCATION, value INT) RETURNS INT LANGUAGE python HANDLER = 'multi_stage_process' ADDRESS = 'http://0.0.0.0:8815'; + +statement ok +CREATE OR REPLACE FUNCTION immutable_multi_stage_process(input_stage STAGE_LOCATION, output_stage STAGE_LOCATION, value VARCHAR) RETURNS INT LANGUAGE python IMMUTABLE HANDLER = 'immutable_multi_stage_process' ADDRESS = 'http://0.0.0.0:8815'; + +statement ok +CREATE OR REPLACE STAGE s3_stage URL = 's3://test/' CONNECTION = ( AWS_KEY_ID = 'minioadmin' AWS_SECRET_KEY = 'minioadmin' ENDPOINT_URL = 'http://127.0.0.1:9000') FILE_FORMAT = (TYPE = CSV); +query T +SELECT stage_summary(@s3_stage/input/2024/, 21) +---- +s3_stage:test:input/2024/:21 + +query I +SELECT multi_stage_process(@s3_stage/input/2024/, @s3_stage/output/2024, 21) +---- +29 + +query I +SELECT immutable_multi_stage_process(@s3_stage/input/2024/, @s3_stage/output/2024, 'hello') +---- +13 diff --git a/tests/udf/udf_server.py b/tests/udf/udf_server.py index 6a758476b6a84..00a25322d21cf 100644 --- a/tests/udf/udf_server.py +++ b/tests/udf/udf_server.py @@ -20,7 +20,7 @@ from pyarrow import flight # https://github.com/datafuselabs/databend-udf -from databend_udf import udf, UDFServer +from databend_udf import StageLocation, UDFServer, udf logging.basicConfig(level=logging.INFO) @@ -425,6 +425,50 @@ def embedding_4(s: str): return [1.1, 1.2, 1.3, 1.4] +@udf(stage_refs=["data_stage"], input_types=["INT"], result_type="VARCHAR") +def stage_summary(data_stage: StageLocation, value: int) -> str: + assert data_stage.stage_type.lower() == "external" + assert data_stage.storage + bucket = data_stage.storage.get("bucket", data_stage.storage.get("container", "")) + return f"{data_stage.stage_name}:{bucket}:{data_stage.relative_path}:{value}" + + +@udf( + stage_refs=["input_stage", "output_stage"], + input_types=["INT"], + result_type="INT", +) +def multi_stage_process( + input_stage: StageLocation, output_stage: StageLocation, value: int +) -> int: + assert input_stage.storage and output_stage.storage + assert input_stage.stage_type.lower() == "external" + assert output_stage.stage_type.lower() == "external" + # Simple deterministic behaviour for testing + return ( + value + + len(input_stage.storage.get("bucket", "")) + + len(output_stage.storage.get("bucket", "")) + ) + +@udf( + stage_refs=["input_stage", "output_stage"], + input_types=["VARCHAR"], + result_type="INT", +) +def immutable_multi_stage_process( + input_stage: StageLocation, output_stage: StageLocation, value: str +) -> int: + assert input_stage.storage and output_stage.storage + assert input_stage.stage_type.lower() == "external" + assert output_stage.stage_type.lower() == "external" + # Simple deterministic behaviour for testing + return ( + len(value) + + len(input_stage.storage.get("bucket", "")) + + len(output_stage.storage.get("bucket", "")) + ) + if __name__ == "__main__": udf_server = CheckHeadersServer( location="0.0.0.0:8815", middleware={"headers": HeadersMiddlewareFactory()} @@ -456,6 +500,9 @@ def embedding_4(s: str): udf_server.add_function(url_len_mul_100) udf_server.add_function(check_headers) udf_server.add_function(embedding_4) + udf_server.add_function(stage_summary) + udf_server.add_function(multi_stage_process) + udf_server.add_function(immutable_multi_stage_process) # Built-in function udf_server.add_function(ping)