diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index ee9ac0e7902d..11103472ae2a 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -518,6 +518,7 @@ message AggregateUDFExprNode { LogicalExprNode filter = 3; repeated SortExprNode order_by = 4; optional bytes fun_definition = 6; + optional NullTreatment null_treatment = 7; } message ScalarUDFExprNode { @@ -538,6 +539,9 @@ message WindowExprNode { // repeated LogicalExprNode filter = 7; WindowFrame window_frame = 8; optional bytes fun_definition = 10; + optional NullTreatment null_treatment = 11; + bool distinct = 12; + LogicalExprNode filter = 13; } message BetweenNode { @@ -622,6 +626,11 @@ message WindowFrameBound { datafusion_common.ScalarValue bound_value = 2; } +enum NullTreatment { + RESPECT_NULLS = 0; + IGNORE_NULLS = 1; +} + /////////////////////////////////////////////////////////////////////////////////////////////////// // Arrow Data Types /////////////////////////////////////////////////////////////////////////////////////////////////// @@ -1365,4 +1374,4 @@ message SortMergeJoinExecNode { JoinFilter filter = 5; repeated SortExprNode sort_options = 6; datafusion_common.NullEquality null_equality = 7; -} \ No newline at end of file +} diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 29967d812000..b34da2c312de 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -596,6 +596,9 @@ impl serde::Serialize for AggregateUdfExprNode { if self.fun_definition.is_some() { len += 1; } + if self.null_treatment.is_some() { + len += 1; + } let mut struct_ser = serializer.serialize_struct("datafusion.AggregateUDFExprNode", len)?; if !self.fun_name.is_empty() { struct_ser.serialize_field("funName", &self.fun_name)?; @@ -617,6 +620,11 @@ impl serde::Serialize for AggregateUdfExprNode { #[allow(clippy::needless_borrows_for_generic_args)] struct_ser.serialize_field("funDefinition", pbjson::private::base64::encode(&v).as_str())?; } + if let Some(v) = self.null_treatment.as_ref() { + let v = NullTreatment::try_from(*v) + .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", *v)))?; + struct_ser.serialize_field("nullTreatment", &v)?; + } struct_ser.end() } } @@ -636,6 +644,8 @@ impl<'de> serde::Deserialize<'de> for AggregateUdfExprNode { "orderBy", "fun_definition", "funDefinition", + "null_treatment", + "nullTreatment", ]; #[allow(clippy::enum_variant_names)] @@ -646,6 +656,7 @@ impl<'de> serde::Deserialize<'de> for AggregateUdfExprNode { Filter, OrderBy, FunDefinition, + NullTreatment, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -673,6 +684,7 @@ impl<'de> serde::Deserialize<'de> for AggregateUdfExprNode { "filter" => Ok(GeneratedField::Filter), "orderBy" | "order_by" => Ok(GeneratedField::OrderBy), "funDefinition" | "fun_definition" => Ok(GeneratedField::FunDefinition), + "nullTreatment" | "null_treatment" => Ok(GeneratedField::NullTreatment), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -698,6 +710,7 @@ impl<'de> serde::Deserialize<'de> for AggregateUdfExprNode { let mut filter__ = None; let mut order_by__ = None; let mut fun_definition__ = None; + let mut null_treatment__ = None; while let Some(k) = map_.next_key()? { match k { GeneratedField::FunName => { @@ -738,6 +751,12 @@ impl<'de> serde::Deserialize<'de> for AggregateUdfExprNode { map_.next_value::<::std::option::Option<::pbjson::private::BytesDeserialize<_>>>()?.map(|x| x.0) ; } + GeneratedField::NullTreatment => { + if null_treatment__.is_some() { + return Err(serde::de::Error::duplicate_field("nullTreatment")); + } + null_treatment__ = map_.next_value::<::std::option::Option>()?.map(|x| x as i32); + } } } Ok(AggregateUdfExprNode { @@ -747,6 +766,7 @@ impl<'de> serde::Deserialize<'de> for AggregateUdfExprNode { filter: filter__, order_by: order_by__.unwrap_or_default(), fun_definition: fun_definition__, + null_treatment: null_treatment__, }) } } @@ -13284,6 +13304,77 @@ impl<'de> serde::Deserialize<'de> for Not { deserializer.deserialize_struct("datafusion.Not", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for NullTreatment { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + let variant = match self { + Self::RespectNulls => "RESPECT_NULLS", + Self::IgnoreNulls => "IGNORE_NULLS", + }; + serializer.serialize_str(variant) + } +} +impl<'de> serde::Deserialize<'de> for NullTreatment { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "RESPECT_NULLS", + "IGNORE_NULLS", + ]; + + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = NullTreatment; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + fn visit_i64(self, v: i64) -> std::result::Result + where + E: serde::de::Error, + { + i32::try_from(v) + .ok() + .and_then(|x| x.try_into().ok()) + .ok_or_else(|| { + serde::de::Error::invalid_value(serde::de::Unexpected::Signed(v), &self) + }) + } + + fn visit_u64(self, v: u64) -> std::result::Result + where + E: serde::de::Error, + { + i32::try_from(v) + .ok() + .and_then(|x| x.try_into().ok()) + .ok_or_else(|| { + serde::de::Error::invalid_value(serde::de::Unexpected::Unsigned(v), &self) + }) + } + + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "RESPECT_NULLS" => Ok(NullTreatment::RespectNulls), + "IGNORE_NULLS" => Ok(NullTreatment::IgnoreNulls), + _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), + } + } + } + deserializer.deserialize_any(GeneratedVisitor) + } +} impl serde::Serialize for OptimizedLogicalPlanType { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result @@ -23514,6 +23605,15 @@ impl serde::Serialize for WindowExprNode { if self.fun_definition.is_some() { len += 1; } + if self.null_treatment.is_some() { + len += 1; + } + if self.distinct { + len += 1; + } + if self.filter.is_some() { + len += 1; + } if self.window_function.is_some() { len += 1; } @@ -23535,6 +23635,17 @@ impl serde::Serialize for WindowExprNode { #[allow(clippy::needless_borrows_for_generic_args)] struct_ser.serialize_field("funDefinition", pbjson::private::base64::encode(&v).as_str())?; } + if let Some(v) = self.null_treatment.as_ref() { + let v = NullTreatment::try_from(*v) + .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", *v)))?; + struct_ser.serialize_field("nullTreatment", &v)?; + } + if self.distinct { + struct_ser.serialize_field("distinct", &self.distinct)?; + } + if let Some(v) = self.filter.as_ref() { + struct_ser.serialize_field("filter", v)?; + } if let Some(v) = self.window_function.as_ref() { match v { window_expr_node::WindowFunction::Udaf(v) => { @@ -23564,6 +23675,10 @@ impl<'de> serde::Deserialize<'de> for WindowExprNode { "windowFrame", "fun_definition", "funDefinition", + "null_treatment", + "nullTreatment", + "distinct", + "filter", "udaf", "udwf", ]; @@ -23575,6 +23690,9 @@ impl<'de> serde::Deserialize<'de> for WindowExprNode { OrderBy, WindowFrame, FunDefinition, + NullTreatment, + Distinct, + Filter, Udaf, Udwf, } @@ -23603,6 +23721,9 @@ impl<'de> serde::Deserialize<'de> for WindowExprNode { "orderBy" | "order_by" => Ok(GeneratedField::OrderBy), "windowFrame" | "window_frame" => Ok(GeneratedField::WindowFrame), "funDefinition" | "fun_definition" => Ok(GeneratedField::FunDefinition), + "nullTreatment" | "null_treatment" => Ok(GeneratedField::NullTreatment), + "distinct" => Ok(GeneratedField::Distinct), + "filter" => Ok(GeneratedField::Filter), "udaf" => Ok(GeneratedField::Udaf), "udwf" => Ok(GeneratedField::Udwf), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), @@ -23629,6 +23750,9 @@ impl<'de> serde::Deserialize<'de> for WindowExprNode { let mut order_by__ = None; let mut window_frame__ = None; let mut fun_definition__ = None; + let mut null_treatment__ = None; + let mut distinct__ = None; + let mut filter__ = None; let mut window_function__ = None; while let Some(k) = map_.next_key()? { match k { @@ -23664,6 +23788,24 @@ impl<'de> serde::Deserialize<'de> for WindowExprNode { map_.next_value::<::std::option::Option<::pbjson::private::BytesDeserialize<_>>>()?.map(|x| x.0) ; } + GeneratedField::NullTreatment => { + if null_treatment__.is_some() { + return Err(serde::de::Error::duplicate_field("nullTreatment")); + } + null_treatment__ = map_.next_value::<::std::option::Option>()?.map(|x| x as i32); + } + GeneratedField::Distinct => { + if distinct__.is_some() { + return Err(serde::de::Error::duplicate_field("distinct")); + } + distinct__ = Some(map_.next_value()?); + } + GeneratedField::Filter => { + if filter__.is_some() { + return Err(serde::de::Error::duplicate_field("filter")); + } + filter__ = map_.next_value()?; + } GeneratedField::Udaf => { if window_function__.is_some() { return Err(serde::de::Error::duplicate_field("udaf")); @@ -23684,6 +23826,9 @@ impl<'de> serde::Deserialize<'de> for WindowExprNode { order_by: order_by__.unwrap_or_default(), window_frame: window_frame__, fun_definition: fun_definition__, + null_treatment: null_treatment__, + distinct: distinct__.unwrap_or_default(), + filter: filter__, window_function: window_function__, }) } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index d3b5f566e98b..2e1c482db65c 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -605,7 +605,7 @@ pub mod logical_expr_node { TryCast(::prost::alloc::boxed::Box), /// window expressions #[prost(message, tag = "18")] - WindowExpr(super::WindowExprNode), + WindowExpr(::prost::alloc::boxed::Box), /// AggregateUDF expressions #[prost(message, tag = "19")] AggregateUdfExpr(::prost::alloc::boxed::Box), @@ -795,6 +795,8 @@ pub struct AggregateUdfExprNode { pub order_by: ::prost::alloc::vec::Vec, #[prost(bytes = "vec", optional, tag = "6")] pub fun_definition: ::core::option::Option<::prost::alloc::vec::Vec>, + #[prost(enumeration = "NullTreatment", optional, tag = "7")] + pub null_treatment: ::core::option::Option, } #[derive(Clone, PartialEq, ::prost::Message)] pub struct ScalarUdfExprNode { @@ -818,6 +820,12 @@ pub struct WindowExprNode { pub window_frame: ::core::option::Option, #[prost(bytes = "vec", optional, tag = "10")] pub fun_definition: ::core::option::Option<::prost::alloc::vec::Vec>, + #[prost(enumeration = "NullTreatment", optional, tag = "11")] + pub null_treatment: ::core::option::Option, + #[prost(bool, tag = "12")] + pub distinct: bool, + #[prost(message, optional, boxed, tag = "13")] + pub filter: ::core::option::Option<::prost::alloc::boxed::Box>, #[prost(oneof = "window_expr_node::WindowFunction", tags = "3, 9")] pub window_function: ::core::option::Option, } @@ -2129,6 +2137,32 @@ impl WindowFrameBoundType { } #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] #[repr(i32)] +pub enum NullTreatment { + RespectNulls = 0, + IgnoreNulls = 1, +} +impl NullTreatment { + /// String value of the enum field names used in the ProtoBuf definition. + /// + /// The values are not transformed in any way and thus are considered stable + /// (if the ProtoBuf definition does not change) and safe for programmatic use. + pub fn as_str_name(&self) -> &'static str { + match self { + Self::RespectNulls => "RESPECT_NULLS", + Self::IgnoreNulls => "IGNORE_NULLS", + } + } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "RESPECT_NULLS" => Some(Self::RespectNulls), + "IGNORE_NULLS" => Some(Self::IgnoreNulls), + _ => None, + } + } +} +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] +#[repr(i32)] pub enum DateUnit { Day = 0, DateMillisecond = 1, diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index cbfa15183b5c..ec6415adc4c9 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -23,7 +23,7 @@ use datafusion_common::{ RecursionUnnestOption, Result, ScalarValue, TableReference, UnnestOptions, }; use datafusion_expr::dml::InsertOp; -use datafusion_expr::expr::{Alias, Placeholder, Sort}; +use datafusion_expr::expr::{Alias, NullTreatment, Placeholder, Sort}; use datafusion_expr::expr::{Unnest, WildcardOptions}; use datafusion_expr::{ expr::{self, InList, WindowFunction}, @@ -243,6 +243,15 @@ impl From for WriteOp { } } +impl From for NullTreatment { + fn from(t: protobuf::NullTreatment) -> Self { + match t { + protobuf::NullTreatment::RespectNulls => NullTreatment::RespectNulls, + protobuf::NullTreatment::IgnoreNulls => NullTreatment::IgnoreNulls, + } + } +} + pub fn parse_expr( proto: &protobuf::LogicalExprNode, registry: &dyn FunctionRegistry, @@ -301,9 +310,21 @@ pub fn parse_expr( exec_datafusion_err!("missing window frame during deserialization") })?; - // TODO: support null treatment, distinct, and filter in proto. - // See https://github.com/apache/datafusion/issues/17417 - match window_function { + let null_treatment = match expr.null_treatment { + Some(null_treatment) => { + let null_treatment = protobuf::NullTreatment::try_from(null_treatment) + .map_err(|_| { + proto_error(format!( + "Received a WindowExprNode message with unknown NullTreatment {}", + null_treatment + )) + })?; + Some(NullTreatment::from(null_treatment)) + } + None => None, + }; + + let agg_fn = match window_function { window_expr_node::WindowFunction::Udaf(udaf_name) => { let udaf_function = match &expr.fun_definition { Some(buf) => codec.try_decode_udaf(udaf_name, buf)?, @@ -311,17 +332,7 @@ pub fn parse_expr( .udaf(udaf_name) .or_else(|_| codec.try_decode_udaf(udaf_name, &[]))?, }; - - let args = parse_exprs(&expr.exprs, registry, codec)?; - Expr::from(WindowFunction::new( - expr::WindowFunctionDefinition::AggregateUDF(udaf_function), - args, - )) - .partition_by(partition_by) - .order_by(order_by) - .window_frame(window_frame) - .build() - .map_err(Error::DataFusionError) + expr::WindowFunctionDefinition::AggregateUDF(udaf_function) } window_expr_node::WindowFunction::Udwf(udwf_name) => { let udwf_function = match &expr.fun_definition { @@ -330,19 +341,28 @@ pub fn parse_expr( .udwf(udwf_name) .or_else(|_| codec.try_decode_udwf(udwf_name, &[]))?, }; - - let args = parse_exprs(&expr.exprs, registry, codec)?; - Expr::from(WindowFunction::new( - expr::WindowFunctionDefinition::WindowUDF(udwf_function), - args, - )) - .partition_by(partition_by) - .order_by(order_by) - .window_frame(window_frame) - .build() - .map_err(Error::DataFusionError) + expr::WindowFunctionDefinition::WindowUDF(udwf_function) } + }; + + let args = parse_exprs(&expr.exprs, registry, codec)?; + let mut builder = Expr::from(WindowFunction::new(agg_fn, args)) + .partition_by(partition_by) + .order_by(order_by) + .window_frame(window_frame) + .null_treatment(null_treatment); + + if expr.distinct { + builder = builder.distinct(); + }; + + if let Some(filter) = + parse_optional_expr(expr.filter.as_deref(), registry, codec)? + { + builder = builder.filter(filter); } + + builder.build().map_err(Error::DataFusionError) } ExprType::Alias(alias) => Ok(Expr::Alias(Alias::new( parse_required_expr(alias.expr.as_deref(), registry, "expr", codec)?, @@ -571,6 +591,19 @@ pub fn parse_expr( .udaf(&pb.fun_name) .or_else(|_| codec.try_decode_udaf(&pb.fun_name, &[]))?, }; + let null_treatment = match pb.null_treatment { + Some(null_treatment) => { + let null_treatment = protobuf::NullTreatment::try_from(null_treatment) + .map_err(|_| { + proto_error(format!( + "Received an AggregateUdfExprNode message with unknown NullTreatment {}", + null_treatment + )) + })?; + Some(NullTreatment::from(null_treatment)) + } + None => None, + }; Ok(Expr::AggregateFunction(expr::AggregateFunction::new_udf( agg_fn, @@ -578,7 +611,7 @@ pub fn parse_expr( pb.distinct, parse_optional_expr(pb.filter.as_deref(), registry, codec)?.map(Box::new), parse_sorts(&pb.order_by, registry, codec)?, - None, + null_treatment, ))) } diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 1be3300008c7..6238c2f1cdde 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -25,7 +25,7 @@ use datafusion_common::{NullEquality, TableReference, UnnestOptions}; use datafusion_expr::dml::InsertOp; use datafusion_expr::expr::{ self, AggregateFunctionParams, Alias, Between, BinaryExpr, Cast, GroupingSet, InList, - Like, Placeholder, ScalarFunction, Unnest, + Like, NullTreatment, Placeholder, ScalarFunction, Unnest, }; use datafusion_expr::WriteOp; use datafusion_expr::{ @@ -314,11 +314,9 @@ pub fn serialize_expr( ref partition_by, ref order_by, ref window_frame, - // TODO: support null treatment, distinct, and filter in proto. - // See https://github.com/apache/datafusion/issues/17417 - null_treatment: _, - distinct: _, - filter: _, + ref null_treatment, + ref distinct, + ref filter, }, } = window_fun.as_ref(); let mut buf = Vec::new(); @@ -342,16 +340,24 @@ pub fn serialize_expr( let window_frame: Option = Some(window_frame.try_into()?); + let window_expr = protobuf::WindowExprNode { exprs: serialize_exprs(args, codec)?, window_function: Some(window_function), partition_by, order_by, window_frame, + distinct: *distinct, + filter: match filter { + Some(e) => Some(Box::new(serialize_expr(e.as_ref(), codec)?)), + None => None, + }, + null_treatment: null_treatment + .map(|nt| protobuf::NullTreatment::from(nt).into()), fun_definition, }; protobuf::LogicalExprNode { - expr_type: Some(ExprType::WindowExpr(window_expr)), + expr_type: Some(ExprType::WindowExpr(Box::new(window_expr))), } } Expr::AggregateFunction(expr::AggregateFunction { @@ -362,7 +368,7 @@ pub fn serialize_expr( ref distinct, ref filter, ref order_by, - null_treatment: _, + ref null_treatment, }, }) => { let mut buf = Vec::new(); @@ -379,6 +385,8 @@ pub fn serialize_expr( }, order_by: serialize_sorts(order_by, codec)?, fun_definition: (!buf.is_empty()).then_some(buf), + null_treatment: null_treatment + .map(|nt| protobuf::NullTreatment::from(nt).into()), }, ))), } @@ -722,3 +730,12 @@ impl From<&WriteOp> for protobuf::dml_node::Type { } } } + +impl From for protobuf::NullTreatment { + fn from(t: NullTreatment) -> Self { + match t { + NullTreatment::RespectNulls => protobuf::NullTreatment::RespectNulls, + NullTreatment::IgnoreNulls => protobuf::NullTreatment::IgnoreNulls, + } + } +} diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index c5d4b49092d9..3d51038eba72 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -73,8 +73,8 @@ use datafusion_common::{ }; use datafusion_expr::dml::CopyTo; use datafusion_expr::expr::{ - self, Between, BinaryExpr, Case, Cast, GroupingSet, InList, Like, ScalarFunction, - Unnest, WildcardOptions, + self, Between, BinaryExpr, Case, Cast, GroupingSet, InList, Like, NullTreatment, + ScalarFunction, Unnest, WildcardOptions, }; use datafusion_expr::logical_plan::{Extension, UserDefinedLogicalNodeCore}; use datafusion_expr::{ @@ -2190,7 +2190,11 @@ fn roundtrip_aggregate_udf() { Arc::new(vec![DataType::Float64, DataType::UInt32]), ); - let test_expr = Expr::AggregateFunction(expr::AggregateFunction::new_udf( + let ctx = SessionContext::new(); + ctx.register_udaf(dummy_agg.clone()); + + // null_treatment absent + let test_expr1 = Expr::AggregateFunction(expr::AggregateFunction::new_udf( Arc::new(dummy_agg.clone()), vec![lit(1.0_f64)], false, @@ -2199,10 +2203,29 @@ fn roundtrip_aggregate_udf() { None, )); - let ctx = SessionContext::new(); - ctx.register_udaf(dummy_agg); + // null_treatment respect nulls + let test_expr2 = Expr::AggregateFunction(expr::AggregateFunction::new_udf( + Arc::new(dummy_agg.clone()), + vec![lit(1.0_f64)], + true, + Some(Box::new(lit(true))), + vec![], + Some(NullTreatment::RespectNulls), + )); - roundtrip_expr_test(test_expr, ctx); + // null_treatment ignore nulls + let test_expr3 = Expr::AggregateFunction(expr::AggregateFunction::new_udf( + Arc::new(dummy_agg), + vec![lit(1.0_f64)], + true, + Some(Box::new(lit(true))), + vec![], + Some(NullTreatment::IgnoreNulls), + )); + + roundtrip_expr_test(test_expr1, ctx.clone()); + roundtrip_expr_test(test_expr2, ctx.clone()); + roundtrip_expr_test(test_expr3, ctx); } fn dummy_udf() -> ScalarUDF { @@ -2566,8 +2589,10 @@ fn roundtrip_window() { .window_frame(row_number_frame.clone()) .build() .unwrap(); + ctx.register_udwf(dummy_window_udf); - let text_expr7 = Expr::from(expr::WindowFunction::new( + // 7. test with average udaf + let test_expr7 = Expr::from(expr::WindowFunction::new( WindowFunctionDefinition::AggregateUDF(avg_udaf()), vec![col("col1")], )) @@ -2575,7 +2600,53 @@ fn roundtrip_window() { .build() .unwrap(); - ctx.register_udwf(dummy_window_udf); + // 8. test with respect nulls + let test_expr8 = Expr::from(expr::WindowFunction::new( + WindowFunctionDefinition::WindowUDF(rank_udwf()), + vec![], + )) + .partition_by(vec![col("col1")]) + .order_by(vec![col("col2").sort(true, false)]) + .window_frame(WindowFrame::new(Some(false))) + .null_treatment(NullTreatment::RespectNulls) + .build() + .unwrap(); + + // 9. test with ignore nulls + let test_expr9 = Expr::from(expr::WindowFunction::new( + WindowFunctionDefinition::WindowUDF(rank_udwf()), + vec![], + )) + .partition_by(vec![col("col1")]) + .order_by(vec![col("col2").sort(true, false)]) + .window_frame(WindowFrame::new(Some(false))) + .null_treatment(NullTreatment::IgnoreNulls) + .build() + .unwrap(); + + // 10. test with distinct is `true` + let test_expr10 = Expr::from(expr::WindowFunction::new( + WindowFunctionDefinition::WindowUDF(rank_udwf()), + vec![], + )) + .partition_by(vec![col("col1")]) + .order_by(vec![col("col2").sort(true, false)]) + .window_frame(WindowFrame::new(Some(false))) + .distinct() + .build() + .unwrap(); + + // 11. test with filter + let test_expr11 = Expr::from(expr::WindowFunction::new( + WindowFunctionDefinition::WindowUDF(rank_udwf()), + vec![], + )) + .partition_by(vec![col("col1")]) + .order_by(vec![col("col2").sort(true, false)]) + .window_frame(WindowFrame::new(Some(false))) + .filter(col("col1").eq(lit(1))) + .build() + .unwrap(); roundtrip_expr_test(test_expr1, ctx.clone()); roundtrip_expr_test(test_expr2, ctx.clone()); @@ -2583,7 +2654,11 @@ fn roundtrip_window() { roundtrip_expr_test(test_expr4, ctx.clone()); roundtrip_expr_test(test_expr5, ctx.clone()); roundtrip_expr_test(test_expr6, ctx.clone()); - roundtrip_expr_test(text_expr7, ctx); + roundtrip_expr_test(test_expr7, ctx.clone()); + roundtrip_expr_test(test_expr8, ctx.clone()); + roundtrip_expr_test(test_expr9, ctx.clone()); + roundtrip_expr_test(test_expr10, ctx.clone()); + roundtrip_expr_test(test_expr11, ctx); } #[tokio::test]