diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index 0fe04a5a41..3ab08063a8 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -1870,7 +1870,9 @@ impl PhysicalPlanner { let builder = match datatype { DataType::Decimal128(_, _) => { - let func = AggregateUDF::new_from_impl(SumDecimal::try_new(datatype)?); + let eval_mode = from_protobuf_eval_mode(expr.eval_mode)?; + let func = + AggregateUDF::new_from_impl(SumDecimal::try_new(datatype, eval_mode)?); AggregateExprBuilder::new(Arc::new(func), vec![child]) } _ => { diff --git a/native/proto/src/proto/expr.proto b/native/proto/src/proto/expr.proto index c9037dcd69..a7736f561a 100644 --- a/native/proto/src/proto/expr.proto +++ b/native/proto/src/proto/expr.proto @@ -120,7 +120,7 @@ message Count { message Sum { Expr child = 1; DataType datatype = 2; - bool fail_on_error = 3; + EvalMode eval_mode = 3; } message Min { diff --git a/native/spark-expr/benches/aggregate.rs b/native/spark-expr/benches/aggregate.rs index 3aa0233716..72628975b3 100644 --- a/native/spark-expr/benches/aggregate.rs +++ b/native/spark-expr/benches/aggregate.rs @@ -31,8 +31,8 @@ use datafusion::physical_expr::expressions::Column; use datafusion::physical_expr::PhysicalExpr; use datafusion::physical_plan::aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy}; use datafusion::physical_plan::ExecutionPlan; -use datafusion_comet_spark_expr::AvgDecimal; use datafusion_comet_spark_expr::SumDecimal; +use datafusion_comet_spark_expr::{AvgDecimal, EvalMode}; use futures::StreamExt; use std::hint::black_box; use std::sync::Arc; @@ -97,7 +97,7 @@ fn criterion_benchmark(c: &mut Criterion) { group.bench_function("sum_decimal_comet", |b| { let comet_sum_decimal = Arc::new(AggregateUDF::new_from_impl( - SumDecimal::try_new(DataType::Decimal128(38, 10)).unwrap(), + SumDecimal::try_new(DataType::Decimal128(38, 10), EvalMode::Legacy).unwrap(), )); b.to_async(&rt).iter(|| { black_box(agg_test( diff --git a/native/spark-expr/src/agg_funcs/sum_decimal.rs b/native/spark-expr/src/agg_funcs/sum_decimal.rs index cc25855902..50645391fd 100644 --- a/native/spark-expr/src/agg_funcs/sum_decimal.rs +++ b/native/spark-expr/src/agg_funcs/sum_decimal.rs @@ -15,19 +15,19 @@ // specific language governing permissions and limitations // under the License. -use crate::utils::{build_bool_state, is_valid_decimal_precision}; +use crate::utils::is_valid_decimal_precision; +use crate::{arithmetic_overflow_error, EvalMode}; use arrow::array::{ cast::AsArray, types::Decimal128Type, Array, ArrayRef, BooleanArray, Decimal128Array, }; use arrow::datatypes::{DataType, Field, FieldRef}; -use arrow::{array::BooleanBufferBuilder, buffer::NullBuffer}; use datafusion::common::{DataFusionError, Result as DFResult, ScalarValue}; use datafusion::logical_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion::logical_expr::Volatility::Immutable; use datafusion::logical_expr::{ Accumulator, AggregateUDFImpl, EmitTo, GroupsAccumulator, ReversedUDAF, Signature, }; -use std::{any::Any, ops::BitAnd, sync::Arc}; +use std::{any::Any, sync::Arc}; #[derive(Debug, PartialEq, Eq, Hash)] pub struct SumDecimal { @@ -40,11 +40,11 @@ pub struct SumDecimal { precision: u8, /// Decimal scale scale: i8, + eval_mode: EvalMode, } impl SumDecimal { - pub fn try_new(data_type: DataType) -> DFResult { - // The `data_type` is the SUM result type passed from Spark side + pub fn try_new(data_type: DataType, eval_mode: EvalMode) -> DFResult { let (precision, scale) = match data_type { DataType::Decimal128(p, s) => (p, s), _ => { @@ -58,6 +58,7 @@ impl SumDecimal { result_type: data_type, precision, scale, + eval_mode, }) } } @@ -71,19 +72,18 @@ impl AggregateUDFImpl for SumDecimal { Ok(Box::new(SumDecimalAccumulator::new( self.precision, self.scale, + self.eval_mode, ))) } fn state_fields(&self, _args: StateFieldsArgs) -> DFResult> { - let fields = vec![ - Arc::new(Field::new( - self.name(), - self.result_type.clone(), - self.is_nullable(), - )), + // For decimal sum, we always track is_empty regardless of eval_mode + // This matches Spark's behavior where DecimalType always uses shouldTrackIsEmpty = true + let data_type = self.result_type.clone(); + Ok(vec![ + Arc::new(Field::new("sum", data_type, true)), Arc::new(Field::new("is_empty", DataType::Boolean, false)), - ]; - Ok(fields) + ]) } fn name(&self) -> &str { @@ -109,6 +109,7 @@ impl AggregateUDFImpl for SumDecimal { Ok(Box::new(SumDecimalGroupsAccumulator::new( self.result_type.clone(), self.precision, + self.eval_mode, ))) } @@ -131,37 +132,48 @@ impl AggregateUDFImpl for SumDecimal { #[derive(Debug)] struct SumDecimalAccumulator { - sum: i128, + sum: Option, is_empty: bool, - is_not_null: bool, - precision: u8, scale: i8, + eval_mode: EvalMode, } impl SumDecimalAccumulator { - fn new(precision: u8, scale: i8) -> Self { + fn new(precision: u8, scale: i8, eval_mode: EvalMode) -> Self { + // For decimal sum, always track is_empty regardless of eval_mode + // This matches Spark's behavior where DecimalType always uses shouldTrackIsEmpty = true Self { - sum: 0, + sum: Some(0), is_empty: true, - is_not_null: true, precision, scale, + eval_mode, } } - fn update_single(&mut self, values: &Decimal128Array, idx: usize) { + fn update_single(&mut self, values: &Decimal128Array, idx: usize) -> DFResult<()> { + // If already overflowed (sum is None but not empty), stay in overflow state + if !self.is_empty && self.sum.is_none() { + return Ok(()); + } + let v = unsafe { values.value_unchecked(idx) }; - let (new_sum, is_overflow) = self.sum.overflowing_add(v); + let running_sum = self.sum.unwrap_or(0); + let (new_sum, is_overflow) = running_sum.overflowing_add(v); if is_overflow || !is_valid_decimal_precision(new_sum, self.precision) { - // Overflow: set buffer accumulator to null - self.is_not_null = false; - return; + if self.eval_mode == EvalMode::Ansi { + return Err(DataFusionError::from(arithmetic_overflow_error("decimal"))); + } + self.sum = None; + self.is_empty = false; + return Ok(()); } - self.sum = new_sum; - self.is_not_null = true; + self.sum = Some(new_sum); + self.is_empty = false; + Ok(()) } } @@ -174,49 +186,46 @@ impl Accumulator for SumDecimalAccumulator { values.len() ); - if !self.is_empty && !self.is_not_null { - // This means there's a overflow in decimal, so we will just skip the rest - // of the computation + // For decimal sum, always check for overflow regardless of eval_mode (per Spark's expectation) + if !self.is_empty && self.sum.is_none() { return Ok(()); } let values = &values[0]; let data = values.as_primitive::(); + // Update is_empty: it remains true only if it was true AND all values are null self.is_empty = self.is_empty && values.len() == values.null_count(); - if values.null_count() == 0 { - for i in 0..data.len() { - self.update_single(data, i); - } - } else { - for i in 0..data.len() { - if data.is_null(i) { - continue; - } - self.update_single(data, i); - } + if self.is_empty { + return Ok(()); } + for i in 0..data.len() { + if data.is_null(i) { + continue; + } + self.update_single(data, i)?; + } Ok(()) } fn evaluate(&mut self) -> DFResult { - // For each group: - // 1. if `is_empty` is true, it means either there is no value or all values for the group - // are null, in this case we'll return null - // 2. if `is_empty` is false, but `null_state` is true, it means there's an overflow. In - // non-ANSI mode Spark returns null. - if self.is_empty - || !self.is_not_null - || !is_valid_decimal_precision(self.sum, self.precision) - { + if self.is_empty { ScalarValue::new_primitive::( None, &DataType::Decimal128(self.precision, self.scale), ) } else { - ScalarValue::try_new_decimal128(self.sum, self.precision, self.scale) + match self.sum { + Some(sum_value) if is_valid_decimal_precision(sum_value, self.precision) => { + ScalarValue::try_new_decimal128(sum_value, self.precision, self.scale) + } + _ => ScalarValue::new_primitive::( + None, + &DataType::Decimal128(self.precision, self.scale), + ), + } } } @@ -225,38 +234,71 @@ impl Accumulator for SumDecimalAccumulator { } fn state(&mut self) -> DFResult> { - let sum = if self.is_not_null { - ScalarValue::try_new_decimal128(self.sum, self.precision, self.scale)? - } else { - ScalarValue::new_primitive::( + let sum = match self.sum { + Some(sum_value) => { + ScalarValue::try_new_decimal128(sum_value, self.precision, self.scale)? + } + None => ScalarValue::new_primitive::( None, &DataType::Decimal128(self.precision, self.scale), - )? + )?, }; + + // For decimal sum, always return 2 state values regardless of eval_mode Ok(vec![sum, ScalarValue::from(self.is_empty)]) } fn merge_batch(&mut self, states: &[ArrayRef]) -> DFResult<()> { + // For decimal sum, always expect 2 state arrays regardless of eval_mode assert_eq!( states.len(), 2, - "Expect two element in 'states' but found {}", + "Expect two elements in 'states' but found {}", states.len() ); assert_eq!(states[0].len(), 1); assert_eq!(states[1].len(), 1); - let that_sum = states[0].as_primitive::(); - let that_is_empty = states[1].as_any().downcast_ref::().unwrap(); + let that_sum_array = states[0].as_primitive::(); + let that_sum = if that_sum_array.is_null(0) { + None + } else { + Some(that_sum_array.value(0)) + }; + + let that_is_empty = states[1].as_boolean().value(0); + let that_overflowed = !that_is_empty && that_sum.is_none(); + let this_overflowed = !self.is_empty && self.sum.is_none(); - let this_overflow = !self.is_empty && !self.is_not_null; - let that_overflow = !that_is_empty.value(0) && that_sum.is_null(0); + if that_overflowed || this_overflowed { + self.sum = None; + self.is_empty = false; + return Ok(()); + } + + if that_is_empty { + return Ok(()); + } + + if self.is_empty { + self.sum = that_sum; + self.is_empty = false; + return Ok(()); + } - self.is_not_null = !this_overflow && !that_overflow; - self.is_empty = self.is_empty && that_is_empty.value(0); + let left = self.sum.unwrap(); + let right = that_sum.unwrap(); + let (new_sum, is_overflow) = left.overflowing_add(right); - if self.is_not_null { - self.sum += that_sum.value(0); + if is_overflow || !is_valid_decimal_precision(new_sum, self.precision) { + if self.eval_mode == EvalMode::Ansi { + return Err(DataFusionError::from(arithmetic_overflow_error("decimal"))); + } else { + self.sum = None; + self.is_empty = false; + } + } else { + self.sum = Some(new_sum); } Ok(()) @@ -264,46 +306,50 @@ impl Accumulator for SumDecimalAccumulator { } struct SumDecimalGroupsAccumulator { - // Whether aggregate buffer for a particular group is null. True indicates it is not null. - is_not_null: BooleanBufferBuilder, - is_empty: BooleanBufferBuilder, - sum: Vec, + sum: Vec>, + is_empty: Vec, result_type: DataType, precision: u8, + eval_mode: EvalMode, } impl SumDecimalGroupsAccumulator { - fn new(result_type: DataType, precision: u8) -> Self { + fn new(result_type: DataType, precision: u8, eval_mode: EvalMode) -> Self { Self { - is_not_null: BooleanBufferBuilder::new(0), - is_empty: BooleanBufferBuilder::new(0), sum: Vec::new(), + is_empty: Vec::new(), result_type, precision, + eval_mode, } } - fn is_overflow(&self, index: usize) -> bool { - !self.is_empty.get_bit(index) && !self.is_not_null.get_bit(index) + fn resize_helper(&mut self, total_num_groups: usize) { + // For decimal sum, always initialize properly regardless of eval_mode + self.sum.resize(total_num_groups, Some(0)); + self.is_empty.resize(total_num_groups, true); } #[inline] - fn update_single(&mut self, group_index: usize, value: i128) { - self.is_empty.set_bit(group_index, false); - let (new_sum, is_overflow) = self.sum[group_index].overflowing_add(value); - self.sum[group_index] = new_sum; + fn update_single(&mut self, group_index: usize, value: i128) -> DFResult<()> { + // For decimal sum, always check for overflow regardless of eval_mode + if !self.is_empty[group_index] && self.sum[group_index].is_none() { + return Ok(()); + } + + let running_sum = self.sum[group_index].unwrap_or(0); + let (new_sum, is_overflow) = running_sum.overflowing_add(value); if is_overflow || !is_valid_decimal_precision(new_sum, self.precision) { - // Overflow: set buffer accumulator to null - self.is_not_null.set_bit(group_index, false); + if self.eval_mode == EvalMode::Ansi { + return Err(DataFusionError::from(arithmetic_overflow_error("decimal"))); + } + self.sum[group_index] = None; + } else { + self.sum[group_index] = Some(new_sum); } - } -} - -fn ensure_bit_capacity(builder: &mut BooleanBufferBuilder, capacity: usize) { - if builder.len() < capacity { - let additional = capacity - builder.len(); - builder.append_n(additional, true); + self.is_empty[group_index] = false; + Ok(()) } } @@ -320,22 +366,19 @@ impl GroupsAccumulator for SumDecimalGroupsAccumulator { let values = values[0].as_primitive::(); let data = values.values(); - // Update size for the accumulate states - self.sum.resize(total_num_groups, 0); - ensure_bit_capacity(&mut self.is_empty, total_num_groups); - ensure_bit_capacity(&mut self.is_not_null, total_num_groups); + self.resize_helper(total_num_groups); let iter = group_indices.iter().zip(data.iter()); if values.null_count() == 0 { for (&group_index, &value) in iter { - self.update_single(group_index, value); + self.update_single(group_index, value)?; } } else { for (idx, (&group_index, &value)) in iter.enumerate() { if values.is_null(idx) { continue; } - self.update_single(group_index, value); + self.update_single(group_index, value)?; } } @@ -343,42 +386,65 @@ impl GroupsAccumulator for SumDecimalGroupsAccumulator { } fn evaluate(&mut self, emit_to: EmitTo) -> DFResult { - // For each group: - // 1. if `is_empty` is true, it means either there is no value or all values for the group - // are null, in this case we'll return null - // 2. if `is_empty` is false, but `null_state` is true, it means there's an overflow. In - // non-ANSI mode Spark returns null. - let result = emit_to.take_needed(&mut self.sum); - result.iter().enumerate().for_each(|(i, &v)| { - if !is_valid_decimal_precision(v, self.precision) { - self.is_not_null.set_bit(i, false); + match emit_to { + EmitTo::All => { + let result = + Decimal128Array::from_iter(self.sum.iter().zip(self.is_empty.iter()).map( + |(&sum, &empty)| { + if empty { + None + } else { + match sum { + Some(v) if is_valid_decimal_precision(v, self.precision) => { + Some(v) + } + _ => None, + } + } + }, + )) + .with_data_type(self.result_type.clone()); + + self.sum.clear(); + self.is_empty.clear(); + Ok(Arc::new(result)) } - }); - - let nulls = build_bool_state(&mut self.is_not_null, &emit_to); - let is_empty = build_bool_state(&mut self.is_empty, &emit_to); - let x = (!&is_empty).bitand(&nulls); - - let result = Decimal128Array::new(result.into(), Some(NullBuffer::new(x))) - .with_data_type(self.result_type.clone()); - - Ok(Arc::new(result)) + EmitTo::First(n) => { + let result = Decimal128Array::from_iter( + self.sum + .drain(..n) + .zip(self.is_empty.drain(..n)) + .map(|(sum, empty)| { + if empty { + None + } else { + match sum { + Some(v) if is_valid_decimal_precision(v, self.precision) => { + Some(v) + } + _ => None, + } + } + }), + ) + .with_data_type(self.result_type.clone()); + + Ok(Arc::new(result)) + } + } } fn state(&mut self, emit_to: EmitTo) -> DFResult> { - let nulls = build_bool_state(&mut self.is_not_null, &emit_to); - let nulls = Some(NullBuffer::new(nulls)); + let sums = emit_to.take_needed(&mut self.sum); - let sum = emit_to.take_needed(&mut self.sum); - let sum = Decimal128Array::new(sum.into(), nulls.clone()) + let sum_array = Decimal128Array::from_iter(sums.iter().copied()) .with_data_type(self.result_type.clone()); - let is_empty = build_bool_state(&mut self.is_empty, &emit_to); - let is_empty = BooleanArray::new(is_empty, None); - + // For decimal sum, always return 2 state arrays regardless of eval_mode + let is_empty = emit_to.take_needed(&mut self.is_empty); Ok(vec![ - Arc::new(sum) as ArrayRef, - Arc::new(is_empty) as ArrayRef, + Arc::new(sum_array), + Arc::new(BooleanArray::from(is_empty)), ]) } @@ -389,57 +455,70 @@ impl GroupsAccumulator for SumDecimalGroupsAccumulator { opt_filter: Option<&BooleanArray>, total_num_groups: usize, ) -> DFResult<()> { + assert!(opt_filter.is_none(), "opt_filter is not supported yet"); + + self.resize_helper(total_num_groups); + + // For decimal sum, always expect 2 arrays regardless of eval_mode assert_eq!( values.len(), 2, "Expected two arrays: 'sum' and 'is_empty', but found {}", values.len() ); - assert!(opt_filter.is_none(), "opt_filter is not supported yet"); - // Make sure we have enough capacity for the additional groups - self.sum.resize(total_num_groups, 0); - ensure_bit_capacity(&mut self.is_empty, total_num_groups); - ensure_bit_capacity(&mut self.is_not_null, total_num_groups); - - let that_sum = &values[0]; - let that_sum = that_sum.as_primitive::(); - let that_is_empty = &values[1]; - let that_is_empty = that_is_empty - .as_any() - .downcast_ref::() - .unwrap(); + let that_sum = values[0].as_primitive::(); + let that_is_empty = values[1].as_boolean(); + + for (idx, &group_index) in group_indices.iter().enumerate() { + let that_sum_val = if that_sum.is_null(idx) { + None + } else { + Some(that_sum.value(idx)) + }; - group_indices - .iter() - .enumerate() - .for_each(|(idx, &group_index)| unsafe { - let this_overflow = self.is_overflow(group_index); - let that_is_empty = that_is_empty.value_unchecked(idx); - let that_overflow = !that_is_empty && that_sum.is_null(idx); - let is_overflow = this_overflow || that_overflow; - - // This part follows the logic in Spark: - // `org.apache.spark.sql.catalyst.expressions.aggregate.Sum` - self.is_not_null.set_bit(group_index, !is_overflow); - self.is_empty.set_bit( - group_index, - self.is_empty.get_bit(group_index) && that_is_empty, - ); - if !is_overflow { - // .. otherwise, the sum value for this particular index must not be null, - // and thus we merge both values and update this sum. - self.sum[group_index] += that_sum.value_unchecked(idx); + let that_is_empty_val = that_is_empty.value(idx); + let that_overflowed = !that_is_empty_val && that_sum_val.is_none(); + let this_overflowed = !self.is_empty[group_index] && self.sum[group_index].is_none(); + + if that_overflowed || this_overflowed { + self.sum[group_index] = None; + self.is_empty[group_index] = false; + continue; + } + + if that_is_empty_val { + continue; + } + + if self.is_empty[group_index] { + self.sum[group_index] = that_sum_val; + self.is_empty[group_index] = false; + continue; + } + + let left = self.sum[group_index].unwrap(); + let right = that_sum_val.unwrap(); + let (new_sum, is_overflow) = left.overflowing_add(right); + + if is_overflow || !is_valid_decimal_precision(new_sum, self.precision) { + if self.eval_mode == EvalMode::Ansi { + return Err(DataFusionError::from(arithmetic_overflow_error("decimal"))); + } else { + self.sum[group_index] = None; + self.is_empty[group_index] = false; } - }); + } else { + self.sum[group_index] = Some(new_sum); + } + } Ok(()) } fn size(&self) -> usize { - self.sum.capacity() * std::mem::size_of::() - + self.is_empty.capacity() / 8 - + self.is_not_null.capacity() / 8 + self.sum.capacity() * std::mem::size_of::>() + + self.is_empty.capacity() * std::mem::size_of::() } } @@ -463,7 +542,7 @@ mod tests { #[test] fn invalid_data_type() { - assert!(SumDecimal::try_new(DataType::Int32).is_err()); + assert!(SumDecimal::try_new(DataType::Int32, EvalMode::Legacy).is_err()); } #[tokio::test] @@ -486,6 +565,7 @@ mod tests { let aggregate_udf = Arc::new(AggregateUDF::new_from_impl(SumDecimal::try_new( data_type.clone(), + EvalMode::Legacy, )?)); let aggr_expr = AggregateExprBuilder::new(aggregate_udf, vec![c1]) diff --git a/spark/src/main/scala/org/apache/comet/serde/aggregates.scala b/spark/src/main/scala/org/apache/comet/serde/aggregates.scala index d00bbf4dfa..8ab568dc83 100644 --- a/spark/src/main/scala/org/apache/comet/serde/aggregates.scala +++ b/spark/src/main/scala/org/apache/comet/serde/aggregates.scala @@ -29,7 +29,8 @@ import org.apache.spark.sql.types.{ByteType, DataTypes, DecimalType, IntegerType import org.apache.comet.CometConf import org.apache.comet.CometConf.COMET_EXEC_STRICT_FLOATING_POINT import org.apache.comet.CometSparkSessionExtensions.withInfo -import org.apache.comet.serde.QueryPlanSerde.{exprToProto, serializeDataType} +import org.apache.comet.serde.QueryPlanSerde.{evalModeToProto, exprToProto, serializeDataType} +import org.apache.comet.shims.CometEvalModeUtil object CometMin extends CometAggregateExpressionSerde[Min] { @@ -214,10 +215,10 @@ object CometSum extends CometAggregateExpressionSerde[Sum] { override def getSupportLevel(sum: Sum): SupportLevel = { sum.evalMode match { - case EvalMode.ANSI => - Incompatible(Some("ANSI mode is not supported")) - case EvalMode.TRY => - Incompatible(Some("TRY mode is not supported")) + case EvalMode.ANSI if !sum.dataType.isInstanceOf[DecimalType] => + Incompatible(Some("ANSI mode for non decimal inputs is not supported")) + case EvalMode.TRY if !sum.dataType.isInstanceOf[DecimalType] => + Incompatible(Some("TRY mode for non decimal inputs is not supported")) case _ => Compatible() } @@ -242,7 +243,7 @@ object CometSum extends CometAggregateExpressionSerde[Sum] { val builder = ExprOuterClass.Sum.newBuilder() builder.setChild(childExpr.get) builder.setDatatype(dataType.get) - builder.setFailOnError(sum.evalMode == EvalMode.ANSI) + builder.setEvalMode(evalModeToProto(CometEvalModeUtil.fromSparkEvalMode(sum.evalMode))) Some( ExprOuterClass.AggExpr diff --git a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala index 7e577c5fda..060579b2ba 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala @@ -24,10 +24,11 @@ import scala.util.Random import org.apache.hadoop.fs.Path import org.apache.spark.sql.{CometTestBase, DataFrame, Row} import org.apache.spark.sql.catalyst.expressions.Cast +import org.apache.spark.sql.catalyst.expressions.aggregate.Sum import org.apache.spark.sql.catalyst.optimizer.EliminateSorts import org.apache.spark.sql.comet.CometHashAggregateExec import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper -import org.apache.spark.sql.functions.{avg, count_distinct, sum} +import org.apache.spark.sql.functions.{avg, col, count_distinct, sum} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DataTypes, StructField, StructType} @@ -1471,6 +1472,168 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } + test("ANSI support for decimal sum - null test") { + Seq(true, false).foreach { ansiEnabled => + withSQLConf( + SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString, + CometConf.getExprAllowIncompatConfigKey(classOf[Sum]) -> "true") { + withParquetTable( + Seq( + (null.asInstanceOf[java.math.BigDecimal], "a"), + (null.asInstanceOf[java.math.BigDecimal], "b")), + "null_tbl") { + val res = sql("SELECT sum(_1) FROM null_tbl") + checkSparkAnswerAndOperator(res) + assert(res.collect() === Array(Row(null))) + } + } + } + } + + test("ANSI support for try_sum decimal - null test") { + Seq(true, false).foreach { ansiEnabled => + withSQLConf( + SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString, + CometConf.getExprAllowIncompatConfigKey(classOf[Sum]) -> "true") { + withParquetTable( + Seq( + (null.asInstanceOf[java.math.BigDecimal], "a"), + (null.asInstanceOf[java.math.BigDecimal], "b")), + "null_tbl") { + val res = sql("SELECT try_sum(_1) FROM null_tbl") + checkSparkAnswerAndOperator(res) + assert(res.collect() === Array(Row(null))) + } + } + } + } + + test("ANSI support for decimal sum - null test (group by)") { + Seq(true, false).foreach { ansiEnabled => + withSQLConf( + SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString, + CometConf.getExprAllowIncompatConfigKey(classOf[Sum]) -> "true") { + withParquetTable( + Seq( + (null.asInstanceOf[java.math.BigDecimal], "a"), + (null.asInstanceOf[java.math.BigDecimal], "a"), + (null.asInstanceOf[java.math.BigDecimal], "b"), + (null.asInstanceOf[java.math.BigDecimal], "b"), + (null.asInstanceOf[java.math.BigDecimal], "b")), + "tbl") { + val res = sql("SELECT _2, sum(_1) FROM tbl group by 1") + checkSparkAnswerAndOperator(res) + assert(res.orderBy(col("_2")).collect() === Array(Row("a", null), Row("b", null))) + } + } + } + } + + test("ANSI support for try_sum decimal - null test (group by)") { + Seq(true, false).foreach { ansiEnabled => + withSQLConf( + SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString, + CometConf.getExprAllowIncompatConfigKey(classOf[Sum]) -> "true") { + withParquetTable( + Seq( + (null.asInstanceOf[java.math.BigDecimal], "a"), + (null.asInstanceOf[java.math.BigDecimal], "a"), + (null.asInstanceOf[java.math.BigDecimal], "b"), + (null.asInstanceOf[java.math.BigDecimal], "b"), + (null.asInstanceOf[java.math.BigDecimal], "b")), + "tbl") { + val res = sql("SELECT _2, try_sum(_1) FROM tbl group by 1") + checkSparkAnswerAndOperator(res) + assert(res.orderBy(col("_2")).collect() === Array(Row("a", null), Row("b", null))) + } + } + } + } + + protected def generateOverflowDecimalInputs: Seq[(java.math.BigDecimal, Int)] = { + val maxDec38_0 = new java.math.BigDecimal("99999999999999999999") + (1 to 50).flatMap(_ => Seq((maxDec38_0, 1))) + } + + test("ANSI support for decimal SUM function") { + Seq(true, false).foreach { ansiEnabled => + withSQLConf( + SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString, + CometConf.getExprAllowIncompatConfigKey(classOf[Sum]) -> "true") { + withParquetTable(generateOverflowDecimalInputs, "tbl") { + val res = sql("SELECT SUM(_1) FROM tbl") + if (ansiEnabled) { + checkSparkAnswerMaybeThrows(res) match { + case (Some(sparkExc), Some(cometExc)) => + assert(sparkExc.getMessage.contains("ARITHMETIC_OVERFLOW")) + assert(cometExc.getMessage.contains("ARITHMETIC_OVERFLOW")) + case _ => + fail("Exception should be thrown for decimal overflow in ANSI mode") + } + } else { + checkSparkAnswerAndOperator(res) + } + } + } + } + } + + test("ANSI support for decimal SUM - GROUP BY") { + Seq(true, false).foreach { ansiEnabled => + withSQLConf( + SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString, + CometConf.getExprAllowIncompatConfigKey(classOf[Sum]) -> "true") { + withParquetTable(generateOverflowDecimalInputs, "tbl") { + val res = + sql("SELECT _2, SUM(_1) FROM tbl GROUP BY _2").repartition(2) + if (ansiEnabled) { + checkSparkAnswerMaybeThrows(res) match { + case (Some(sparkExc), Some(cometExc)) => + assert(sparkExc.getMessage.contains("ARITHMETIC_OVERFLOW")) + assert(cometExc.getMessage.contains("ARITHMETIC_OVERFLOW")) + case _ => + fail("Exception should be thrown for decimal overflow with GROUP BY in ANSI mode") + } + } else { + checkSparkAnswerAndOperator(res) + } + } + } + } + } + + test("try_sum decimal overflow") { + withSQLConf(CometConf.getExprAllowIncompatConfigKey(classOf[Sum]) -> "true") { + withParquetTable(generateOverflowDecimalInputs, "tbl") { + val res = sql("SELECT try_sum(_1) FROM tbl") + checkSparkAnswerAndOperator(res) + } + } + } + + test("try_sum decimal overflow - with GROUP BY") { + withSQLConf(CometConf.getExprAllowIncompatConfigKey(classOf[Sum]) -> "true") { + withParquetTable(generateOverflowDecimalInputs, "tbl") { + val res = sql("SELECT _2, try_sum(_1) FROM tbl GROUP BY _2").repartition(2, col("_2")) + checkSparkAnswerAndOperator(res) + } + } + } + + test("try_sum decimal partial overflow - with GROUP BY") { + withSQLConf(CometConf.getExprAllowIncompatConfigKey(classOf[Sum]) -> "true") { + // Group 1 overflows, Group 2 succeeds + val data: Seq[(java.math.BigDecimal, Int)] = generateOverflowDecimalInputs ++ Seq( + (new java.math.BigDecimal(300), 2), + (new java.math.BigDecimal(200), 2)) + withParquetTable(data, "tbl") { + val res = sql("SELECT _2, try_sum(_1) FROM tbl GROUP BY _2") + // Group 1 should be NULL, Group 2 should be 500 + checkSparkAnswerAndOperator(res) + } + } + } + protected def checkSparkAnswerAndNumOfAggregates(query: String, numAggregates: Int): Unit = { val df = sql(query) checkSparkAnswer(df)