Skip to content

Commit 039d07d

Browse files
authored
feat(examples): better torch distributed training example with mpi (#77)
* feat: direct get init_process_group paramters from ompi and environment * fix: pass lint * fix: print avg loss * fix: only rank 0 print the progress status * fix: only rank 0 print the progress status
1 parent 8894f6c commit 039d07d

1 file changed

Lines changed: 15 additions & 8 deletions

File tree

advanced/pytorch-example/main.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import torch.nn.functional as F
55
import torch.distributed as dist
66
import os
7-
from torchvision import datasets, transforms
7+
from torchvision import transforms
88
from torch.nn.parallel import DistributedDataParallel as DDP
99
from torch.utils.data import DataLoader, DistributedSampler
1010

@@ -36,13 +36,19 @@ def forward(self, x):
3636
return F.log_softmax(x, dim=1)
3737

3838
def train():
39+
master_addr = os.environ.get("MASTER_ADDR", "localhost")
40+
master_port = os.environ.get("MASTER_PORT", "29500")
41+
world_size = int(os.environ["OMPI_COMM_WORLD_SIZE"])
42+
rank = int(os.environ["OMPI_COMM_WORLD_RANK"])
43+
local_rank = rank % torch.cuda.device_count()
44+
3945
# Initialize process group
40-
dist.init_process_group(backend="nccl")
41-
42-
# Get local rank from environment variable
43-
local_rank = int(os.environ["LOCAL_RANK"])
44-
rank = int(os.environ["RANK"])
45-
world_size = int(os.environ["WORLD_SIZE"])
46+
dist.init_process_group(
47+
backend="nccl",
48+
init_method=f"tcp://{master_addr}:{master_port}",
49+
world_size=world_size,
50+
rank=rank
51+
)
4652

4753
# Set device
4854
torch.cuda.set_device(local_rank)
@@ -77,7 +83,8 @@ def transform(example):
7783
loss.backward()
7884
optimizer.step()
7985

80-
if batch_idx % 10 == 0:
86+
dist.all_reduce(loss, op=dist.ReduceOp.AVG)
87+
if rank == 0 and batch_idx % 10 == 0:
8188
print(f"Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} ({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}")
8289

8390
if rank == 0:

0 commit comments

Comments
 (0)