diff --git a/advanced/pytorch-example/main.py b/advanced/pytorch-example/main.py index 5a64c11..4a2c68a 100644 --- a/advanced/pytorch-example/main.py +++ b/advanced/pytorch-example/main.py @@ -84,6 +84,7 @@ def transform(example): torch.save(model.module.state_dict(), "mnist_model.pth") print("Model saved as mnist_model.pth") + dist.barrier() dist.destroy_process_group() if __name__ == "__main__":