Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,25 @@
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexLocalRef;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexProgram;
import org.apache.calcite.rex.RexProgramBuilder;
import org.apache.calcite.rex.RexSlot;
import org.apache.calcite.rex.RexUtil;
import org.apache.calcite.rex.RexVisitorImpl;
import org.apache.calcite.sql.SqlFunction;
import org.apache.calcite.sql.SqlKind;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;

/** Utilities for {@link RelNode}. */
Expand Down Expand Up @@ -93,6 +99,9 @@ public static boolean isMergeable(Project topProject, Project bottomProject) {
final int[] topInputRefCounter =
initializeArray(topProject.getInput().getRowType().getFieldCount(), 0);

if (functionResultShouldBeReused(topProject, bottomProject)) {
return false;
}
return mergeable(topInputRefCounter, topProject.getProjects(), bottomProject.getProjects());
}

Expand All @@ -104,6 +113,10 @@ public static boolean isMergeable(Project topProject, Project bottomProject) {
public static boolean isMergeable(Calc topCalc, Calc bottomCalc) {
final RexProgram topProgram = topCalc.getProgram();
final RexProgram bottomProgram = bottomCalc.getProgram();
if (functionResultShouldBeReused(topProgram, bottomProgram)) {
return false;
}

final int[] topInputRefCounter =
initializeArray(topCalc.getInput().getRowType().getFieldCount(), 0);

Expand All @@ -122,6 +135,84 @@ public static boolean isMergeable(Calc topCalc, Calc bottomCalc) {
return mergeable(topInputRefCounter, topInputRefs, bottomProjects);
}

private static boolean functionResultShouldBeReused(Project topProject, Project bottomProject) {
Set<Integer> indexSet = new HashSet<>();
List<RexNode> bottomProjectList = bottomProject.getProjects();
for (int i = 0; i < bottomProjectList.size(); i++) {
RexNode project = bottomProjectList.get(i);
if (project instanceof RexCall
// && SqlKind.FUNCTION.contains(((RexCall) project).op.getKind())
&& ((RexCall) project).op.isDeterministic()) {
Copy link
Contributor

@davidradl davidradl Dec 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not know this area very well, just to check my understanding

  • why are we calling the rexNodes projects and not rexNodes?
  • I assume that isDeterministic implies that for a given input you get the same output. If there are non deterministic functions in the list of projects, this could imply we cannot use the output for those. Is it safe to reuse the results if there are non deterministic projects? What would it mean to re-use the results when there is a mix of deterministic and non-deterministic projects? Maybe have a test with the mix to ensure is works as you intend.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is still a kind of PoC

regarding deterministic: further in stack there is a dedicated check for deterministic/non-deterministic, so do not care much about this here

indexSet.add(i);
}
}
if (indexSet.isEmpty()) {
return false;
}

Set<RexNode> rexNodes = new HashSet<>();
List<RexNode> topProjectList = topProject.getProjects();
for (RexNode rex : topProjectList) {
if (!(rex instanceof RexCall)) {
continue;
}
RexCall rCall = (RexCall) rex;
if (!(rCall.op instanceof SqlFunction)) {
continue;
}
List<RexNode> operands = rCall.operands;
for (RexNode op : operands) {
if (op instanceof RexSlot) {
if (indexSet.contains(((RexSlot) op).getIndex()) && !rexNodes.add(op)) {
return true;
}
}
}
}

return false;
}

private static boolean functionResultShouldBeReused(
RexProgram topProgram, RexProgram bottomProgram) {
Set<Integer> indexSet = new HashSet<>();
List<RexLocalRef> bottomProjectList = bottomProgram.getProjectList();
for (int i = 0; i < bottomProjectList.size(); i++) {
int index = bottomProjectList.get(i).getIndex();
RexNode rexNode = bottomProgram.getExprList().get(index);
if (rexNode instanceof RexCall
&& SqlKind.FUNCTION.contains(((RexCall) rexNode).op.getKind())
&& ((RexCall) rexNode).op.isDeterministic()) {
indexSet.add(i);
}
}
if (indexSet.isEmpty()) {
return false;
}

Set<RexNode> rexNodes = new HashSet<>();
List<RexNode> topExprList = topProgram.getExprList();
for (RexNode rex : topExprList) {
if (!(rex instanceof RexCall)) {
continue;
}
RexCall rCall = (RexCall) rex;
if (!(rCall.op instanceof SqlFunction)) {
continue;
}
List<RexNode> operands = rCall.operands;
for (RexNode op : operands) {
if (op instanceof RexSlot) {
if (indexSet.contains(((RexSlot) op).getIndex()) && !rexNodes.add(op)) {
return true;
}
}
}
}

return false;
}

/**
* Merges the programs of two {@link Calc} instances and returns a new {@link Calc} instance
* with the merged program.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ import org.apache.calcite.rel.{RelNode, RelWriter}
import org.apache.calcite.rel.core.Calc
import org.apache.calcite.rel.hint.RelHint
import org.apache.calcite.rel.metadata.RelMetadataQuery
import org.apache.calcite.rex.{RexCall, RexInputRef, RexLiteral, RexProgram}
import org.apache.calcite.sql.SqlExplainLevel
import org.apache.calcite.rex.{RexCall, RexInputRef, RexLiteral, RexLocalRef, RexNode, RexProgram, RexShuttle}
import org.apache.calcite.sql.{SqlExplainLevel, SqlKind}

import java.util.Collections

Expand All @@ -49,12 +49,26 @@ abstract class CommonCalc(
// conditions, etc. We only want to account for computations, not for simple projections.
// CASTs in RexProgram are reduced as far as possible by ReduceExpressionsRule
// in normalization stage. So we should ignore CASTs here in optimization stage.
val compCnt = calcProgram.getProjectList.map(calcProgram.expandLocalRef).toList.count {
val map = new java.util.HashMap[RexNode, Integer]()
val shuttle = new FunctionCounter(calcProgram.getExprList, map)
calcProgram.getProjectList.map(rf => rf.accept(shuttle))
val compCnt1 = calcProgram.getProjectList.map(calcProgram.expandLocalRef).toList.count {
case _: RexInputRef => false
case _: RexLiteral => false
case c: RexCall if c.getOperator.getName.equals("CAST") => false
case _ => true
}
val offset = map
.filterKeys {
case _: RexInputRef => false
case _: RexLiteral => false
case c: RexCall if c.getOperator.getName.equals("CAST") => false
case _ => true
}
.values
.foldLeft(0)(_ + _)

val compCnt = Math.max(compCnt1, offset)
val newRowCnt = mq.getRowCount(this)
// TODO use inputRowCnt to compute cpu cost
planner.getCostFactory.makeCost(newRowCnt, newRowCnt * compCnt, 0)
Expand Down Expand Up @@ -102,4 +116,21 @@ abstract class CommonCalc(
.mkString(", ")
}

class FunctionCounter(
private val exprs: java.util.List[RexNode],
val map: java.util.Map[RexNode, Integer])
extends RexShuttle {
override def visitLocalRef(localRef: RexLocalRef): RexNode = {
val tree: RexNode = this.exprs.get(localRef.getIndex)
if (
SqlKind.FUNCTION.contains(tree.getKind)
&& tree.isInstanceOf[RexCall]
&& tree.asInstanceOf[RexCall].op.isDeterministic
) {
map.merge(tree, 1, (x, y) => x + y)
}
tree.accept(this)
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.flink.table.planner.plan.nodes.physical.stream

import org.apache.flink.table.planner.calcite.FlinkTypeFactory
import org.apache.flink.table.planner.functions.sql.BuiltInSqlFunction
import org.apache.flink.table.planner.plan.nodes.exec.{ExecNode, InputProperty}
import org.apache.flink.table.planner.plan.nodes.exec.stream.StreamExecCalc
import org.apache.flink.table.planner.utils.ShortcutUtils.unwrapTableConfig
Expand All @@ -26,7 +27,12 @@ import org.apache.calcite.plan.{RelOptCluster, RelTraitSet}
import org.apache.calcite.rel.`type`.RelDataType
import org.apache.calcite.rel.RelNode
import org.apache.calcite.rel.core.Calc
import org.apache.calcite.rex.RexProgram
import org.apache.calcite.rex.{RexCall, RexLocalRef, RexNode, RexProgram, RexShuttle}
import org.apache.calcite.sql.{SqlFunction, SqlKind}

import java.util
import java.util.List
import java.util.function.BiFunction

import scala.collection.JavaConversions._

Expand All @@ -44,7 +50,12 @@ class StreamPhysicalCalc(
}

override def translateToExecNode(): ExecNode[_] = {
val projection = calcProgram.getProjectList.map(calcProgram.expandLocalRef)
val funtionToCountMap = new util.HashMap[RexNode, Integer]()
val shuttle = new FunctionRefCounter(calcProgram.getExprList, funtionToCountMap)

calcProgram.getProjectList.map(ref => ref.accept(shuttle))
val projection = calcProgram.getProjectList.map(
ref => ref.accept(new ExpansionShuttle(calcProgram.getExprList, funtionToCountMap)))
val condition = if (calcProgram.getCondition != null) {
calcProgram.expandLocalRef(calcProgram.getCondition)
} else {
Expand All @@ -59,4 +70,49 @@ class StreamPhysicalCalc(
FlinkTypeFactory.toLogicalRowType(getRowType),
getRelDetailedDescription)
}

private def isDeterministicFunction(rexNode: RexNode): Boolean = {
SqlKind.FUNCTION.contains(rexNode.getKind) && rexNode.isInstanceOf[RexCall] && rexNode
.asInstanceOf[RexCall]
.op
.isInstanceOf[SqlFunction] && rexNode
.asInstanceOf[RexCall]
.op
.asInstanceOf[SqlFunction]
.isDeterministic
}

private class ExpansionShuttle(
private val exprs: util.List[RexNode],
val map: util.Map[RexNode, Integer])
extends RexShuttle {
override def visitLocalRef(localRef: RexLocalRef): RexNode = {
val tree: RexNode = this.exprs.get(localRef.getIndex)
if (
isDeterministicFunction(tree) && map
.get(tree) > 1
) {
for (op <- tree.asInstanceOf[RexCall].operands) {
if (op.isInstanceOf[RexLocalRef]) {
return tree.accept(this)
}
}
return localRef
}
tree.accept(this)
}
}

private class FunctionRefCounter(
private val exprs: util.List[RexNode],
val map: util.Map[RexNode, Integer])
extends RexShuttle {
override def visitLocalRef(localRef: RexLocalRef): RexNode = {
val tree: RexNode = this.exprs.get(localRef.getIndex)
if (isDeterministicFunction(tree)) {
map.merge(tree, 1, (x, y) => x + y)
}
tree.accept(this)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,26 @@ public boolean isDeterministic() {
}
}

/** Deterministic scalar function. */
public static class DeterministicUdf extends ScalarFunction {
public int eval() {
return 0;
}

public int eval(@DataTypeHint("INT") int v) {
return v;
}

public String eval(String v) {
return v;
}

@Override
public boolean isDeterministic() {
return true;
}
}

/** Test for Python Scalar Function. */
public static class PythonScalarFunction extends ScalarFunction implements PythonFunction {
private final String name;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -678,6 +678,44 @@ LogicalProject(EXPR$0=[$0])
<![CDATA[
Calc(select=[a AS EXPR$0])
+- TableSourceScan(table=[[default_catalog, default_database, MyTable]], fields=[a, b, c])
]]>
</Resource>
</TestCase>
<TestCase name="testReusedFunctionInProject">
<Resource name="sql">
<![CDATA[SELECT LTRIM(q), RTRIM(q) FROM (SELECT TRIM(c) as q FROM MyTable) t]]>
</Resource>
<Resource name="ast">
<![CDATA[
LogicalProject(EXPR$0=[LTRIM($0)], EXPR$1=[RTRIM($0)])
+- LogicalProject(q=[TRIM(FLAG(BOTH), _UTF-16LE' ', $2)])
+- LogicalTableScan(table=[[default_catalog, default_database, MyTable]])
]]>
</Resource>
<Resource name="optimized exec plan">
<![CDATA[
Calc(select=[LTRIM(q) AS EXPR$0, RTRIM(q) AS EXPR$1])
+- Calc(select=[TRIM(BOTH, ' ', c) AS q])
+- TableSourceScan(table=[[default_catalog, default_database, MyTable]], fields=[a, b, c])
]]>
</Resource>
</TestCase>
<TestCase name="testReusedFunctionInProjectWithDeterministicUdf">
<Resource name="sql">
<![CDATA[SELECT JSON_VALUE(json_data, '$.id'), JSON_VALUE(json_data, '$.name') FROM (SELECT deterministic_udf(c) as json_data FROM MyTable) t]]>
</Resource>
<Resource name="ast">
<![CDATA[
LogicalProject(EXPR$0=[JSON_VALUE($0, _UTF-16LE'$.id')], EXPR$1=[JSON_VALUE($0, _UTF-16LE'$.name')])
+- LogicalProject(json_data=[deterministic_udf($2)])
+- LogicalTableScan(table=[[default_catalog, default_database, MyTable]])
]]>
</Resource>
<Resource name="optimized exec plan">
<![CDATA[
Calc(select=[JSON_VALUE(json_data, '$.id') AS EXPR$0, JSON_VALUE(json_data, '$.name') AS EXPR$1])
+- Calc(select=[deterministic_udf(c) AS json_data])
+- TableSourceScan(table=[[default_catalog, default_database, MyTable]], fields=[a, b, c])
]]>
</Resource>
</TestCase>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.flink.table.planner.plan.stream.sql

import org.apache.flink.table.api._
import org.apache.flink.table.planner.plan.utils.MyPojo
import org.apache.flink.table.planner.runtime.utils.JavaUserDefinedScalarFunctions.NonDeterministicUdf
import org.apache.flink.table.planner.runtime.utils.JavaUserDefinedScalarFunctions.{DeterministicUdf, NonDeterministicUdf}
import org.apache.flink.table.planner.runtime.utils.JavaUserDefinedTableFunctions.StringSplit
import org.apache.flink.table.planner.utils.TableTestBase
import org.apache.flink.table.types.AbstractDataType
Expand All @@ -36,13 +36,25 @@ class CalcTest extends TableTestBase {
def setup(): Unit = {
util.addTableSource[(Long, Int, String)]("MyTable", 'a, 'b, 'c)
util.addTemporarySystemFunction("random_udf", new NonDeterministicUdf)
util.addTemporarySystemFunction("deterministic_udf", new DeterministicUdf)
}

@Test
def testOnlyProject(): Unit = {
util.verifyExecPlan("SELECT a, c FROM MyTable")
}

@Test
def testReusedFunctionInProject(): Unit = {
util.verifyExecPlan("SELECT LTRIM(q), RTRIM(q) FROM (SELECT TRIM(c) as q FROM MyTable) t")
}

@Test
def testReusedFunctionInProjectWithDeterministicUdf(): Unit = {
util.verifyExecPlan(
"SELECT JSON_VALUE(json_data, '$.id'), JSON_VALUE(json_data, '$.name') FROM (SELECT deterministic_udf(c) as json_data FROM MyTable) t")
}

@Test
def testProjectWithNaming(): Unit = {
util.verifyExecPlan("SELECT `1-_./Ü`, b, c FROM (SELECT a as `1-_./Ü`, b, c FROM MyTable)")
Expand Down