diff --git a/advanced/pytorch-example/main.py b/advanced/pytorch-example/main.py new file mode 100644 index 0000000..6e7f462 --- /dev/null +++ b/advanced/pytorch-example/main.py @@ -0,0 +1,75 @@ +import torch +import torch.nn as nn +import torch.optim as optim +import torch.nn.functional as F +import torch.distributed as dist +from torchvision import datasets, transforms +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.data import DataLoader, DistributedSampler + +class MNISTModel(nn.Module): + def __init__(self): + super(MNISTModel, self).__init__() + self.conv1 = nn.Conv2d(1, 32, 3, 1) + self.conv2 = nn.Conv2d(32, 64, 3, 1) + self.dropout1 = nn.Dropout(0.25) + self.dropout2 = nn.Dropout(0.5) + self.fc1 = nn.Linear(9216, 128) + self.fc2 = nn.Linear(128, 10) + + def forward(self, x): + x = self.conv1(x) + x = F.relu(x) + x = self.conv2(x) + x = F.relu(x) + x = F.max_pool2d(x, 2) + x = self.dropout1(x) + x = torch.flatten(x, 1) + x = self.fc1(x) + x = F.relu(x) + x = self.dropout2(x) + x = self.fc2(x) + return F.log_softmax(x, dim=1) + +def train(rank, world_size): + print(f"Running on rank {rank}.") + dist.init_process_group("nccl", rank=rank, world_size=world_size) + + transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.1307,), (0.3081,)) + ]) + dataset = datasets.MNIST('./data', train=True, download=True, transform=transform) + sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank) + train_loader = DataLoader(dataset, batch_size=64, sampler=sampler) + + model = MNISTModel().to(rank) + model = DDP(model, device_ids=[rank]) + optimizer = optim.Adam(model.parameters(), lr=0.001) + + model.train() + for epoch in range(1, 11): + sampler.set_epoch(epoch) + for batch_idx, (data, target) in enumerate(train_loader): + data, target = data.to(rank), target.to(rank) + optimizer.zero_grad() + output = model(data) + loss = F.nll_loss(output, target) + loss.backward() + optimizer.step() + + if batch_idx % 10 == 0: + print(f"Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} ({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}") + + if rank == 0: + torch.save(model.module.state_dict(), "mnist_model.pth") + print("Model saved as mnist_model.pth") + + dist.destroy_process_group() + +def main(): + world_size = torch.cuda.device_count() + torch.multiprocessing.spawn(train, args=(world_size,), nprocs=world_size, join=True) + +if __name__ == "__main__": + main() \ No newline at end of file