diff --git a/native/spark-expr/src/math_funcs/internal/make_decimal.rs b/native/spark-expr/src/math_funcs/internal/make_decimal.rs index 8feba54f5a..3383175202 100644 --- a/native/spark-expr/src/math_funcs/internal/make_decimal.rs +++ b/native/spark-expr/src/math_funcs/internal/make_decimal.rs @@ -40,18 +40,21 @@ pub fn spark_make_decimal( ))), sv => internal_err!("Expected Int64 but found {sv:?}"), }, - ColumnarValue::Array(a) => { - let arr = a.as_primitive::(); - let mut result = Decimal128Builder::new(); - for v in arr.into_iter() { - result.append_option(long_to_decimal(&v, precision)) - } - let result_type = DataType::Decimal128(precision, scale); + ColumnarValue::Array(a) => match a.data_type() { + DataType::Int64 => { + let arr = a.as_primitive::(); + let mut result = Decimal128Builder::new(); + for v in arr.into_iter() { + result.append_option(long_to_decimal(&v, precision)) + } + let result_type = DataType::Decimal128(precision, scale); - Ok(ColumnarValue::Array(Arc::new( - result.finish().with_data_type(result_type), - ))) - } + Ok(ColumnarValue::Array(Arc::new( + result.finish().with_data_type(result_type), + ))) + } + av => internal_err!("Expected Int64 but found {av:?}"), + }, } } diff --git a/spark/src/main/scala/org/apache/comet/serde/decimalExpressions.scala b/spark/src/main/scala/org/apache/comet/serde/decimalExpressions.scala index c606d1ac5b..9a019b9035 100644 --- a/spark/src/main/scala/org/apache/comet/serde/decimalExpressions.scala +++ b/spark/src/main/scala/org/apache/comet/serde/decimalExpressions.scala @@ -38,6 +38,14 @@ object CometUnscaledValue extends CometExpressionSerde[UnscaledValue] { } object CometMakeDecimal extends CometExpressionSerde[MakeDecimal] { + + override def getSupportLevel(expr: MakeDecimal): SupportLevel = { + expr.child.dataType match { + case _: LongType => Compatible() + case other => Unsupported(Some(s"Unsupported input data type: $other")) + } + } + override def convert( expr: MakeDecimal, inputs: Seq[Attribute], diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index b0c718a2b6..b8d0864749 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -29,6 +29,7 @@ import org.scalatest.Tag import org.apache.hadoop.fs.Path import org.apache.spark.sql.{CometTestBase, DataFrame, Row} import org.apache.spark.sql.catalyst.expressions.{Alias, Cast, FromUnixTime, Literal, TruncDate, TruncTimestamp} +import org.apache.spark.sql.catalyst.expressions.aggregate.Sum import org.apache.spark.sql.catalyst.optimizer.SimplifyExtractValueOps import org.apache.spark.sql.comet.{CometColumnarToRowExec, CometProjectExec} import org.apache.spark.sql.execution.{InputAdapter, ProjectExec, SparkPlan, WholeStageCodegenExec} @@ -3187,4 +3188,30 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { CometConcat.unsupportedReason) } } + + // https://github.com/apache/datafusion-comet/issues/2813 + test("make decimal using DataFrame API") { + withTable("t1") { + sql("create table t1 using parquet as select 123456 as c1 from range(1)") + + withSQLConf( + CometConf.COMET_EXEC_ENABLED.key -> "true", + SQLConf.USE_V1_SOURCE_LIST.key -> "parquet", + CometConf.COMET_ENABLED.key -> "true", + CometConf.COMET_EXPLAIN_FALLBACK_ENABLED.key -> "true", + SQLConf.ANSI_ENABLED.key -> "false", + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false", + CometConf.getExprAllowIncompatConfigKey(classOf[Sum]) -> "true", + CometConf.COMET_NATIVE_SCAN_IMPL.key -> CometConf.SCAN_NATIVE_ICEBERG_COMPAT, + SQLConf.ADAPTIVE_OPTIMIZER_EXCLUDED_RULES.key -> "org.apache.spark.sql.catalyst.optimizer.ConstantFolding") { + + val df = sql("select * from t1") + val makeDecimalColumn = createMakeDecimalColumn(df.col("c1").expr, 3, 0) + val df1 = df.withColumn("result", makeDecimalColumn) + + checkSparkAnswerAndFallbackReason(df1, "Unsupported input data type: IntegerType") + } + } + } + } diff --git a/spark/src/test/spark-3.4/org/apache/spark/sql/ShimCometTestBase.scala b/spark/src/test/spark-3.4/org/apache/spark/sql/ShimCometTestBase.scala index b8ecfacb31..a7dfb42645 100644 --- a/spark/src/test/spark-3.4/org/apache/spark/sql/ShimCometTestBase.scala +++ b/spark/src/test/spark-3.4/org/apache/spark/sql/ShimCometTestBase.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql import org.apache.spark.SparkConf -import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.{Expression, MakeDecimal} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan trait ShimCometTestBase { @@ -46,4 +46,8 @@ trait ShimCometTestBase { def extractLogicalPlan(df: DataFrame): LogicalPlan = { df.logicalPlan } + + def createMakeDecimalColumn(child: Expression, precision: Int, scale: Int): Column = { + new Column(MakeDecimal(child, precision, scale)) + } } diff --git a/spark/src/test/spark-3.5/org/apache/spark/sql/ShimCometTestBase.scala b/spark/src/test/spark-3.5/org/apache/spark/sql/ShimCometTestBase.scala index f2b4195565..7f22494ad2 100644 --- a/spark/src/test/spark-3.5/org/apache/spark/sql/ShimCometTestBase.scala +++ b/spark/src/test/spark-3.5/org/apache/spark/sql/ShimCometTestBase.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql import org.apache.spark.SparkConf -import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.{Expression, MakeDecimal} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan trait ShimCometTestBase { @@ -47,4 +47,8 @@ trait ShimCometTestBase { df.logicalPlan } + def createMakeDecimalColumn(child: Expression, precision: Int, scale: Int): Column = { + new Column(MakeDecimal(child, precision, scale)) + } + } diff --git a/spark/src/test/spark-4.0/org/apache/spark/sql/ShimCometTestBase.scala b/spark/src/test/spark-4.0/org/apache/spark/sql/ShimCometTestBase.scala index 8fb2e69705..5ad4543220 100644 --- a/spark/src/test/spark-4.0/org/apache/spark/sql/ShimCometTestBase.scala +++ b/spark/src/test/spark-4.0/org/apache/spark/sql/ShimCometTestBase.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql import org.apache.spark.SparkConf -import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.{Expression, MakeDecimal} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.classic.{Dataset, ExpressionColumnNode, SparkSession} @@ -47,4 +47,8 @@ trait ShimCometTestBase { def extractLogicalPlan(df: DataFrame): LogicalPlan = { df.queryExecution.analyzed } + + def createMakeDecimalColumn(child: Expression, precision: Int, scale: Int): Column = { + new Column(ExpressionColumnNode.apply(MakeDecimal(child, precision, scale, true))) + } }