Skip to content

Commit cae4dd8

Browse files
committed
Merge upstream main into issue-1927-modernize-entrypoints
2 parents bfe47dc + 0f346cf commit cae4dd8

File tree

1 file changed

+21
-16
lines changed

1 file changed

+21
-16
lines changed

src/llmcompressor/datasets/utils.py

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
import multiprocessing
1111
import re
12-
from typing import Any, Callable, Dict, List, Optional
12+
from typing import Any, Callable
1313

1414
import torch
1515
from datasets import Dataset
@@ -24,10 +24,10 @@
2424

2525
def get_processed_dataset(
2626
dataset_args: DatasetArguments,
27-
processor: Optional[Processor] = None,
27+
processor: Processor | None = None,
2828
do_oneshot: bool = False,
2929
do_train: bool = True,
30-
) -> Optional[Dict[str, Dataset]]:
30+
) -> dict[str, Dataset] | None:
3131
"""
3232
Loads datasets for each flow based on dataset_args, stores a Dataset for each
3333
enabled flow in datasets
@@ -50,17 +50,22 @@ def get_processed_dataset(
5050

5151
def _get_split_name(inp_str):
5252
# strip out split name, for ex train[60%:] -> train
53-
match = re.match(r"(\w*)\[.*\]", inp_str)
54-
if match is not None:
55-
return match.group(1)
53+
split_name_match = re.match(r"(\w*)\[.*\]", inp_str)
54+
if split_name_match is not None:
55+
return split_name_match.group(1)
5656
return inp_str
5757

58-
if splits is None:
59-
splits = {"all": None}
60-
elif isinstance(splits, str):
61-
splits = {_get_split_name(splits): splits}
62-
elif isinstance(splits, List):
63-
splits = {_get_split_name(s): s for s in splits}
58+
match splits:
59+
case None:
60+
splits = {"all": None}
61+
case str():
62+
splits = {_get_split_name(splits): splits}
63+
case list():
64+
splits = {_get_split_name(s): s for s in splits}
65+
case dict():
66+
pass
67+
case _:
68+
raise ValueError(f"Invalid splits type: {type(splits)}")
6469

6570
# default to custom dataset if dataset provided isn't a string
6671
registry_id = (
@@ -121,10 +126,10 @@ def get_calibration_dataloader(
121126

122127
def format_calibration_data(
123128
tokenized_dataset: Dataset,
124-
num_calibration_samples: Optional[int] = None,
129+
num_calibration_samples: int | None = None,
125130
do_shuffle: bool = True,
126131
collate_fn: Callable = default_data_collator,
127-
) -> List[torch.Tensor]:
132+
) -> list[torch.Tensor]:
128133
"""
129134
Creates a dataloader out of the calibration dataset split, trimming it to
130135
the desired number of calibration samples
@@ -172,10 +177,10 @@ def format_calibration_data(
172177

173178

174179
def make_dataset_splits(
175-
tokenized_datasets: Dict[str, Any],
180+
tokenized_datasets: dict[str, Any],
176181
do_oneshot: bool = True,
177182
do_train: bool = False,
178-
) -> Dict[str, Dataset]:
183+
) -> dict[str, Dataset]:
179184
"""
180185
Restructures the datasets dictionary based on what tasks will be run
181186
train

0 commit comments

Comments
 (0)