Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/pruna/evaluation/metrics/metric_dino_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def __init__(self, device: str | torch.device | None = None, call_type: str = SI
self.device = set_to_best_available_device(device)
if device is not None and not any(self.device.startswith(prefix) for prefix in self.runs_on):
pruna_logger.error(f"DinoScore: device {device} not supported. Supported devices: {self.runs_on}")
raise
raise ValueError(f"DinoScore: device {device} not supported. Supported devices: {self.runs_on}")
self.call_type = get_call_type_for_single_metric(call_type, self.default_call_type)
# Load the DINO ViT-S/16 model once
self.model = timm.create_model("vit_small_patch16_224.dino", pretrained=True)
Expand Down
2 changes: 1 addition & 1 deletion src/pruna/evaluation/metrics/metric_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def compute(self, model: PrunaModel, dataloader: DataLoader) -> MetricResult:

if len(model_device_indices) > 1 and self.device_map is None:
pruna_logger.error("Multiple GPUs detected, but no device map found. Please check the model configuration.")
raise
raise RuntimeError("Multiple GPUs detected, but no device map found. Please check the model configuration.")
else:
move_to_device(model, "cuda")
return MetricResult(self.mode, self.__dict__.copy(), peak_memory)
Expand Down
4 changes: 2 additions & 2 deletions src/pruna/evaluation/metrics/metric_sharpness.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def update(self, x: List[Any] | Tensor, gt: List[Any] | Tensor, outputs: Any) ->

if images.ndim != 4:
pruna_logger.error(f"Expected 4‑D tensor (B, C, H, W); got shape {tuple(images.shape)}")
raise
raise ValueError(f"Expected 4-D tensor (B, C, H, W); got shape {tuple(images.shape)}")

# Move to CPU - OpenCV only works on numpy
imgs = images.detach().cpu()
Expand Down Expand Up @@ -134,7 +134,7 @@ def update(self, x: List[Any] | Tensor, gt: List[Any] | Tensor, outputs: Any) ->

else:
pruna_logger.error("SharpnessMetric: unsupported channel count")
raise
raise ValueError(f"SharpnessMetric: unsupported channel count {img.shape[0]}. Expected 1 or 3 channels.")

self.scores.append(sharpness_score)

Expand Down
2 changes: 1 addition & 1 deletion src/pruna/evaluation/metrics/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def get_metric(cls, name: str, **kwargs) -> BaseMetric | StatefulMetric:
elif isinstance(metric_cls, partial): # For the mock tests
return metric_cls(**kwargs)
else:
raise ValueError(f"Metric '{metric_cls}' dos not inherit from a valid metric class.")
raise ValueError(f"Metric '{metric_cls}' does not inherit from a valid metric class.")

@classmethod
def get_metrics(cls, names: List[str], **kwargs) -> List[BaseMetric | StatefulMetric]:
Expand Down
10 changes: 8 additions & 2 deletions src/pruna/evaluation/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from pruna.evaluation.metrics.utils import get_hyperparameters
from pruna.logging.logger import pruna_logger

AVAILABLE_REQUESTS = ("image_generation_quality",)
AVAILABLE_REQUESTS = ("image_generation_quality", "text_generation_quality")
PARENT_METRICS = (
"ModelArchitectureStats",
"InferenceTimeStats",
Expand Down Expand Up @@ -164,7 +164,8 @@ def get_metrics(
Parameters
----------
request : str | List[str]
The user request. Right now, it only supports image generation quality.
The user request. Supports named requests like 'image_generation_quality'
and 'text_generation_quality', or a list of metric names.
inference_device : str | torch.device | None, optional
The device to be used for inference, e.g., 'cuda' or 'cpu'. Default is None.
If None, the best available device will be used.
Expand Down Expand Up @@ -255,6 +256,11 @@ def _process_single_request(
TorchMetricWrapper("clip_score", call_type="pairwise", device=stateful_metric_device),
CMMD(device=stateful_metric_device),
]
elif request == "text_generation_quality":
pruna_logger.info("An evaluation task for text generation quality is being created.")
return [
TorchMetricWrapper("perplexity", device=stateful_metric_device),
]
else:
pruna_logger.error(f"Metric {request} not found. Available requests: {AVAILABLE_REQUESTS}.")
raise ValueError(f"Metric {request} not found. Available requests: {AVAILABLE_REQUESTS}.")
16 changes: 16 additions & 0 deletions tests/evaluation/test_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,3 +107,19 @@ def test_task_from_string_request():
assert isinstance(task.metrics[0], CMMD)
assert isinstance(task.metrics[1], PairwiseClipScore)
assert isinstance(task.metrics[2], TorchMetricWrapper)


@pytest.mark.cpu
def test_task_text_generation_quality_request():
"""Test that 'text_generation_quality' named request creates perplexity metric."""
task = Task(request="text_generation_quality", datamodule=PrunaDataModule.from_string("TinyWikiText"), device="cpu")
assert len(task.metrics) == 1
assert isinstance(task.metrics[0], TorchMetricWrapper)
assert task.metrics[0].metric_name == "perplexity"


@pytest.mark.cpu
def test_task_invalid_named_request():
"""Test that an invalid named request raises a ValueError."""
with pytest.raises(ValueError, match="not found"):
Task(request="nonexistent_quality", datamodule=PrunaDataModule.from_string("LAION256"), device="cpu")