Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
d7dc9b2
ci: split common & distributed
Borda Jul 25, 2025
bf756d3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 25, 2025
324a610
Merge branch 'master' into ci/split
SkafteNicki Jul 28, 2025
dfeee3d
try random port instead
SkafteNicki Aug 4, 2025
c17511d
try instead a free socket approach
SkafteNicki Aug 4, 2025
608310f
Merge branch 'master' into ci/split
Borda Aug 4, 2025
e2af420
fix ports not syncing
SkafteNicki Aug 4, 2025
060939b
Merge branch 'ci/split' of https://github.com/Lightning-AI/torchmetri…
SkafteNicki Aug 4, 2025
6c92ed9
try limit number of parallel jobs for debugging
SkafteNicki Aug 5, 2025
61b7ada
Merge branch 'master' into ci/split
Borda Aug 6, 2025
86f33bd
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 6, 2025
ce90ca4
Merge branch 'master' into ci/split
Borda Aug 6, 2025
0e502a6
pytest
Borda Aug 6, 2025
913cea6
gpu
Borda Aug 6, 2025
9f2d1cb
Merge branch 'master' into ci/split
Borda Aug 6, 2025
5637d3b
doctest
Borda Aug 6, 2025
030c6b6
Merge branch 'master' into ci/split
Borda Aug 6, 2025
a0ad488
rev GPU
Borda Aug 7, 2025
40e7840
Merge branch 'master' into ci/split
Borda Aug 7, 2025
d10a688
get_free_port
Borda Aug 11, 2025
1c93193
get_free_port
Borda Aug 11, 2025
6f9bec7
get_free_port
Borda Aug 11, 2025
5f81e75
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 11, 2025
4b3c37e
get_free_port()
Borda Aug 12, 2025
b5f9ae3
settimeout
Borda Aug 12, 2025
666cb7f
setup_ddp
Borda Aug 12, 2025
12253e6
timeout-minutes: 70
Borda Aug 12, 2025
e73e1d3
enumerate
Borda Aug 12, 2025
b8eae2c
get_free_port
Borda Aug 12, 2025
2021ca5
get_free_port
Borda Aug 12, 2025
f7fd132
master port
Borda Aug 12, 2025
154a4c6
no rerun
Borda Aug 13, 2025
4f5aedc
Merge branch 'master' into ci/split
Borda Aug 20, 2025
ac87706
Merge branch 'master' into ci/split
Borda Sep 10, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 32 additions & 17 deletions .github/workflows/ci-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,19 +46,38 @@ jobs:
- "2.6.0"
- "2.7.1"
- "2.8.0"
testing: ["common", "distributed"]
include:
# cover additional python and PT combinations
- { os: "ubuntu-22.04", python-version: "3.9", pytorch-version: "2.0.1", requires: "oldest" }
- { os: "ubuntu-22.04", python-version: "3.12", pytorch-version: "2.7.1" }
- { os: "ubuntu-22.04", python-version: "3.12", pytorch-version: "2.8.0" }
- {
os: "ubuntu-22.04",
python-version: "3.9",
pytorch-version: "2.0.1",
requires: "oldest",
testing: "common",
}
- {
os: "ubuntu-22.04",
python-version: "3.9",
pytorch-version: "2.0.1",
requires: "oldest",
testing: "distributed",
}
- { os: "ubuntu-22.04", python-version: "3.12", pytorch-version: "2.8.0", testing: "common" }
- { os: "ubuntu-22.04", python-version: "3.12", pytorch-version: "2.8.0", testing: "distributed" }
# standard mac machine, not the M1
- { os: "macOS-13", python-version: "3.10", pytorch-version: "2.0.1" }
- { os: "macOS-14", python-version: "3.10", pytorch-version: "2.0.1", testing: "common" }
- { os: "macOS-14", python-version: "3.10", pytorch-version: "2.0.1", testing: "distributed" }
# using the ARM based M1 machine
- { os: "macOS-14", python-version: "3.10", pytorch-version: "2.0.1" }
- { os: "macOS-14", python-version: "3.12", pytorch-version: "2.8.0" }
- { os: "macOS-14", python-version: "3.10", pytorch-version: "2.0.1", testing: "common" }
- { os: "macOS-14", python-version: "3.10", pytorch-version: "2.0.1", testing: "distributed" }
- { os: "macOS-14", python-version: "3.12", pytorch-version: "2.8.0", testing: "common" }
- { os: "macOS-14", python-version: "3.12", pytorch-version: "2.8.0", testing: "distributed" }
# some windows
- { os: "windows-2022", python-version: "3.10", pytorch-version: "2.0.1" }
- { os: "windows-2022", python-version: "3.12", pytorch-version: "2.8.0" }
- { os: "windows-2022", python-version: "3.10", pytorch-version: "2.0.1", testing: "common" }
- { os: "windows-2022", python-version: "3.10", pytorch-version: "2.0.1", testing: "distributed" }
- { os: "windows-2022", python-version: "3.12", pytorch-version: "2.8.0", testing: "common" }
- { os: "windows-2022", python-version: "3.12", pytorch-version: "2.8.0", testing: "distributed" }
# Future released version
#- { os: "ubuntu-22.04", python-version: "3.11", pytorch-version: "2.8.0" }
#- { os: "macOS-14", python-version: "3.11", pytorch-version: "2.8.0" }
Expand All @@ -73,7 +92,7 @@ jobs:

