diff --git a/advanced/pytorch-example/main.py b/advanced/pytorch-example/main.py index 4a2c68a..a415874 100644 --- a/advanced/pytorch-example/main.py +++ b/advanced/pytorch-example/main.py @@ -4,7 +4,7 @@ import torch.nn.functional as F import torch.distributed as dist import os -from torchvision import datasets, transforms +from torchvision import transforms from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.data import DataLoader, DistributedSampler @@ -36,13 +36,19 @@ def forward(self, x): return F.log_softmax(x, dim=1) def train(): + master_addr = os.environ.get("MASTER_ADDR", "localhost") + master_port = os.environ.get("MASTER_PORT", "29500") + world_size = int(os.environ["OMPI_COMM_WORLD_SIZE"]) + rank = int(os.environ["OMPI_COMM_WORLD_RANK"]) + local_rank = rank % torch.cuda.device_count() + # Initialize process group - dist.init_process_group(backend="nccl") - - # Get local rank from environment variable - local_rank = int(os.environ["LOCAL_RANK"]) - rank = int(os.environ["RANK"]) - world_size = int(os.environ["WORLD_SIZE"]) + dist.init_process_group( + backend="nccl", + init_method=f"tcp://{master_addr}:{master_port}", + world_size=world_size, + rank=rank + ) # Set device torch.cuda.set_device(local_rank) @@ -77,7 +83,8 @@ def transform(example): loss.backward() optimizer.step() - if batch_idx % 10 == 0: + dist.all_reduce(loss, op=dist.ReduceOp.AVG) + if rank == 0 and batch_idx % 10 == 0: print(f"Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} ({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}") if rank == 0: