diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/utils/FlinkRelUtil.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/utils/FlinkRelUtil.java index ee39cdb807bb6..9ce56b6a6b28d 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/utils/FlinkRelUtil.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/utils/FlinkRelUtil.java @@ -24,19 +24,25 @@ import org.apache.calcite.rel.core.Project; import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rex.RexBuilder; +import org.apache.calcite.rex.RexCall; import org.apache.calcite.rex.RexInputRef; +import org.apache.calcite.rex.RexLocalRef; import org.apache.calcite.rex.RexNode; import org.apache.calcite.rex.RexProgram; import org.apache.calcite.rex.RexProgramBuilder; import org.apache.calcite.rex.RexSlot; import org.apache.calcite.rex.RexUtil; import org.apache.calcite.rex.RexVisitorImpl; +import org.apache.calcite.sql.SqlFunction; +import org.apache.calcite.sql.SqlKind; import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Set; import java.util.stream.Collectors; /** Utilities for {@link RelNode}. */ @@ -93,6 +99,9 @@ public static boolean isMergeable(Project topProject, Project bottomProject) { final int[] topInputRefCounter = initializeArray(topProject.getInput().getRowType().getFieldCount(), 0); + if (functionResultShouldBeReused(topProject, bottomProject)) { + return false; + } return mergeable(topInputRefCounter, topProject.getProjects(), bottomProject.getProjects()); } @@ -104,6 +113,10 @@ public static boolean isMergeable(Project topProject, Project bottomProject) { public static boolean isMergeable(Calc topCalc, Calc bottomCalc) { final RexProgram topProgram = topCalc.getProgram(); final RexProgram bottomProgram = bottomCalc.getProgram(); + if (functionResultShouldBeReused(topProgram, bottomProgram)) { + return false; + } + final int[] topInputRefCounter = initializeArray(topCalc.getInput().getRowType().getFieldCount(), 0); @@ -122,6 +135,84 @@ public static boolean isMergeable(Calc topCalc, Calc bottomCalc) { return mergeable(topInputRefCounter, topInputRefs, bottomProjects); } + private static boolean functionResultShouldBeReused(Project topProject, Project bottomProject) { + Set indexSet = new HashSet<>(); + List bottomProjectList = bottomProject.getProjects(); + for (int i = 0; i < bottomProjectList.size(); i++) { + RexNode project = bottomProjectList.get(i); + if (project instanceof RexCall + // && SqlKind.FUNCTION.contains(((RexCall) project).op.getKind()) + && ((RexCall) project).op.isDeterministic()) { + indexSet.add(i); + } + } + if (indexSet.isEmpty()) { + return false; + } + + Set rexNodes = new HashSet<>(); + List topProjectList = topProject.getProjects(); + for (RexNode rex : topProjectList) { + if (!(rex instanceof RexCall)) { + continue; + } + RexCall rCall = (RexCall) rex; + if (!(rCall.op instanceof SqlFunction)) { + continue; + } + List operands = rCall.operands; + for (RexNode op : operands) { + if (op instanceof RexSlot) { + if (indexSet.contains(((RexSlot) op).getIndex()) && !rexNodes.add(op)) { + return true; + } + } + } + } + + return false; + } + + private static boolean functionResultShouldBeReused( + RexProgram topProgram, RexProgram bottomProgram) { + Set indexSet = new HashSet<>(); + List bottomProjectList = bottomProgram.getProjectList(); + for (int i = 0; i < bottomProjectList.size(); i++) { + int index = bottomProjectList.get(i).getIndex(); + RexNode rexNode = bottomProgram.getExprList().get(index); + if (rexNode instanceof RexCall + && SqlKind.FUNCTION.contains(((RexCall) rexNode).op.getKind()) + && ((RexCall) rexNode).op.isDeterministic()) { + indexSet.add(i); + } + } + if (indexSet.isEmpty()) { + return false; + } + + Set rexNodes = new HashSet<>(); + List topExprList = topProgram.getExprList(); + for (RexNode rex : topExprList) { + if (!(rex instanceof RexCall)) { + continue; + } + RexCall rCall = (RexCall) rex; + if (!(rCall.op instanceof SqlFunction)) { + continue; + } + List operands = rCall.operands; + for (RexNode op : operands) { + if (op instanceof RexSlot) { + if (indexSet.contains(((RexSlot) op).getIndex()) && !rexNodes.add(op)) { + return true; + } + } + } + } + + return false; + } + /** * Merges the programs of two {@link Calc} instances and returns a new {@link Calc} instance * with the merged program. diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/nodes/common/CommonCalc.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/nodes/common/CommonCalc.scala index e78c5a57287cf..a1e501ddf3d09 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/nodes/common/CommonCalc.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/nodes/common/CommonCalc.scala @@ -27,8 +27,8 @@ import org.apache.calcite.rel.{RelNode, RelWriter} import org.apache.calcite.rel.core.Calc import org.apache.calcite.rel.hint.RelHint import org.apache.calcite.rel.metadata.RelMetadataQuery -import org.apache.calcite.rex.{RexCall, RexInputRef, RexLiteral, RexProgram} -import org.apache.calcite.sql.SqlExplainLevel +import org.apache.calcite.rex.{RexCall, RexInputRef, RexLiteral, RexLocalRef, RexNode, RexProgram, RexShuttle} +import org.apache.calcite.sql.{SqlExplainLevel, SqlKind} import java.util.Collections @@ -49,12 +49,26 @@ abstract class CommonCalc( // conditions, etc. We only want to account for computations, not for simple projections. // CASTs in RexProgram are reduced as far as possible by ReduceExpressionsRule // in normalization stage. So we should ignore CASTs here in optimization stage. - val compCnt = calcProgram.getProjectList.map(calcProgram.expandLocalRef).toList.count { + val map = new java.util.HashMap[RexNode, Integer]() + val shuttle = new FunctionCounter(calcProgram.getExprList, map) + calcProgram.getProjectList.map(rf => rf.accept(shuttle)) + val compCnt1 = calcProgram.getProjectList.map(calcProgram.expandLocalRef).toList.count { case _: RexInputRef => false case _: RexLiteral => false case c: RexCall if c.getOperator.getName.equals("CAST") => false case _ => true } + val offset = map + .filterKeys { + case _: RexInputRef => false + case _: RexLiteral => false + case c: RexCall if c.getOperator.getName.equals("CAST") => false + case _ => true + } + .values + .foldLeft(0)(_ + _) + + val compCnt = Math.max(compCnt1, offset) val newRowCnt = mq.getRowCount(this) // TODO use inputRowCnt to compute cpu cost planner.getCostFactory.makeCost(newRowCnt, newRowCnt * compCnt, 0) @@ -102,4 +116,21 @@ abstract class CommonCalc( .mkString(", ") } + class FunctionCounter( + private val exprs: java.util.List[RexNode], + val map: java.util.Map[RexNode, Integer]) + extends RexShuttle { + override def visitLocalRef(localRef: RexLocalRef): RexNode = { + val tree: RexNode = this.exprs.get(localRef.getIndex) + if ( + SqlKind.FUNCTION.contains(tree.getKind) + && tree.isInstanceOf[RexCall] + && tree.asInstanceOf[RexCall].op.isDeterministic + ) { + map.merge(tree, 1, (x, y) => x + y) + } + tree.accept(this) + } + } + } diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalCalc.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalCalc.scala index 520a9505d2568..b7fdc055c2b41 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalCalc.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalCalc.scala @@ -18,6 +18,7 @@ package org.apache.flink.table.planner.plan.nodes.physical.stream import org.apache.flink.table.planner.calcite.FlinkTypeFactory +import org.apache.flink.table.planner.functions.sql.BuiltInSqlFunction import org.apache.flink.table.planner.plan.nodes.exec.{ExecNode, InputProperty} import org.apache.flink.table.planner.plan.nodes.exec.stream.StreamExecCalc import org.apache.flink.table.planner.utils.ShortcutUtils.unwrapTableConfig @@ -26,7 +27,12 @@ import org.apache.calcite.plan.{RelOptCluster, RelTraitSet} import org.apache.calcite.rel.`type`.RelDataType import org.apache.calcite.rel.RelNode import org.apache.calcite.rel.core.Calc -import org.apache.calcite.rex.RexProgram +import org.apache.calcite.rex.{RexCall, RexLocalRef, RexNode, RexProgram, RexShuttle} +import org.apache.calcite.sql.{SqlFunction, SqlKind} + +import java.util +import java.util.List +import java.util.function.BiFunction import scala.collection.JavaConversions._ @@ -44,7 +50,12 @@ class StreamPhysicalCalc( } override def translateToExecNode(): ExecNode[_] = { - val projection = calcProgram.getProjectList.map(calcProgram.expandLocalRef) + val funtionToCountMap = new util.HashMap[RexNode, Integer]() + val shuttle = new FunctionRefCounter(calcProgram.getExprList, funtionToCountMap) + + calcProgram.getProjectList.map(ref => ref.accept(shuttle)) + val projection = calcProgram.getProjectList.map( + ref => ref.accept(new ExpansionShuttle(calcProgram.getExprList, funtionToCountMap))) val condition = if (calcProgram.getCondition != null) { calcProgram.expandLocalRef(calcProgram.getCondition) } else { @@ -59,4 +70,49 @@ class StreamPhysicalCalc( FlinkTypeFactory.toLogicalRowType(getRowType), getRelDetailedDescription) } + + private def isDeterministicFunction(rexNode: RexNode): Boolean = { + SqlKind.FUNCTION.contains(rexNode.getKind) && rexNode.isInstanceOf[RexCall] && rexNode + .asInstanceOf[RexCall] + .op + .isInstanceOf[SqlFunction] && rexNode + .asInstanceOf[RexCall] + .op + .asInstanceOf[SqlFunction] + .isDeterministic + } + + private class ExpansionShuttle( + private val exprs: util.List[RexNode], + val map: util.Map[RexNode, Integer]) + extends RexShuttle { + override def visitLocalRef(localRef: RexLocalRef): RexNode = { + val tree: RexNode = this.exprs.get(localRef.getIndex) + if ( + isDeterministicFunction(tree) && map + .get(tree) > 1 + ) { + for (op <- tree.asInstanceOf[RexCall].operands) { + if (op.isInstanceOf[RexLocalRef]) { + return tree.accept(this) + } + } + return localRef + } + tree.accept(this) + } + } + + private class FunctionRefCounter( + private val exprs: util.List[RexNode], + val map: util.Map[RexNode, Integer]) + extends RexShuttle { + override def visitLocalRef(localRef: RexLocalRef): RexNode = { + val tree: RexNode = this.exprs.get(localRef.getIndex) + if (isDeterministicFunction(tree)) { + map.merge(tree, 1, (x, y) => x + y) + } + tree.accept(this) + } + } } diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/runtime/utils/JavaUserDefinedScalarFunctions.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/runtime/utils/JavaUserDefinedScalarFunctions.java index 7013cd3208322..16715e3eb3121 100644 --- a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/runtime/utils/JavaUserDefinedScalarFunctions.java +++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/runtime/utils/JavaUserDefinedScalarFunctions.java @@ -141,6 +141,26 @@ public boolean isDeterministic() { } } + /** Deterministic scalar function. */ + public static class DeterministicUdf extends ScalarFunction { + public int eval() { + return 0; + } + + public int eval(@DataTypeHint("INT") int v) { + return v; + } + + public String eval(String v) { + return v; + } + + @Override + public boolean isDeterministic() { + return true; + } + } + /** Test for Python Scalar Function. */ public static class PythonScalarFunction extends ScalarFunction implements PythonFunction { private final String name; diff --git a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/CalcTest.xml b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/CalcTest.xml index 85b1675e09676..adb4ff9f15a2a 100644 --- a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/CalcTest.xml +++ b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/CalcTest.xml @@ -678,6 +678,44 @@ LogicalProject(EXPR$0=[$0]) + + + + + + + + + + + + + + + + + + + + + + diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/stream/sql/CalcTest.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/stream/sql/CalcTest.scala index d6c47d48459f3..d3feb4129fcf7 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/stream/sql/CalcTest.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/stream/sql/CalcTest.scala @@ -19,7 +19,7 @@ package org.apache.flink.table.planner.plan.stream.sql import org.apache.flink.table.api._ import org.apache.flink.table.planner.plan.utils.MyPojo -import org.apache.flink.table.planner.runtime.utils.JavaUserDefinedScalarFunctions.NonDeterministicUdf +import org.apache.flink.table.planner.runtime.utils.JavaUserDefinedScalarFunctions.{DeterministicUdf, NonDeterministicUdf} import org.apache.flink.table.planner.runtime.utils.JavaUserDefinedTableFunctions.StringSplit import org.apache.flink.table.planner.utils.TableTestBase import org.apache.flink.table.types.AbstractDataType @@ -36,6 +36,7 @@ class CalcTest extends TableTestBase { def setup(): Unit = { util.addTableSource[(Long, Int, String)]("MyTable", 'a, 'b, 'c) util.addTemporarySystemFunction("random_udf", new NonDeterministicUdf) + util.addTemporarySystemFunction("deterministic_udf", new DeterministicUdf) } @Test @@ -43,6 +44,17 @@ class CalcTest extends TableTestBase { util.verifyExecPlan("SELECT a, c FROM MyTable") } + @Test + def testReusedFunctionInProject(): Unit = { + util.verifyExecPlan("SELECT LTRIM(q), RTRIM(q) FROM (SELECT TRIM(c) as q FROM MyTable) t") + } + + @Test + def testReusedFunctionInProjectWithDeterministicUdf(): Unit = { + util.verifyExecPlan( + "SELECT JSON_VALUE(json_data, '$.id'), JSON_VALUE(json_data, '$.name') FROM (SELECT deterministic_udf(c) as json_data FROM MyTable) t") + } + @Test def testProjectWithNaming(): Unit = { util.verifyExecPlan("SELECT `1-_./Ü`, b, c FROM (SELECT a as `1-_./Ü`, b, c FROM MyTable)")