# Timeout: https://stackoverflow.com/a/59076067/4521646
# seems that macOS jobs take much more than orger OS
timeout-minutes: 120
timeout-minutes: 70

steps:
- uses: actions/checkout@v5
Expand Down Expand Up @@ -182,34 +201,30 @@ jobs:

- name: Unittests common
# skip for PR if there is nothing to test, note that outside PR there is default 'unittests'
if: ${{ env.TEST_DIRS != '' }}
if: ${{ env.TEST_DIRS != '' && matrix.testing == 'common' }}
working-directory: ./tests
run: |
python -m pytest \
pytest \
$TEST_DIRS \
--cov=torchmetrics \
--durations=50 \
--reruns 3 \
--reruns-delay 1 \
-m "not DDP" \
-n auto \
--dist=load \
${{ env.UNITTEST_TIMEOUT }}

- name: Unittests DDP
# skip for PR if there is nothing to test, note that outside PR there is default 'unittests'
if: ${{ env.TEST_DIRS != '' }}
if: ${{ env.TEST_DIRS != '' && matrix.testing == 'distributed' }}
working-directory: ./tests
env:
USE_PYTEST_POOL: "1"
run: |
python -m pytest -v \
pytest -v \
$TEST_DIRS \
--cov=torchmetrics \
--durations=50 \
-m DDP \
--reruns 3 \
--reruns-delay 1 \
${{ env.UNITTEST_TIMEOUT }}

- name: Statistics
Expand Down
13 changes: 7 additions & 6 deletions tests/unittests/bases/test_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from unittests import NUM_PROCESSES, USE_PYTEST_POOL
from unittests._helpers import _IS_WINDOWS, seed_all
from unittests._helpers.testers import DummyListMetric, DummyMetric, DummyMetricSum
from unittests.conftest import setup_ddp
from unittests.conftest import get_free_port, setup_ddp

seed_all(42)

Expand Down Expand Up @@ -105,9 +105,9 @@ def test_ddp(process):
pytest.pool.map(process, range(NUM_PROCESSES))


def _test_ddp_gather_all_autograd_same_shape(rank: int, worldsize: int = NUM_PROCESSES) -> None:
def _test_ddp_gather_all_autograd_same_shape(rank: int, worldsize: int, port: int) -> None:
"""Test that ddp gather preserves local rank's autograd graph for same-shaped tensors across ranks."""
setup_ddp(rank, worldsize)
setup_ddp(rank, worldsize, port)
x = (rank + 1) * torch.ones(10, requires_grad=True)

# random linear transformation, it should really not matter what we do here
Expand All @@ -120,9 +120,9 @@ def _test_ddp_gather_all_autograd_same_shape(rank: int, worldsize: int = NUM_PRO
assert torch.allclose(grad, a * torch.ones_like(x))


def _test_ddp_gather_all_autograd_different_shape(rank: int, worldsize: int = NUM_PROCESSES) -> None:
def _test_ddp_gather_all_autograd_different_shape(rank: int, worldsize: int, port: int) -> None:
"""Test that ddp gather preserves local rank's autograd graph for differently-shaped tensors across ranks."""
setup_ddp(rank, worldsize)
setup_ddp(rank, worldsize, port)
x = (rank + 1) * torch.ones(rank + 1, 2 - rank, requires_grad=True)

# random linear transformation, it should really not matter what we do here
Expand All @@ -143,7 +143,8 @@ def _test_ddp_gather_all_autograd_different_shape(rank: int, worldsize: int = NU
)
def test_ddp_autograd(process):
"""Test ddp functions for autograd compatibility."""
pytest.pool.map(process, range(NUM_PROCESSES))
port = get_free_port()
pytest.pool.starmap(process, [(rank, NUM_PROCESSES, port) for rank in range(NUM_PROCESSES)])


def _test_non_contiguous_tensors(rank):
Expand Down
33 changes: 21 additions & 12 deletions tests/unittests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
import contextlib
import os
import socket
import sys

import pytest
Expand All @@ -30,9 +31,6 @@
EXTRA_DIM = 3
THRESHOLD = 0.5

MAX_PORT = 8100
START_PORT = 8088
CURRENT_PORT = START_PORT
USE_PYTEST_POOL = os.getenv("USE_PYTEST_POOL", "0") == "1"


Expand All @@ -44,7 +42,16 @@ def use_deterministic_algorithms():
torch.use_deterministic_algorithms(False)


def setup_ddp(rank, world_size):
def get_free_port() -> int:
"""Find an available free port on localhost and keep it reserved."""
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("localhost", 0)) # Bind to a free port provided by the OS
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
port = s.getsockname()[1]
return int(port)


