Skip to content

Commit 1145efa

Browse files
authored
balance classes faster
from michalpiasecki0
2 parents a5b6728 + 24b6860 commit 1145efa

File tree

2 files changed

+8
-1
lines changed

2 files changed

+8
-1
lines changed

timm/data/loader.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import torch.utils.data
1818
import numpy as np
19+
import pandas as pd
1920

2021
from .constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
2122
from .dataset import IterableImageDataset, ImageDataset
@@ -228,6 +229,7 @@ def create_loader(
228229
worker_seeding: str = 'all',
229230
tf_preprocessing: bool = False,
230231
balance_classes: bool = False,
232+
dataset_csv_path: Optional[str] = None
231233
):
232234
"""
233235
@@ -272,10 +274,12 @@ def create_loader(
272274
worker_seeding: Control worker random seeding at init.
273275
tf_preprocessing: Use TF 1.0 inference preprocessing for testing model ports.
274276
balance_classes: Sample classes with uniform probability
277+
dataset_csv_path: Path to dataset csv, used for class balancing
275278
276279
Returns:
277280
DataLoader
278281
"""
282+
279283
re_num_splits = 0
280284
if re_split:
281285
# apply RE to second half of batch if no aug split otherwise line up with aug split
@@ -329,7 +333,9 @@ def create_loader(
329333
else:
330334
assert num_aug_repeats == 0, "RepeatAugment not currently supported in non-distributed or IterableDataset use"
331335
if balance_classes:
332-
all_labels = [c for (_, c) in dataset]
336+
assert dataset_csv_path, "Provide csv with labels to use balance_classes."
337+
dataset_csv = pd.read_csv(dataset_csv_path)
338+
all_labels = dataset_csv["label"].values
333339
unique, counts = np.unique(all_labels, return_counts=True)
334340
unique_counts = {v: c for v, c in zip(unique, counts)}
335341
label_weights = np.array([1 / unique_counts[num] for num in all_labels])

timm/train.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -755,6 +755,7 @@ def train(config: dict[str, t.Any]):
755755
use_multi_epochs_loader=args.use_multi_epochs_loader,
756756
worker_seeding=args.worker_seeding,
757757
balance_classes=args.balance_classes,
758+
samples_csv_path=args.train_samples_csv_path
758759
)
759760

760761
loader_eval = None

0 commit comments

Comments
 (0)