diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index 0fe04a5a41..bdb9ff611c 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -1893,6 +1893,8 @@ impl PhysicalPlanner { let child = self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&schema))?; let datatype = to_arrow_datatype(expr.datatype.as_ref().unwrap()); let input_datatype = to_arrow_datatype(expr.sum_datatype.as_ref().unwrap()); + let eval_mode = from_protobuf_eval_mode(expr.eval_mode)?; + let builder = match datatype { DataType::Decimal128(_, _) => { let func = @@ -1900,12 +1902,15 @@ impl PhysicalPlanner { AggregateExprBuilder::new(Arc::new(func), vec![child]) } _ => { - // cast to the result data type of AVG if the result data type is different - // from the input type, e.g. AVG(Int32). We should not expect a cast - // failure since it should have already been checked at Spark side. + // For all other numeric types (Int8/16/32/64, Float32/64): + // Cast to Float64 for accumulation let child: Arc = - Arc::new(CastExpr::new(Arc::clone(&child), datatype.clone(), None)); - let func = AggregateUDF::new_from_impl(Avg::new("avg", datatype)); + Arc::new(CastExpr::new(Arc::clone(&child), DataType::Float64, None)); + let func = AggregateUDF::new_from_impl(Avg::new( + "avg", + DataType::Float64, + eval_mode, + )); AggregateExprBuilder::new(Arc::new(func), vec![child]) } }; diff --git a/native/proto/src/proto/expr.proto b/native/proto/src/proto/expr.proto index c9037dcd69..7ec4a9aebe 100644 --- a/native/proto/src/proto/expr.proto +++ b/native/proto/src/proto/expr.proto @@ -137,7 +137,7 @@ message Avg { Expr child = 1; DataType datatype = 2; DataType sum_datatype = 3; - bool fail_on_error = 4; // currently unused (useful for deciding Ansi vs Legacy mode) + EvalMode eval_mode = 4; } message First { diff --git a/native/spark-expr/src/agg_funcs/avg.rs b/native/spark-expr/src/agg_funcs/avg.rs index e8b90b4f46..9850c02605 100644 --- a/native/spark-expr/src/agg_funcs/avg.rs +++ b/native/spark-expr/src/agg_funcs/avg.rs @@ -15,11 +15,12 @@ // specific language governing permissions and limitations // under the License. +use crate::EvalMode; use arrow::array::{ builder::PrimitiveBuilder, cast::AsArray, types::{Float64Type, Int64Type}, - Array, ArrayRef, ArrowNumericType, Int64Array, PrimitiveArray, + Array, ArrayRef, ArrowNativeTypeOp, ArrowNumericType, Int64Array, PrimitiveArray, }; use arrow::compute::sum; use arrow::datatypes::{DataType, Field, FieldRef}; @@ -31,24 +32,22 @@ use datafusion::logical_expr::{ use datafusion::physical_expr::expressions::format_state_name; use std::{any::Any, sync::Arc}; -use arrow::array::ArrowNativeTypeOp; use datafusion::logical_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion::logical_expr::Volatility::Immutable; use DataType::*; -/// AVG aggregate expression #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct Avg { name: String, signature: Signature, - // expr: Arc, input_data_type: DataType, result_data_type: DataType, + eval_mode: EvalMode, } impl Avg { /// Create a new AVG aggregate function - pub fn new(name: impl Into, data_type: DataType) -> Self { + pub fn new(name: impl Into, data_type: DataType, eval_mode: EvalMode) -> Self { let result_data_type = avg_return_type("avg", &data_type).unwrap(); Self { @@ -56,20 +55,20 @@ impl Avg { signature: Signature::user_defined(Immutable), input_data_type: data_type, result_data_type, + eval_mode, } } } impl AggregateUDFImpl for Avg { - /// Return a reference to Any that can be used for downcasting fn as_any(&self) -> &dyn Any { self } fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { - // instantiate specialized accumulator based for the type + // All numeric types use Float64 accumulation after casting match (&self.input_data_type, &self.result_data_type) { - (Float64, Float64) => Ok(Box::::default()), + (Float64, Float64) => Ok(Box::new(AvgAccumulator::new(self.eval_mode))), _ => not_impl_err!( "AvgAccumulator for ({} --> {})", self.input_data_type, @@ -109,10 +108,10 @@ impl AggregateUDFImpl for Avg { &self, _args: AccumulatorArgs, ) -> Result> { - // instantiate specialized accumulator based for the type match (&self.input_data_type, &self.result_data_type) { (Float64, Float64) => Ok(Box::new(AvgGroupsAccumulator::::new( &self.input_data_type, + self.eval_mode, |sum: f64, count: i64| Ok(sum / count as f64), ))), @@ -137,11 +136,22 @@ impl AggregateUDFImpl for Avg { } } -/// An accumulator to compute the average -#[derive(Debug, Default)] +#[derive(Debug)] pub struct AvgAccumulator { sum: Option, count: i64, + #[allow(dead_code)] + eval_mode: EvalMode, +} + +impl AvgAccumulator { + pub fn new(eval_mode: EvalMode) -> Self { + Self { + sum: None, + count: 0, + eval_mode, + } + } } impl Accumulator for AvgAccumulator { @@ -166,7 +176,7 @@ impl Accumulator for AvgAccumulator { // counts are summed self.count += sum(states[1].as_primitive::()).unwrap_or_default(); - // sums are summed + // sums are summed - no overflow checking if let Some(x) = sum(states[0].as_primitive::()) { let v = self.sum.get_or_insert(0.); *v += x; @@ -176,8 +186,6 @@ impl Accumulator for AvgAccumulator { fn evaluate(&mut self) -> Result { if self.count == 0 { - // If all input are nulls, count will be 0 and we will get null after the division. - // This is consistent with Spark Average implementation. Ok(ScalarValue::Float64(None)) } else { Ok(ScalarValue::Float64( @@ -192,7 +200,7 @@ impl Accumulator for AvgAccumulator { } /// An accumulator to compute the average of `[PrimitiveArray]`. -/// Stores values as native types, and does overflow checking +/// Stores values as native types. /// /// F: Function that calculates the average value from a sum of /// T::Native and a total count @@ -211,6 +219,10 @@ where /// Sums per group, stored as the native type sums: Vec, + /// Evaluation mode (stored but not used for Float64) + #[allow(dead_code)] + eval_mode: EvalMode, + /// Function that computes the final average (value / count) avg_fn: F, } @@ -220,11 +232,12 @@ where T: ArrowNumericType + Send, F: Fn(T::Native, i64) -> Result + Send, { - pub fn new(return_data_type: &DataType, avg_fn: F) -> Self { + pub fn new(return_data_type: &DataType, eval_mode: EvalMode, avg_fn: F) -> Self { Self { return_data_type: return_data_type.clone(), counts: vec![], sums: vec![], + eval_mode, avg_fn, } } @@ -254,6 +267,7 @@ where if values.null_count() == 0 { for (&group_index, &value) in iter { let sum = &mut self.sums[group_index]; + // No overflow checking - INFINITY is a valid result *sum = (*sum).add_wrapping(value); self.counts[group_index] += 1; } @@ -264,7 +278,6 @@ where } let sum = &mut self.sums[group_index]; *sum = (*sum).add_wrapping(value); - self.counts[group_index] += 1; } } @@ -280,9 +293,9 @@ where total_num_groups: usize, ) -> Result<()> { assert_eq!(values.len(), 2, "two arguments to merge_batch"); - // first batch is partial sums, second is counts let partial_sums = values[0].as_primitive::(); let partial_counts = values[1].as_primitive::(); + // update counts with partial counts self.counts.resize(total_num_groups, 0); let iter1 = group_indices.iter().zip(partial_counts.values().iter()); @@ -290,7 +303,7 @@ where self.counts[group_index] += partial_count; } - // update sums + // update sums - no overflow checking self.sums.resize(total_num_groups, T::default_value()); let iter2 = group_indices.iter().zip(partial_sums.values().iter()); for (&group_index, &new_value) in iter2 { @@ -319,7 +332,6 @@ where Ok(Arc::new(array)) } - // return arrays for sums and counts fn state(&mut self, emit_to: EmitTo) -> Result> { let counts = emit_to.take_needed(&mut self.counts); let counts = Int64Array::new(counts.into(), None); diff --git a/spark/src/main/scala/org/apache/comet/serde/aggregates.scala b/spark/src/main/scala/org/apache/comet/serde/aggregates.scala index d00bbf4dfa..c1760a8b67 100644 --- a/spark/src/main/scala/org/apache/comet/serde/aggregates.scala +++ b/spark/src/main/scala/org/apache/comet/serde/aggregates.scala @@ -29,7 +29,8 @@ import org.apache.spark.sql.types.{ByteType, DataTypes, DecimalType, IntegerType import org.apache.comet.CometConf import org.apache.comet.CometConf.COMET_EXEC_STRICT_FLOATING_POINT import org.apache.comet.CometSparkSessionExtensions.withInfo -import org.apache.comet.serde.QueryPlanSerde.{exprToProto, serializeDataType} +import org.apache.comet.serde.QueryPlanSerde.{evalModeToProto, exprToProto, serializeDataType} +import org.apache.comet.shims.CometEvalModeUtil object CometMin extends CometAggregateExpressionSerde[Min] { @@ -150,17 +151,6 @@ object CometCount extends CometAggregateExpressionSerde[Count] { object CometAverage extends CometAggregateExpressionSerde[Average] { - override def getSupportLevel(avg: Average): SupportLevel = { - avg.evalMode match { - case EvalMode.ANSI => - Incompatible(Some("ANSI mode is not supported")) - case EvalMode.TRY => - Incompatible(Some("TRY mode is not supported")) - case _ => - Compatible() - } - } - override def convert( aggExpr: AggregateExpression, avg: Average, @@ -192,7 +182,7 @@ object CometAverage extends CometAggregateExpressionSerde[Average] { val builder = ExprOuterClass.Avg.newBuilder() builder.setChild(childExpr.get) builder.setDatatype(dataType.get) - builder.setFailOnError(avg.evalMode == EvalMode.ANSI) + builder.setEvalMode(evalModeToProto(CometEvalModeUtil.fromSparkEvalMode(avg.evalMode))) builder.setSumDatatype(sumDataType.get) Some( diff --git a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala index 7e577c5fda..a9af3bc4f1 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala @@ -1471,6 +1471,42 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } + test("AVG and try_avg - basic functionality") { + withParquetTable( + Seq( + (10L, 1), + (20L, 1), + (null.asInstanceOf[Long], 1), + (100L, 2), + (200L, 2), + (null.asInstanceOf[Long], 3)), + "tbl") { + + Seq(true, false).foreach({ k => + // without GROUP BY + withSQLConf(SQLConf.ANSI_ENABLED.key -> k.toString) { + val res = sql("SELECT avg(_1) FROM tbl") + checkSparkAnswerAndOperator(res) + } + + // with GROUP BY + withSQLConf(SQLConf.ANSI_ENABLED.key -> k.toString) { + val res = sql("SELECT _2, avg(_1) FROM tbl GROUP BY _2") + checkSparkAnswerAndOperator(res) + } + + }) + + // try_avg without GROUP BY + val resTry = sql("SELECT try_avg(_1) FROM tbl") + checkSparkAnswerAndOperator(resTry) + + // try_avg with GROUP BY + val resTryGroup = sql("SELECT _2, try_avg(_1) FROM tbl GROUP BY _2") + checkSparkAnswerAndOperator(resTryGroup) + } + } + protected def checkSparkAnswerAndNumOfAggregates(query: String, numAggregates: Int): Unit = { val df = sql(query) checkSparkAnswer(df) diff --git a/spark/src/test/scala/org/apache/spark/sql/comet/CometPlanStabilitySuite.scala b/spark/src/test/scala/org/apache/spark/sql/comet/CometPlanStabilitySuite.scala index 8f260e2ca8..1c0c8f8966 100644 --- a/spark/src/test/scala/org/apache/spark/sql/comet/CometPlanStabilitySuite.scala +++ b/spark/src/test/scala/org/apache/spark/sql/comet/CometPlanStabilitySuite.scala @@ -29,7 +29,7 @@ import org.apache.spark.SparkContext import org.apache.spark.internal.config.{MEMORY_OFFHEAP_ENABLED, MEMORY_OFFHEAP_SIZE} import org.apache.spark.sql.TPCDSBase import org.apache.spark.sql.catalyst.expressions.{AttributeSet, Cast} -import org.apache.spark.sql.catalyst.expressions.aggregate.{Average, Sum} +import org.apache.spark.sql.catalyst.expressions.aggregate.Sum import org.apache.spark.sql.catalyst.util.resourceToString import org.apache.spark.sql.execution.{FormattedMode, ReusedSubqueryExec, SparkPlan, SubqueryBroadcastExec, SubqueryExec} import org.apache.spark.sql.execution.adaptive.DisableAdaptiveExecutionSuite @@ -225,7 +225,6 @@ trait CometPlanStabilitySuite extends DisableAdaptiveExecutionSuite with TPCDSBa CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true", CometConf.COMET_EXEC_SORT_MERGE_JOIN_WITH_JOIN_FILTER_ENABLED.key -> "true", // Allow Incompatible is needed for Sum + Average for Spark 4.0.0 / ANSI support - CometConf.getExprAllowIncompatConfigKey(classOf[Average]) -> "true", CometConf.getExprAllowIncompatConfigKey(classOf[Sum]) -> "true", // as well as for v1.4/q9, v1.4/q44, v2.7.0/q6, v2.7.0/q64 CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) -> "true",