def setup_ddp(rank: int, world_size: int, port: int) -> None:
"""Initialize ddp environment.

If a particular test relies on the order of the processes in the pool to be [0, 1, 2, ...], then this function
Expand All @@ -54,16 +61,11 @@ def setup_ddp(rank, world_size):
Args:
rank: the rank of the process
world_size: the number of processes
port: the port to use for communication

"""
global CURRENT_PORT

os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = str(CURRENT_PORT)

CURRENT_PORT += 1
if CURRENT_PORT > MAX_PORT:
CURRENT_PORT = START_PORT
os.environ["MASTER_PORT"] = str(port)

if torch.distributed.group.WORLD is not None: # if already initialized, destroy the process group
torch.distributed.destroy_process_group()
Expand All @@ -72,12 +74,19 @@ def setup_ddp(rank, world_size):
torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size)


def cleanup_ddp():
"""Clean up the DDP process group if initialized."""
if torch.distributed.is_initialized():
torch.distributed.destroy_process_group()


def pytest_sessionstart():
"""Global initialization of multiprocessing pool; runs before any test."""
if not USE_PYTEST_POOL:
return
port = get_free_port()
pool = Pool(processes=NUM_PROCESSES)
pool.starmap(setup_ddp, [(rank, NUM_PROCESSES) for rank in range(NUM_PROCESSES)])
pool.starmap(setup_ddp, [(rank, NUM_PROCESSES, port) for rank in range(NUM_PROCESSES)])
pytest.pool = pool


Expand Down
17 changes: 0 additions & 17 deletions tests/unittests/image/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,23 +13,6 @@
# limitations under the License.
import os

import torch
import torch.distributed as dist

from unittests import _PATH_ALL_TESTS

_SAMPLE_IMAGE = os.path.join(_PATH_ALL_TESTS, "_data", "image", "i01_01_5.bmp")


def setup_ddp(rank: int, world_size: int, free_port: int):
"""Set up DDP with a free port and assign CUDA device to the given rank."""
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = str(free_port)
dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)


def cleanup_ddp():
"""Clean up the DDP process group if initialized."""
if dist.is_initialized():
dist.destroy_process_group()
5 changes: 2 additions & 3 deletions tests/unittests/image/test_ms_ssim.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@
from unittests import NUM_BATCHES, _Input
from unittests._helpers import _IS_WINDOWS, seed_all
from unittests._helpers.testers import MetricTester
from unittests.image import cleanup_ddp, setup_ddp
from unittests.utilities.test_utilities import find_free_port
from unittests.conftest import cleanup_ddp, get_free_port, setup_ddp

seed_all(42)

Expand Down Expand Up @@ -136,7 +135,7 @@ def test_ms_ssim_reduction_none_ddp():

"""
world_size = 2
free_port = find_free_port()
free_port = get_free_port()
if free_port == -1:
pytest.skip("No free port available for DDP test.")
mp.spawn(_run_ms_ssim_ddp, args=(world_size, free_port), nprocs=world_size, join=True)
5 changes: 2 additions & 3 deletions tests/unittests/image/test_ssim.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,7 @@
from unittests import NUM_BATCHES, _Input
from unittests._helpers import _IS_WINDOWS, seed_all
from unittests._helpers.testers import MetricTester
from unittests.image import cleanup_ddp, setup_ddp
from unittests.utilities.test_utilities import find_free_port
from unittests.conftest import cleanup_ddp, get_free_port, setup_ddp

seed_all(42)

Expand Down Expand Up @@ -391,7 +390,7 @@ def test_ssim_reduction_none_ddp():

"""
world_size = 2
free_port = find_free_port()
free_port = get_free_port()
if free_port == -1:
pytest.skip("No free port available for DDP test.")
mp.spawn(_run_ssim_ddp, args=(world_size, free_port), nprocs=world_size, join=True)
14 changes: 0 additions & 14 deletions tests/unittests/utilities/test_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import socket
import sys

import numpy as np
Expand All @@ -20,7 +19,6 @@
from lightning_utilities.test.warning import no_warning_call
from torch import tensor
from unittests._helpers import _IS_WINDOWS
from unittests.conftest import MAX_PORT, START_PORT

from torchmetrics.regression import MeanSquaredError, PearsonCorrCoef
from torchmetrics.utilities import check_forward_full_state_property, rank_zero_debug, rank_zero_info, rank_zero_warn
Expand Down Expand Up @@ -240,15 +238,3 @@ def test_half_precision_top_k_cpu_raises_error():
x = torch.randn(100, 10, dtype=torch.half)
with pytest.raises(RuntimeError, match="\"topk_cpu\" not implemented for 'Half'"):
torch.topk(x, k=3, dim=1)


def find_free_port(start=START_PORT, end=MAX_PORT):
"""Returns an available localhost port in the given range or returns -1 if no port available."""
for port in range(start, end + 1):
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
try:
s.bind(("localhost", port))
return port
except OSError:
continue
return -1
Loading