Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions torchrec/metrics/cpu_comms_metric_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,6 @@ def _load_metric_states(
Uses aggregated states.
"""

# All update() calls were done prior. Clear previous computed state.
# Otherwise, we get warnings that compute() was called before
# update() which is not the case.
computation = cast(RecMetricComputation, computation)
set_update_called(computation)
computation._computed = None
Expand Down Expand Up @@ -157,8 +154,9 @@ def _clone_rec_metrics(self) -> RecMetricList:

def set_update_called(computation: RecMetricComputation) -> None:
"""
Set _update_called to True for RecMetricComputation.
This is a workaround for torchmetrics 1.0.3+.
All update() calls were done prior. Clear previous computed state.
Otherwise, we get warnings that compute() was called before
update() which is not the case.
"""
try:
computation._update_called = True
Expand Down
96 changes: 57 additions & 39 deletions torchrec/metrics/cpu_offloaded_metric_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
MetricUpdateJob,
SynchronizationMarker,
)
from torchrec.metrics.metric_module import MetricValue, RecMetricModule
from torchrec.metrics.metric_module import MetricsFuture, MetricsResult, RecMetricModule
from torchrec.metrics.metric_state_snapshot import MetricStateSnapshot
from torchrec.metrics.model_utils import parse_task_model_outputs
from torchrec.metrics.rec_metric import RecMetricException
Expand Down Expand Up @@ -62,19 +62,27 @@ class CPUOffloadedRecMetricModule(RecMetricModule):

def __init__(
self,
device: torch.device,
update_queue_size: int = 100,
compute_queue_size: int = 100,
*args: Any,
**kwargs: Any,
) -> None:
"""
Args:
All arguments are the same as RecMetricModule except for
- update_queue_size: Maximum size of the update queue. Default is 100.
- compute_queue_size: Maximum size of the update queue. Default is 100.
batch_size: batch size used by this trainer.
world_size: the number of trainers.
device: the device where the model is located (used to determine whether to perform GPU to CPU transfers).
update_queue_size: Maximum size of the update queue. Default is 100.
compute_queue_size: Maximum size of the update queue. Default is 100.
*args: Additional positional arguments passed to RecMetricModule.
**kwargs: Additional keyword arguments passed to RecMetricModule.
"""
super().__init__(*args, **kwargs)
self._shutdown_event = threading.Event()
self._device = device
self._shutdown_event: threading.Event = threading.Event()
self._captured_exception_event: threading.Event = threading.Event()
self._captured_exception: Optional[Exception] = None

self.update_queue: queue.Queue[
Union[MetricUpdateJob, SynchronizationMarker]
Expand Down Expand Up @@ -132,8 +140,16 @@ def _update_rec_metrics(
if self._shutdown_event.is_set():
raise RecMetricException("metric processor thread is shut down.")

if self._captured_exception_event.is_set():
assert self._captured_exception is not None
raise self._captured_exception

try:
cpu_model_out, transfer_completed_event = self._transfer_to_cpu(model_out)
cpu_model_out, transfer_completed_event = (
self._transfer_to_cpu(model_out)
if self._device == torch.device("cuda")
else (model_out, None)
)
self.update_queue.put_nowait(
MetricUpdateJob(
model_out=cpu_model_out,
Expand Down Expand Up @@ -191,31 +207,25 @@ def _process_metric_update_job(self, metric_update_job: MetricUpdateJob) -> None
"""

with record_function("## CPUOffloadedRecMetricModule:update ##"):
try:
if metric_update_job.transfer_completed_event is not None:
metric_update_job.transfer_completed_event.synchronize()
labels, predictions, weights, required_inputs = (
parse_task_model_outputs(
self.rec_tasks,
metric_update_job.model_out,
self.get_required_inputs(),
)
)
if required_inputs:
metric_update_job.kwargs["required_inputs"] = required_inputs

self.rec_metrics.update(
predictions=predictions,
labels=labels,
weights=weights,
**metric_update_job.kwargs,
)

if self.throughput_metric:
self.throughput_metric.update()
labels, predictions, weights, required_inputs = parse_task_model_outputs(
self.rec_tasks,
metric_update_job.model_out,
self.get_required_inputs(),
)
if required_inputs:
metric_update_job.kwargs["required_inputs"] = required_inputs

self.rec_metrics.update(
predictions=predictions,
labels=labels,
weights=weights,
**metric_update_job.kwargs,
)

except Exception as e:
logger.exception("Error processing metric update: %s", e)
raise e
if self.throughput_metric:
self.throughput_metric.update()

@override
def shutdown(self) -> None:
Expand Down Expand Up @@ -248,30 +258,34 @@ def shutdown(self) -> None:
logger.info("CPUOffloadedRecMetricModule has been successfully shutdown.")

@override
def compute(self) -> Dict[str, MetricValue]:
def compute(self) -> MetricsResult:
raise RecMetricException(
"compute() is not supported in CPUOffloadedRecMetricModule. Use async_compute() instead."
"CPUOffloadedRecMetricModule does not support compute(). Use async_compute() instead."
)

@override
def async_compute(
self, future: concurrent.futures.Future[Dict[str, MetricValue]]
) -> None:
def async_compute(self) -> MetricsFuture:
"""
Entry point for asynchronous metric compute. It enqueues a synchronization marker
to the update queue.

Args:
Returns:
future: Pre-created future where the computed metrics will be set.
"""
metrics_future = concurrent.futures.Future()
if self._shutdown_event.is_set():
future.set_exception(
metrics_future.set_exception(
RecMetricException("metric processor thread is shut down.")
)
return
return metrics_future

if self._captured_exception_event.is_set():
assert self._captured_exception is not None
raise self._captured_exception

self.update_queue.put_nowait(SynchronizationMarker(future))
self.update_queue.put_nowait(SynchronizationMarker(metrics_future))
self.update_queue_size_logger.add(self.update_queue.qsize())
return metrics_future

def _process_synchronization_marker(
self, synchronization_marker: SynchronizationMarker
Expand Down Expand Up @@ -304,7 +318,7 @@ def _process_synchronization_marker(

def _process_metric_compute_job(
self, metric_compute_job: MetricComputeJob
) -> Dict[str, MetricValue]:
) -> MetricsResult:
"""
Process a metric compute job:
1. Comms module performs all gather
Expand Down Expand Up @@ -355,6 +369,8 @@ def _update_loop(self) -> None:
self._do_work(self.update_queue)
except Exception as e:
logger.exception(f"Exception in update loop: {e}")
self._captured_exception_event.set()
self._captured_exception = e
raise e

remaining = self._flush_remaining_work(self.update_queue)
Expand All @@ -372,6 +388,8 @@ def _compute_loop(self) -> None:
self._do_work(self.compute_queue)
except Exception as e:
logger.exception(f"Exception in compute loop: {e}")
self._captured_exception_event.set()
self._captured_exception = e
raise e

remaining = self._flush_remaining_work(self.compute_queue)
Expand Down
8 changes: 5 additions & 3 deletions torchrec/metrics/metric_job_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
# pyre-strict

import concurrent
from typing import Any, Dict
from typing import Any, Dict, Optional

import torch
from torchrec.metrics.metric_module import MetricValue
Expand All @@ -26,7 +26,7 @@ class MetricUpdateJob:
def __init__(
self,
model_out: Dict[str, torch.Tensor],
transfer_completed_event: torch.cuda.Event,
transfer_completed_event: Optional[torch.cuda.Event],
kwargs: Dict[str, Any],
) -> None:
"""
Expand All @@ -37,7 +37,9 @@ def __init__(
"""

self.model_out: Dict[str, torch.Tensor] = model_out
self.transfer_completed_event: torch.cuda.Event = transfer_completed_event
self.transfer_completed_event: Optional[torch.cuda.Event] = (
transfer_completed_event
)
self.kwargs: Dict[str, Any] = kwargs


Expand Down
17 changes: 9 additions & 8 deletions torchrec/metrics/metric_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,9 @@


MetricValue = Union[torch.Tensor, float]
MetricsResult = Dict[str, MetricValue]
MetricsFuture = concurrent.futures.Future[MetricsResult]
MetricsOutput = Union[MetricsResult, MetricsFuture]


class StateMetric(abc.ABC):
Expand All @@ -125,7 +128,7 @@ class StateMetric(abc.ABC):
"""

@abc.abstractmethod
def get_metrics(self) -> Dict[str, MetricValue]:
def get_metrics(self) -> MetricsResult:
pass


Expand Down Expand Up @@ -335,12 +338,12 @@ def _adjust_compute_interval(self) -> None:
def should_compute(self) -> bool:
return self.trained_batches % self.compute_interval_steps == 0

def compute(self) -> Dict[str, MetricValue]:
def compute(self) -> MetricsResult:
r"""compute() is called when the global metrics are required, usually
right before logging the metrics results to the data sink.
"""
self.compute_count += 1
ret: Dict[str, MetricValue] = {}
ret: MetricsResult = {}
with record_function("## RecMetricModule:compute ##"):
if self.rec_metrics:
self._adjust_compute_interval()
Expand All @@ -357,11 +360,11 @@ def compute(self) -> Dict[str, MetricValue]:
)
return ret

def local_compute(self) -> Dict[str, MetricValue]:
def local_compute(self) -> MetricsResult:
r"""local_compute() is called when per-trainer metrics are required. It's
can be used for debugging. Currently only rec_metrics is supported.
"""
ret: Dict[str, MetricValue] = {}
ret: MetricsResult = {}
if self.rec_metrics:
ret.update(self.rec_metrics.local_compute())
return ret
Expand Down Expand Up @@ -512,9 +515,7 @@ def load_pre_compute_states(
def shutdown(self) -> None:
logger.info("Initiating graceful shutdown...")

def async_compute(
self, future: concurrent.futures.Future[Dict[str, MetricValue]]
) -> None:
def async_compute(self) -> MetricsFuture:
raise RecMetricException("async_compute is not supported in RecMetricModule")


Expand Down
99 changes: 99 additions & 0 deletions torchrec/metrics/metrics_output_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

"""
Utility functions for handling MetricsOutput (Union[MetricsResult, MetricsFuture]) from
- RecMetricModule.compute()
- CPUOffloadedRecMetricModule.async_compute()
"""

import concurrent
import logging
from typing import Callable, TypeVar

from torchrec.metrics.metric_module import MetricsFuture, MetricsOutput, MetricsResult

logger: logging.Logger = logging.getLogger(__name__)

T = TypeVar("T")


def get_metrics_async(
metrics_output: MetricsOutput,
callback: Callable[[MetricsResult], T],
*,
on_error: Callable[[Exception], None] | None = None,
) -> T | None:
"""
Register a callback to execute when metrics are ready.

Preserves CPUOffloadedRecMetricModule's async benefits by executing callbacks when Future resolves,
without blocking the critical training path.

Args:
metrics_output: Either metrics dict (sync from RecMetricModule) or Future (async from CPUOffloadedRecMetricModule)
callback: Function to execute with resolved metrics
on_error: Optional error handler for exceptions

Returns:
Result of callback if metrics are immediately available (Dict[str, MetricValue]),
None if async (Future) - callback will be invoked later
"""

# Asynchronous path
if isinstance(metrics_output, concurrent.futures.Future):

def on_complete(future: MetricsFuture) -> None:
try:
result = future.result()
callback(result)
except Exception as e:
if on_error:
on_error(e)
else:
logger.exception("Error in metrics callback")
raise

metrics_output.add_done_callback(on_complete)
return None
else:
# Synchronous path
return callback(metrics_output)


def get_metrics_sync(
metrics_output: MetricsOutput,
timeout: float | None = None,
) -> MetricsResult:
"""
Synchronously resolve MetricsOutput to MetricsResult.

Use this when you need the actual metrics dict immediately (e.g., to modify it).
For async handling, use get_metrics_async() instead.

Args:
metrics_output: Either metrics dict (sync) or Future (async)
timeout: Optional timeout in seconds for Future resolution

Returns:
Resolved metrics dict

Raises:
TimeoutError: If Future doesn't resolve within timeout (if specified)
Exception: Any exception from Future computation

Example:
>>> metrics_output = self.metrics.compute()
>>> metrics_result = resolve_metrics(metrics_output) # wait until metrics are ready
>>> publish_metrics(metrics_result)
"""
if isinstance(metrics_output, concurrent.futures.Future):
return metrics_output.result(timeout=timeout)
else:
return metrics_output
Loading
Loading