|
16 | 16 |
|
17 | 17 | import torch.utils.data |
18 | 18 | import numpy as np |
| 19 | +import pandas as pd |
19 | 20 |
|
20 | 21 | from .constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD |
21 | 22 | from .dataset import IterableImageDataset, ImageDataset |
@@ -228,6 +229,7 @@ def create_loader( |
228 | 229 | worker_seeding: str = 'all', |
229 | 230 | tf_preprocessing: bool = False, |
230 | 231 | balance_classes: bool = False, |
| 232 | + dataset_csv_path: Optional[str] = None |
231 | 233 | ): |
232 | 234 | """ |
233 | 235 |
|
@@ -272,10 +274,12 @@ def create_loader( |
272 | 274 | worker_seeding: Control worker random seeding at init. |
273 | 275 | tf_preprocessing: Use TF 1.0 inference preprocessing for testing model ports. |
274 | 276 | balance_classes: Sample classes with uniform probability |
| 277 | + dataset_csv_path: Path to dataset csv, used for class balancing |
275 | 278 |
|
276 | 279 | Returns: |
277 | 280 | DataLoader |
278 | 281 | """ |
| 282 | + |
279 | 283 | re_num_splits = 0 |
280 | 284 | if re_split: |
281 | 285 | # apply RE to second half of batch if no aug split otherwise line up with aug split |
@@ -329,7 +333,9 @@ def create_loader( |
329 | 333 | else: |
330 | 334 | assert num_aug_repeats == 0, "RepeatAugment not currently supported in non-distributed or IterableDataset use" |
331 | 335 | if balance_classes: |
332 | | - all_labels = [c for (_, c) in dataset] |
| 336 | + assert dataset_csv_path, "Provide csv with labels to use balance_classes." |
| 337 | + dataset_csv = pd.read_csv(dataset_csv_path) |
| 338 | + all_labels = dataset_csv["label"].values |
333 | 339 | unique, counts = np.unique(all_labels, return_counts=True) |
334 | 340 | unique_counts = {v: c for v, c in zip(unique, counts)} |
335 | 341 | label_weights = np.array([1 / unique_counts[num] for num in all_labels]) |
|
0 commit comments