diff --git a/src/torchmetrics/utilities/data.py b/src/torchmetrics/utilities/data.py index dc67d5a4e34..1c7cd37bfd4 100644 --- a/src/torchmetrics/utilities/data.py +++ b/src/torchmetrics/utilities/data.py @@ -230,13 +230,8 @@ def _flexible_bincount(x: Tensor) -> Tensor: Number of occurrences for each unique element in x """ - # make sure elements in x start from 0 - x = x - x.min() - unique_x = torch.unique(x) - - output = _bincount(x, minlength=torch.max(unique_x) + 1) # type: ignore[arg-type] - # remove zeros from output tensor - return output[unique_x] + unique_x, inverse_indices = torch.unique(x, return_inverse=True) + return _bincount(inverse_indices, minlength=len(unique_x)) # type: ignore[arg-type] def allclose(tensor1: Tensor, tensor2: Tensor) -> bool: