Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 14 additions & 11 deletions native/spark-expr/src/math_funcs/internal/make_decimal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,18 +40,21 @@ pub fn spark_make_decimal(
))),
sv => internal_err!("Expected Int64 but found {sv:?}"),
},
ColumnarValue::Array(a) => {
let arr = a.as_primitive::<Int64Type>();
let mut result = Decimal128Builder::new();
for v in arr.into_iter() {
result.append_option(long_to_decimal(&v, precision))
}
let result_type = DataType::Decimal128(precision, scale);
ColumnarValue::Array(a) => match a.data_type() {
DataType::Int64 => {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we also support Int32? 🤔

let arr = a.as_primitive::<Int64Type>();
let mut result = Decimal128Builder::new();
for v in arr.into_iter() {
result.append_option(long_to_decimal(&v, precision))
}
let result_type = DataType::Decimal128(precision, scale);

Ok(ColumnarValue::Array(Arc::new(
result.finish().with_data_type(result_type),
)))
}
Ok(ColumnarValue::Array(Arc::new(
result.finish().with_data_type(result_type),
)))
}
av => internal_err!("Expected Int64 but found {av:?}"),
},
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,14 @@ object CometUnscaledValue extends CometExpressionSerde[UnscaledValue] {
}

object CometMakeDecimal extends CometExpressionSerde[MakeDecimal] {

override def getSupportLevel(expr: MakeDecimal): SupportLevel = {
expr.child.dataType match {
case _: LongType => Compatible()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
case _: LongType => Compatible()
case LongType => Compatible()

because LongType is an object/singleton

case other => Unsupported(Some(s"Unsupported input data type: $other"))
}
}

override def convert(
expr: MakeDecimal,
inputs: Seq[Attribute],
Expand Down
27 changes: 27 additions & 0 deletions spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import org.scalatest.Tag
import org.apache.hadoop.fs.Path
import org.apache.spark.sql.{CometTestBase, DataFrame, Row}
import org.apache.spark.sql.catalyst.expressions.{Alias, Cast, FromUnixTime, Literal, TruncDate, TruncTimestamp}
import org.apache.spark.sql.catalyst.expressions.aggregate.Sum
import org.apache.spark.sql.catalyst.optimizer.SimplifyExtractValueOps
import org.apache.spark.sql.comet.{CometColumnarToRowExec, CometProjectExec}
import org.apache.spark.sql.execution.{InputAdapter, ProjectExec, SparkPlan, WholeStageCodegenExec}
Expand Down Expand Up @@ -3187,4 +3188,30 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
CometConcat.unsupportedReason)
}
}

// https://github.com/apache/datafusion-comet/issues/2813
test("make decimal using DataFrame API") {
withTable("t1") {
sql("create table t1 using parquet as select 123456 as c1 from range(1)")

withSQLConf(
CometConf.COMET_EXEC_ENABLED.key -> "true",
SQLConf.USE_V1_SOURCE_LIST.key -> "parquet",
CometConf.COMET_ENABLED.key -> "true",
CometConf.COMET_EXPLAIN_FALLBACK_ENABLED.key -> "true",
SQLConf.ANSI_ENABLED.key -> "false",
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false",
CometConf.getExprAllowIncompatConfigKey(classOf[Sum]) -> "true",
CometConf.COMET_NATIVE_SCAN_IMPL.key -> CometConf.SCAN_NATIVE_ICEBERG_COMPAT,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are all settings needed for this test ? E.g. Sum and Iceberg compat look unrelated ?

SQLConf.ADAPTIVE_OPTIMIZER_EXCLUDED_RULES.key -> "org.apache.spark.sql.catalyst.optimizer.ConstantFolding") {

val df = sql("select * from t1")
val makeDecimalColumn = createMakeDecimalColumn(df.col("c1").expr, 3, 0)
val df1 = df.withColumn("result", makeDecimalColumn)

checkSparkAnswerAndFallbackReason(df1, "Unsupported input data type: IntegerType")
}
}
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO it would be useful to add a positive test too - with LongType


}
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
package org.apache.spark.sql

import org.apache.spark.SparkConf
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.expressions.{Expression, MakeDecimal}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan

trait ShimCometTestBase {
Expand All @@ -46,4 +46,8 @@ trait ShimCometTestBase {
def extractLogicalPlan(df: DataFrame): LogicalPlan = {
df.logicalPlan
}

def createMakeDecimalColumn(child: Expression, precision: Int, scale: Int): Column = {
new Column(MakeDecimal(child, precision, scale))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
package org.apache.spark.sql

import org.apache.spark.SparkConf
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.expressions.{Expression, MakeDecimal}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan

trait ShimCometTestBase {
Expand All @@ -47,4 +47,8 @@ trait ShimCometTestBase {
df.logicalPlan
}

def createMakeDecimalColumn(child: Expression, precision: Int, scale: Int): Column = {
new Column(MakeDecimal(child, precision, scale))
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
package org.apache.spark.sql

import org.apache.spark.SparkConf
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.expressions.{Expression, MakeDecimal}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.classic.{Dataset, ExpressionColumnNode, SparkSession}

Expand All @@ -47,4 +47,8 @@ trait ShimCometTestBase {
def extractLogicalPlan(df: DataFrame): LogicalPlan = {
df.queryExecution.analyzed
}

def createMakeDecimalColumn(child: Expression, precision: Int, scale: Int): Column = {
new Column(ExpressionColumnNode.apply(MakeDecimal(child, precision, scale, true)))
}
}
Loading