diff --git a/datafusion/catalog/src/default_table_source.rs b/datafusion/catalog/src/default_table_source.rs index 11963c06c88f..6bd7366db483 100644 --- a/datafusion/catalog/src/default_table_source.rs +++ b/datafusion/catalog/src/default_table_source.rs @@ -83,6 +83,10 @@ impl TableSource for DefaultTableSource { fn get_column_default(&self, column: &str) -> Option<&Expr> { self.table_provider.get_column_default(column) } + + fn statistics(&self) -> Option { + self.table_provider.statistics() + } } /// Wrap TableProvider in TableSource diff --git a/datafusion/expr/src/table_source.rs b/datafusion/expr/src/table_source.rs index d3b253c0e102..dbd034ffa736 100644 --- a/datafusion/expr/src/table_source.rs +++ b/datafusion/expr/src/table_source.rs @@ -20,7 +20,7 @@ use crate::{Expr, LogicalPlan}; use arrow::datatypes::SchemaRef; -use datafusion_common::{Constraints, Result}; +use datafusion_common::{Constraints, Result, Statistics}; use std::{any::Any, borrow::Cow}; @@ -129,4 +129,12 @@ pub trait TableSource: Sync + Send { fn get_column_default(&self, _column: &str) -> Option<&Expr> { None } + + /// Get statistics for this table source, if available + /// Although not presently used in mainline DataFusion, this allows implementation specific + /// behavior for downstream repositories, in conjunction with specialized optimizer rules to + /// perform operations such as re-ordering of joins. + fn statistics(&self) -> Option { + None + } } diff --git a/datafusion/optimizer/src/lib.rs b/datafusion/optimizer/src/lib.rs index 85fa9493f449..7d2af1743949 100644 --- a/datafusion/optimizer/src/lib.rs +++ b/datafusion/optimizer/src/lib.rs @@ -58,6 +58,7 @@ pub mod optimizer; pub mod propagate_empty_relation; pub mod push_down_filter; pub mod push_down_limit; +pub mod reorder_join; pub mod replace_distinct_aggregate; pub mod scalar_subquery_to_join; pub mod simplify_expressions; diff --git a/datafusion/optimizer/src/reorder_join/cost.rs b/datafusion/optimizer/src/reorder_join/cost.rs new file mode 100644 index 000000000000..e346c8bc6441 --- /dev/null +++ b/datafusion/optimizer/src/reorder_join/cost.rs @@ -0,0 +1,59 @@ +use datafusion_common::{plan_datafusion_err, plan_err, stats::Precision, Result}; +use datafusion_expr::{Join, JoinType, LogicalPlan}; + +pub trait JoinCostEstimator: std::fmt::Debug { + fn cardinality(&self, plan: &LogicalPlan) -> Option { + estimate_cardinality(plan).ok() + } + + fn selectivity(&self, join: &Join) -> f64 { + match join.join_type { + JoinType::Inner => 0.1, + _ => 1.0, + } + } + + fn cost(&self, selectivity: f64, cardinality: f64) -> f64 { + selectivity * cardinality + } +} + +/// Default implementation of JoinCostEstimator +#[derive(Debug, Clone, Copy)] +pub struct DefaultCostEstimator; + +impl JoinCostEstimator for DefaultCostEstimator {} + +fn estimate_cardinality(plan: &LogicalPlan) -> Result { + match plan { + LogicalPlan::Filter(filter) => { + let input_cardinality = estimate_cardinality(&filter.input)?; + Ok(0.1 * input_cardinality) + } + LogicalPlan::Aggregate(agg) => { + let input_cardinality = estimate_cardinality(&agg.input)?; + Ok(0.1 * input_cardinality) + } + LogicalPlan::TableScan(scan) => { + let statistics = scan + .source + .statistics() + .ok_or_else(|| plan_datafusion_err!("Table statistics not available"))?; + if let Precision::Exact(num_rows) | Precision::Inexact(num_rows) = + statistics.num_rows + { + Ok(num_rows as f64) + } else { + plan_err!("Number of rows not available") + } + } + x => { + let inputs = x.inputs(); + if inputs.len() == 1 { + estimate_cardinality(inputs[0]) + } else { + plan_err!("Cannot estimate cardinality for plan with multiple inputs") + } + } + } +} diff --git a/datafusion/optimizer/src/reorder_join/left_deep_join_plan.rs b/datafusion/optimizer/src/reorder_join/left_deep_join_plan.rs new file mode 100644 index 000000000000..f3b65de1e2e4 --- /dev/null +++ b/datafusion/optimizer/src/reorder_join/left_deep_join_plan.rs @@ -0,0 +1,878 @@ +use std::{collections::HashSet, fmt::Debug, rc::Rc, sync::Arc}; + +use datafusion_common::{plan_datafusion_err, plan_err, Result}; +use datafusion_expr::LogicalPlan; + +use crate::reorder_join::{ + cost::JoinCostEstimator, + query_graph::{NodeId, QueryGraph}, +}; + +/// Generates an optimized left-deep join plan from a logical plan using the Ibaraki-Kameda algorithm. +/// +/// This function is the main entry point for join reordering optimization. It takes a logical plan +/// that may contain joins along with wrapper operators (filters, sorts, aggregations, etc.) and +/// produces an optimized plan with reordered joins while preserving the wrapper operators. +/// +/// # Algorithm Overview +/// +/// The optimization process consists of several steps: +/// +/// 1. **Extraction**: Separates the join subtree from wrapper operators (filters, sorts, limits, etc.) +/// 2. **Graph Conversion**: Converts the join subtree into a query graph representation where: +/// - Nodes represent base relations (table scans, subqueries, etc.) +/// - Edges represent join conditions between relations +/// 3. **Optimization**: Uses the Ibaraki-Kameda algorithm to find the optimal left-deep join ordering +/// by trying each node as a potential root and selecting the plan with the lowest estimated cost +/// 4. **Reconstruction**: Rebuilds the complete logical plan by applying the wrapper operators +/// to the optimized join plan +/// +/// # Left-Deep Join Plans +/// +/// A left-deep join plan is a join tree where: +/// - Each join has a relation or previous join result on the left side +/// - Each join has a single relation on the right side +/// - This creates a linear "chain" of joins processed left-to-right +/// +/// Example: `((A ⋈ B) ⋈ C) ⋈ D` is left-deep, while `(A ⋈ B) ⋈ (C ⋈ D)` is not. +/// +/// Left-deep plans are preferred because they: +/// - Allow pipelining of intermediate results +/// - Work well with hash join implementations +/// - Have predictable memory usage patterns +/// +/// # Arguments +/// +/// * `plan` - The logical plan to optimize. Must contain at least one join node. +/// * `cost_estimator` - Cost estimator for calculating join costs, cardinality, and selectivity. +/// Used to compare different join orderings and select the optimal one. +/// +/// # Returns +/// +/// Returns a `LogicalPlan` with optimized join ordering. The plan structure is: +/// - Wrapper operators (filters, sorts, etc.) in their original positions +/// - Joins reordered to minimize estimated execution cost +/// - Join semantics preserved (same result set as input plan) +/// +/// # Errors +/// +/// Returns an error if: +/// - The plan does not contain any join nodes +/// - Join extraction fails (e.g., joins are not consecutive in the plan tree) +/// - The query graph cannot be constructed from the join subtree +/// - Join reordering optimization fails (no valid join ordering found) +/// - Plan reconstruction fails +/// +/// # Example +/// +/// ```ignore +/// use datafusion_optimizer::reorder_join::{optimal_left_deep_join_plan, cost::JoinCostEstimator}; +/// use std::rc::Rc; +/// +/// // Assume we have a plan with joins: customer ⋈ orders ⋈ lineitem +/// let plan = ...; // Your logical plan +/// let cost_estimator: Rc = Rc::new(MyCostEstimator::new()); +/// +/// // Optimize join ordering +/// let optimized = optimal_left_deep_join_plan(plan, cost_estimator)?; +/// // Result might reorder to: lineitem ⋈ orders ⋈ customer (if this is cheaper) +/// ``` +pub fn optimal_left_deep_join_plan( + plan: LogicalPlan, + cost_estimator: Rc, +) -> Result { + // Extract the join subtree and wrappers + let (join_subtree, wrappers) = + crate::reorder_join::query_graph::extract_join_subtree(plan)?; + + // Convert join subtree to query graph + let query_graph = QueryGraph::try_from(join_subtree)?; + + // Optimize the joins + let optimized_joins = + query_graph_to_optimal_left_deep_join_plan(query_graph, cost_estimator)?; + + // Reconstruct the full plan with wrappers + + crate::reorder_join::query_graph::reconstruct_plan(optimized_joins, wrappers) +} + +/// Generates an optimized linear join plan from a query graph using the Ibaraki-Kameda algorithm. +/// +/// This function finds the optimal join ordering for a query by: +/// 1. Trying each node in the query graph as a potential root +/// 2. For each root, building a precedence tree and optimizing it through normalization/denormalization +/// 3. Selecting the plan with the lowest estimated cost +/// +/// The optimization process uses the Ibaraki-Kameda algorithm, which arranges joins to minimize +/// intermediate result sizes by considering both cardinality and cost estimates. +/// +/// # Algorithm Steps +/// +/// For each candidate root node: +/// 1. **Construction**: Build a precedence tree from the query graph starting at that node +/// 2. **Normalization**: Transform the tree into a chain structure ordered by rank +/// 3. **Denormalization**: Split merged operations back into individual nodes while maintaining chain structure +/// 4. **Cost Comparison**: Compare the resulting plan's cost against the current best +/// +/// # Arguments +/// +/// * `query_graph` - The query graph containing logical plan nodes and join specifications +/// * `cost_estimator` - The cost estimator to use for calculating cardinality, selectivity, and cost +/// +/// # Returns +/// +/// Returns a `LogicalPlan` representing the optimal join ordering with the lowest estimated cost. +/// +/// # Errors +/// +/// Returns an error if: +/// - The query graph is empty or invalid +/// - Tree construction, normalization, or denormalization fails +/// - No valid precedence graph can be generated +pub fn query_graph_to_optimal_left_deep_join_plan( + query_graph: QueryGraph, + cost_estimator: Rc, +) -> Result { + let mut best_graph: Option = None; + + for (node_id, _) in query_graph.nodes() { + let mut precedence_graph = PrecedenceTreeNode::from_query_graph( + &query_graph, + node_id, + Rc::clone(&cost_estimator), + )?; + precedence_graph.normalize(); + precedence_graph.denormalize()?; + + best_graph = match best_graph.take() { + Some(current) => { + let new_cost = precedence_graph.cost()?; + if new_cost < current.cost()? { + Some(precedence_graph) + } else { + Some(current) + } + } + None => Some(precedence_graph), + }; + } + + best_graph + .ok_or_else(|| plan_datafusion_err!("No valid precedence graph found"))? + .into_logical_plan(&query_graph) +} + +#[derive(Debug)] +struct QueryNode { + node_id: NodeId, + // T in [IbarakiKameda84] + selectivity: f64, + // C in [IbarakiKameda84] + cost: f64, +} + +impl QueryNode { + fn rank(&self) -> f64 { + (self.selectivity - 1.0) / self.cost + } +} + +/// A node in the precedence tree for query optimization. +/// +/// The precedence tree is a data structure used by the Ibaraki-Kameda algorithm for +/// optimizing join ordering in database queries. It can represent both arbitrary tree +/// structures and linear chain structures (where each node has at most one child). +/// +/// # Lifecycle +/// +/// A typical precedence tree goes through three phases: +/// +/// 1. **Construction** ([`from_query_graph`](Self::from_query_graph)): Build an initial tree +/// from a query graph, creating nodes with cost/cardinality estimates +/// 2. **Normalization** ([`normalize`](Self::normalize)): Transform the tree into a chain +/// where nodes are ordered by rank, potentially merging multiple query operations into +/// single nodes +/// 3. **Denormalization** ([`denormalize`](Self::denormalize)): Split merged operations back +/// into individual nodes while maintaining the optimized chain structure +/// +/// The result is a linear execution order that minimizes intermediate result sizes. +/// +/// # Fields +/// +/// * `query_nodes` - Vector of query operations with cost estimates. In an initial tree, +/// contains one operation. After normalization, may contain multiple merged operations. +/// After denormalization, contains exactly one operation. +/// * `children` - Child nodes in the tree. In a normalized/denormalized chain, contains +/// at most one child. In an arbitrary tree, may contain multiple children. +/// * `query_graph` - Reference to the original query graph, used for accessing node +/// relationships and metadata during tree transformations. +struct PrecedenceTreeNode<'graph> { + query_nodes: Vec, + children: Vec>, + query_graph: &'graph QueryGraph, +} + +impl Debug for PrecedenceTreeNode<'_> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("PrecedenceTreeNode") + .field("query_nodes", &self.query_nodes) + .field("children", &self.children) + .finish() + } +} + +impl<'graph> PrecedenceTreeNode<'graph> { + /// Creates a precedence tree from a query graph. + /// + /// This is the main entry point for transforming a query graph into a precedence tree + /// structure. The tree represents an initial join ordering with cost and cardinality + /// estimates for query optimization. + /// + /// The function performs a depth-first traversal starting from the root node, + /// building a tree where: + /// - Each node contains cost/cardinality estimates for a query operation + /// - Children represent connected query nodes (joins, filters, etc.) + /// - The root node starts with selectivity of 1.0 (no filtering) + /// + /// # Arguments + /// + /// * `graph` - The query graph to transform into a precedence tree + /// * `root_id` - The ID of the node to use as the root of the tree + /// * `cost_estimator` - The cost estimator to use for calculating cardinality, selectivity, and cost + /// + /// # Returns + /// + /// Returns a `PrecedenceTreeNode` representing the entire query graph as a tree structure, + /// with the specified root node at the top. + /// + /// # Errors + /// + /// Returns an error if: + /// - The `root_id` is not found in the query graph + /// - Any connected node cannot be found during traversal + pub(crate) fn from_query_graph( + graph: &'graph QueryGraph, + root_id: NodeId, + cost_estimator: Rc, + ) -> Result { + let mut remaining: HashSet = graph.nodes().map(|(x, _)| x).collect(); + remaining.remove(&root_id); + PrecedenceTreeNode::from_query_node( + root_id, + 1.0, + graph, + &mut remaining, + cost_estimator, + true, + ) + } + + /// Recursively constructs a precedence tree node from a query graph node. + /// + /// This function builds a tree structure by: + /// 1. Creating a node with cost and cardinality estimates for the current query node + /// 2. Recursively processing all connected unvisited nodes as children + /// 3. Removing visited nodes from the `remaining` set to avoid cycles + /// + /// # Arguments + /// + /// * `node_id` - The ID of the query graph node to process + /// * `selectivity` - The selectivity factor from the parent edge (1.0 for root) + /// * `query_graph` - Reference to the query graph being transformed + /// * `remaining` - Mutable set of node IDs not yet visited (updated during traversal) + /// * `cost_estimator` - The cost estimator to use for calculating cardinality, selectivity, and cost + /// + /// # Returns + /// + /// Returns a `PrecedenceTreeNode` containing: + /// - A single `NodeEstimates` with cardinality and cost based on input cardinality and selectivity + /// - Child nodes for each connected unvisited neighbor in the query graph + /// + /// # Errors + /// + /// Returns an error if the specified `node_id` is not found in the query graph. + fn from_query_node( + node_id: NodeId, + selectivity: f64, + query_graph: &'graph QueryGraph, + remaining: &mut HashSet, + cost_estimator: Rc, + is_root: bool, + ) -> Result { + let node = query_graph + .get_node(node_id) + .ok_or_else(|| plan_datafusion_err!("Root node not found"))?; + let input_cardinality = cost_estimator.cardinality(&node.plan).unwrap_or(1.0); + + let children = node + .connections() + .iter() + .filter_map(|edge_id| { + let edge = query_graph.get_edge(*edge_id)?; + let other = edge + .nodes + .into_iter() + .find(|x| *x != node_id && remaining.contains(x))?; + + remaining.remove(&other); + let child_selectivity = cost_estimator.selectivity(&edge.join); + Some(PrecedenceTreeNode::from_query_node( + other, + child_selectivity, + query_graph, + remaining, + Rc::clone(&cost_estimator), + false, + )) + }) + .collect::>>()?; + + Ok(PrecedenceTreeNode { + query_nodes: vec![QueryNode { + node_id, + selectivity: (selectivity * input_cardinality), + cost: if is_root { + 0.0 + } else { + cost_estimator.cost(selectivity, input_cardinality) + }, + }], + children, + query_graph, + }) + } + + /// Rank function according to IbarakiKameda84 + fn rank(&self) -> f64 { + let (cardinality, cost) = + self.query_nodes + .iter() + .fold((1.0, 0.0), |(cardinality, cost), node| { + let cost = cost + cardinality * node.cost; + let cardinality = cardinality * node.selectivity; + (cardinality, cost) + }); + if cost == 0.0 { + 0.0 + } else { + (cardinality - 1.0) as f64 / cost + } + } + + /// Normalizes the precedence tree into a linear chain structure. + /// + /// This transformation converts the tree into a normalized form where each node + /// has at most one child, creating a linear sequence of query nodes. The normalization + /// process uses the rank function to determine optimal ordering according to the + /// Ibaraki-Kameda algorithm. + /// + /// The normalization handles three cases: + /// - **Leaf nodes (0 children)**: Already normalized, no action needed + /// - **Single child (1 child)**: If the child has lower rank than current node, merge + /// the child's query nodes into the current node, creating a sequence. Otherwise, + /// recursively normalize the child. + /// - **Multiple children (2+ children)**: Recursively normalize all children into chains, + /// then merge all child chains into a single chain using the merge operation. + /// + /// After normalization, the tree becomes a chain where nodes are ordered by their + /// rank values, with each node containing one or more query operations in sequence. + /// + /// # Algorithm + /// + /// Based on the Ibaraki-Kameda join ordering algorithm, which optimizes query + /// execution by arranging operations to minimize intermediate result sizes. + fn normalize(&mut self) { + match self.children.len() { + 0 => (), + 1 => { + // If child has lower rank, merge it into current node + if self.children[0].rank() < self.rank() { + let mut child = self.children.pop().unwrap(); + self.query_nodes.append(&mut child.query_nodes); + self.children = child.children; + self.normalize(); + } else { + self.children[0].normalize(); + } + } + _ => { + // Normalize all child trees into chains, then merge them + for child in &mut self.children { + child.normalize(); + } + let child = std::mem::take(&mut self.children) + .into_iter() + .reduce(Self::merge) + .unwrap(); + self.children = vec![child]; + } + } + } + + /// Merges two precedence tree chains into a single chain. + /// + /// This operation combines two normalized tree chains (each with at most one child) + /// into a single chain, preserving rank ordering. The chain with the lower rank becomes + /// the parent, and the higher-ranked chain is attached as a descendant. + /// + /// The merge strategy depends on whether the lower-ranked chain has children: + /// - **No children**: The higher-ranked chain becomes the direct child + /// - **Has child**: Recursively merge the higher-ranked chain with the child, + /// maintaining the chain structure + /// + /// This ensures the resulting chain maintains proper rank ordering from root to leaf, + /// which is essential for the Ibaraki-Kameda optimization algorithm. + /// + /// # Arguments + /// + /// * `self` - The first tree chain to merge + /// * `other` - The second tree chain to merge + /// + /// # Returns + /// + /// Returns a merged `PrecedenceTreeNode` chain with both input chains combined, + /// ordered by rank values. + /// + /// # Panics + /// + /// May panic if called on non-normalized trees (trees with multiple children). + fn merge(self, other: PrecedenceTreeNode<'graph>) -> Self { + let (mut first, second) = if self.rank() < other.rank() { + (self, other) + } else { + (other, self) + }; + if first.children.is_empty() { + first.children = vec![second]; + } else { + first.children = vec![first.children.pop().unwrap().merge(second)]; + } + first + } + + /// Denormalizes a normalized precedence tree by splitting merged query nodes. + /// + /// This is the inverse operation of normalization, but with a critical property: + /// **the result is still a chain structure** (each node has at most one child). + /// It converts a normalized chain where nodes contain multiple query operations + /// into a longer chain where each node contains exactly one query operation. + /// + /// The denormalization process: + /// 1. **Validates input**: Ensures the tree is normalized (0 or 1 children per node) + /// 2. **Recursively processes children**: Denormalizes the child chain first + /// 3. **Splits merged nodes**: For nodes with multiple query operations, iteratively + /// extracts operations one at a time based on neighbor relationships with the child + /// 4. **Maintains ordering**: Uses rank-based selection to determine which query node + /// to extract next, choosing the highest-ranked neighbor of the child node + /// + /// **Key property**: After denormalization, the result remains a chain (not a tree with + /// branches). Each node contains exactly one query operation, but the chain structure + /// is preserved. This is the essence of the normalize-denormalize algorithm: transforming + /// an arbitrary tree into an optimized chain while respecting query dependencies. + /// + /// # Errors + /// + /// Returns an error if: + /// - The tree is not normalized (has more than one child at any level) + /// + /// # Algorithm + /// + /// The splitting process uses the query graph's neighbor relationships to determine + /// which nodes should be adjacent in the chain, maintaining logical dependencies + /// between query operations while producing a linear execution order. + fn denormalize(&mut self) -> Result<()> { + // Normalized trees must have 0 or 1 children + match self.children.len() { + 0 => (), + 1 => self.children[0].denormalize()?, + _ => return plan_err!("Tree is not normalized"), + } + + // Split query nodes into a chain based on neighbor relationships + while self.query_nodes.len() > 1 { + if self.children.is_empty() { + let highest_rank_idx = self + .query_nodes + .iter() + .enumerate() + .max_by(|(_, a), (_, b)| a.rank().partial_cmp(&b.rank()).unwrap()) + .map(|(idx, _)| idx) + .unwrap(); + + let node = self.query_nodes.remove(highest_rank_idx); + + self.children.push(PrecedenceTreeNode { + query_nodes: vec![node], + children: Vec::new(), + query_graph: self.query_graph, + }); + } else { + let child_id = self.children[0].query_nodes[0].node_id; + let child_node = self.query_graph.get_node(child_id).unwrap(); + let neighbours = child_node.neighbours(child_id, self.query_graph); + + // Find the highest-ranked neighbor node + let highest_rank_idx = self + .query_nodes + .iter() + .enumerate() + .filter(|(_, node)| neighbours.contains(&node.node_id)) + .max_by(|(_, a), (_, b)| a.rank().partial_cmp(&b.rank()).unwrap()) + .map(|(idx, _)| idx) + .unwrap(); + + let node = self.query_nodes.remove(highest_rank_idx); + + let child = std::mem::replace( + &mut self.children[0], + PrecedenceTreeNode { + query_nodes: vec![node], + children: Vec::new(), + query_graph: self.query_graph, + }, + ); + self.children[0].children = vec![child]; + }; + + // Insert the node between current and its child + } + Ok(()) + } + + /// Converts the precedence tree chain into a DataFusion `LogicalPlan`. + /// + /// This method walks down the optimized chain structure, building a left-deep join tree + /// by repeatedly joining the accumulated result with the next node in the chain. + /// + /// # Algorithm + /// + /// 1. Start with the first node's `LogicalPlan` from the query graph + /// 2. For each subsequent node in the chain: + /// - Get the node's `LogicalPlan` from the query graph + /// - Find the edge connecting the current and next nodes + /// - Create a join using the edge's join specification + /// - The accumulated plan becomes the left side of the join + /// 3. Return the final joined `LogicalPlan` + /// + /// # Arguments + /// + /// * `query_graph` - The query graph containing the logical plans and join specifications + /// + /// # Returns + /// + /// Returns a `LogicalPlan` representing the optimized join execution order. + /// + /// # Errors + /// + /// Returns an error if: + /// - A node or edge is missing from the query graph + /// - The precedence tree is not in the expected chain format + pub(crate) fn into_logical_plan( + self, + query_graph: &QueryGraph, + ) -> Result { + // Get the first node's logical plan + let current_node_id = self.query_nodes[0].node_id; + let mut current_plan = query_graph + .get_node(current_node_id) + .ok_or_else(|| plan_datafusion_err!("Node {:?} not found", current_node_id))? + .plan + .as_ref() + .clone(); + + // Track all processed nodes in order + let mut processed_nodes = vec![current_node_id]; + + // Walk down the chain, joining each subsequent node + let mut current_chain = &self; + + while !current_chain.children.is_empty() { + let child = ¤t_chain.children[0]; + let next_node_id = child.query_nodes[0].node_id; + + // Get the next node's logical plan + let next_plan = query_graph + .get_node(next_node_id) + .ok_or_else(|| plan_datafusion_err!("Node {:?} not found", next_node_id))? + .plan + .as_ref() + .clone(); + + // Find the edge connecting next_node to any processed node + let next_node = query_graph.get_node(next_node_id).ok_or_else(|| { + plan_datafusion_err!("Node {:?} not found", next_node_id) + })?; + + let edge = processed_nodes + .iter() + .rev() + .find_map(|&processed_id| { + next_node.connection_with(processed_id, query_graph) + }) + .ok_or_else(|| { + plan_datafusion_err!( + "No edge found between {:?} and any processed nodes {:?}", + next_node_id, + processed_nodes + ) + })?; + + // Determine if the join order was swapped compared to the original edge. + // We check if the qualified columns (relation + name) from the join expressions + // match the schemas. This handles all cases including when multiple tables + // have columns with the same name. + let current_schema = current_plan.schema(); + let next_schema = next_plan.schema(); + + let join_order_swapped = if !edge.join.on.is_empty() { + // Extract columns from the first join condition + let (left_expr, right_expr) = &edge.join.on[0]; + let left_columns = left_expr.column_refs(); + let right_columns = right_expr.column_refs(); + + // Helper to check if a qualified column exists in a schema + let column_in_schema = |col: &datafusion_common::Column, + schema: &datafusion_common::DFSchema| + -> bool { + if let Some(relation) = &col.relation { + // Column has a table qualifier - must match exactly (relation + name) + schema.iter().any(|(qualifier, field)| { + qualifier == Some(relation) && field.name() == col.name() + }) + } else { + // Unqualified column - check if the name exists anywhere in schema + schema.field_with_unqualified_name(&col.name).is_ok() + } + }; + + // Check which schema each expression's columns belong to + let left_in_current = + left_columns.iter().all(|c| column_in_schema(c, current_schema.as_ref())); + let right_in_next = + right_columns.iter().all(|c| column_in_schema(c, next_schema.as_ref())); + let left_in_next = + left_columns.iter().all(|c| column_in_schema(c, next_schema.as_ref())); + let right_in_current = + right_columns.iter().all(|c| column_in_schema(c, current_schema.as_ref())); + + // Determine swap based on where the qualified columns are found + if left_in_current && right_in_next { + // Left expression belongs to current, right to next → no swap + false + } else if left_in_next && right_in_current { + // Left expression belongs to next, right to current → swap + true + } else { + // Ambiguous or error case - default to no swap to preserve original order + // This shouldn't happen with properly qualified columns + false + } + } else { + // If there are no join conditions, we can't determine swap status + // This shouldn't happen in practice for equi-joins + false + }; + + // When the join order is swapped, we need to adjust the on conditions and join type + // to maintain correct semantics. For example: + // - Original: A LeftSemi B ON A.x = B.y + // - After swap: B RightSemi A ON B.y = A.x + let (on, join_type) = if join_order_swapped { + let swapped_on = edge + .join + .on + .iter() + .map(|(left, right)| (right.clone(), left.clone())) + .collect(); + (swapped_on, edge.join.join_type.swap()) + } else { + (edge.join.on.clone(), edge.join.join_type) + }; + + // Create the join plan + current_plan = LogicalPlan::Join(datafusion_expr::Join { + left: Arc::new(current_plan), + right: Arc::new(next_plan), + on, + filter: edge.join.filter.clone(), + join_type, + join_constraint: edge.join.join_constraint, + schema: Arc::clone(&edge.join.schema), + null_equality: edge.join.null_equality, + }); + + // Move to the next node in the chain + processed_nodes.push(next_node_id); + current_chain = child; + } + + Ok(current_plan) + } + + fn cost(&self) -> Result { + self.cost_recursive(self.query_nodes[0].selectivity, 0.0) + } + + fn cost_recursive(&self, cardinality: f64, cost: f64) -> Result { + let cost = match self.children.len() { + 0 => cost + cardinality * self.query_nodes[0].cost, + 1 => self.children[0].cost_recursive( + cardinality * self.query_nodes[0].selectivity, + cost + cardinality * self.query_nodes[0].cost, + )?, + _ => { + return plan_err!( + "Cost calculation requires normalized tree with 0 or 1 children" + ) + } + }; + Ok(cost) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::decorrelate_predicate_subquery::DecorrelatePredicateSubquery; + use crate::eliminate_filter::EliminateFilter; + use crate::extract_equijoin_predicate::ExtractEquijoinPredicate; + use crate::filter_null_join_keys::FilterNullJoinKeys; + use crate::optimizer::{Optimizer, OptimizerContext}; + use crate::push_down_filter::PushDownFilter; + use crate::reorder_join::cost::JoinCostEstimator; + use crate::scalar_subquery_to_join::ScalarSubqueryToJoin; + use crate::simplify_expressions::SimplifyExpressions; + use crate::test::*; + use datafusion_expr::logical_plan::JoinType; + use datafusion_expr::LogicalPlanBuilder; + + /// A simple cost estimator for testing + #[derive(Debug)] + struct TestCostEstimator; + + impl JoinCostEstimator for TestCostEstimator {} + + /// A simple TableSource implementation for testing join ordering with statistics + #[derive(Debug)] + struct JoinSource { + schema: arrow::datatypes::SchemaRef, + num_rows: usize, + } + + impl datafusion_expr::TableSource for JoinSource { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn schema(&self) -> arrow::datatypes::SchemaRef { + Arc::clone(&self.schema) + } + + fn statistics(&self) -> Option { + use datafusion_common::stats::Precision; + Some( + datafusion_common::Statistics::new_unknown(&self.schema) + .with_num_rows(Precision::Exact(self.num_rows)), + ) + } + } + + /// Create a table scan with statistics for testing join ordering + fn scan_tpch_table_with_stats(table: &str, num_rows: usize) -> LogicalPlan { + let schema = Arc::new(get_tpch_table_schema(table)); + let table_source: Arc = Arc::new(JoinSource { + schema: Arc::clone(&schema), + num_rows, + }); + LogicalPlanBuilder::scan(table, table_source, None) + .unwrap() + .build() + .unwrap() + } + + /// Test three-way join: customer -> orders -> lineitem + #[test] + fn test_three_way_join_customer_orders_lineitem() -> Result<()> { + use datafusion_expr::test::function_stub::sum; + use datafusion_expr::{col, in_subquery, lit}; + // Create the base table scans with statistics + // Create the base table scans with statistics + let customer = scan_tpch_table_with_stats("customer", 150); + let orders = scan_tpch_table_with_stats("orders", 1_500); + let lineitem = scan_tpch_table_with_stats("lineitem", 6_000); + + // Step 1: Build the subquery + // SELECT l_orderkey FROM lineitem + // GROUP BY l_orderkey + // HAVING sum(l_quantity) > 300 + let subquery = LogicalPlanBuilder::from(lineitem.clone()) + .aggregate(vec![col("l_orderkey")], vec![sum(col("l_quantity"))])? + .filter(sum(col("l_quantity")).gt(lit(300)))? + .project(vec![col("l_orderkey")])? + .build()?; + + // Step 2: Build the main query with joins + let plan = LogicalPlanBuilder::from(customer.clone()) + .join( + orders.clone(), + JoinType::Inner, + (vec!["c_custkey"], vec!["o_custkey"]), + None, + )? + .join( + lineitem.clone(), + JoinType::Inner, + (vec!["o_orderkey"], vec!["l_orderkey"]), + None, + )? + // Step 3: Apply the IN subquery filter + .filter(in_subquery(col("o_orderkey"), Arc::new(subquery)))? + // Step 4: Aggregate + .aggregate( + vec![ + col("c_name"), + col("c_custkey"), + col("o_orderkey"), + col("o_totalprice"), + ], + vec![sum(col("l_quantity"))], + )? + // Step 5: Sort + .sort(vec![col("o_totalprice").sort(false, true)])? + // Step 6: Limit + .limit(0, Some(100))? + .build()?; + + println!("{}", plan.display_indent()); + + // Optimize the plan with custom optimizer before join reordering + // We exclude OptimizeProjections to keep joins consecutive + let config = OptimizerContext::new().with_skip_failing_rules(false); + let optimizer = Optimizer::with_rules(vec![ + Arc::new(SimplifyExpressions::new()), + Arc::new(DecorrelatePredicateSubquery::new()), + Arc::new(ScalarSubqueryToJoin::new()), + Arc::new(ExtractEquijoinPredicate::new()), + Arc::new(EliminateFilter::new()), + Arc::new(FilterNullJoinKeys::default()), + Arc::new(PushDownFilter::new()), + // Note: OptimizeProjections is intentionally excluded to keep joins consecutive + ]); + let plan = optimizer.optimize(plan, &config, |_, _| {}).unwrap(); + + println!("After standard optimization:"); + println!("{}", plan.display_indent()); + + let optimized_plan = + optimal_left_deep_join_plan(plan, Rc::new(TestCostEstimator)).unwrap(); + + println!("Optimized Plan:"); + println!("{}", optimized_plan.display_indent()); + + // Verify the plan structure + assert!(matches!(optimized_plan, LogicalPlan::Limit(_))); + + Ok(()) + } +} diff --git a/datafusion/optimizer/src/reorder_join/mod.rs b/datafusion/optimizer/src/reorder_join/mod.rs new file mode 100644 index 000000000000..758f1daeacb1 --- /dev/null +++ b/datafusion/optimizer/src/reorder_join/mod.rs @@ -0,0 +1,22 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Optimizer rule for reordering joins to minimize query execution cost + +pub mod cost; +pub mod left_deep_join_plan; +pub mod query_graph; diff --git a/datafusion/optimizer/src/reorder_join/query_graph.rs b/datafusion/optimizer/src/reorder_join/query_graph.rs new file mode 100644 index 000000000000..ce5e67f29fe4 --- /dev/null +++ b/datafusion/optimizer/src/reorder_join/query_graph.rs @@ -0,0 +1,492 @@ +use std::sync::Arc; + +use datafusion_common::{ + plan_err, + tree_node::{TreeNode, TreeNodeRecursion}, + DataFusionError, Result, +}; +use datafusion_expr::{utils::check_all_columns_from_schema, Join, LogicalPlan}; + +pub type NodeId = usize; + +pub struct Node { + pub plan: Arc, + pub(crate) connections: Vec, +} + +impl Node { + pub(crate) fn connections(&self) -> &[EdgeId] { + &self.connections + } + + pub(crate) fn connection_with<'graph>( + &self, + node_id: NodeId, + query_graph: &'graph QueryGraph, + ) -> Option<&'graph Edge> { + self.connections + .iter() + .filter_map(|edge_id| query_graph.get_edge(*edge_id)) + .find(move |x| x.nodes.contains(&node_id)) + } + + pub(crate) fn neighbours( + &self, + node_id: NodeId, + query_graph: &QueryGraph, + ) -> Vec { + self.connections + .iter() + .filter_map(|edge_id| query_graph.get_edge(*edge_id)) + .flat_map(|edge| edge.nodes) + .filter(|&id| id != node_id) + .collect() + } +} + +pub type EdgeId = usize; + +pub struct Edge { + pub nodes: [NodeId; 2], + pub join: Join, +} + +pub struct QueryGraph { + pub(crate) nodes: VecMap, + edges: VecMap, +} + +impl QueryGraph { + pub(crate) fn new() -> Self { + Self { + nodes: VecMap::new(), + edges: VecMap::new(), + } + } + + pub(crate) fn add_node(&mut self, node_data: Arc) -> NodeId { + self.nodes.insert(Node { + plan: node_data, + connections: Vec::new(), + }) + } + + pub(crate) fn add_node_with_edge( + &mut self, + other: NodeId, + node_data: Arc, + edge_data: Join, + ) -> Option { + if self.nodes.contains_key(other) { + let new_id = self.nodes.insert(Node { + plan: node_data, + connections: Vec::new(), + }); + self.add_edge(new_id, other, edge_data); + Some(new_id) + } else { + None + } + } + + fn add_edge(&mut self, from: NodeId, to: NodeId, data: Join) -> Option { + if self.nodes.contains_key(from) && self.nodes.contains_key(to) { + let edge_id = self.edges.insert(Edge { + nodes: [from, to], + join: data, + }); + if let Some(from) = self.nodes.get_mut(from) { + from.connections.push(edge_id); + } + if let Some(to) = self.nodes.get_mut(to) { + to.connections.push(edge_id); + } + Some(edge_id) + } else { + None + } + } + + pub(crate) fn remove_node(&mut self, node_id: NodeId) -> Option> { + if let Some(node) = self.nodes.remove(node_id) { + // Remove all edges connected to this node + for edge_id in &node.connections { + if let Some(edge) = self.edges.remove(*edge_id) { + // Remove the edge from the other node's connections + for other_node_id in edge.nodes { + if other_node_id != node_id { + if let Some(other_node) = self.nodes.get_mut(other_node_id) { + other_node.connections.retain(|id| id != edge_id); + } + } + } + } + } + Some(node.plan) + } else { + None + } + } + + fn remove_edge(&mut self, edge_id: EdgeId) -> Option { + if let Some(edge) = self.edges.remove(edge_id) { + // Remove the edge from both nodes' connections + for node_id in edge.nodes { + if let Some(node) = self.nodes.get_mut(node_id) { + node.connections.retain(|id| *id != edge_id); + } + } + Some(edge.join) + } else { + None + } + } + + pub(crate) fn nodes(&self) -> impl Iterator { + self.nodes.iter() + } + + pub(crate) fn get_node(&self, key: NodeId) -> Option<&Node> { + self.nodes.get(key) + } + + pub(crate) fn get_edge(&self, key: EdgeId) -> Option<&Edge> { + self.edges.get(key) + } +} + +/// Extracts the join subtree from a logical plan, separating it from wrapper operators. +/// +/// This function traverses the plan tree from the root downward, collecting all non-join +/// operators until it finds the topmost join node. The join subtree (all consecutive joins) +/// is extracted and returned separately from the wrapper operators. +/// +/// # Arguments +/// +/// * `plan` - The logical plan to extract from +/// +/// # Returns +/// +/// Returns a tuple of (join_subtree, wrapper_operators) where: +/// - `join_subtree` is the topmost join and all joins beneath it +/// - `wrapper_operators` is a vector of non-join operators above the joins, in order from root to join +/// +/// # Errors +/// +/// Returns an error if the plan doesn't contain any joins. +pub(crate) fn extract_join_subtree( + plan: LogicalPlan, +) -> Result<(LogicalPlan, Vec)> { + let mut wrappers = Vec::new(); + let mut current = plan; + + // Descend through non-join nodes until we find a join + loop { + match current { + LogicalPlan::Join(_) => { + // Found the join subtree root + return Ok((current, wrappers)); + } + other => { + // Check if this node contains joins in its children + if !contains_join(&other) { + return plan_err!( + "Plan does not contain any join nodes: {}", + other.display() + ); + } + + // This node is a wrapper - store it and descend to its child + // For now, we only support single-child wrappers (Filter, Sort, Limit, Aggregate, etc.) + let inputs = other.inputs(); + if inputs.len() != 1 { + return plan_err!( + "Join extraction only supports single-input operators, found {} inputs in: {}", + inputs.len(), + other.display() + ); + } + + wrappers.push(other.clone()); + current = (*inputs[0]).clone(); + } + } + } +} + +/// Reconstructs a logical plan by wrapping an optimized join plan with the original wrapper operators. +/// +/// This function takes an optimized join plan and re-applies the wrapper operators (Filter, Sort, +/// Aggregate, etc.) that were removed during extraction. The wrappers are applied in reverse order +/// (innermost to outermost) to reconstruct the original plan structure. +/// +/// # Arguments +/// +/// * `join_plan` - The optimized join plan to wrap +/// * `wrappers` - Vector of wrapper operators in order from outermost to innermost (root to join) +/// +/// # Returns +/// +/// Returns the fully reconstructed logical plan with all wrapper operators reapplied. +/// +/// # Errors +/// +/// Returns an error if reconstructing any wrapper operator fails. +pub(crate) fn reconstruct_plan( + join_plan: LogicalPlan, + wrappers: Vec, +) -> Result { + let mut current = join_plan; + + // Apply wrappers in reverse order (from innermost to outermost) + for wrapper in wrappers.into_iter().rev() { + // Use with_new_exprs to reconstruct the wrapper with the new input + current = wrapper.with_new_exprs(wrapper.expressions(), vec![current])?; + } + + Ok(current) +} + +impl TryFrom for QueryGraph { + type Error = DataFusionError; + + fn try_from(value: LogicalPlan) -> Result { + // First, extract the join subtree from any wrapper operators + let (join_subtree, _wrappers) = extract_join_subtree(value)?; + + // Now convert only the join subtree to a query graph + let mut query_graph = QueryGraph::new(); + flatten_joins_recursive(join_subtree, &mut query_graph)?; + Ok(query_graph) + } +} + +fn flatten_joins_recursive( + plan: LogicalPlan, + query_graph: &mut QueryGraph, +) -> Result<()> { + match plan { + LogicalPlan::Join(join) => { + flatten_joins_recursive( + Arc::unwrap_or_clone(Arc::clone(&join.left)), + query_graph, + )?; + flatten_joins_recursive( + Arc::unwrap_or_clone(Arc::clone(&join.right)), + query_graph, + )?; + + // Process each equijoin predicate to find which nodes it connects + for (left_key, right_key) in &join.on { + // Extract column references from both join keys + let left_columns = left_key.column_refs(); + let right_columns = right_key.column_refs(); + + // Filter nodes by checking which ones contain the columns from each expression + let matching_nodes: Vec = query_graph + .nodes() + .filter_map(|(node_id, node)| { + let schema = node.plan.schema(); + // Check if this node's schema contains columns from either left or right key + let has_left = + check_all_columns_from_schema(&left_columns, schema.as_ref()) + .unwrap_or(false); + let has_right = check_all_columns_from_schema( + &right_columns, + schema.as_ref(), + ) + .unwrap_or(false); + + // Include node if it contains columns from either key (but not both, as that would be invalid) + if (has_left && !has_right) || (!has_left && has_right) { + Some(node_id) + } else { + None + } + }) + .collect(); + + // We should have exactly two nodes: one with left_key columns, one with right_key columns + if matching_nodes.len() != 2 { + return plan_err!( + "Could not find exactly two nodes for join predicate: {} = {} (found {} nodes)", + left_key, + right_key, + matching_nodes.len() + ); + } + + let node_id_a = matching_nodes[0]; + let node_id_b = matching_nodes[1]; + + // Add an edge if one doesn't exist yet + if let Some(node_a) = query_graph.get_node(node_id_a) { + if node_a.connection_with(node_id_b, query_graph).is_none() { + // No edge exists yet, create one with this join + query_graph.add_edge(node_id_a, node_id_b, join.clone()); + } + } + } + + Ok(()) + } + x => { + if contains_join(&x) { + plan_err!( + "Join reordering requires joins to be consecutive in the plan tree. \ + Found a non-join node that contains nested joins: {}", + x.display() + ) + } else { + query_graph.add_node(Arc::new(x)); + Ok(()) + } + } + } +} + +/// Checks if a LogicalPlan contains any join nodes +/// +/// Uses a TreeNode visitor to traverse the plan tree and detect the presence +/// of any `LogicalPlan::Join` nodes. +/// +/// # Arguments +/// +/// * `plan` - The logical plan to check +/// +/// # Returns +/// +/// `true` if the plan contains at least one join node, `false` otherwise +pub(crate) fn contains_join(plan: &LogicalPlan) -> bool { + let mut has_join = false; + + // Use TreeNode's apply method to traverse the plan + let _ = plan.apply(|node| { + if matches!(node, LogicalPlan::Join(_)) { + has_join = true; + // Stop traversal once we find a join + Ok(TreeNodeRecursion::Stop) + } else { + // Continue traversal + Ok(TreeNodeRecursion::Continue) + } + }); + + has_join +} + +/// A simple Vec-based map that uses Option for sparse storage +/// Keys are never reused once removed +pub(crate) struct VecMap(Vec>); + +impl VecMap { + pub(crate) fn new() -> Self { + Self(Vec::new()) + } + + pub(crate) fn insert(&mut self, value: V) -> usize { + let idx = self.0.len(); + self.0.push(Some(value)); + idx + } + + pub(crate) fn get(&self, key: usize) -> Option<&V> { + self.0.get(key)?.as_ref() + } + + pub(crate) fn get_mut(&mut self, key: usize) -> Option<&mut V> { + self.0.get_mut(key)?.as_mut() + } + + pub(crate) fn remove(&mut self, key: usize) -> Option { + self.0.get_mut(key)?.take() + } + + pub(crate) fn contains_key(&self, key: usize) -> bool { + self.0.get(key).and_then(|v| v.as_ref()).is_some() + } + + pub(crate) fn iter(&self) -> impl Iterator { + self.0 + .iter() + .enumerate() + .filter_map(|(idx, slot)| slot.as_ref().map(|v| (idx, v))) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::test::*; + use datafusion_expr::logical_plan::JoinType; + use datafusion_expr::{col, LogicalPlanBuilder}; + + /// Test converting a three-way join with filter into a QueryGraph + #[test] + fn test_try_from_three_way_join_with_filter() -> Result<(), DataFusionError> { + // Create three-way join: customer JOIN orders JOIN lineitem + // with a filter on the orders-lineitem join + let customer = scan_tpch_table("customer"); + let orders = scan_tpch_table("orders"); + let lineitem = scan_tpch_table("lineitem"); + + let plan = LogicalPlanBuilder::from(customer.clone()) + .join( + orders.clone(), + JoinType::Inner, + (vec!["c_custkey"], vec!["o_custkey"]), + None, + ) + .unwrap() + .join_with_expr_keys( + lineitem.clone(), + JoinType::Inner, + (vec![col("o_orderkey")], vec![col("l_orderkey")]), + Some(col("l_quantity").gt(datafusion_expr::lit(10.0))), + ) + .unwrap() + .build() + .unwrap(); + + // Convert to QueryGraph + let query_graph = QueryGraph::try_from(plan)?; + + // Verify structure: 3 nodes, 2 edges + assert_eq!(query_graph.nodes().count(), 3); + assert_eq!(query_graph.edges.iter().count(), 2); + + // Verify connectivity: one node has 2 connections (orders), two nodes have 1 + let mut connections: Vec = query_graph + .nodes() + .map(|(_, node)| node.connections().len()) + .collect(); + connections.sort(); + assert_eq!(connections, vec![1, 1, 2]); + + // Verify edges have correct join predicates + let edges: Vec<&Edge> = query_graph.edges.iter().map(|(_, e)| e).collect(); + + // One edge should have c_custkey = o_custkey + let has_customer_orders = edges.iter().any(|e| { + e.join.on.iter().any(|(l, r)| { + let s = format!("{l}{r}"); + s.contains("c_custkey") && s.contains("o_custkey") + }) + }); + assert!(has_customer_orders, "Missing customer-orders join"); + + // One edge should have o_orderkey = l_orderkey with a filter + let has_orders_lineitem = edges.iter().any(|e| { + e.join.on.iter().any(|(l, r)| { + let s = format!("{l}{r}"); + s.contains("o_orderkey") && s.contains("l_orderkey") + }) && e.join.filter.is_some() + }); + assert!( + has_orders_lineitem, + "Missing orders-lineitem join with filter" + ); + + Ok(()) + } +}