Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions src/lightning/pytorch/strategies/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ def __init__(
self._process_group_backend: Optional[str] = process_group_backend
self._timeout: Optional[timedelta] = timeout
self._start_method = start_method
self._pl_static_graph_delay_done = False

@property
def is_distributed(self) -> bool: # pragma: no-cover
Expand Down Expand Up @@ -319,6 +320,27 @@ def pre_backward(self, closure_loss: Tensor) -> None:
if not self.lightning_module.automatic_optimization:
prepare_for_backward(self.model, closure_loss)

@override
def post_backward(self, closure_loss: Tensor) -> None:
# Only for first static-graph iteration with manual optimization
model = self.model
lm = self.lightning_module
if not isinstance(model, DistributedDataParallel):
return
if lm is None or lm.automatic_optimization:
return
if not getattr(model, "static_graph", False):
return
if self._pl_static_graph_delay_done:
return

# Call DDP's own first-iter static-graph flush.
# This is what actually launches the bucket all-reduces.
reducer = model.reducer
reducer._delay_all_reduce()

self._pl_static_graph_delay_done = True

@override
def model_to_device(self) -> None:
log.debug(f"{self.__class__.__name__}: moving model to device [{self.root_device}]...")
Expand Down
47 changes: 47 additions & 0 deletions tests/tests_pytorch/strategies/test_ddp_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,3 +448,50 @@ def creates_processes_externally(self):
RuntimeError, match="Lightning attempted to launch new distributed processes with `local_rank > 0`."
):
trainer.fit(model)


@RunIf(min_cuda_gpus=2, standalone=True)
@pytest.mark.parametrize("automatic_optimization", [True, False])
@pytest.mark.parametrize("static_graph", [True, False])
def test_ddp_gradients_synced(tmp_path, automatic_optimization, static_graph):
"""Ensure gradients are synchronized across ranks for both optimization modes and static_graph settings."""

class TestModel(BoringModel):
def __init__(self):
super().__init__()
self.automatic_optimization = automatic_optimization

def training_step(self, batch, batch_idx):
if self.automatic_optimization:
return super().training_step(batch, batch_idx)

# manual optimization path
opt = self.optimizers()
opt.zero_grad()
out = super().training_step(batch, batch_idx)
loss = out["loss"]
self.manual_backward(loss)
opt.step()
return out

def on_train_batch_end(self, *args, **kwargs):
# record grad sum for sync check
grad_sum = self.layer.bias.grad.detach().sum()
self.log("grad_sum_min", grad_sum, sync_dist=True, reduce_fx="min")
self.log("grad_sum_max", grad_sum, sync_dist=True, reduce_fx="max")

trainer = Trainer(
default_root_dir=tmp_path,
accelerator="gpu",
devices=2,
strategy=DDPStrategy(static_graph=static_graph),
max_steps=1,
enable_progress_bar=False,
enable_model_summary=False,
)
trainer.fit(TestModel(), datamodule=BoringDataModule())

# assert all ranks saw identical grads
gmin = trainer.callback_metrics["grad_sum_min"]
gmax = trainer.callback_metrics["grad_sum_max"]
assert torch.allclose(gmin, gmax)
Loading