From f6e124745553b02a6bc20a801915916cc0a9098f Mon Sep 17 00:00:00 2001 From: jingxuz Date: Thu, 7 Aug 2025 16:27:12 +0800 Subject: [PATCH 1/5] feat: direct get init_process_group paramters from ompi and environment --- advanced/pytorch-example/main.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/advanced/pytorch-example/main.py b/advanced/pytorch-example/main.py index 4a2c68a..fd72e11 100644 --- a/advanced/pytorch-example/main.py +++ b/advanced/pytorch-example/main.py @@ -36,13 +36,19 @@ def forward(self, x): return F.log_softmax(x, dim=1) def train(): + master_addr = os.environ.get("MASTER_ADDR", "localhost") + master_port = os.environ.get("MASTER_PORT", "29500") + world_size = int(os.environ["OMPI_COMM_WORLD_SIZE"]) + rank = int(os.environ["OMPI_COMM_WORLD_RANK"]) + local_rank = rank % torch.cuda.device_count() + # Initialize process group - dist.init_process_group(backend="nccl") - - # Get local rank from environment variable - local_rank = int(os.environ["LOCAL_RANK"]) - rank = int(os.environ["RANK"]) - world_size = int(os.environ["WORLD_SIZE"]) + dist.init_process_group( + backend="nccl", + init_method=f"tcp://{master_addr}:{master_port}", + world_size=world_size, + rank=rank + ) # Set device torch.cuda.set_device(local_rank) From b815a4aaceb98dfdfb3c0cf3459dca12d5d3f9cc Mon Sep 17 00:00:00 2001 From: jingxuz Date: Thu, 7 Aug 2025 16:35:44 +0800 Subject: [PATCH 2/5] fix: pass lint --- advanced/pytorch-example/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/advanced/pytorch-example/main.py b/advanced/pytorch-example/main.py index fd72e11..c6ac7c8 100644 --- a/advanced/pytorch-example/main.py +++ b/advanced/pytorch-example/main.py @@ -4,7 +4,7 @@ import torch.nn.functional as F import torch.distributed as dist import os -from torchvision import datasets, transforms +from torchvision import transforms from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.data import DataLoader, DistributedSampler From 854f48afd0ec991d2e1e63d547b880b653b5929a Mon Sep 17 00:00:00 2001 From: jingxuz Date: Thu, 7 Aug 2025 17:02:33 +0800 Subject: [PATCH 3/5] fix: print avg loss --- advanced/pytorch-example/main.py | 1 + 1 file changed, 1 insertion(+) diff --git a/advanced/pytorch-example/main.py b/advanced/pytorch-example/main.py index c6ac7c8..8a6ef79 100644 --- a/advanced/pytorch-example/main.py +++ b/advanced/pytorch-example/main.py @@ -84,6 +84,7 @@ def transform(example): optimizer.step() if batch_idx % 10 == 0: + dist.all_reduce(loss, op=dist.ReduceOp.AVG) print(f"Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} ({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}") if rank == 0: From c4e4771a6aa425db66135b223b351ee6ebb078b9 Mon Sep 17 00:00:00 2001 From: jingxuz Date: Thu, 7 Aug 2025 17:06:07 +0800 Subject: [PATCH 4/5] fix: only rank 0 print the progress status --- advanced/pytorch-example/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/advanced/pytorch-example/main.py b/advanced/pytorch-example/main.py index 8a6ef79..4cff487 100644 --- a/advanced/pytorch-example/main.py +++ b/advanced/pytorch-example/main.py @@ -83,7 +83,7 @@ def transform(example): loss.backward() optimizer.step() - if batch_idx % 10 == 0: + if rank == 0 and batch_idx % 10 == 0: dist.all_reduce(loss, op=dist.ReduceOp.AVG) print(f"Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} ({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}") From c684e5d49f4d452ecea07c85441195a4eb7dd81f Mon Sep 17 00:00:00 2001 From: jingxuz Date: Thu, 7 Aug 2025 17:11:21 +0800 Subject: [PATCH 5/5] fix: only rank 0 print the progress status --- advanced/pytorch-example/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/advanced/pytorch-example/main.py b/advanced/pytorch-example/main.py index 4cff487..a415874 100644 --- a/advanced/pytorch-example/main.py +++ b/advanced/pytorch-example/main.py @@ -83,8 +83,8 @@ def transform(example): loss.backward() optimizer.step() + dist.all_reduce(loss, op=dist.ReduceOp.AVG) if rank == 0 and batch_idx % 10 == 0: - dist.all_reduce(loss, op=dist.ReduceOp.AVG) print(f"Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} ({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}") if rank == 0: