diff --git a/src/pruna/evaluation/metrics/metric_dino_score.py b/src/pruna/evaluation/metrics/metric_dino_score.py index 6ac5ed37..488b0a2c 100644 --- a/src/pruna/evaluation/metrics/metric_dino_score.py +++ b/src/pruna/evaluation/metrics/metric_dino_score.py @@ -69,7 +69,7 @@ def __init__(self, device: str | torch.device | None = None, call_type: str = SI self.device = set_to_best_available_device(device) if device is not None and not any(self.device.startswith(prefix) for prefix in self.runs_on): pruna_logger.error(f"DinoScore: device {device} not supported. Supported devices: {self.runs_on}") - raise + raise ValueError(f"DinoScore: device {device} not supported. Supported devices: {self.runs_on}") self.call_type = get_call_type_for_single_metric(call_type, self.default_call_type) # Load the DINO ViT-S/16 model once self.model = timm.create_model("vit_small_patch16_224.dino", pretrained=True) diff --git a/src/pruna/evaluation/metrics/metric_memory.py b/src/pruna/evaluation/metrics/metric_memory.py index e01494c8..5e77b8b5 100644 --- a/src/pruna/evaluation/metrics/metric_memory.py +++ b/src/pruna/evaluation/metrics/metric_memory.py @@ -189,7 +189,7 @@ def compute(self, model: PrunaModel, dataloader: DataLoader) -> MetricResult: if len(model_device_indices) > 1 and self.device_map is None: pruna_logger.error("Multiple GPUs detected, but no device map found. Please check the model configuration.") - raise + raise RuntimeError("Multiple GPUs detected, but no device map found. Please check the model configuration.") else: move_to_device(model, "cuda") return MetricResult(self.mode, self.__dict__.copy(), peak_memory) diff --git a/src/pruna/evaluation/metrics/metric_sharpness.py b/src/pruna/evaluation/metrics/metric_sharpness.py index b09067de..557066ba 100644 --- a/src/pruna/evaluation/metrics/metric_sharpness.py +++ b/src/pruna/evaluation/metrics/metric_sharpness.py @@ -99,7 +99,7 @@ def update(self, x: List[Any] | Tensor, gt: List[Any] | Tensor, outputs: Any) -> if images.ndim != 4: pruna_logger.error(f"Expected 4‑D tensor (B, C, H, W); got shape {tuple(images.shape)}") - raise + raise ValueError(f"Expected 4-D tensor (B, C, H, W); got shape {tuple(images.shape)}") # Move to CPU - OpenCV only works on numpy imgs = images.detach().cpu() @@ -134,7 +134,7 @@ def update(self, x: List[Any] | Tensor, gt: List[Any] | Tensor, outputs: Any) -> else: pruna_logger.error("SharpnessMetric: unsupported channel count") - raise + raise ValueError(f"SharpnessMetric: unsupported channel count {img.shape[0]}. Expected 1 or 3 channels.") self.scores.append(sharpness_score) diff --git a/src/pruna/evaluation/metrics/registry.py b/src/pruna/evaluation/metrics/registry.py index f2af4ea1..cbebeded 100644 --- a/src/pruna/evaluation/metrics/registry.py +++ b/src/pruna/evaluation/metrics/registry.py @@ -125,7 +125,7 @@ def get_metric(cls, name: str, **kwargs) -> BaseMetric | StatefulMetric: elif isinstance(metric_cls, partial): # For the mock tests return metric_cls(**kwargs) else: - raise ValueError(f"Metric '{metric_cls}' dos not inherit from a valid metric class.") + raise ValueError(f"Metric '{metric_cls}' does not inherit from a valid metric class.") @classmethod def get_metrics(cls, names: List[str], **kwargs) -> List[BaseMetric | StatefulMetric]: diff --git a/src/pruna/evaluation/task.py b/src/pruna/evaluation/task.py index e8c63688..1f3d8e78 100644 --- a/src/pruna/evaluation/task.py +++ b/src/pruna/evaluation/task.py @@ -28,7 +28,7 @@ from pruna.evaluation.metrics.utils import get_hyperparameters from pruna.logging.logger import pruna_logger -AVAILABLE_REQUESTS = ("image_generation_quality",) +AVAILABLE_REQUESTS = ("image_generation_quality", "text_generation_quality") PARENT_METRICS = ( "ModelArchitectureStats", "InferenceTimeStats", @@ -164,7 +164,8 @@ def get_metrics( Parameters ---------- request : str | List[str] - The user request. Right now, it only supports image generation quality. + The user request. Supports named requests like 'image_generation_quality' + and 'text_generation_quality', or a list of metric names. inference_device : str | torch.device | None, optional The device to be used for inference, e.g., 'cuda' or 'cpu'. Default is None. If None, the best available device will be used. @@ -255,6 +256,11 @@ def _process_single_request( TorchMetricWrapper("clip_score", call_type="pairwise", device=stateful_metric_device), CMMD(device=stateful_metric_device), ] + elif request == "text_generation_quality": + pruna_logger.info("An evaluation task for text generation quality is being created.") + return [ + TorchMetricWrapper("perplexity", device=stateful_metric_device), + ] else: pruna_logger.error(f"Metric {request} not found. Available requests: {AVAILABLE_REQUESTS}.") raise ValueError(f"Metric {request} not found. Available requests: {AVAILABLE_REQUESTS}.") diff --git a/tests/evaluation/test_task.py b/tests/evaluation/test_task.py index 5420a17b..18ad67e1 100644 --- a/tests/evaluation/test_task.py +++ b/tests/evaluation/test_task.py @@ -107,3 +107,19 @@ def test_task_from_string_request(): assert isinstance(task.metrics[0], CMMD) assert isinstance(task.metrics[1], PairwiseClipScore) assert isinstance(task.metrics[2], TorchMetricWrapper) + + +@pytest.mark.cpu +def test_task_text_generation_quality_request(): + """Test that 'text_generation_quality' named request creates perplexity metric.""" + task = Task(request="text_generation_quality", datamodule=PrunaDataModule.from_string("TinyWikiText"), device="cpu") + assert len(task.metrics) == 1 + assert isinstance(task.metrics[0], TorchMetricWrapper) + assert task.metrics[0].metric_name == "perplexity" + + +@pytest.mark.cpu +def test_task_invalid_named_request(): + """Test that an invalid named request raises a ValueError.""" + with pytest.raises(ValueError, match="not found"): + Task(request="nonexistent_quality", datamodule=PrunaDataModule.from_string("LAION256"), device="cpu")