From 8371e17057869f42817c77487493e03e17ec61a6 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Fri, 17 Oct 2025 12:24:24 -0500 Subject: [PATCH 01/13] passing tests --- .../user_defined_scalar_functions.rs | 5 +- datafusion/expr/src/expr.rs | 52 ++++++- datafusion/expr/src/expr_schema.rs | 13 +- datafusion/expr/src/tree_node.rs | 4 +- datafusion/functions/src/core/arrow_cast.rs | 2 +- datafusion/physical-expr/src/planner.rs | 6 +- datafusion/proto/src/logical_plan/to_proto.rs | 4 +- datafusion/sql/src/unparser/expr.rs | 138 +++++++----------- .../src/logical_plan/producer/expr/cast.rs | 12 +- .../src/logical_plan/producer/types.rs | 38 ++--- 10 files changed, 149 insertions(+), 125 deletions(-) diff --git a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs index f1af66de9b59..351b02d12ea7 100644 --- a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs @@ -712,10 +712,7 @@ impl ScalarUDFImpl for CastToI64UDF { arg } else { // need to use an actual cast to get the correct type - Expr::Cast(datafusion_expr::Cast { - expr: Box::new(arg), - data_type: DataType::Int64, - }) + Expr::Cast(datafusion_expr::Cast::new(Box::new(arg), DataType::Int64)) }; // return the newly written argument to DataFusion Ok(ExprSimplifyResult::Simplified(new_expr)) diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 282b3f6a0f55..f70510b1f81c 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -1012,13 +1012,23 @@ pub struct Cast { /// The expression being cast pub expr: Box, /// The `DataType` the expression will yield - pub data_type: DataType, + pub data_type: FieldRef, } impl Cast { /// Create a new Cast expression pub fn new(expr: Box, data_type: DataType) -> Self { - Self { expr, data_type } + Self { + expr, + data_type: Field::new("", data_type, true).into(), + } + } + + pub fn new_from_field(expr: Box, field: FieldRef) -> Self { + Self { + expr, + data_type: field, + } } } @@ -1028,13 +1038,23 @@ pub struct TryCast { /// The expression being cast pub expr: Box, /// The `DataType` the expression will yield - pub data_type: DataType, + pub data_type: FieldRef, } impl TryCast { /// Create a new TryCast expression pub fn new(expr: Box, data_type: DataType) -> Self { - Self { expr, data_type } + Self { + expr, + data_type: Field::new("", data_type, true).into(), + } + } + + pub fn new_from_field(expr: Box, field: FieldRef) -> Self { + Self { + expr, + data_type: field, + } } } @@ -3488,10 +3508,28 @@ impl Display for Expr { write!(f, "END") } Expr::Cast(Cast { expr, data_type }) => { - write!(f, "CAST({expr} AS {data_type})") + if data_type.metadata().is_empty() { + write!(f, "CAST({expr} AS {})", data_type.data_type()) + } else { + write!( + f, + "CAST({expr} AS {}<{:?}>)", + data_type.data_type(), + data_type.metadata() + ) + } } Expr::TryCast(TryCast { expr, data_type }) => { - write!(f, "TRY_CAST({expr} AS {data_type})") + if data_type.metadata().is_empty() { + write!(f, "TRY_CAST({expr} AS {})", data_type.data_type()) + } else { + write!( + f, + "TRY_CAST({expr} AS {}<{:?}>)", + data_type.data_type(), + data_type.metadata() + ) + } } Expr::Not(expr) => write!(f, "NOT {expr}"), Expr::Negative(expr) => write!(f, "(- {expr})"), @@ -3844,7 +3882,7 @@ mod test { fn format_cast() -> Result<()> { let expr = Expr::Cast(Cast { expr: Box::new(Expr::Literal(ScalarValue::Float32(Some(1.23)), None)), - data_type: DataType::Utf8, + data_type: Field::new("", DataType::Utf8, true).into(), }); let expected_canonical = "CAST(Float32(1.23) AS Utf8)"; assert_eq!(expected_canonical, format!("{expr}")); diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index e803e3534130..31701812451a 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -127,7 +127,9 @@ impl ExprSchemable for Expr { .map_or(Ok(DataType::Null), |e| e.get_type(schema)) } Expr::Cast(Cast { data_type, .. }) - | Expr::TryCast(TryCast { data_type, .. }) => Ok(data_type.clone()), + | Expr::TryCast(TryCast { data_type, .. }) => { + Ok(data_type.data_type().clone()) + } Expr::Unnest(Unnest { expr }) => { let arg_data_type = expr.get_type(schema)?; // Unnest's output type is the inner type of the list @@ -592,7 +594,14 @@ impl ExprSchemable for Expr { // _ => Ok((self.get_type(schema)?, self.nullable(schema)?)), Expr::Cast(Cast { expr, data_type }) => expr .to_field(schema) - .map(|(_, f)| f.as_ref().clone().with_data_type(data_type.clone())) + .map(|(_, f)| { + f.as_ref() + .clone() + .with_data_type(data_type.data_type().clone()) + .with_metadata(f.metadata().clone()) + // TODO: should nullability be overridden here or derived from the + // input expression? + }) .map(Arc::new), Expr::Like(_) | Expr::SimilarTo(_) diff --git a/datafusion/expr/src/tree_node.rs b/datafusion/expr/src/tree_node.rs index 81846b4f8060..e949bd71a71f 100644 --- a/datafusion/expr/src/tree_node.rs +++ b/datafusion/expr/src/tree_node.rs @@ -222,10 +222,10 @@ impl TreeNode for Expr { }), Expr::Cast(Cast { expr, data_type }) => expr .map_elements(f)? - .update_data(|be| Expr::Cast(Cast::new(be, data_type))), + .update_data(|be| Expr::Cast(Cast::new_from_field(be, data_type))), Expr::TryCast(TryCast { expr, data_type }) => expr .map_elements(f)? - .update_data(|be| Expr::TryCast(TryCast::new(be, data_type))), + .update_data(|be| Expr::TryCast(TryCast::new_from_field(be, data_type))), Expr::ScalarFunction(ScalarFunction { func, args }) => { args.map_elements(f)?.map_data(|new_args| { Ok(Expr::ScalarFunction(ScalarFunction::new_udf( diff --git a/datafusion/functions/src/core/arrow_cast.rs b/datafusion/functions/src/core/arrow_cast.rs index 94a41ba4bb25..964fdd9f996c 100644 --- a/datafusion/functions/src/core/arrow_cast.rs +++ b/datafusion/functions/src/core/arrow_cast.rs @@ -161,7 +161,7 @@ impl ScalarUDFImpl for ArrowCastFunc { // Use an actual cast to get the correct type Expr::Cast(datafusion_expr::Cast { expr: Box::new(arg), - data_type: target_type, + data_type: Field::new("", target_type, true).into(), }) }; // return the newly written argument to DataFusion diff --git a/datafusion/physical-expr/src/planner.rs b/datafusion/physical-expr/src/planner.rs index 73df60c42e96..d9a5f1473a0f 100644 --- a/datafusion/physical-expr/src/planner.rs +++ b/datafusion/physical-expr/src/planner.rs @@ -291,12 +291,14 @@ pub fn create_physical_expr( Expr::Cast(Cast { expr, data_type }) => expressions::cast( create_physical_expr(expr, input_dfschema, execution_props)?, input_schema, - data_type.clone(), + // TODO: this drops extension metadata associated with the cast + data_type.data_type().clone(), ), Expr::TryCast(TryCast { expr, data_type }) => expressions::try_cast( create_physical_expr(expr, input_dfschema, execution_props)?, input_schema, - data_type.clone(), + // TODO: this drops extension metadata associated with the cast + data_type.data_type().clone(), ), Expr::Not(expr) => { expressions::not(create_physical_expr(expr, input_dfschema, execution_props)?) diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 6238c2f1cdde..bb9d70fb2027 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -525,7 +525,7 @@ pub fn serialize_expr( Expr::Cast(Cast { expr, data_type }) => { let expr = Box::new(protobuf::CastNode { expr: Some(Box::new(serialize_expr(expr.as_ref(), codec)?)), - arrow_type: Some(data_type.try_into()?), + arrow_type: Some(data_type.data_type().try_into()?), }); protobuf::LogicalExprNode { expr_type: Some(ExprType::Cast(expr)), @@ -534,7 +534,7 @@ pub fn serialize_expr( Expr::TryCast(TryCast { expr, data_type }) => { let expr = Box::new(protobuf::TryCastNode { expr: Some(Box::new(serialize_expr(expr.as_ref(), codec)?)), - arrow_type: Some(data_type.try_into()?), + arrow_type: Some(data_type.data_type().try_into()?), }); protobuf::LogicalExprNode { expr_type: Some(ExprType::TryCast(expr)), diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index a7fe8efa153c..a2d2b978fac5 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -37,6 +37,7 @@ use arrow::array::{ }; use arrow::datatypes::{ DataType, Decimal128Type, Decimal256Type, Decimal32Type, Decimal64Type, DecimalType, + Field, FieldRef, }; use arrow::util::display::array_value_to_string; use datafusion_common::{ @@ -1133,24 +1134,27 @@ impl Unparser<'_> { // Explicit type cast on ast::Expr::Value is not needed by underlying engine for certain types // For example: CAST(Utf8("binary_value") AS Binary) and CAST(Utf8("dictionary_value") AS Dictionary) - fn cast_to_sql(&self, expr: &Expr, data_type: &DataType) -> Result { + fn cast_to_sql(&self, expr: &Expr, field: &FieldRef) -> Result { let inner_expr = self.expr_to_sql_inner(expr)?; + let data_type = field.data_type(); match inner_expr { ast::Expr::Value(_) => match data_type { - DataType::Dictionary(_, _) | DataType::Binary | DataType::BinaryView => { + DataType::Dictionary(_, _) | DataType::Binary | DataType::BinaryView + if field.metadata().is_empty() => + { Ok(inner_expr) } _ => Ok(ast::Expr::Cast { kind: ast::CastKind::Cast, expr: Box::new(inner_expr), - data_type: self.arrow_dtype_to_ast_dtype(data_type)?, + data_type: self.arrow_dtype_to_ast_dtype(field)?, format: None, }), }, _ => Ok(ast::Expr::Cast { kind: ast::CastKind::Cast, expr: Box::new(inner_expr), - data_type: self.arrow_dtype_to_ast_dtype(data_type)?, + data_type: self.arrow_dtype_to_ast_dtype(field)?, format: None, }), } @@ -1672,7 +1676,8 @@ impl Unparser<'_> { })) } - fn arrow_dtype_to_ast_dtype(&self, data_type: &DataType) -> Result { + fn arrow_dtype_to_ast_dtype(&self, field: &FieldRef) -> Result { + let data_type = field.data_type(); match data_type { DataType::Null => { not_impl_err!("Unsupported DataType: conversion: {data_type}") @@ -1745,7 +1750,9 @@ impl Unparser<'_> { DataType::Union(_, _) => { not_impl_err!("Unsupported DataType: conversion: {data_type}") } - DataType::Dictionary(_, val) => self.arrow_dtype_to_ast_dtype(val), + DataType::Dictionary(_, val) => self.arrow_dtype_to_ast_dtype( + &Field::new("", val.as_ref().clone(), true).into(), + ), DataType::Decimal32(precision, scale) | DataType::Decimal64(precision, scale) | DataType::Decimal128(precision, scale) @@ -1885,34 +1892,25 @@ mod tests { r#"CASE WHEN a IS NOT NULL THEN true ELSE false END"#, ), ( - Expr::Cast(Cast { - expr: Box::new(col("a")), - data_type: DataType::Date64, - }), + Expr::Cast(Cast::new(Box::new(col("a")), DataType::Date64)), r#"CAST(a AS DATETIME)"#, ), ( - Expr::Cast(Cast { - expr: Box::new(col("a")), - data_type: DataType::Timestamp( - TimeUnit::Nanosecond, - Some("+08:00".into()), - ), - }), + Expr::Cast(Cast::new( + Box::new(col("a")), + DataType::Timestamp(TimeUnit::Nanosecond, Some("+08:00".into())), + )), r#"CAST(a AS TIMESTAMP WITH TIME ZONE)"#, ), ( - Expr::Cast(Cast { - expr: Box::new(col("a")), - data_type: DataType::Timestamp(TimeUnit::Millisecond, None), - }), + Expr::Cast(Cast::new( + Box::new(col("a")), + DataType::Timestamp(TimeUnit::Millisecond, None), + )), r#"CAST(a AS TIMESTAMP)"#, ), ( - Expr::Cast(Cast { - expr: Box::new(col("a")), - data_type: DataType::UInt32, - }), + Expr::Cast(Cast::new(Box::new(col("a")), DataType::UInt32)), r#"CAST(a AS INTEGER UNSIGNED)"#, ), ( @@ -2227,10 +2225,7 @@ mod tests { r#"((a + b) > 100.123)"#, ), ( - Expr::Cast(Cast { - expr: Box::new(col("a")), - data_type: DataType::Decimal128(10, -2), - }), + Expr::Cast(Cast::new(Box::new(col("a")), DataType::Decimal128(10, -2))), r#"CAST(a AS DECIMAL(12,0))"#, ), ( @@ -2367,10 +2362,7 @@ mod tests { .build(); let unparser = Unparser::new(&dialect); - let expr = Expr::Cast(Cast { - expr: Box::new(col("a")), - data_type: DataType::Date64, - }); + let expr = Expr::Cast(Cast::new(Box::new(col("a")), DataType::Date64)); let ast = unparser.expr_to_sql(&expr)?; let actual = format!("{ast}"); @@ -2392,10 +2384,7 @@ mod tests { .build(); let unparser = Unparser::new(&dialect); - let expr = Expr::Cast(Cast { - expr: Box::new(col("a")), - data_type: DataType::Float64, - }); + let expr = Expr::Cast(Cast::new(Box::new(col("a")), DataType::Float64)); let ast = unparser.expr_to_sql(&expr)?; let actual = format!("{ast}"); @@ -2625,23 +2614,23 @@ mod tests { fn test_cast_value_to_binary_expr() { let tests = [ ( - Expr::Cast(Cast { - expr: Box::new(Expr::Literal( + Expr::Cast(Cast::new( + Box::new(Expr::Literal( ScalarValue::Utf8(Some("blah".to_string())), None, )), - data_type: DataType::Binary, - }), + DataType::Binary, + )), "'blah'", ), ( - Expr::Cast(Cast { - expr: Box::new(Expr::Literal( + Expr::Cast(Cast::new( + Box::new(Expr::Literal( ScalarValue::Utf8(Some("blah".to_string())), None, )), - data_type: DataType::BinaryView, - }), + DataType::BinaryView, + )), "'blah'", ), ]; @@ -2672,10 +2661,7 @@ mod tests { ] { let unparser = Unparser::new(dialect); - let expr = Expr::Cast(Cast { - expr: Box::new(col("a")), - data_type, - }); + let expr = Expr::Cast(Cast::new(Box::new(col("a")), data_type)); let ast = unparser.expr_to_sql(&expr)?; let actual = format!("{ast}"); @@ -2758,10 +2744,7 @@ mod tests { [(default_dialect, "BIGINT"), (mysql_dialect, "SIGNED")] { let unparser = Unparser::new(&dialect); - let expr = Expr::Cast(Cast { - expr: Box::new(col("a")), - data_type: DataType::Int64, - }); + let expr = Expr::Cast(Cast::new(Box::new(col("a")), DataType::Int64)); let ast = unparser.expr_to_sql(&expr)?; let actual = format!("{ast}"); @@ -2786,10 +2769,7 @@ mod tests { [(default_dialect, "INTEGER"), (mysql_dialect, "SIGNED")] { let unparser = Unparser::new(&dialect); - let expr = Expr::Cast(Cast { - expr: Box::new(col("a")), - data_type: DataType::Int32, - }); + let expr = Expr::Cast(Cast::new(Box::new(col("a")), DataType::Int32)); let ast = unparser.expr_to_sql(&expr)?; let actual = format!("{ast}"); @@ -2825,10 +2805,7 @@ mod tests { (&mysql_dialect, ×tamp_with_tz, "DATETIME"), ] { let unparser = Unparser::new(dialect); - let expr = Expr::Cast(Cast { - expr: Box::new(col("a")), - data_type: data_type.clone(), - }); + let expr = Expr::Cast(Cast::new(Box::new(col("a")), data_type.clone())); let ast = unparser.expr_to_sql(&expr)?; let actual = format!("{ast}"); @@ -2881,10 +2858,7 @@ mod tests { ] { let unparser = Unparser::new(dialect); - let expr = Expr::Cast(Cast { - expr: Box::new(col("a")), - data_type, - }); + let expr = Expr::Cast(Cast::new(Box::new(col("a")), data_type)); let ast = unparser.expr_to_sql(&expr)?; let actual = format!("{ast}"); @@ -2924,13 +2898,13 @@ mod tests { #[test] fn test_cast_value_to_dict_expr() { let tests = [( - Expr::Cast(Cast { - expr: Box::new(Expr::Literal( + Expr::Cast(Cast::new( + Box::new(Expr::Literal( ScalarValue::Utf8(Some("variation".to_string())), None, )), - data_type: DataType::Dictionary(Box::new(Int8), Box::new(DataType::Utf8)), - }), + DataType::Dictionary(Box::new(Int8), Box::new(DataType::Utf8)), + )), "'variation'", )]; for (value, expected) in tests { @@ -2962,10 +2936,7 @@ mod tests { datafusion_functions::math::round::RoundFunc::new(), )), args: vec![ - Expr::Cast(Cast { - expr: Box::new(col("a")), - data_type: DataType::Float64, - }), + Expr::Cast(Cast::new(Box::new(col("a")), DataType::Float64)), Expr::Literal(ScalarValue::Int64(Some(2)), None), ], }); @@ -3127,10 +3098,12 @@ mod tests { let unparser = Unparser::new(&dialect); - let ast_dtype = unparser.arrow_dtype_to_ast_dtype(&DataType::Dictionary( - Box::new(DataType::Int32), - Box::new(DataType::Utf8), - ))?; + let arrow_field = Arc::new(Field::new( + "", + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), + true, + )); + let ast_dtype = unparser.arrow_dtype_to_ast_dtype(&arrow_field)?; assert_eq!(ast_dtype, ast::DataType::Varchar(None)); @@ -3144,7 +3117,8 @@ mod tests { .build(); let unparser = Unparser::new(&dialect); - let ast_dtype = unparser.arrow_dtype_to_ast_dtype(&DataType::Utf8View)?; + let arrow_field = Arc::new(Field::new("", DataType::Utf8View, true)); + let ast_dtype = unparser.arrow_dtype_to_ast_dtype(&arrow_field)?; assert_eq!(ast_dtype, ast::DataType::Char(None)); @@ -3212,10 +3186,10 @@ mod tests { let dialect: Arc = Arc::new(SqliteDialect {}); let unparser = Unparser::new(dialect.as_ref()); - let expr = Expr::Cast(Cast { - expr: Box::new(col("a")), - data_type: DataType::Timestamp(TimeUnit::Nanosecond, None), - }); + let expr = Expr::Cast(Cast::new( + Box::new(col("a")), + DataType::Timestamp(TimeUnit::Nanosecond, None), + )); let ast = unparser.expr_to_sql(&expr)?; diff --git a/datafusion/substrait/src/logical_plan/producer/expr/cast.rs b/datafusion/substrait/src/logical_plan/producer/expr/cast.rs index 9741dcdd1095..c42fea23daac 100644 --- a/datafusion/substrait/src/logical_plan/producer/expr/cast.rs +++ b/datafusion/substrait/src/logical_plan/producer/expr/cast.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use crate::logical_plan::producer::{to_substrait_type, SubstraitProducer}; +use crate::logical_plan::producer::{to_substrait_type_from_field, SubstraitProducer}; use crate::variation_const::DEFAULT_TYPE_VARIATION_REF; use datafusion::common::{DFSchemaRef, ScalarValue}; use datafusion::logical_expr::{Cast, Expr, TryCast}; @@ -39,8 +39,8 @@ pub fn from_cast( let lit = Literal { nullable: true, type_variation_reference: DEFAULT_TYPE_VARIATION_REF, - literal_type: Some(LiteralType::Null(to_substrait_type( - data_type, true, + literal_type: Some(LiteralType::Null(to_substrait_type_from_field( + data_type, )?)), }; return Ok(Expression { @@ -51,7 +51,7 @@ pub fn from_cast( Ok(Expression { rex_type: Some(RexType::Cast(Box::new( substrait::proto::expression::Cast { - r#type: Some(to_substrait_type(data_type, true)?), + r#type: Some(to_substrait_type_from_field(data_type)?), input: Some(Box::new(producer.handle_expr(expr, schema)?)), failure_behavior: FailureBehavior::ThrowException.into(), }, @@ -68,7 +68,7 @@ pub fn from_try_cast( Ok(Expression { rex_type: Some(RexType::Cast(Box::new( substrait::proto::expression::Cast { - r#type: Some(to_substrait_type(data_type, true)?), + r#type: Some(to_substrait_type_from_field(data_type)?), input: Some(Box::new(producer.handle_expr(expr, schema)?)), failure_behavior: FailureBehavior::ReturnNull.into(), }, @@ -79,7 +79,7 @@ pub fn from_try_cast( #[cfg(test)] mod tests { use super::*; - use crate::logical_plan::producer::to_substrait_extended_expr; + use crate::logical_plan::producer::{to_substrait_extended_expr, to_substrait_type}; use datafusion::arrow::datatypes::{DataType, Field}; use datafusion::common::DFSchema; use datafusion::execution::SessionStateBuilder; diff --git a/datafusion/substrait/src/logical_plan/producer/types.rs b/datafusion/substrait/src/logical_plan/producer/types.rs index 3da9269c5b9e..25a2e9db7875 100644 --- a/datafusion/substrait/src/logical_plan/producer/types.rs +++ b/datafusion/substrait/src/logical_plan/producer/types.rs @@ -27,7 +27,7 @@ use crate::variation_const::{ TIME_64_TYPE_VARIATION_REF, UNSIGNED_INTEGER_TYPE_VARIATION_REF, VIEW_CONTAINER_TYPE_VARIATION_REF, }; -use datafusion::arrow::datatypes::{DataType, IntervalUnit}; +use datafusion::arrow::datatypes::{DataType, Field, FieldRef, IntervalUnit}; use datafusion::common::{internal_err, not_impl_err, plan_err, DFSchemaRef}; use substrait::proto::{r#type, NamedStruct}; @@ -35,12 +35,18 @@ pub(crate) fn to_substrait_type( dt: &DataType, nullable: bool, ) -> datafusion::common::Result { - let nullability = if nullable { + to_substrait_type_from_field(&Field::new("", dt.clone(), nullable).into()) +} + +pub(crate) fn to_substrait_type_from_field( + dt: &FieldRef, +) -> datafusion::common::Result { + let nullability = if dt.is_nullable() { r#type::Nullability::Nullable as i32 } else { r#type::Nullability::Required as i32 }; - match dt { + match dt.data_type() { DataType::Null => internal_err!("Null cast is not valid"), DataType::Boolean => Ok(substrait::proto::Type { kind: Some(r#type::Kind::Bool(r#type::Boolean { @@ -244,7 +250,7 @@ pub(crate) fn to_substrait_type( })), }), DataType::List(inner) => { - let inner_type = to_substrait_type(inner.data_type(), inner.is_nullable())?; + let inner_type = to_substrait_type_from_field(inner)?; Ok(substrait::proto::Type { kind: Some(r#type::Kind::List(Box::new(r#type::List { r#type: Some(Box::new(inner_type)), @@ -254,7 +260,7 @@ pub(crate) fn to_substrait_type( }) } DataType::LargeList(inner) => { - let inner_type = to_substrait_type(inner.data_type(), inner.is_nullable())?; + let inner_type = to_substrait_type_from_field(inner)?; Ok(substrait::proto::Type { kind: Some(r#type::Kind::List(Box::new(r#type::List { r#type: Some(Box::new(inner_type)), @@ -265,14 +271,8 @@ pub(crate) fn to_substrait_type( } DataType::Map(inner, _) => match inner.data_type() { DataType::Struct(key_and_value) if key_and_value.len() == 2 => { - let key_type = to_substrait_type( - key_and_value[0].data_type(), - key_and_value[0].is_nullable(), - )?; - let value_type = to_substrait_type( - key_and_value[1].data_type(), - key_and_value[1].is_nullable(), - )?; + let key_type = to_substrait_type_from_field(&key_and_value[0])?; + let value_type = to_substrait_type_from_field(&key_and_value[1])?; Ok(substrait::proto::Type { kind: Some(r#type::Kind::Map(Box::new(r#type::Map { key: Some(Box::new(key_type)), @@ -285,8 +285,12 @@ pub(crate) fn to_substrait_type( _ => plan_err!("Map fields must contain a Struct with exactly 2 fields"), }, DataType::Dictionary(key_type, value_type) => { - let key_type = to_substrait_type(key_type, nullable)?; - let value_type = to_substrait_type(value_type, nullable)?; + let key_type = to_substrait_type_from_field( + &Field::new("", key_type.as_ref().clone(), dt.is_nullable()).into(), + )?; + let value_type = to_substrait_type_from_field( + &Field::new("", value_type.as_ref().clone(), dt.is_nullable()).into(), + )?; Ok(substrait::proto::Type { kind: Some(r#type::Kind::Map(Box::new(r#type::Map { key: Some(Box::new(key_type)), @@ -299,7 +303,7 @@ pub(crate) fn to_substrait_type( DataType::Struct(fields) => { let field_types = fields .iter() - .map(|field| to_substrait_type(field.data_type(), field.is_nullable())) + .map(|field| to_substrait_type_from_field(field)) .collect::>>()?; Ok(substrait::proto::Type { kind: Some(r#type::Kind::Struct(r#type::Struct { @@ -341,7 +345,7 @@ pub(crate) fn to_substrait_named_struct( types: schema .fields() .iter() - .map(|f| to_substrait_type(f.data_type(), f.is_nullable())) + .map(|f| to_substrait_type_from_field(f)) .collect::>()?, type_variation_reference: DEFAULT_TYPE_VARIATION_REF, nullability: r#type::Nullability::Required as i32, From 40c18c9145cd0556ee234839996ae05ec7364b2d Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Fri, 17 Oct 2025 12:25:50 -0500 Subject: [PATCH 02/13] clippy --- datafusion/substrait/src/logical_plan/producer/types.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/substrait/src/logical_plan/producer/types.rs b/datafusion/substrait/src/logical_plan/producer/types.rs index 25a2e9db7875..4ae1528372ee 100644 --- a/datafusion/substrait/src/logical_plan/producer/types.rs +++ b/datafusion/substrait/src/logical_plan/producer/types.rs @@ -345,7 +345,7 @@ pub(crate) fn to_substrait_named_struct( types: schema .fields() .iter() - .map(|f| to_substrait_type_from_field(f)) + .map(to_substrait_type_from_field) .collect::>()?, type_variation_reference: DEFAULT_TYPE_VARIATION_REF, nullability: r#type::Nullability::Required as i32, From 7121449fbb3d2dd4ca70a28395434848c63b1f40 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Fri, 17 Oct 2025 12:40:48 -0500 Subject: [PATCH 03/13] proto --- datafusion/proto/proto/datafusion.proto | 4 ++ datafusion/proto/src/generated/pbjson.rs | 72 +++++++++++++++++++ datafusion/proto/src/generated/prost.rs | 14 ++++ .../proto/src/logical_plan/from_proto.rs | 10 ++- datafusion/proto/src/logical_plan/to_proto.rs | 4 ++ .../src/logical_plan/producer/types.rs | 2 +- 6 files changed, 103 insertions(+), 3 deletions(-) diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 11103472ae2a..272190fee858 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -586,11 +586,15 @@ message WhenThen { message CastNode { LogicalExprNode expr = 1; datafusion_common.ArrowType arrow_type = 2; + map metadata = 3; + optional bool nullable = 4; } message TryCastNode { LogicalExprNode expr = 1; datafusion_common.ArrowType arrow_type = 2; + map metadata = 3; + optional bool nullable = 4; } message SortExprNode { diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index b34da2c312de..de0b640529b5 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -1834,6 +1834,12 @@ impl serde::Serialize for CastNode { if self.arrow_type.is_some() { len += 1; } + if !self.metadata.is_empty() { + len += 1; + } + if self.nullable.is_some() { + len += 1; + } let mut struct_ser = serializer.serialize_struct("datafusion.CastNode", len)?; if let Some(v) = self.expr.as_ref() { struct_ser.serialize_field("expr", v)?; @@ -1841,6 +1847,12 @@ impl serde::Serialize for CastNode { if let Some(v) = self.arrow_type.as_ref() { struct_ser.serialize_field("arrowType", v)?; } + if !self.metadata.is_empty() { + struct_ser.serialize_field("metadata", &self.metadata)?; + } + if let Some(v) = self.nullable.as_ref() { + struct_ser.serialize_field("nullable", v)?; + } struct_ser.end() } } @@ -1854,12 +1866,16 @@ impl<'de> serde::Deserialize<'de> for CastNode { "expr", "arrow_type", "arrowType", + "metadata", + "nullable", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { Expr, ArrowType, + Metadata, + Nullable, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -1883,6 +1899,8 @@ impl<'de> serde::Deserialize<'de> for CastNode { match value { "expr" => Ok(GeneratedField::Expr), "arrowType" | "arrow_type" => Ok(GeneratedField::ArrowType), + "metadata" => Ok(GeneratedField::Metadata), + "nullable" => Ok(GeneratedField::Nullable), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -1904,6 +1922,8 @@ impl<'de> serde::Deserialize<'de> for CastNode { { let mut expr__ = None; let mut arrow_type__ = None; + let mut metadata__ = None; + let mut nullable__ = None; while let Some(k) = map_.next_key()? { match k { GeneratedField::Expr => { @@ -1918,11 +1938,27 @@ impl<'de> serde::Deserialize<'de> for CastNode { } arrow_type__ = map_.next_value()?; } + GeneratedField::Metadata => { + if metadata__.is_some() { + return Err(serde::de::Error::duplicate_field("metadata")); + } + metadata__ = Some( + map_.next_value::>()? + ); + } + GeneratedField::Nullable => { + if nullable__.is_some() { + return Err(serde::de::Error::duplicate_field("nullable")); + } + nullable__ = map_.next_value()?; + } } } Ok(CastNode { expr: expr__, arrow_type: arrow_type__, + metadata: metadata__.unwrap_or_default(), + nullable: nullable__, }) } } @@ -21996,6 +22032,12 @@ impl serde::Serialize for TryCastNode { if self.arrow_type.is_some() { len += 1; } + if !self.metadata.is_empty() { + len += 1; + } + if self.nullable.is_some() { + len += 1; + } let mut struct_ser = serializer.serialize_struct("datafusion.TryCastNode", len)?; if let Some(v) = self.expr.as_ref() { struct_ser.serialize_field("expr", v)?; @@ -22003,6 +22045,12 @@ impl serde::Serialize for TryCastNode { if let Some(v) = self.arrow_type.as_ref() { struct_ser.serialize_field("arrowType", v)?; } + if !self.metadata.is_empty() { + struct_ser.serialize_field("metadata", &self.metadata)?; + } + if let Some(v) = self.nullable.as_ref() { + struct_ser.serialize_field("nullable", v)?; + } struct_ser.end() } } @@ -22016,12 +22064,16 @@ impl<'de> serde::Deserialize<'de> for TryCastNode { "expr", "arrow_type", "arrowType", + "metadata", + "nullable", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { Expr, ArrowType, + Metadata, + Nullable, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -22045,6 +22097,8 @@ impl<'de> serde::Deserialize<'de> for TryCastNode { match value { "expr" => Ok(GeneratedField::Expr), "arrowType" | "arrow_type" => Ok(GeneratedField::ArrowType), + "metadata" => Ok(GeneratedField::Metadata), + "nullable" => Ok(GeneratedField::Nullable), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -22066,6 +22120,8 @@ impl<'de> serde::Deserialize<'de> for TryCastNode { { let mut expr__ = None; let mut arrow_type__ = None; + let mut metadata__ = None; + let mut nullable__ = None; while let Some(k) = map_.next_key()? { match k { GeneratedField::Expr => { @@ -22080,11 +22136,27 @@ impl<'de> serde::Deserialize<'de> for TryCastNode { } arrow_type__ = map_.next_value()?; } + GeneratedField::Metadata => { + if metadata__.is_some() { + return Err(serde::de::Error::duplicate_field("metadata")); + } + metadata__ = Some( + map_.next_value::>()? + ); + } + GeneratedField::Nullable => { + if nullable__.is_some() { + return Err(serde::de::Error::duplicate_field("nullable")); + } + nullable__ = map_.next_value()?; + } } } Ok(TryCastNode { expr: expr__, arrow_type: arrow_type__, + metadata: metadata__.unwrap_or_default(), + nullable: nullable__, }) } } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 2e1c482db65c..15e82ce3c812 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -906,6 +906,13 @@ pub struct CastNode { pub expr: ::core::option::Option<::prost::alloc::boxed::Box>, #[prost(message, optional, tag = "2")] pub arrow_type: ::core::option::Option, + #[prost(map = "string, string", tag = "3")] + pub metadata: ::std::collections::HashMap< + ::prost::alloc::string::String, + ::prost::alloc::string::String, + >, + #[prost(bool, optional, tag = "4")] + pub nullable: ::core::option::Option, } #[derive(Clone, PartialEq, ::prost::Message)] pub struct TryCastNode { @@ -913,6 +920,13 @@ pub struct TryCastNode { pub expr: ::core::option::Option<::prost::alloc::boxed::Box>, #[prost(message, optional, tag = "2")] pub arrow_type: ::core::option::Option, + #[prost(map = "string, string", tag = "3")] + pub metadata: ::std::collections::HashMap< + ::prost::alloc::string::String, + ::prost::alloc::string::String, + >, + #[prost(bool, optional, tag = "4")] + pub nullable: ::core::option::Option, } #[derive(Clone, PartialEq, ::prost::Message)] pub struct SortExprNode { diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index ec6415adc4c9..d4b1d10c6faf 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -17,6 +17,7 @@ use std::sync::Arc; +use arrow::datatypes::Field; use datafusion::execution::registry::FunctionRegistry; use datafusion_common::{ exec_datafusion_err, internal_err, plan_datafusion_err, NullEquality, @@ -528,7 +529,8 @@ pub fn parse_expr( codec, )?); let data_type = cast.arrow_type.as_ref().required("arrow_type")?; - Ok(Expr::Cast(Cast::new(expr, data_type))) + let field = Field::new("", data_type, cast.nullable.unwrap_or(true)); + Ok(Expr::Cast(Cast::new_from_field(expr, Arc::new(field)))) } ExprType::TryCast(cast) => { let expr = Box::new(parse_required_expr( @@ -538,7 +540,11 @@ pub fn parse_expr( codec, )?); let data_type = cast.arrow_type.as_ref().required("arrow_type")?; - Ok(Expr::TryCast(TryCast::new(expr, data_type))) + let field = Field::new("", data_type, cast.nullable.unwrap_or(true)); + Ok(Expr::TryCast(TryCast::new_from_field( + expr, + Arc::new(field), + ))) } ExprType::Negative(negative) => Ok(Expr::Negative(Box::new( parse_required_expr(negative.expr.as_deref(), registry, "expr", codec)?, diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index bb9d70fb2027..4f8ae19bdd62 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -526,6 +526,8 @@ pub fn serialize_expr( let expr = Box::new(protobuf::CastNode { expr: Some(Box::new(serialize_expr(expr.as_ref(), codec)?)), arrow_type: Some(data_type.data_type().try_into()?), + metadata: data_type.metadata().clone(), + nullable: Some(data_type.is_nullable()), }); protobuf::LogicalExprNode { expr_type: Some(ExprType::Cast(expr)), @@ -535,6 +537,8 @@ pub fn serialize_expr( let expr = Box::new(protobuf::TryCastNode { expr: Some(Box::new(serialize_expr(expr.as_ref(), codec)?)), arrow_type: Some(data_type.data_type().try_into()?), + metadata: data_type.metadata().clone(), + nullable: Some(data_type.is_nullable()), }); protobuf::LogicalExprNode { expr_type: Some(ExprType::TryCast(expr)), diff --git a/datafusion/substrait/src/logical_plan/producer/types.rs b/datafusion/substrait/src/logical_plan/producer/types.rs index 4ae1528372ee..0559a8bf016c 100644 --- a/datafusion/substrait/src/logical_plan/producer/types.rs +++ b/datafusion/substrait/src/logical_plan/producer/types.rs @@ -303,7 +303,7 @@ pub(crate) fn to_substrait_type_from_field( DataType::Struct(fields) => { let field_types = fields .iter() - .map(|field| to_substrait_type_from_field(field)) + .map(to_substrait_type_from_field) .collect::>>()?; Ok(substrait::proto::Type { kind: Some(r#type::Kind::Struct(r#type::Struct { From c9b77d9d4b0ccd0e2eed246da52c024d6fd8e3c4 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Fri, 24 Oct 2025 11:43:30 -0500 Subject: [PATCH 04/13] fix merge --- datafusion/proto/src/logical_plan/from_proto.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index a6797df74b0e..2a458f9721ed 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -18,7 +18,6 @@ use std::sync::Arc; use arrow::datatypes::Field; -use datafusion::execution::registry::FunctionRegistry; use datafusion_common::{ exec_datafusion_err, internal_err, plan_datafusion_err, NullEquality, RecursionUnnestOption, Result, ScalarValue, TableReference, UnnestOptions, From b67cbfee426b2b04c2db3f786a242744bf7e7b86 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Fri, 24 Oct 2025 11:45:53 -0500 Subject: [PATCH 05/13] fmt --- .../src/logical_plan/producer/expr/cast.rs | 2 +- .../substrait/src/logical_plan/producer/types.rs | 13 +++++++------ 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/datafusion/substrait/src/logical_plan/producer/expr/cast.rs b/datafusion/substrait/src/logical_plan/producer/expr/cast.rs index bfdf33435713..2b9838406f20 100644 --- a/datafusion/substrait/src/logical_plan/producer/expr/cast.rs +++ b/datafusion/substrait/src/logical_plan/producer/expr/cast.rs @@ -80,7 +80,7 @@ pub fn from_try_cast( mod tests { use super::*; use crate::logical_plan::producer::{ - to_substrait_extended_expr, DefaultSubstraitProducer, to_substrait_type + to_substrait_extended_expr, to_substrait_type, DefaultSubstraitProducer, }; use datafusion::arrow::datatypes::{DataType, Field}; use datafusion::common::DFSchema; diff --git a/datafusion/substrait/src/logical_plan/producer/types.rs b/datafusion/substrait/src/logical_plan/producer/types.rs index 731513abbe71..3f0ebec98963 100644 --- a/datafusion/substrait/src/logical_plan/producer/types.rs +++ b/datafusion/substrait/src/logical_plan/producer/types.rs @@ -284,7 +284,8 @@ pub(crate) fn to_substrait_type_from_field( DataType::Map(inner, _) => match inner.data_type() { DataType::Struct(key_and_value) if key_and_value.len() == 2 => { let key_type = to_substrait_type_from_field(producer, &key_and_value[0])?; - let value_type = to_substrait_type_from_field(producer, &key_and_value[1])?; + let value_type = + to_substrait_type_from_field(producer, &key_and_value[1])?; Ok(substrait::proto::Type { kind: Some(r#type::Kind::Map(Box::new(r#type::Map { key: Some(Box::new(key_type)), @@ -298,10 +299,12 @@ pub(crate) fn to_substrait_type_from_field( }, DataType::Dictionary(key_type, value_type) => { let key_type = to_substrait_type_from_field( - producer, &Field::new("", key_type.as_ref().clone(), dt.is_nullable()).into(), + producer, + &Field::new("", key_type.as_ref().clone(), dt.is_nullable()).into(), )?; let value_type = to_substrait_type_from_field( - producer, &Field::new("", value_type.as_ref().clone(), dt.is_nullable()).into(), + producer, + &Field::new("", value_type.as_ref().clone(), dt.is_nullable()).into(), )?; Ok(substrait::proto::Type { kind: Some(r#type::Kind::Map(Box::new(r#type::Map { @@ -315,9 +318,7 @@ pub(crate) fn to_substrait_type_from_field( DataType::Struct(fields) => { let field_types = fields .iter() - .map(|field| { - to_substrait_type_from_field(producer, field) - }) + .map(|field| to_substrait_type_from_field(producer, field)) .collect::>>()?; Ok(substrait::proto::Type { kind: Some(r#type::Kind::Struct(r#type::Struct { From db86ce94df987dedee2fa5ba0ed7d4864d6b0d04 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Fri, 24 Oct 2025 12:12:42 -0500 Subject: [PATCH 06/13] use the helper --- datafusion/expr/src/expr.rs | 31 +++++++++++-------------------- 1 file changed, 11 insertions(+), 20 deletions(-) diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 19a72a1ddd16..df188b2338c6 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -32,6 +32,7 @@ use crate::{ExprSchemable, Operator, Signature, WindowFrame, WindowUDF}; use arrow::datatypes::{DataType, Field, FieldRef}; use datafusion_common::cse::{HashNode, NormalizeEq, Normalizeable}; +use datafusion_common::metadata::format_type_and_metadata; use datafusion_common::tree_node::{ Transformed, TransformedResult, TreeNode, TreeNodeContainer, TreeNodeRecursion, }; @@ -3295,28 +3296,18 @@ impl Display for Expr { write!(f, "END") } Expr::Cast(Cast { expr, data_type }) => { - if data_type.metadata().is_empty() { - write!(f, "CAST({expr} AS {})", data_type.data_type()) - } else { - write!( - f, - "CAST({expr} AS {}<{:?}>)", - data_type.data_type(), - data_type.metadata() - ) - } + let formatted = format_type_and_metadata( + data_type.data_type(), + Some(data_type.metadata()), + ); + write!(f, "CAST({expr} AS {})", formatted) } Expr::TryCast(TryCast { expr, data_type }) => { - if data_type.metadata().is_empty() { - write!(f, "TRY_CAST({expr} AS {})", data_type.data_type()) - } else { - write!( - f, - "TRY_CAST({expr} AS {}<{:?}>)", - data_type.data_type(), - data_type.metadata() - ) - } + let formatted = format_type_and_metadata( + data_type.data_type(), + Some(data_type.metadata()), + ); + write!(f, "TRY_CAST({expr} AS {})", formatted) } Expr::Not(expr) => write!(f, "NOT {expr}"), Expr::Negative(expr) => write!(f, "(- {expr})"), From dc8140e2948a36603c0e797971a1b2c18f26baec Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Fri, 24 Oct 2025 20:37:40 -0500 Subject: [PATCH 07/13] clippy --- datafusion/expr/src/expr.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index df188b2338c6..f2de43399f8b 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -3300,14 +3300,14 @@ impl Display for Expr { data_type.data_type(), Some(data_type.metadata()), ); - write!(f, "CAST({expr} AS {})", formatted) + write!(f, "CAST({expr} AS {formatted})") } Expr::TryCast(TryCast { expr, data_type }) => { let formatted = format_type_and_metadata( data_type.data_type(), Some(data_type.metadata()), ); - write!(f, "TRY_CAST({expr} AS {})", formatted) + write!(f, "TRY_CAST({expr} AS {formatted})") } Expr::Not(expr) => write!(f, "NOT {expr}"), Expr::Negative(expr) => write!(f, "(- {expr})"), From a3987a27c6a5644120258a546fedcab42ec9e04f Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Fri, 24 Oct 2025 21:56:54 -0500 Subject: [PATCH 08/13] test not suported --- datafusion/physical-expr/src/planner.rs | 95 +++++++++++++++++++++---- 1 file changed, 81 insertions(+), 14 deletions(-) diff --git a/datafusion/physical-expr/src/planner.rs b/datafusion/physical-expr/src/planner.rs index 61b85724ff45..ad22f90e65ac 100644 --- a/datafusion/physical-expr/src/planner.rs +++ b/datafusion/physical-expr/src/planner.rs @@ -25,7 +25,7 @@ use crate::{ use arrow::datatypes::Schema; use datafusion_common::config::ConfigOptions; -use datafusion_common::metadata::FieldMetadata; +use datafusion_common::metadata::{format_type_and_metadata, FieldMetadata}; use datafusion_common::{ exec_err, not_impl_err, plan_err, DFSchema, Result, ScalarValue, ToDFSchema, }; @@ -34,7 +34,7 @@ use datafusion_expr::expr::{Alias, Cast, InList, Placeholder, ScalarFunction}; use datafusion_expr::var_provider::is_system_variables; use datafusion_expr::var_provider::VarType; use datafusion_expr::{ - binary_expr, lit, Between, BinaryExpr, Expr, Like, Operator, TryCast, + binary_expr, lit, Between, BinaryExpr, Expr, ExprSchemable, Like, Operator, TryCast, }; /// [PhysicalExpr] evaluate DataFusion expressions such as `A + 1`, or `CAST(c1 @@ -287,18 +287,50 @@ pub fn create_physical_expr( }; Ok(expressions::case(expr, when_then_expr, else_expr)?) } - Expr::Cast(Cast { expr, data_type }) => expressions::cast( - create_physical_expr(expr, input_dfschema, execution_props)?, - input_schema, - // TODO: this drops extension metadata associated with the cast - data_type.data_type().clone(), - ), - Expr::TryCast(TryCast { expr, data_type }) => expressions::try_cast( - create_physical_expr(expr, input_dfschema, execution_props)?, - input_schema, - // TODO: this drops extension metadata associated with the cast - data_type.data_type().clone(), - ), + Expr::Cast(Cast { expr, data_type }) => { + if !data_type.metadata().is_empty() { + let (_, src_field) = expr.to_field(input_dfschema)?; + return plan_err!( + "Cast from {} to {} is not supported", + format_type_and_metadata( + src_field.data_type(), + Some(src_field.metadata()), + ), + format_type_and_metadata( + data_type.data_type(), + Some(data_type.metadata()) + ) + ); + } + + expressions::cast( + create_physical_expr(expr, input_dfschema, execution_props)?, + input_schema, + data_type.data_type().clone(), + ) + } + Expr::TryCast(TryCast { expr, data_type }) => { + if !data_type.metadata().is_empty() { + let (_, src_field) = expr.to_field(input_dfschema)?; + return plan_err!( + "TryCast from {} to {} is not supported", + format_type_and_metadata( + src_field.data_type(), + Some(src_field.metadata()), + ), + format_type_and_metadata( + data_type.data_type(), + Some(data_type.metadata()) + ) + ); + } + + expressions::try_cast( + create_physical_expr(expr, input_dfschema, execution_props)?, + input_schema, + data_type.data_type().clone(), + ) + } Expr::Not(expr) => { expressions::not(create_physical_expr(expr, input_dfschema, execution_props)?) } @@ -419,6 +451,7 @@ mod tests { use arrow::array::{ArrayRef, BooleanArray, RecordBatch, StringArray}; use arrow::datatypes::{DataType, Field}; + use datafusion_common::datatype::DataTypeExt; use datafusion_expr::{col, lit}; use super::*; @@ -447,4 +480,38 @@ mod tests { Ok(()) } + + #[test] + fn test_cast_to_extension_type() -> Result<()> { + let extension_field_type = Arc::new( + DataType::FixedSizeBinary(16) + .into_nullable_field() + .with_metadata( + [("ARROW:extension:name".to_string(), "arrow.uuid".to_string())] + .into(), + ), + ); + let expr = lit("3230e5d4-888e-408b-b09b-831f44aa0c58"); + let cast_expr = Expr::Cast(Cast::new_from_field( + Box::new(expr.clone()), + extension_field_type.clone(), + )); + let err = + create_physical_expr(&cast_expr, &DFSchema::empty(), &ExecutionProps::new()) + .unwrap_err(); + assert!(err.message().contains("arrow.uuid")); + + let try_cast_expr = Expr::TryCast(TryCast::new_from_field( + Box::new(expr.clone()), + extension_field_type.clone(), + )); + let err = create_physical_expr( + &try_cast_expr, + &DFSchema::empty(), + &ExecutionProps::new(), + ) + .unwrap_err(); + assert!(err.message().contains("arrow.uuid")); + Ok(()) + } } From d3a1e3fa8f040edbed234fcc61f4444221661e51 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Fri, 24 Oct 2025 22:15:21 -0500 Subject: [PATCH 09/13] rename data type to field --- .../provider_filter_pushdown.rs | 2 +- datafusion/expr/src/expr.rs | 59 +++++++------------ datafusion/expr/src/expr_rewriter/order_by.rs | 8 +-- datafusion/expr/src/expr_schema.rs | 9 ++- datafusion/expr/src/tree_node.rs | 8 +-- datafusion/functions/src/core/arrow_cast.rs | 4 +- .../optimizer/src/eliminate_outer_join.rs | 4 +- datafusion/physical-expr/src/planner.rs | 22 +++---- datafusion/proto/src/logical_plan/to_proto.rs | 16 ++--- datafusion/sql/src/unparser/expr.rs | 8 +-- .../src/logical_plan/producer/expr/cast.rs | 10 ++-- 11 files changed, 64 insertions(+), 86 deletions(-) diff --git a/datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs b/datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs index c80c0b4bf54b..ca01a0657988 100644 --- a/datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs +++ b/datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs @@ -183,7 +183,7 @@ impl TableProvider for CustomProvider { Expr::Literal(ScalarValue::Int16(Some(i)), _) => *i as i64, Expr::Literal(ScalarValue::Int32(Some(i)), _) => *i as i64, Expr::Literal(ScalarValue::Int64(Some(i)), _) => *i, - Expr::Cast(Cast { expr, data_type: _ }) => match expr.deref() { + Expr::Cast(Cast { expr, field: _ }) => match expr.deref() { Expr::Literal(lit_value, _) => match lit_value { ScalarValue::Int8(Some(v)) => *v as i64, ScalarValue::Int16(Some(v)) => *v as i64, diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index f2de43399f8b..1e9418ec659f 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -32,6 +32,7 @@ use crate::{ExprSchemable, Operator, Signature, WindowFrame, WindowUDF}; use arrow::datatypes::{DataType, Field, FieldRef}; use datafusion_common::cse::{HashNode, NormalizeEq, Normalizeable}; +use datafusion_common::datatype::DataTypeExt; use datafusion_common::metadata::format_type_and_metadata; use datafusion_common::tree_node::{ Transformed, TransformedResult, TreeNode, TreeNodeContainer, TreeNodeRecursion, @@ -787,7 +788,7 @@ pub struct Cast { /// The expression being cast pub expr: Box, /// The `DataType` the expression will yield - pub data_type: FieldRef, + pub field: FieldRef, } impl Cast { @@ -795,15 +796,12 @@ impl Cast { pub fn new(expr: Box, data_type: DataType) -> Self { Self { expr, - data_type: Field::new("", data_type, true).into(), + field: data_type.into_nullable_field_ref(), } } pub fn new_from_field(expr: Box, field: FieldRef) -> Self { - Self { - expr, - data_type: field, - } + Self { expr, field } } } @@ -813,7 +811,7 @@ pub struct TryCast { /// The expression being cast pub expr: Box, /// The `DataType` the expression will yield - pub data_type: FieldRef, + pub field: FieldRef, } impl TryCast { @@ -821,15 +819,12 @@ impl TryCast { pub fn new(expr: Box, data_type: DataType) -> Self { Self { expr, - data_type: Field::new("", data_type, true).into(), + field: data_type.into_nullable_field_ref(), } } pub fn new_from_field(expr: Box, field: FieldRef) -> Self { - Self { - expr, - data_type: field, - } + Self { expr, field } } } @@ -2264,23 +2259,23 @@ impl NormalizeEq for Expr { ( Expr::Cast(Cast { expr: self_expr, - data_type: self_data_type, + field: self_field, }), Expr::Cast(Cast { expr: other_expr, - data_type: other_data_type, + field: other_field, }), ) | ( Expr::TryCast(TryCast { expr: self_expr, - data_type: self_data_type, + field: self_field, }), Expr::TryCast(TryCast { expr: other_expr, - data_type: other_data_type, + field: other_field, }), - ) => self_data_type == other_data_type && self_expr.normalize_eq(other_expr), + ) => self_field == other_field && self_expr.normalize_eq(other_expr), ( Expr::ScalarFunction(ScalarFunction { func: self_func, @@ -2596,15 +2591,9 @@ impl HashNode for Expr { when_then_expr: _when_then_expr, else_expr: _else_expr, }) => {} - Expr::Cast(Cast { - expr: _expr, - data_type, - }) - | Expr::TryCast(TryCast { - expr: _expr, - data_type, - }) => { - data_type.hash(state); + Expr::Cast(Cast { expr: _expr, field }) + | Expr::TryCast(TryCast { expr: _expr, field }) => { + field.hash(state); } Expr::ScalarFunction(ScalarFunction { func, args: _args }) => { func.hash(state); @@ -3295,18 +3284,14 @@ impl Display for Expr { } write!(f, "END") } - Expr::Cast(Cast { expr, data_type }) => { - let formatted = format_type_and_metadata( - data_type.data_type(), - Some(data_type.metadata()), - ); + Expr::Cast(Cast { expr, field }) => { + let formatted = + format_type_and_metadata(field.data_type(), Some(field.metadata())); write!(f, "CAST({expr} AS {formatted})") } - Expr::TryCast(TryCast { expr, data_type }) => { - let formatted = format_type_and_metadata( - data_type.data_type(), - Some(data_type.metadata()), - ); + Expr::TryCast(TryCast { expr, field }) => { + let formatted = + format_type_and_metadata(field.data_type(), Some(field.metadata())); write!(f, "TRY_CAST({expr} AS {formatted})") } Expr::Not(expr) => write!(f, "NOT {expr}"), @@ -3693,7 +3678,7 @@ mod test { fn format_cast() -> Result<()> { let expr = Expr::Cast(Cast { expr: Box::new(Expr::Literal(ScalarValue::Float32(Some(1.23)), None)), - data_type: Field::new("", DataType::Utf8, true).into(), + field: DataType::Utf8.into_nullable_field_ref(), }); let expected_canonical = "CAST(Float32(1.23) AS Utf8)"; assert_eq!(expected_canonical, format!("{expr}")); diff --git a/datafusion/expr/src/expr_rewriter/order_by.rs b/datafusion/expr/src/expr_rewriter/order_by.rs index 6db95555502d..3b54f1b46c95 100644 --- a/datafusion/expr/src/expr_rewriter/order_by.rs +++ b/datafusion/expr/src/expr_rewriter/order_by.rs @@ -116,13 +116,13 @@ fn rewrite_in_terms_of_projection( if let Some(found) = found { return Ok(Transformed::yes(match normalized_expr { - Expr::Cast(Cast { expr: _, data_type }) => Expr::Cast(Cast { + Expr::Cast(Cast { expr: _, field }) => Expr::Cast(Cast { expr: Box::new(found), - data_type, + field, }), - Expr::TryCast(TryCast { expr: _, data_type }) => Expr::TryCast(TryCast { + Expr::TryCast(TryCast { expr: _, field }) => Expr::TryCast(TryCast { expr: Box::new(found), - data_type, + field, }), _ => found, })); diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index fe22a0b76b9a..2a38b78e0aff 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -127,9 +127,8 @@ impl ExprSchemable for Expr { .as_ref() .map_or(Ok(DataType::Null), |e| e.get_type(schema)) } - Expr::Cast(Cast { data_type, .. }) - | Expr::TryCast(TryCast { data_type, .. }) => { - Ok(data_type.data_type().clone()) + Expr::Cast(Cast { field, .. }) | Expr::TryCast(TryCast { field, .. }) => { + Ok(field.data_type().clone()) } Expr::Unnest(Unnest { expr }) => { let arg_data_type = expr.get_type(schema)?; @@ -581,12 +580,12 @@ impl ExprSchemable for Expr { func.return_field_from_args(args) } // _ => Ok((self.get_type(schema)?, self.nullable(schema)?)), - Expr::Cast(Cast { expr, data_type }) => expr + Expr::Cast(Cast { expr, field }) => expr .to_field(schema) .map(|(_, f)| { f.as_ref() .clone() - .with_data_type(data_type.data_type().clone()) + .with_data_type(field.data_type().clone()) .with_metadata(f.metadata().clone()) // TODO: should nullability be overridden here or derived from the // input expression? diff --git a/datafusion/expr/src/tree_node.rs b/datafusion/expr/src/tree_node.rs index e949bd71a71f..2fbd57a11ecb 100644 --- a/datafusion/expr/src/tree_node.rs +++ b/datafusion/expr/src/tree_node.rs @@ -220,12 +220,12 @@ impl TreeNode for Expr { .update_data(|(new_expr, new_when_then_expr, new_else_expr)| { Expr::Case(Case::new(new_expr, new_when_then_expr, new_else_expr)) }), - Expr::Cast(Cast { expr, data_type }) => expr + Expr::Cast(Cast { expr, field }) => expr .map_elements(f)? - .update_data(|be| Expr::Cast(Cast::new_from_field(be, data_type))), - Expr::TryCast(TryCast { expr, data_type }) => expr + .update_data(|be| Expr::Cast(Cast::new_from_field(be, field))), + Expr::TryCast(TryCast { expr, field }) => expr .map_elements(f)? - .update_data(|be| Expr::TryCast(TryCast::new_from_field(be, data_type))), + .update_data(|be| Expr::TryCast(TryCast::new_from_field(be, field))), Expr::ScalarFunction(ScalarFunction { func, args }) => { args.map_elements(f)?.map_data(|new_args| { Ok(Expr::ScalarFunction(ScalarFunction::new_udf( diff --git a/datafusion/functions/src/core/arrow_cast.rs b/datafusion/functions/src/core/arrow_cast.rs index 964fdd9f996c..29492e90fb76 100644 --- a/datafusion/functions/src/core/arrow_cast.rs +++ b/datafusion/functions/src/core/arrow_cast.rs @@ -19,6 +19,7 @@ use arrow::datatypes::{DataType, Field, FieldRef}; use arrow::error::ArrowError; +use datafusion_common::datatype::DataTypeExt; use datafusion_common::{ arrow_datafusion_err, exec_err, internal_err, Result, ScalarValue, }; @@ -154,6 +155,7 @@ impl ScalarUDFImpl for ArrowCastFunc { let arg = args.pop().unwrap(); let source_type = info.get_data_type(&arg)?; + // TODO: check type equality for real let new_expr = if source_type == target_type { // the argument's data type is already the correct type arg @@ -161,7 +163,7 @@ impl ScalarUDFImpl for ArrowCastFunc { // Use an actual cast to get the correct type Expr::Cast(datafusion_expr::Cast { expr: Box::new(arg), - data_type: Field::new("", target_type, true).into(), + field: target_type.into_nullable_field_ref(), }) }; // return the newly written argument to DataFusion diff --git a/datafusion/optimizer/src/eliminate_outer_join.rs b/datafusion/optimizer/src/eliminate_outer_join.rs index 45877642f276..160c09cde2f5 100644 --- a/datafusion/optimizer/src/eliminate_outer_join.rs +++ b/datafusion/optimizer/src/eliminate_outer_join.rs @@ -289,8 +289,8 @@ fn extract_non_nullable_columns( false, ) } - Expr::Cast(Cast { expr, data_type: _ }) - | Expr::TryCast(TryCast { expr, data_type: _ }) => extract_non_nullable_columns( + Expr::Cast(Cast { expr, field: _ }) + | Expr::TryCast(TryCast { expr, field: _ }) => extract_non_nullable_columns( expr, non_nullable_cols, left_schema, diff --git a/datafusion/physical-expr/src/planner.rs b/datafusion/physical-expr/src/planner.rs index ad22f90e65ac..f3de5b43118b 100644 --- a/datafusion/physical-expr/src/planner.rs +++ b/datafusion/physical-expr/src/planner.rs @@ -287,8 +287,8 @@ pub fn create_physical_expr( }; Ok(expressions::case(expr, when_then_expr, else_expr)?) } - Expr::Cast(Cast { expr, data_type }) => { - if !data_type.metadata().is_empty() { + Expr::Cast(Cast { expr, field }) => { + if !field.metadata().is_empty() { let (_, src_field) = expr.to_field(input_dfschema)?; return plan_err!( "Cast from {} to {} is not supported", @@ -296,21 +296,18 @@ pub fn create_physical_expr( src_field.data_type(), Some(src_field.metadata()), ), - format_type_and_metadata( - data_type.data_type(), - Some(data_type.metadata()) - ) + format_type_and_metadata(field.data_type(), Some(field.metadata())) ); } expressions::cast( create_physical_expr(expr, input_dfschema, execution_props)?, input_schema, - data_type.data_type().clone(), + field.data_type().clone(), ) } - Expr::TryCast(TryCast { expr, data_type }) => { - if !data_type.metadata().is_empty() { + Expr::TryCast(TryCast { expr, field }) => { + if !field.metadata().is_empty() { let (_, src_field) = expr.to_field(input_dfschema)?; return plan_err!( "TryCast from {} to {} is not supported", @@ -318,17 +315,14 @@ pub fn create_physical_expr( src_field.data_type(), Some(src_field.metadata()), ), - format_type_and_metadata( - data_type.data_type(), - Some(data_type.metadata()) - ) + format_type_and_metadata(field.data_type(), Some(field.metadata())) ); } expressions::try_cast( create_physical_expr(expr, input_dfschema, execution_props)?, input_schema, - data_type.data_type().clone(), + field.data_type().clone(), ) } Expr::Not(expr) => { diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 187aef25b8a8..c1b9f30f2de8 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -522,23 +522,23 @@ pub fn serialize_expr( expr_type: Some(ExprType::Case(expr)), } } - Expr::Cast(Cast { expr, data_type }) => { + Expr::Cast(Cast { expr, field }) => { let expr = Box::new(protobuf::CastNode { expr: Some(Box::new(serialize_expr(expr.as_ref(), codec)?)), - arrow_type: Some(data_type.data_type().try_into()?), - metadata: data_type.metadata().clone(), - nullable: Some(data_type.is_nullable()), + arrow_type: Some(field.data_type().try_into()?), + metadata: field.metadata().clone(), + nullable: Some(field.is_nullable()), }); protobuf::LogicalExprNode { expr_type: Some(ExprType::Cast(expr)), } } - Expr::TryCast(TryCast { expr, data_type }) => { + Expr::TryCast(TryCast { expr, field }) => { let expr = Box::new(protobuf::TryCastNode { expr: Some(Box::new(serialize_expr(expr.as_ref(), codec)?)), - arrow_type: Some(data_type.data_type().try_into()?), - metadata: data_type.metadata().clone(), - nullable: Some(data_type.is_nullable()), + arrow_type: Some(field.data_type().try_into()?), + metadata: field.metadata().clone(), + nullable: Some(field.is_nullable()), }); protobuf::LogicalExprNode { expr_type: Some(ExprType::TryCast(expr)), diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index a2d2b978fac5..f51ff4c1ba9a 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -190,9 +190,7 @@ impl Unparser<'_> { end_token: AttachedToken::empty(), }) } - Expr::Cast(Cast { expr, data_type }) => { - Ok(self.cast_to_sql(expr, data_type)?) - } + Expr::Cast(Cast { expr, field }) => Ok(self.cast_to_sql(expr, field)?), Expr::Literal(value, _) => Ok(self.scalar_to_sql(value)?), Expr::Alias(Alias { expr, name: _, .. }) => self.expr_to_sql_inner(expr), Expr::WindowFunction(window_fun) => { @@ -465,12 +463,12 @@ impl Unparser<'_> { ) }) } - Expr::TryCast(TryCast { expr, data_type }) => { + Expr::TryCast(TryCast { expr, field }) => { let inner_expr = self.expr_to_sql_inner(expr)?; Ok(ast::Expr::Cast { kind: ast::CastKind::TryCast, expr: Box::new(inner_expr), - data_type: self.arrow_dtype_to_ast_dtype(data_type)?, + data_type: self.arrow_dtype_to_ast_dtype(field)?, format: None, }) } diff --git a/datafusion/substrait/src/logical_plan/producer/expr/cast.rs b/datafusion/substrait/src/logical_plan/producer/expr/cast.rs index 2b9838406f20..889a285eb6a0 100644 --- a/datafusion/substrait/src/logical_plan/producer/expr/cast.rs +++ b/datafusion/substrait/src/logical_plan/producer/expr/cast.rs @@ -29,7 +29,7 @@ pub fn from_cast( cast: &Cast, schema: &DFSchemaRef, ) -> datafusion::common::Result { - let Cast { expr, data_type } = cast; + let Cast { expr, field } = cast; // since substrait Null must be typed, so if we see a cast(null, dt), we make it a typed null if let Expr::Literal(lit, _) = expr.as_ref() { // only the untyped(a null scalar value) null literal need this special handling @@ -40,7 +40,7 @@ pub fn from_cast( nullable: true, type_variation_reference: DEFAULT_TYPE_VARIATION_REF, literal_type: Some(LiteralType::Null(to_substrait_type_from_field( - producer, data_type, + producer, field, )?)), }; return Ok(Expression { @@ -51,7 +51,7 @@ pub fn from_cast( Ok(Expression { rex_type: Some(RexType::Cast(Box::new( substrait::proto::expression::Cast { - r#type: Some(to_substrait_type_from_field(producer, data_type)?), + r#type: Some(to_substrait_type_from_field(producer, field)?), input: Some(Box::new(producer.handle_expr(expr, schema)?)), failure_behavior: FailureBehavior::ThrowException.into(), }, @@ -64,11 +64,11 @@ pub fn from_try_cast( cast: &TryCast, schema: &DFSchemaRef, ) -> datafusion::common::Result { - let TryCast { expr, data_type } = cast; + let TryCast { expr, field } = cast; Ok(Expression { rex_type: Some(RexType::Cast(Box::new( substrait::proto::expression::Cast { - r#type: Some(to_substrait_type_from_field(producer, data_type)?), + r#type: Some(to_substrait_type_from_field(producer, field)?), input: Some(Box::new(producer.handle_expr(expr, schema)?)), failure_behavior: FailureBehavior::ReturnNull.into(), }, From ee13fc5b723f8e004f04a2d03bbcc92a94f3bdfd Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Fri, 24 Oct 2025 22:48:00 -0500 Subject: [PATCH 10/13] clippy --- datafusion/physical-expr/src/planner.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/datafusion/physical-expr/src/planner.rs b/datafusion/physical-expr/src/planner.rs index f3de5b43118b..bc80c6df1aa2 100644 --- a/datafusion/physical-expr/src/planner.rs +++ b/datafusion/physical-expr/src/planner.rs @@ -488,7 +488,7 @@ mod tests { let expr = lit("3230e5d4-888e-408b-b09b-831f44aa0c58"); let cast_expr = Expr::Cast(Cast::new_from_field( Box::new(expr.clone()), - extension_field_type.clone(), + Arc::clone(&extension_field_type), )); let err = create_physical_expr(&cast_expr, &DFSchema::empty(), &ExecutionProps::new()) @@ -497,7 +497,7 @@ mod tests { let try_cast_expr = Expr::TryCast(TryCast::new_from_field( Box::new(expr.clone()), - extension_field_type.clone(), + Arc::clone(&extension_field_type), )); let err = create_physical_expr( &try_cast_expr, From dc3667efb675b35a0462f8900951271c819f5ad0 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Fri, 24 Oct 2025 23:08:00 -0500 Subject: [PATCH 11/13] maybe better substrait consumer integration --- .../src/logical_plan/consumer/expr/cast.rs | 11 +++++----- .../src/logical_plan/consumer/types.rs | 20 ++++++++++++++++++- .../src/logical_plan/producer/types.rs | 12 +++++------ 3 files changed, 31 insertions(+), 12 deletions(-) diff --git a/datafusion/substrait/src/logical_plan/consumer/expr/cast.rs b/datafusion/substrait/src/logical_plan/consumer/expr/cast.rs index 5e8d3d93065f..ff88464b8bb0 100644 --- a/datafusion/substrait/src/logical_plan/consumer/expr/cast.rs +++ b/datafusion/substrait/src/logical_plan/consumer/expr/cast.rs @@ -15,8 +15,9 @@ // specific language governing permissions and limitations // under the License. -use crate::logical_plan::consumer::types::from_substrait_type_without_names; -use crate::logical_plan::consumer::SubstraitConsumer; +use crate::logical_plan::consumer::{ + field_from_substrait_type_without_names, SubstraitConsumer, +}; use datafusion::common::{substrait_err, DFSchema}; use datafusion::logical_expr::{Cast, Expr, TryCast}; use substrait::proto::expression as substrait_expression; @@ -37,11 +38,11 @@ pub async fn from_cast( ) .await?, ); - let data_type = from_substrait_type_without_names(consumer, output_type)?; + let field = field_from_substrait_type_without_names(consumer, output_type)?; if cast.failure_behavior() == ReturnNull { - Ok(Expr::TryCast(TryCast::new(input_expr, data_type))) + Ok(Expr::TryCast(TryCast::new_from_field(input_expr, field))) } else { - Ok(Expr::Cast(Cast::new(input_expr, data_type))) + Ok(Expr::Cast(Cast::new_from_field(input_expr, field))) } } None => substrait_err!("Cast expression without output type is not allowed"), diff --git a/datafusion/substrait/src/logical_plan/consumer/types.rs b/datafusion/substrait/src/logical_plan/consumer/types.rs index 772ea7177ca2..738620c5f5aa 100644 --- a/datafusion/substrait/src/logical_plan/consumer/types.rs +++ b/datafusion/substrait/src/logical_plan/consumer/types.rs @@ -34,14 +34,22 @@ use crate::variation_const::{ VIEW_CONTAINER_TYPE_VARIATION_REF, }; use datafusion::arrow::datatypes::{ - DataType, Field, Fields, IntervalUnit, Schema, TimeUnit, + DataType, Field, FieldRef, Fields, IntervalUnit, Schema, TimeUnit, }; +use datafusion::common::datatype::DataTypeExt; use datafusion::common::{ not_impl_err, substrait_datafusion_err, substrait_err, DFSchema, }; use std::sync::Arc; use substrait::proto::{r#type, NamedStruct, Type}; +pub(crate) fn field_from_substrait_type_without_names( + consumer: &impl SubstraitConsumer, + dt: &Type, +) -> datafusion::common::Result { + Ok(from_substrait_type_without_names(consumer, dt)?.into_nullable_field_ref()) +} + pub(crate) fn from_substrait_type_without_names( consumer: &impl SubstraitConsumer, dt: &Type, @@ -49,6 +57,16 @@ pub(crate) fn from_substrait_type_without_names( from_substrait_type(consumer, dt, &[], &mut 0) } +pub fn field_from_substrait_type( + consumer: &impl SubstraitConsumer, + dt: &Type, + dfs_names: &[String], + name_idx: &mut usize, +) -> datafusion::common::Result { + // We could add nullability here now that we are returning a Field + Ok(from_substrait_type(consumer, dt, dfs_names, name_idx)?.into_nullable_field_ref()) +} + pub fn from_substrait_type( consumer: &impl SubstraitConsumer, dt: &Type, diff --git a/datafusion/substrait/src/logical_plan/producer/types.rs b/datafusion/substrait/src/logical_plan/producer/types.rs index 3f0ebec98963..99f5ecec6690 100644 --- a/datafusion/substrait/src/logical_plan/producer/types.rs +++ b/datafusion/substrait/src/logical_plan/producer/types.rs @@ -41,14 +41,14 @@ pub(crate) fn to_substrait_type( pub(crate) fn to_substrait_type_from_field( producer: &mut impl SubstraitProducer, - dt: &FieldRef, + field: &FieldRef, ) -> datafusion::common::Result { - let nullability = if dt.is_nullable() { + let nullability = if field.is_nullable() { r#type::Nullability::Nullable as i32 } else { r#type::Nullability::Required as i32 }; - match dt.data_type() { + match field.data_type() { DataType::Null => internal_err!("Null cast is not valid"), DataType::Boolean => Ok(substrait::proto::Type { kind: Some(r#type::Kind::Bool(r#type::Boolean { @@ -300,11 +300,11 @@ pub(crate) fn to_substrait_type_from_field( DataType::Dictionary(key_type, value_type) => { let key_type = to_substrait_type_from_field( producer, - &Field::new("", key_type.as_ref().clone(), dt.is_nullable()).into(), + &Field::new("", key_type.as_ref().clone(), field.is_nullable()).into(), )?; let value_type = to_substrait_type_from_field( producer, - &Field::new("", value_type.as_ref().clone(), dt.is_nullable()).into(), + &Field::new("", value_type.as_ref().clone(), field.is_nullable()).into(), )?; Ok(substrait::proto::Type { kind: Some(r#type::Kind::Map(Box::new(r#type::Map { @@ -344,7 +344,7 @@ pub(crate) fn to_substrait_type_from_field( precision: *p as i32, })), }), - _ => not_impl_err!("Unsupported cast type: {dt}"), + _ => not_impl_err!("Unsupported cast type: {field}"), } } From fdcb31d3ddc86dd1597a89a96241e228fae1d3ab Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Fri, 24 Oct 2025 23:20:51 -0500 Subject: [PATCH 12/13] no need to update this --- datafusion/functions/src/core/arrow_cast.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/datafusion/functions/src/core/arrow_cast.rs b/datafusion/functions/src/core/arrow_cast.rs index 29492e90fb76..016b9c4bbccd 100644 --- a/datafusion/functions/src/core/arrow_cast.rs +++ b/datafusion/functions/src/core/arrow_cast.rs @@ -155,7 +155,6 @@ impl ScalarUDFImpl for ArrowCastFunc { let arg = args.pop().unwrap(); let source_type = info.get_data_type(&arg)?; - // TODO: check type equality for real let new_expr = if source_type == target_type { // the argument's data type is already the correct type arg From 099dc0524658e95b35776b2d7d6bdcf009f96e63 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Fri, 24 Oct 2025 23:26:23 -0500 Subject: [PATCH 13/13] comment about nullability --- datafusion/expr/src/expr_schema.rs | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index 2a38b78e0aff..0f41c6967548 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -583,12 +583,13 @@ impl ExprSchemable for Expr { Expr::Cast(Cast { expr, field }) => expr .to_field(schema) .map(|(_, f)| { + // This currently propagates the nullability of the input + // expression as the resulting physical expression does + // not currently consider the nullability specified here f.as_ref() .clone() .with_data_type(field.data_type().clone()) .with_metadata(f.metadata().clone()) - // TODO: should nullability be overridden here or derived from the - // input expression? }) .map(Arc::new), Expr::Placeholder(Placeholder {