Skip to content

Commit 83ebf56

Browse files
committed
Refactor oneshot function parameters to use Optional types and enhance documentation
1 parent 99474d7 commit 83ebf56

File tree

2 files changed

+30
-19
lines changed

2 files changed

+30
-19
lines changed

src/llmcompressor/entrypoints/oneshot.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -243,8 +243,8 @@ def oneshot(
243243
min_tokens_per_module: Optional[float] = None,
244244
calibrate_moe_context: bool = False,
245245
pipeline: str = "independent",
246-
tracing_ignore: List[str] = None,
247-
raw_kwargs: Dict[str, Any] = None,
246+
tracing_ignore: Optional[List[str]] = None,
247+
raw_kwargs: Optional[Dict[str, Any]] = None,
248248
preprocessing_func: Optional[Callable] = None,
249249
max_train_samples: Optional[int] = None,
250250
remove_columns: Optional[List[str]] = None,
@@ -320,6 +320,16 @@ def oneshot(
320320
during forward pass in calibration. When False, quantization is disabled
321321
during forward pass in calibration. Default is set to True.
322322
323+
:param pipeline: The pipeline configuration to use for calibration. Options include
324+
'independent', 'sequential', or 'layer_sequential'.
325+
:param tracing_ignore: List of module names to ignore during tracing.
326+
:param raw_kwargs: Dictionary of raw keyword arguments passed to the function.
327+
:param preprocessing_func: Optional callable for preprocessing the dataset.
328+
:param max_train_samples: Maximum number of training samples to use.
329+
:param remove_columns: List of column names to remove from the dataset.
330+
:param dvc_data_repository: Path to the DVC data repository, if applicable.
331+
:param sequential_targets: List of sequential targets for calibration.
332+
323333
# Miscellaneous arguments
324334
:param output_dir: Path to save the output model after calibration.
325335
Nothing is saved if None.
@@ -333,11 +343,17 @@ def oneshot(
333343
raise ValueError(
334344
"Invalid configuration: "
335345
"sequential_targets' cannot be used with 'independent' pipeline. "
336-
"Please use 'sequential' or 'layer_sequential' pipeline when specifying"
346+
"Please use 'sequential' or 'layer_sequential' pipeline when specifying "
337347
"sequential_targets."
338348
)
339349

340350
# pass all args directly into Oneshot
351+
if raw_kwargs is None:
352+
raw_kwargs = {}
353+
354+
local_args = {
355+
k: v for k, v in locals().items() if k not in ("local_args", "kwargs")
356+
}
341357
local_args = {
342358
k: v for k, v in locals().items() if k not in ("local_args", "kwargs")
343359
}

tests/llmcompressor/transformers/oneshot/test_api_inputs.py

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
import pytest
22
from transformers import AutoModelForCausalLM, AutoTokenizer
3+
import os
4+
import logging
5+
6+
logging.basicConfig(level=logging.INFO)
7+
logger = logging.getLogger(__name__)
38

49
from llmcompressor import oneshot
510
from tests.llmcompressor.transformers.oneshot.dataset_processing import get_data_utils
@@ -46,7 +51,7 @@ def wrapped_preprocess_func(sample):
4651
args["sequential_targets"] = config.get("sequential_targets", None)
4752
args["tracing_ignore"] = config.get("tracing_ignore", [])
4853
args["raw_kwargs"] = config.get("raw_kwargs", {})
49-
args["preprocessing_func"] = (config.get("preprocessing_func", lambda x: x),)
54+
args["preprocessing_func"] = config.get("preprocessing_func", lambda x: x)
5055
args["max_train_samples"] = config.get("max_train_samples", 50)
5156
args["remove_columns"] = config.get("remove_columns", None)
5257
args["dvc_data_repository"] = config.get("dvc_data_repository", None)
@@ -59,10 +64,10 @@ def wrapped_preprocess_func(sample):
5964
@pytest.mark.smoke
6065
@pytest.mark.integration
6166
def test_one_shot_inputs(one_shot_args, tmp_path):
62-
print(f"Dataset type: {type(one_shot_args.get('dataset'))}")
67+
logger.info(f"Dataset type: {type(one_shot_args.get('dataset'))}")
6368
if isinstance(one_shot_args.get("dataset"), str):
64-
print(f"Dataset name: {one_shot_args.get('dataset')}")
65-
print(f"Dataset config: {one_shot_args.get('dataset_config_name')}")
69+
logger.info(f"Dataset name: {one_shot_args.get('dataset')}")
70+
logger.info(f"Dataset config: {one_shot_args.get('dataset_config_name')}")
6671
try:
6772
# Call oneshot with all parameters as flat arguments
6873
oneshot(
@@ -76,18 +81,8 @@ def test_one_shot_inputs(one_shot_args, tmp_path):
7681
if "num_samples should be a positive integer value" in str(
7782
e
7883
) or "Dataset is empty. Cannot create a calibration dataloader" in str(e):
79-
print(f"Dataset is empty: {one_shot_args.get('dataset')}")
84+
logger.warning(f"Dataset is empty: {one_shot_args.get('dataset')}")
8085
pytest.skip(f"Dataset is empty: {one_shot_args.get('dataset')}")
8186
else:
8287
raise # Re-raise other ValueError exceptions
83-
finally:
84-
# Clean up temporary files to avoid the "megabytes of temp files" error
85-
import os
86-
87-
# Clean up the output directory
88-
if os.path.exists(tmp_path):
89-
print(f"Cleaning up temp directory: {tmp_path}")
90-
# Remove files but keep the directory structure
91-
for root, dirs, files in os.walk(tmp_path):
92-
for file in files:
93-
os.remove(os.path.join(root, file))
88+

0 commit comments

Comments
 (0)