diff --git a/plexe/internal/models/entities/metric.py b/plexe/internal/models/entities/metric.py index f535598..ce313f0 100644 --- a/plexe/internal/models/entities/metric.py +++ b/plexe/internal/models/entities/metric.py @@ -22,6 +22,8 @@ from enum import Enum from functools import total_ordering +from typing import Optional +from weakref import WeakValueDictionary class ComparisonMethod(Enum): @@ -98,32 +100,140 @@ def compare(self, value1: float, value2: float) -> int: raise ValueError("Invalid comparison method.") -# todo: this class is a mess as it mixes concerns of a metric and a metric value; needs refactoring +# Internal cache for sharing MetricComparator instances across all metrics +# This ensures only one comparator object exists per unique (method, target, epsilon) combination +_comparator_cache: WeakValueDictionary = WeakValueDictionary() + + +def _get_shared_comparator( + comparison_method: ComparisonMethod, target: Optional[float] = None, epsilon: float = 1e-9 +) -> MetricComparator: + """ + Get or create a shared MetricComparator instance. + + This function ensures that identical comparators are reused across all Metric instances, + reducing memory usage and ensuring consistency. + + :param comparison_method: The comparison method. + :param target: Optional target value for TARGET_IS_BETTER. + :param epsilon: Tolerance for floating-point comparisons. + :return: A shared MetricComparator instance. + """ + # Create a cache key from the comparator parameters + cache_key = (comparison_method, target, epsilon) + + # Try to get existing comparator from cache + if cache_key in _comparator_cache: + return _comparator_cache[cache_key] + + # Create new comparator and cache it + comparator = MetricComparator(comparison_method, target, epsilon) + _comparator_cache[cache_key] = comparator + return comparator + + +class _MetricDefinition: + """ + Internal class representing a metric type definition. + + This separates the metric definition (what it is) from the metric value (a measurement). + Metric definitions are immutable and can be shared across multiple metric values. + + This is an internal implementation detail - users should not interact with this class directly. + """ + + def __init__(self, name: str, comparator: MetricComparator): + """ + Initialize a metric definition. + + :param name: The name of the metric. + :param comparator: The shared comparator instance. + """ + self._name = name + self._comparator = comparator + + @property + def name(self) -> str: + """The name of the metric.""" + return self._name + + @property + def comparator(self) -> MetricComparator: + """The shared comparator instance.""" + return self._comparator + + def __eq__(self, other) -> bool: + """Check if two metric definitions are equal.""" + if not isinstance(other, _MetricDefinition): + return False + return ( + self.name == other.name + and self.comparator.comparison_method == other.comparator.comparison_method + and self.comparator.target == other.comparator.target + and self.comparator.epsilon == other.comparator.epsilon + ) + + def __hash__(self) -> int: + """Hash the metric definition.""" + return hash((self.name, self.comparator.comparison_method, self.comparator.target, self.comparator.epsilon)) + + @total_ordering class Metric: """ Represents a metric with a name, a value, and a comparator for determining which metric is better. + This class internally separates the metric definition (type) from the metric value (measurement), + and automatically shares comparator instances to reduce memory usage. + Attributes: name (str): The name of the metric (e.g., 'accuracy', 'loss'). value (float): The numeric value of the metric. - comparator (MetricComparator): The comparison logic for the metric. + comparator (MetricComparator): The comparison logic for the metric (shared instance). """ def __init__(self, name: str, value: float = None, comparator: MetricComparator = None, is_worst: bool = False): """ Initializes a Metric object. + The comparator instance is automatically shared with other metrics that have the same + comparison method, target, and epsilon values, reducing memory usage. + :param name: The name of the metric. :param value: The numeric value of the metric. :param comparator: An instance of MetricComparator for comparison logic. :param is_worst: Indicates if the metric value is the worst possible value. """ - self.name = name + # Store the metric value (dynamic, instance-specific) self.value = value - self.comparator = comparator self.is_worst = is_worst or value is None + # Get or create a shared comparator instance + if comparator is not None: + # Use the shared comparator cache to ensure we reuse identical comparators + # This is the key optimization: identical comparators are shared across all metrics + shared_comparator = _get_shared_comparator( + comparison_method=comparator.comparison_method, target=comparator.target, epsilon=comparator.epsilon + ) + else: + # If no comparator provided, raise an error as it's required for a valid metric + # This maintains the same behavior as before + raise ValueError("Metric requires a comparator. Provide a MetricComparator instance.") + + # Create internal metric definition (separates type from value) + # This is the key separation: definition (what it is) vs value (measurement) + self._definition = _MetricDefinition(name=name, comparator=shared_comparator) + + @property + def name(self) -> str: + """The name of the metric (for backward compatibility).""" + return self._definition.name + + @property + def comparator(self) -> MetricComparator: + """The shared comparator instance (for backward compatibility).""" + return self._definition.comparator + def __gt__(self, other) -> bool: """ Determine if this metric is better than another metric. @@ -135,23 +245,24 @@ def __gt__(self, other) -> bool: if not isinstance(other, Metric): return NotImplemented - if self.is_worst or (self.is_worst and other.is_worst): + if self.is_worst: return False if other.is_worst: return True - if self.name != other.name: - raise ValueError("Cannot compare metrics with different names.") - - if self.comparator.comparison_method != other.comparator.comparison_method: - raise ValueError("Cannot compare metrics with different comparison methods.") - - if ( - self.comparator.comparison_method == ComparisonMethod.TARGET_IS_BETTER - and self.comparator.target != other.comparator.target - ): - raise ValueError("Cannot compare 'TARGET_IS_BETTER' metrics with different target values.") + # Compare using definitions - this is cleaner and ensures consistency + if self._definition != other._definition: + # Provide detailed error message for backward compatibility + if self.name != other.name: + raise ValueError("Cannot compare metrics with different names.") + if self.comparator.comparison_method != other.comparator.comparison_method: + raise ValueError("Cannot compare metrics with different comparison methods.") + if ( + self.comparator.comparison_method == ComparisonMethod.TARGET_IS_BETTER + and self.comparator.target != other.comparator.target + ): + raise ValueError("Cannot compare 'TARGET_IS_BETTER' metrics with different target values.") return self.comparator.compare(self.value, other.value) < 0 @@ -171,11 +282,8 @@ def __eq__(self, other) -> bool: if self.is_worst or other.is_worst: return False - return ( - self.name == other.name - and self.comparator.comparison_method == other.comparator.comparison_method - and self.comparator.compare(self.value, other.value) == 0 - ) + # Use definition equality for cleaner comparison + return self._definition == other._definition and self.comparator.compare(self.value, other.value) == 0 def __repr__(self) -> str: """