Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@ public class SqlTaskExecution
private final List<PlanNodeId> sourceStartOrder;
@GuardedBy("this")
private int schedulingPlanNodeOrdinal;
@GuardedBy("this")
private ListenableFuture<Void> pipelineDependenciesSatisfied = immediateVoidFuture();

@GuardedBy("this")
private final Map<PlanNodeId, PendingSplitsForPlanNode> pendingSplitsByPlanNode;
Expand Down Expand Up @@ -309,7 +311,8 @@ private synchronized Set<PlanNodeId> updateSplitAssignments(List<SplitAssignment
// update task with new sources
for (SplitAssignment splitAssignment : unacknowledgedSplitAssignment) {
if (driverRunnerFactoriesWithSplitLifeCycle.containsKey(splitAssignment.getPlanNodeId())) {
schedulePartitionedSource(splitAssignment);
mergeIntoPendingSplits(splitAssignment.getPlanNodeId(), splitAssignment.getSplits(), splitAssignment.isNoMoreSplits());
schedulePartitionedSourcePendingSplits();
}
else {
// tell existing drivers about the new splits
Expand All @@ -331,23 +334,40 @@ private void mergeIntoPendingSplits(PlanNodeId planNodeId, Set<ScheduledSplit> s
PendingSplitsForPlanNode pendingSplitsForPlanNode = pendingSplitsByPlanNode.get(planNodeId);

partitionedDriverFactory.splitsAdded(scheduledSplits.size(), SplitWeight.rawValueSum(scheduledSplits, scheduledSplit -> scheduledSplit.getSplit().getSplitWeight()));
for (ScheduledSplit scheduledSplit : scheduledSplits) {
pendingSplitsForPlanNode.addSplit(scheduledSplit);
}
pendingSplitsForPlanNode.addSplits(scheduledSplits);
if (noMoreSplits) {
pendingSplitsForPlanNode.setNoMoreSplits();
}
}

private synchronized void schedulePartitionedSource(SplitAssignment splitAssignmentUpdate)
private synchronized void scheduleSourcePartitionedSplitsAfterPipelineUnblocked()
{
mergeIntoPendingSplits(splitAssignmentUpdate.getPlanNodeId(), splitAssignmentUpdate.getSplits(), splitAssignmentUpdate.isNoMoreSplits());
try (SetThreadName _ = new SetThreadName("Task-" + taskId)) {
// Enqueue pending splits as split runners after unblocking
schedulePartitionedSourcePendingSplits();
// Re-check for task completion since we may have just set no more splits
checkTaskCompletion();
}
}

private synchronized void schedulePartitionedSourcePendingSplits()
{
while (schedulingPlanNodeOrdinal < sourceStartOrder.size()) {
PlanNodeId schedulingPlanNode = sourceStartOrder.get(schedulingPlanNodeOrdinal);

DriverSplitRunnerFactory partitionedDriverRunnerFactory = driverRunnerFactoriesWithSplitLifeCycle.get(schedulingPlanNode);

// Avoid creating split runners for pipelines that are awaiting another pipeline completing (e.g. probe side of a join waiting
// on the broadcast completion). Otherwise, build side pipelines will have reduced concurrency available.
ListenableFuture<Void> pipelineDependenciesSatisfied = partitionedDriverRunnerFactory.getPipelineDependenciesSatisfied();
if (!pipelineDependenciesSatisfied.isDone()) {
// Only register a single re-schedule listener if we're blocked on pipeline dependencies
if (this.pipelineDependenciesSatisfied.isDone()) {
this.pipelineDependenciesSatisfied = pipelineDependenciesSatisfied;
pipelineDependenciesSatisfied.addListener(this::scheduleSourcePartitionedSplitsAfterPipelineUnblocked, notificationExecutor);
}
break;
}
PendingSplitsForPlanNode pendingSplits = pendingSplitsByPlanNode.get(schedulingPlanNode);

// Enqueue driver runners with split lifecycle for this plan node and driver life cycle combination.
Expand Down Expand Up @@ -541,10 +561,10 @@ public SplitsState getState()
return state;
}

public void addSplit(ScheduledSplit scheduledSplit)
public void addSplits(Set<ScheduledSplit> scheduledSplits)
{
checkState(state == ADDING_SPLITS);
splits.add(scheduledSplit);
splits.addAll(scheduledSplits);
}

public Set<ScheduledSplit> removeAllSplits()
Expand Down Expand Up @@ -602,6 +622,11 @@ private DriverSplitRunnerFactory(DriverFactory driverFactory, Tracer tracer, boo
.startSpan();
}

public ListenableFuture<Void> getPipelineDependenciesSatisfied()
{
return driverFactory.getPipelineDependenciesSatisfied();
}

public DriverSplitRunner createPartitionedDriverRunner(ScheduledSplit partitionedSplit)
{
return createDriverRunner(partitionedSplit, partitionedSplit.getSplit().getSplitWeight().getRawValue());
Expand Down
14 changes: 14 additions & 0 deletions core/trino-main/src/main/java/io/trino/operator/DriverFactory.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
package io.trino.operator;

import com.google.common.collect.ImmutableList;
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.errorprone.annotations.concurrent.GuardedBy;
import io.trino.sql.planner.plan.PlanNodeId;
import jakarta.annotation.Nullable;
Expand All @@ -26,6 +28,7 @@
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.util.concurrent.MoreExecutors.directExecutor;
import static java.util.Objects.requireNonNull;

public class DriverFactory
Expand All @@ -35,6 +38,7 @@ public class DriverFactory
private final boolean outputDriver;
private final Optional<PlanNodeId> sourceId;
private final OptionalInt driverInstances;
private final ListenableFuture<Void> pipelineDependenciesSatisfied;

// must synchronize between createDriver() and noMoreDrivers(), but isNoMoreDrivers() is safe without synchronizing
@GuardedBy("this")
Expand All @@ -57,6 +61,11 @@ public DriverFactory(int pipelineId, boolean inputDriver, boolean outputDriver,
.collect(toImmutableList());
checkArgument(sourceIds.size() <= 1, "Expected at most one source operator in driver factory, but found %s", sourceIds);
this.sourceId = sourceIds.isEmpty() ? Optional.empty() : Optional.of(sourceIds.get(0));
List<ListenableFuture<Void>> pipelineDependencies = operatorFactories.stream()
.map(OperatorFactory::pipelineDependenciesSatisfied)
.filter(future -> !future.isDone())
.collect(toImmutableList());
this.pipelineDependenciesSatisfied = pipelineDependencies.isEmpty() ? Futures.immediateVoidFuture() : Futures.whenAllComplete(pipelineDependencies).call(() -> null, directExecutor());
}

public int getPipelineId()
Expand All @@ -74,6 +83,11 @@ public boolean isOutputDriver()
return outputDriver;
}

public ListenableFuture<Void> getPipelineDependenciesSatisfied()
{
return pipelineDependenciesSatisfied;
}

/**
* return the sourceId of this DriverFactory.
* A DriverFactory doesn't always have source node.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@
*/
package io.trino.operator;

import com.google.common.util.concurrent.ListenableFuture;

import static com.google.common.util.concurrent.Futures.immediateVoidFuture;

public interface OperatorFactory
{
Operator createOperator(DriverContext driverContext);
Expand All @@ -27,4 +31,14 @@ public interface OperatorFactory
void noMoreOperators();

OperatorFactory duplicate();

/**
* Returns a future indicating that any dependencies operators have on other pipelines has been satisfied and that leaf splits
* should be allowed to start for this operator. This is used to prevent join probe splits from starting before the build side
* of a join is ready when the two are in the same stage (i.e.: broadcast join on top of a table scan).
*/
default ListenableFuture<Void> pipelineDependenciesSatisfied()
{
return immediateVoidFuture();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
package io.trino.operator;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture;
import io.trino.memory.context.MemoryTrackingContext;
import io.trino.operator.join.JoinOperatorFactory;
Expand Down Expand Up @@ -86,6 +87,21 @@ public Optional<OperatorFactory> createOuterOperatorFactory()
return lookupJoin.createOuterOperatorFactory();
}

@Override
public ListenableFuture<Void> buildPipelineReady()
{
if (!(operatorFactory instanceof JoinOperatorFactory lookupJoin)) {
return Futures.immediateVoidFuture();
}
return lookupJoin.buildPipelineReady();
}

@Override
public ListenableFuture<Void> pipelineDependenciesSatisfied()
{
return buildPipelineReady();
}

@VisibleForTesting
public WorkProcessorOperatorFactory getWorkProcessorOperatorFactory()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ public static JoinBridgeManager<PartitionedLookupSourceFactory> lookupAllAtOnce(
private final List<Type> buildOutputTypes;
private final boolean buildOuter;
private final T joinBridge;
private final ListenableFuture<Void> whenBuildFinishes;

private final AtomicBoolean initialized = new AtomicBoolean();
private JoinLifecycle joinLifecycle;
Expand All @@ -57,6 +58,7 @@ public JoinBridgeManager(
this.buildOuter = buildOuter;
this.joinBridge = requireNonNull(joinBridge, "joinBridge is null");
this.buildOutputTypes = requireNonNull(buildOutputTypes, "buildOutputTypes is null");
this.whenBuildFinishes = requireNonNull(joinBridge.whenBuildFinishes(), "whenBuildFinishes is null");
}

private void initializeIfNecessary()
Expand All @@ -67,7 +69,7 @@ private void initializeIfNecessary()
return;
}
int finalProbeFactoryCount = probeFactoryCount.get();
joinLifecycle = new JoinLifecycle(joinBridge, finalProbeFactoryCount, buildOuter ? 1 : 0);
joinLifecycle = new JoinLifecycle(whenBuildFinishes, joinBridge, finalProbeFactoryCount, buildOuter ? 1 : 0);
initialized.set(true);
}
}
Expand All @@ -83,6 +85,11 @@ public void incrementProbeFactoryCount()
probeFactoryCount.increment();
}

public ListenableFuture<Void> getBuildFinishedFuture()
{
return whenBuildFinishes;
}

public T getJoinBridge()
{
initializeIfNecessary();
Expand Down Expand Up @@ -139,7 +146,7 @@ private static class JoinLifecycle
private final ListenableFuture<Void> whenBuildAndProbeFinishes;
private final ListenableFuture<Void> whenAllFinishes;

public JoinLifecycle(JoinBridge joinBridge, int probeFactoryCount, int outerFactoryCount)
private JoinLifecycle(ListenableFuture<Void> whenBuildFinishes, JoinBridge joinBridge, int probeFactoryCount, int outerFactoryCount)
{
// When all probe and lookup-outer operators finish, destroy the join bridge (freeing the memory)
// * Each LookupOuterOperatorFactory count as 1
Expand All @@ -152,7 +159,7 @@ public JoinLifecycle(JoinBridge joinBridge, int probeFactoryCount, int outerFact
// * Each probe operator count as 1
probeReferenceCount = new ReferenceCount(probeFactoryCount);

whenBuildAndProbeFinishes = Futures.whenAllSucceed(joinBridge.whenBuildFinishes(), probeReferenceCount.getFreeFuture()).call(() -> null, directExecutor());
whenBuildAndProbeFinishes = Futures.whenAllSucceed(whenBuildFinishes, probeReferenceCount.getFreeFuture()).call(() -> null, directExecutor());
whenAllFinishes = Futures.whenAllSucceed(whenBuildAndProbeFinishes, outerReferenceCount.getFreeFuture()).call(() -> null, directExecutor());
whenAllFinishes.addListener(joinBridge::destroy, directExecutor());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,18 @@
*/
package io.trino.operator.join;

import com.google.common.util.concurrent.ListenableFuture;
import io.trino.operator.OperatorFactory;

import java.util.Optional;

public interface JoinOperatorFactory
{
Optional<OperatorFactory> createOuterOperatorFactory();

/**
* Future that indicates when the build side of the join has been completed and probe processing
* can begin. Used by {@link OperatorFactory#pipelineDependenciesSatisfied()}.
*/
ListenableFuture<Void> buildPipelineReady();
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import com.google.common.collect.ImmutableList;
import com.google.common.primitives.Ints;
import com.google.common.util.concurrent.ListenableFuture;
import io.trino.operator.HashGenerator;
import io.trino.operator.JoinOperatorType;
import io.trino.operator.OperatorFactory;
Expand Down Expand Up @@ -164,6 +165,12 @@ public String getOperatorType()
return LookupJoinOperator.class.getSimpleName();
}

@Override
public ListenableFuture<Void> buildPipelineReady()
{
return joinBridgeManager.getBuildFinishedFuture();
}

@Override
public WorkProcessorOperator create(ProcessorContext processorContext, WorkProcessor<Page> sourcePages)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
package io.trino.operator.join.unspilled;

import com.google.common.collect.ImmutableList;
import com.google.common.util.concurrent.ListenableFuture;
import io.trino.operator.JoinOperatorType;
import io.trino.operator.OperatorFactory;
import io.trino.operator.ProcessorContext;
Expand Down Expand Up @@ -133,6 +134,12 @@ public String getOperatorType()
return LookupJoinOperator.class.getSimpleName();
}

@Override
public ListenableFuture<Void> buildPipelineReady()
{
return joinBridgeManager.getBuildFinishedFuture();
}

@Override
public WorkProcessorOperator create(ProcessorContext processorContext, WorkProcessor<Page> sourcePages)
{
Expand Down