diff --git a/src/torchmetrics/metric.py b/src/torchmetrics/metric.py index 6d64955198d..746fe8476cc 100644 --- a/src/torchmetrics/metric.py +++ b/src/torchmetrics/metric.py @@ -438,6 +438,9 @@ def _sync_dist(self, dist_sync_fn: Callable = gather_all_tensors, process_group: if reduction_fn == dim_zero_cat and isinstance(input_dict[attr], list) and len(input_dict[attr]) > 1: input_dict[attr] = [dim_zero_cat(input_dict[attr])] + if dist_sync_fn == gather_all_tensors: + dist_sync_fn = functools.partial(gather_all_tensors, device=self.device) + output_dict = apply_to_collection( input_dict, Tensor, diff --git a/src/torchmetrics/utilities/distributed.py b/src/torchmetrics/utilities/distributed.py index 455d64c4ae0..75994c55dc3 100644 --- a/src/torchmetrics/utilities/distributed.py +++ b/src/torchmetrics/utilities/distributed.py @@ -94,7 +94,9 @@ def _simple_gather_all_tensors(result: Tensor, group: Any, world_size: int) -> L return gathered_result -def gather_all_tensors(result: Tensor, group: Optional[Any] = None) -> List[Tensor]: +def gather_all_tensors( + result: Tensor, group: Optional[Any] = None, device: Optional[torch.device] = None +) -> List[Tensor]: """Gather all tensors from several ddp processes onto a list that is broadcasted to all processes. Works on tensors that have the same number of dimensions, but where each dimension may differ. In this case @@ -103,6 +105,7 @@ def gather_all_tensors(result: Tensor, group: Optional[Any] = None) -> List[Tens Args: result: the value to sync group: the process group to gather results from. Defaults to all processes (world) + device: optional device to move the result tensor to before gathering Return: list with size equal to the process group where element i corresponds to result tensor from process i @@ -117,6 +120,10 @@ def gather_all_tensors(result: Tensor, group: Optional[Any] = None) -> List[Tens world_size = torch.distributed.get_world_size(group) torch.distributed.barrier(group=group) + # make sure this works with CPU tensors + if device is not None: + result = result.to(device) + # if the tensor is scalar, things are easy if result.ndim == 0: return _simple_gather_all_tensors(result, group, world_size) diff --git a/tests/unittests/bases/test_ddp.py b/tests/unittests/bases/test_ddp.py index 1a44a2145ba..fa7dd5dacfc 100644 --- a/tests/unittests/bases/test_ddp.py +++ b/tests/unittests/bases/test_ddp.py @@ -277,3 +277,19 @@ def _test_sync_with_empty_lists(rank): def test_sync_with_empty_lists(): """Test that synchronization of states can be enabled and disabled for compute.""" pytest.pool.map(_test_sync_with_empty_lists, range(NUM_PROCESSES)) + + +def _test_compute_on_cpu_distributed(rank): + dummy = DummyListMetric(compute_on_cpu=True).to(f"cuda:{rank}") + dummy.update(tensor(rank + 1)) + val = dummy.compute() + assert val == [tensor(rank + 1)] + + +@pytest.mark.DDP() +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Test requires at least 2 GPUs") +@pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows") +@pytest.mark.skipif(not hasattr(pytest, "pool"), reason="DDP pool not available.") +def test_compute_on_cpu_distributed_multi_gpu(): + """Check that compute_on_cpu works with DDP and multiple GPUs.""" + pytest.pool.map(_test_compute_on_cpu_distributed, range(NUM_PROCESSES))