Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 129 additions & 21 deletions plexe/internal/models/entities/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@

from enum import Enum
from functools import total_ordering
from typing import Optional
from weakref import WeakValueDictionary


class ComparisonMethod(Enum):
Expand Down Expand Up @@ -98,32 +100,140 @@ def compare(self, value1: float, value2: float) -> int:
raise ValueError("Invalid comparison method.")


# todo: this class is a mess as it mixes concerns of a metric and a metric value; needs refactoring
# Internal cache for sharing MetricComparator instances across all metrics
# This ensures only one comparator object exists per unique (method, target, epsilon) combination
_comparator_cache: WeakValueDictionary = WeakValueDictionary()


def _get_shared_comparator(
comparison_method: ComparisonMethod, target: Optional[float] = None, epsilon: float = 1e-9
) -> MetricComparator:
"""
Get or create a shared MetricComparator instance.

This function ensures that identical comparators are reused across all Metric instances,
reducing memory usage and ensuring consistency.

:param comparison_method: The comparison method.
:param target: Optional target value for TARGET_IS_BETTER.
:param epsilon: Tolerance for floating-point comparisons.
:return: A shared MetricComparator instance.
"""
# Create a cache key from the comparator parameters
cache_key = (comparison_method, target, epsilon)

# Try to get existing comparator from cache
if cache_key in _comparator_cache:
return _comparator_cache[cache_key]

# Create new comparator and cache it
comparator = MetricComparator(comparison_method, target, epsilon)
_comparator_cache[cache_key] = comparator
return comparator


class _MetricDefinition:
"""
Internal class representing a metric type definition.

This separates the metric definition (what it is) from the metric value (a measurement).
Metric definitions are immutable and can be shared across multiple metric values.

This is an internal implementation detail - users should not interact with this class directly.
"""

def __init__(self, name: str, comparator: MetricComparator):
"""
Initialize a metric definition.

:param name: The name of the metric.
:param comparator: The shared comparator instance.
"""
self._name = name
self._comparator = comparator

@property
def name(self) -> str:
"""The name of the metric."""
return self._name

@property
def comparator(self) -> MetricComparator:
"""The shared comparator instance."""
return self._comparator

def __eq__(self, other) -> bool:
"""Check if two metric definitions are equal."""
if not isinstance(other, _MetricDefinition):
return False
return (
self.name == other.name
and self.comparator.comparison_method == other.comparator.comparison_method
and self.comparator.target == other.comparator.target
and self.comparator.epsilon == other.comparator.epsilon
)

def __hash__(self) -> int:
"""Hash the metric definition."""
return hash((self.name, self.comparator.comparison_method, self.comparator.target, self.comparator.epsilon))


@total_ordering
class Metric:
"""
Represents a metric with a name, a value, and a comparator for determining which metric is better.

This class internally separates the metric definition (type) from the metric value (measurement),
and automatically shares comparator instances to reduce memory usage.

Attributes:
name (str): The name of the metric (e.g., 'accuracy', 'loss').
value (float): The numeric value of the metric.
comparator (MetricComparator): The comparison logic for the metric.
comparator (MetricComparator): The comparison logic for the metric (shared instance).
"""

def __init__(self, name: str, value: float = None, comparator: MetricComparator = None, is_worst: bool = False):
"""
Initializes a Metric object.

The comparator instance is automatically shared with other metrics that have the same
comparison method, target, and epsilon values, reducing memory usage.

:param name: The name of the metric.
:param value: The numeric value of the metric.
:param comparator: An instance of MetricComparator for comparison logic.
:param is_worst: Indicates if the metric value is the worst possible value.
"""
self.name = name
# Store the metric value (dynamic, instance-specific)
self.value = value
self.comparator = comparator
self.is_worst = is_worst or value is None

# Get or create a shared comparator instance
if comparator is not None:
# Use the shared comparator cache to ensure we reuse identical comparators
# This is the key optimization: identical comparators are shared across all metrics
shared_comparator = _get_shared_comparator(
comparison_method=comparator.comparison_method, target=comparator.target, epsilon=comparator.epsilon
)
else:
# If no comparator provided, raise an error as it's required for a valid metric
# This maintains the same behavior as before
raise ValueError("Metric requires a comparator. Provide a MetricComparator instance.")

# Create internal metric definition (separates type from value)
# This is the key separation: definition (what it is) vs value (measurement)
self._definition = _MetricDefinition(name=name, comparator=shared_comparator)

@property
def name(self) -> str:
"""The name of the metric (for backward compatibility)."""
return self._definition.name

@property
def comparator(self) -> MetricComparator:
"""The shared comparator instance (for backward compatibility)."""
return self._definition.comparator

def __gt__(self, other) -> bool:
"""
Determine if this metric is better than another metric.
Expand All @@ -135,23 +245,24 @@ def __gt__(self, other) -> bool:
if not isinstance(other, Metric):
return NotImplemented

if self.is_worst or (self.is_worst and other.is_worst):
if self.is_worst:
return False

if other.is_worst:
return True

if self.name != other.name:
raise ValueError("Cannot compare metrics with different names.")

if self.comparator.comparison_method != other.comparator.comparison_method:
raise ValueError("Cannot compare metrics with different comparison methods.")

if (
self.comparator.comparison_method == ComparisonMethod.TARGET_IS_BETTER
and self.comparator.target != other.comparator.target
):
raise ValueError("Cannot compare 'TARGET_IS_BETTER' metrics with different target values.")
# Compare using definitions - this is cleaner and ensures consistency
if self._definition != other._definition:
# Provide detailed error message for backward compatibility
if self.name != other.name:
raise ValueError("Cannot compare metrics with different names.")
if self.comparator.comparison_method != other.comparator.comparison_method:
raise ValueError("Cannot compare metrics with different comparison methods.")
if (
self.comparator.comparison_method == ComparisonMethod.TARGET_IS_BETTER
and self.comparator.target != other.comparator.target
):
raise ValueError("Cannot compare 'TARGET_IS_BETTER' metrics with different target values.")

return self.comparator.compare(self.value, other.value) < 0

Expand All @@ -171,11 +282,8 @@ def __eq__(self, other) -> bool:
if self.is_worst or other.is_worst:
return False

return (
self.name == other.name
and self.comparator.comparison_method == other.comparator.comparison_method
and self.comparator.compare(self.value, other.value) == 0
)
# Use definition equality for cleaner comparison
return self._definition == other._definition and self.comparator.compare(self.value, other.value) == 0

def __repr__(self) -> str:
"""
Expand Down