From 515731b3c96334e857050c9479740b8d4be125a4 Mon Sep 17 00:00:00 2001 From: brando90 Date: Sat, 5 Feb 2022 13:34:09 -0600 Subject: [PATCH] data_augmentation for get_tasks --- learn2learn/vision/benchmarks/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/learn2learn/vision/benchmarks/__init__.py b/learn2learn/vision/benchmarks/__init__.py index 988927ed..f202f579 100644 --- a/learn2learn/vision/benchmarks/__init__.py +++ b/learn2learn/vision/benchmarks/__init__.py @@ -59,6 +59,7 @@ def get_tasksets( test_samples=10, num_tasks=-1, root='~/data', + data_augmentation=None, device=None, **kwargs, ): @@ -103,6 +104,7 @@ def get_tasksets( test_ways=test_ways, test_samples=test_samples, root=root, + data_augmentation=data_augmentation, device=device, **kwargs) train_dataset, validation_dataset, test_dataset = datasets