11import pytest
22from transformers import AutoModelForCausalLM , AutoTokenizer
3+ import os
4+ import logging
5+
6+ logging .basicConfig (level = logging .INFO )
7+ logger = logging .getLogger (__name__ )
38
49from llmcompressor import oneshot
510from tests .llmcompressor .transformers .oneshot .dataset_processing import get_data_utils
@@ -46,7 +51,7 @@ def wrapped_preprocess_func(sample):
4651 args ["sequential_targets" ] = config .get ("sequential_targets" , None )
4752 args ["tracing_ignore" ] = config .get ("tracing_ignore" , [])
4853 args ["raw_kwargs" ] = config .get ("raw_kwargs" , {})
49- args ["preprocessing_func" ] = ( config .get ("preprocessing_func" , lambda x : x ), )
54+ args ["preprocessing_func" ] = config .get ("preprocessing_func" , lambda x : x )
5055 args ["max_train_samples" ] = config .get ("max_train_samples" , 50 )
5156 args ["remove_columns" ] = config .get ("remove_columns" , None )
5257 args ["dvc_data_repository" ] = config .get ("dvc_data_repository" , None )
@@ -59,10 +64,10 @@ def wrapped_preprocess_func(sample):
5964@pytest .mark .smoke
6065@pytest .mark .integration
6166def test_one_shot_inputs (one_shot_args , tmp_path ):
62- print (f"Dataset type: { type (one_shot_args .get ('dataset' ))} " )
67+ logger . info (f"Dataset type: { type (one_shot_args .get ('dataset' ))} " )
6368 if isinstance (one_shot_args .get ("dataset" ), str ):
64- print (f"Dataset name: { one_shot_args .get ('dataset' )} " )
65- print (f"Dataset config: { one_shot_args .get ('dataset_config_name' )} " )
69+ logger . info (f"Dataset name: { one_shot_args .get ('dataset' )} " )
70+ logger . info (f"Dataset config: { one_shot_args .get ('dataset_config_name' )} " )
6671 try :
6772 # Call oneshot with all parameters as flat arguments
6873 oneshot (
@@ -76,18 +81,8 @@ def test_one_shot_inputs(one_shot_args, tmp_path):
7681 if "num_samples should be a positive integer value" in str (
7782 e
7883 ) or "Dataset is empty. Cannot create a calibration dataloader" in str (e ):
79- print (f"Dataset is empty: { one_shot_args .get ('dataset' )} " )
84+ logger . warning (f"Dataset is empty: { one_shot_args .get ('dataset' )} " )
8085 pytest .skip (f"Dataset is empty: { one_shot_args .get ('dataset' )} " )
8186 else :
8287 raise # Re-raise other ValueError exceptions
83- finally :
84- # Clean up temporary files to avoid the "megabytes of temp files" error
85- import os
86-
87- # Clean up the output directory
88- if os .path .exists (tmp_path ):
89- print (f"Cleaning up temp directory: { tmp_path } " )
90- # Remove files but keep the directory structure
91- for root , dirs , files in os .walk (tmp_path ):
92- for file in files :
93- os .remove (os .path .join (root , file ))
88+
0 commit comments