diff --git a/extensions/spark/kyuubi-spark-lineage/src/main/scala/org/apache/kyuubi/plugin/lineage/helper/SparkSQLLineageParseHelper.scala b/extensions/spark/kyuubi-spark-lineage/src/main/scala/org/apache/kyuubi/plugin/lineage/helper/SparkSQLLineageParseHelper.scala index dc6673d4cfc..dc8d1a35a62 100644 --- a/extensions/spark/kyuubi-spark-lineage/src/main/scala/org/apache/kyuubi/plugin/lineage/helper/SparkSQLLineageParseHelper.scala +++ b/extensions/spark/kyuubi-spark-lineage/src/main/scala/org/apache/kyuubi/plugin/lineage/helper/SparkSQLLineageParseHelper.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.{NamedRelation, PersistedView, ViewType} import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable, HiveTableRelation} -import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeSet, Expression, NamedExpression, ScalarSubquery} +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeSet, Expression, NamedExpression, ScalarSubquery, SubqueryExpression} import org.apache.spark.sql.catalyst.expressions.aggregate.Count import org.apache.spark.sql.catalyst.plans.{LeftAnti, LeftSemi} import org.apache.spark.sql.catalyst.plans.logical._ @@ -572,6 +572,24 @@ trait LineageParser { p.cacheBuilder.tableName.toSeq) } + case p: Filter => + if (SparkContextHelper.getConf( + LineageConf.COLLECT_FILTER_CONDITION_TABLES_ENABLED)) { + p.condition.foreach { + case expression: SubqueryExpression => + extractColumnsLineage( + expression.plan, + ListMap[Attribute, AttributeSet](), + inputTablesByPlan) + case _ => + } + } + + p.children.map(extractColumnsLineage( + _, + parentColumnsLineage, + inputTablesByPlan)).reduce(mergeColumnsLineage) + case p if p.children.isEmpty => ListMap[Attribute, AttributeSet]() case p => diff --git a/extensions/spark/kyuubi-spark-lineage/src/main/scala/org/apache/spark/kyuubi/lineage/LineageConf.scala b/extensions/spark/kyuubi-spark-lineage/src/main/scala/org/apache/spark/kyuubi/lineage/LineageConf.scala index afffb5f578b..7bcef559fbe 100644 --- a/extensions/spark/kyuubi-spark-lineage/src/main/scala/org/apache/spark/kyuubi/lineage/LineageConf.scala +++ b/extensions/spark/kyuubi-spark-lineage/src/main/scala/org/apache/spark/kyuubi/lineage/LineageConf.scala @@ -55,6 +55,14 @@ object LineageConf { .booleanConf .createWithDefault(false) + val COLLECT_FILTER_CONDITION_TABLES_ENABLED = + ConfigBuilder("spark.kyuubi.plugin.lineage.collectFilterTables") + .internal + .doc("Whether to collect the tables referenced in filter conditions as lineage input tables.") + .version("1.11.0") + .booleanConf + .createWithDefault(true) + val DEFAULT_CATALOG: String = SQLConf.get.getConf(SQLConf.DEFAULT_CATALOG) } diff --git a/extensions/spark/kyuubi-spark-lineage/src/test/scala/org/apache/kyuubi/plugin/lineage/helper/SparkSQLLineageParserHelperSuite.scala b/extensions/spark/kyuubi-spark-lineage/src/test/scala/org/apache/kyuubi/plugin/lineage/helper/SparkSQLLineageParserHelperSuite.scala index 90da4650b1f..12335ea656a 100644 --- a/extensions/spark/kyuubi-spark-lineage/src/test/scala/org/apache/kyuubi/plugin/lineage/helper/SparkSQLLineageParserHelperSuite.scala +++ b/extensions/spark/kyuubi-spark-lineage/src/test/scala/org/apache/kyuubi/plugin/lineage/helper/SparkSQLLineageParserHelperSuite.scala @@ -986,6 +986,7 @@ abstract class SparkSQLLineageParserHelperSuite extends KyuubiFunSuite |""".stripMargin ddls.split("\n").filter(_.nonEmpty).foreach(spark.sql(_).collect()) withTable("table0", "table1") { _ => + SparkContextHelper.setConf(LineageConf.COLLECT_FILTER_CONDITION_TABLES_ENABLED, false) val sql0 = """ |select a as aa, bb, cc from (select b as bb, c as cc from table1) t0, table0 @@ -1465,6 +1466,102 @@ abstract class SparkSQLLineageParserHelperSuite extends KyuubiFunSuite } } + test("columns lineage extract - extract condition tables") { + val ddls = + """ + |create table v2_catalog.db.tb1(col1 string, col2 string, col3 string) + |create table v2_catalog.db.tb2(col1 string, col2 string, col3 string) + |create table v2_catalog.db.tb3(col1 string, col2 string, col3 string) + |""".stripMargin + ddls.split("\n").filter(_.nonEmpty).foreach(spark.sql) + withTable("v2_catalog.db.tb1", "v2_catalog.db.tb2", "v2_catalog.db.tb3") { _ => + SparkContextHelper.setConf(LineageConf.COLLECT_FILTER_CONDITION_TABLES_ENABLED, true) + val sql0 = + """ + |insert overwrite v2_catalog.db.tb3 + |select * + |from v2_catalog.db.tb1 t1 + |where exists (select 1 from v2_catalog.db.tb2 t2 where t2.col1 = t1.col1); + |""".stripMargin + val ret0 = extractLineage(sql0) + assert(ret0 == Lineage( + List("v2_catalog.db.tb1", "v2_catalog.db.tb2"), + List("v2_catalog.db.tb3"), + List( + ("v2_catalog.db.tb3.col1", Set("v2_catalog.db.tb1.col1")), + ("v2_catalog.db.tb3.col2", Set("v2_catalog.db.tb1.col2")), + ("v2_catalog.db.tb3.col3", Set("v2_catalog.db.tb1.col3"))))) + + val sql1 = + """ + |insert overwrite v2_catalog.db.tb3 + |select col1, 'agg_flag' as col2, 'static' as col3 + |from v2_catalog.db.tb1 + |group by col1 + |having count(*) > (select count(*) from v2_catalog.db.tb2) + |""".stripMargin + val ret1 = extractLineage(sql1) + assert(ret1 == Lineage( + List("v2_catalog.db.tb1", "v2_catalog.db.tb2"), + List("v2_catalog.db.tb3"), + List( + ("v2_catalog.db.tb3.col1", Set("v2_catalog.db.tb1.col1")), + ("v2_catalog.db.tb3.col2", Set()), + ("v2_catalog.db.tb3.col3", Set())))) + + val sql2 = + """ + |insert overwrite v2_catalog.db.tb3 + |select * + |from v2_catalog.db.tb1 + |where col1 > (select max(col1) from v2_catalog.db.tb2 where col2 = 'X') + |""".stripMargin + val ret2 = extractLineage(sql2) + assert(ret2 == Lineage( + List("v2_catalog.db.tb1", "v2_catalog.db.tb2"), + List("v2_catalog.db.tb3"), + List( + ("v2_catalog.db.tb3.col1", Set("v2_catalog.db.tb1.col1")), + ("v2_catalog.db.tb3.col2", Set("v2_catalog.db.tb1.col2")), + ("v2_catalog.db.tb3.col3", Set("v2_catalog.db.tb1.col3"))))) + + val sql3 = + """ + |insert into v2_catalog.db.tb3 + |select * + |from v2_catalog.db.tb1 + |where col1 not in (select col1 from v2_catalog.db.tb2 where col2 is not null) + |""".stripMargin + val ret3 = extractLineage(sql3) + assert(ret3 == Lineage( + List("v2_catalog.db.tb1", "v2_catalog.db.tb2"), + List("v2_catalog.db.tb3"), + List( + ("v2_catalog.db.tb3.col1", Set("v2_catalog.db.tb1.col1")), + ("v2_catalog.db.tb3.col2", Set("v2_catalog.db.tb1.col2")), + ("v2_catalog.db.tb3.col3", Set("v2_catalog.db.tb1.col3"))))) + + val sql4 = + """ + |insert into v2_catalog.db.tb3 + |select * + |from v2_catalog.db.tb1 + |where col3 in ( + | select col3 from v2_catalog.db.tb2 + | where col1 in (select col1 from v2_catalog.db.tb3 where col2 = 'V') + |) + |""".stripMargin + val ret4 = extractLineage(sql4) + assert(ret4 == Lineage( + List("v2_catalog.db.tb1", "v2_catalog.db.tb2", "v2_catalog.db.tb3"), + List("v2_catalog.db.tb3"), + List( + ("v2_catalog.db.tb3.col1", Set("v2_catalog.db.tb1.col1")), + ("v2_catalog.db.tb3.col2", Set("v2_catalog.db.tb1.col2")), + ("v2_catalog.db.tb3.col3", Set("v2_catalog.db.tb1.col3"))))) + } + } + protected def extractLineageWithoutExecuting(sql: String): Lineage = { val parsed = spark.sessionState.sqlParser.parsePlan(sql) val analyzed = spark.sessionState.analyzer.execute(parsed)