diff --git a/advanced/pytorch-example/main.py b/advanced/pytorch-example/main.py index 8a08549..5a64c11 100644 --- a/advanced/pytorch-example/main.py +++ b/advanced/pytorch-example/main.py @@ -8,6 +8,9 @@ from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.data import DataLoader, DistributedSampler +from datasets import load_dataset + + class MNISTModel(nn.Module): def __init__(self): super(MNISTModel, self).__init__() @@ -47,11 +50,15 @@ def train(): print(f"Running on rank {rank} (local_rank: {local_rank})") - transform = transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize((0.1307,), (0.3081,)) - ]) - dataset = datasets.MNIST('./data', train=True, download=True, transform=transform) + def transform(example): + imgs = [transforms.ToTensor()(img) for img in example["image"]] + imgs = [transforms.Normalize((0.1307,), (0.3081,))(img) for img in imgs] + example["image"] = torch.stack(imgs) + example["label"] = torch.tensor(example["label"]) + return example + + dataset = load_dataset("mnist", split="train") + dataset = dataset.with_transform(transform) sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank) train_loader = DataLoader(dataset, batch_size=64, sampler=sampler) @@ -62,8 +69,8 @@ def train(): model.train() for epoch in range(1, 11): sampler.set_epoch(epoch) - for batch_idx, (data, target) in enumerate(train_loader): - data, target = data.to(device), target.to(device) + for batch_idx, batch_data in enumerate(train_loader): + data, target = batch_data["image"].to(device), batch_data["label"].to(device) optimizer.zero_grad() output = model(data) loss = F.nll_loss(output, target) diff --git a/advanced/pytorch-example/requirements.txt b/advanced/pytorch-example/requirements.txt new file mode 100644 index 0000000..c0190e1 --- /dev/null +++ b/advanced/pytorch-example/requirements.txt @@ -0,0 +1 @@ +datasets \ No newline at end of file