diff --git a/.azure/gpu-tests.yml b/.azure/gpu-tests.yml index e37f3c4130..139dc32e98 100644 --- a/.azure/gpu-tests.yml +++ b/.azure/gpu-tests.yml @@ -185,7 +185,9 @@ jobs: pytest \ thunder/tests/distributed \ -v --durations=0 \ + --ignore=thunder/tests/distributed/test_dtensor.py \ --random-order-seed=42 + torchrun --nproc-per-node 2 --no-python pytest thunder/tests/distributed/test_dtensor.py # compile coverage results # TODO: collect and merge reports # python -m coverage report diff --git a/.lightning/workflows/all-tests.yaml b/.lightning/workflows/all-tests.yaml index c46b1dea2c..5ea5b292cb 100644 --- a/.lightning/workflows/all-tests.yaml +++ b/.lightning/workflows/all-tests.yaml @@ -91,7 +91,9 @@ run: | elif [ "${testing}" == "distributed" ]; then pytest thunder/tests/distributed \ -v --durations=0 \ + --ignore=thunder/tests/distributed/test_dtensor.py \ --random-order-seed=42 + torchrun --nproc-per-node 2 --no-python pytest thunder/tests/distributed/test_dtensor.py else echo "Unknown testing type: ${testing}" exit 1 diff --git a/thunder/tests/distributed/test_dtensor.py b/thunder/tests/distributed/test_dtensor.py index c8938b0252..d2578ece84 100644 --- a/thunder/tests/distributed/test_dtensor.py +++ b/thunder/tests/distributed/test_dtensor.py @@ -1,21 +1,18 @@ -import unittest -from itertools import product from collections.abc import Sequence +import os import pytest import torch -if not torch.distributed.is_available(): +if not torch.distributed.is_available() or not (torch.cuda.is_available() and torch.distributed.is_nccl_available()): pytest.skip(allow_module_level=True) import thunder -from thunder.tests.distributed.helper import DistributedParallelTestCase from torch.distributed._tensor import DeviceMesh, distribute_tensor from torch.distributed.tensor.placement_types import Shard from torch.testing._internal.distributed._tensor.common_dtensor import DTensorConverter -from torch.testing._internal import common_utils from thunder.tests.distributed.helper import executors_map from thunder.tests.opinfos import OpInfo, get_opinfo @@ -75,229 +72,239 @@ def __init__(self, *, name, op, torch_reference, supports_grad, sample_inputs): ) -@unittest.skipUnless( - torch.cuda.is_available() and torch.distributed.is_nccl_available(), - "DTensor test requires CUDA and NCCL `torch.distributed` backend", -) -class DTensorTest(DistributedParallelTestCase): - @common_utils.parametrize("executor, fn_key", product(tuple(executors_map.keys()), functions_to_test.keys())) - def test_dtensor_basic_op(self, executor, fn_key): - num_devices = self.world_size - mesh = DeviceMesh("cuda", list(range(num_devices))) +@pytest.fixture(scope="module", autouse=True) +def setup(): + local_rank = int(os.environ["LOCAL_RANK"]) + torch.cuda.set_device(f"cuda:{local_rank}") - dim_size = 16 - def _helper(fn, in_dtensor, w_dtensor): - expected = torch.compile(fn)(in_dtensor, w_dtensor) - tmodel = thunder.jit(fn, executors=executors_map[executor].executors_list()) - actual = tmodel(in_dtensor, w_dtensor) +def get_num_devices(): + world_size = int(os.environ["WORLD_SIZE"]) + return world_size - torch.testing.assert_close(actual, expected) - g_o = distribute_tensor(torch.ones(dim_size, dim_size), mesh, [Shard(0)]) - expected_g = torch.autograd.grad( - expected, - (in_dtensor, w_dtensor), - g_o, - ) - actual_g = torch.autograd.grad(actual, (in_dtensor, w_dtensor), g_o) +@pytest.fixture(scope="session", autouse=True) +def teardown(): + yield + torch.distributed.destroy_process_group() - torch.testing.assert_close(actual_g, expected_g) - w_dtensor = distribute_tensor(torch.randn(dim_size, dim_size, requires_grad=True), mesh, [Shard(0)]) - in_dtensor = distribute_tensor(torch.randn(dim_size, dim_size, requires_grad=True), mesh, [Shard(0)]) +@pytest.mark.parametrize("executor", tuple(executors_map.keys())) +@pytest.mark.parametrize("fn_key", functions_to_test.keys()) +def test_dtensor_basic_op(executor, fn_key): + num_devices = get_num_devices() + mesh = DeviceMesh("cuda", list(range(num_devices))) - # Verify torch API works - _helper(functions_to_test[fn_key], in_dtensor, w_dtensor) + dim_size = 16 - def test_dtensor_unsupported(self): - num_devices = self.world_size - mesh = DeviceMesh("cuda", list(range(num_devices))) + def _helper(fn, in_dtensor, w_dtensor): + expected = torch.compile(fn)(in_dtensor, w_dtensor) + tmodel = thunder.jit(fn, executors=executors_map[executor].executors_list()) + actual = tmodel(in_dtensor, w_dtensor) - dim_size = 16 + torch.testing.assert_close(actual, expected) - w_dtensor = distribute_tensor(torch.randn(dim_size, dim_size, requires_grad=True), mesh, [Shard(0)]) + g_o = distribute_tensor(torch.ones(dim_size, dim_size), mesh, [Shard(0)]) + expected_g = torch.autograd.grad( + expected, + (in_dtensor, w_dtensor), + g_o, + ) + actual_g = torch.autograd.grad(actual, (in_dtensor, w_dtensor), g_o) - in_dtensor = distribute_tensor(torch.randn(dim_size, dim_size, requires_grad=True), mesh, [Shard(0)]) + torch.testing.assert_close(actual_g, expected_g) - def fn(x, w): - return torch.div(x, w) + w_dtensor = distribute_tensor(torch.randn(dim_size, dim_size, requires_grad=True), mesh, [Shard(0)]) + in_dtensor = distribute_tensor(torch.randn(dim_size, dim_size, requires_grad=True), mesh, [Shard(0)]) - tmodel = thunder.jit(fn) - with pytest.raises(AssertionError): - tmodel(in_dtensor, w_dtensor) + # Verify torch API works + _helper(functions_to_test[fn_key], in_dtensor, w_dtensor) - def fn(x, w): - return x / w - tmodel = thunder.jit(fn) - with pytest.raises(AssertionError): - tmodel(in_dtensor, w_dtensor) +def test_dtensor_unsupported(): + num_devices = get_num_devices() + mesh = DeviceMesh("cuda", list(range(num_devices))) - def test_dtensor_unsupported_mixed_input(self): - num_devices = self.world_size - mesh = DeviceMesh("cuda", list(range(num_devices))) + dim_size = 16 - dim_size = 16 + w_dtensor = distribute_tensor(torch.randn(dim_size, dim_size, requires_grad=True), mesh, [Shard(0)]) - def fn(x, w): - return torch.div(x, w) + in_dtensor = distribute_tensor(torch.randn(dim_size, dim_size, requires_grad=True), mesh, [Shard(0)]) - w = torch.randn(dim_size, dim_size, requires_grad=True) + def fn(x, w): + return torch.div(x, w) - in_dtensor = distribute_tensor(torch.randn(dim_size, dim_size, requires_grad=True), mesh, [Shard(0)]) + tmodel = thunder.jit(fn) + with pytest.raises(AssertionError): + tmodel(in_dtensor, w_dtensor) - tmodel = thunder.jit(fn, executors=thunder.get_always_executors()) - with pytest.raises(AssertionError): - tmodel(in_dtensor, w) + def fn(x, w): + return x / w - def test_dtensor_incorrect_cotangent(self): - num_devices = self.world_size - mesh = DeviceMesh("cuda", list(range(num_devices))) + tmodel = thunder.jit(fn) + with pytest.raises(AssertionError): + tmodel(in_dtensor, w_dtensor) - dim_size = 16 - w_dtensor = distribute_tensor(torch.randn(dim_size, dim_size, requires_grad=True), mesh, [Shard(0)]) - in_dtensor = distribute_tensor(torch.randn(dim_size, dim_size, requires_grad=True), mesh, [Shard(0)]) +def test_dtensor_unsupported_mixed_input(): + num_devices = get_num_devices() + mesh = DeviceMesh("cuda", list(range(num_devices))) - def fn(x, w): - return torch.mul(x, w) + dim_size = 16 - tmodel = thunder.jit(fn, executors=thunder.get_always_executors()) - actual = tmodel(in_dtensor, w_dtensor) - g_o = distribute_tensor(torch.ones(dim_size, dim_size), mesh, [Shard(1)]) + def fn(x, w): + return torch.div(x, w) - with pytest.raises(RuntimeError, match="has changed for cotangent between tracing and runtime"): - torch.autograd.grad(actual, (in_dtensor, w_dtensor), g_o) + w = torch.randn(dim_size, dim_size, requires_grad=True) - @common_utils.parametrize("executor", tuple(executors_map.keys())) - def test_dtensor_convert_element_type(self, executor): - from thunder.torch.experimental.dtensor_torch_and_prims import dtensor_convert_element_type_prim + in_dtensor = distribute_tensor(torch.randn(dim_size, dim_size, requires_grad=True), mesh, [Shard(0)]) - num_devices = self.world_size - mesh = DeviceMesh("cuda", list(range(num_devices))) + tmodel = thunder.jit(fn, executors=thunder.get_always_executors()) + with pytest.raises(AssertionError): + tmodel(in_dtensor, w) - dim_size = 16 - in_dtensor = distribute_tensor(torch.randn(dim_size, dim_size, requires_grad=True), mesh, [Shard(0)]) +def test_dtensor_incorrect_cotangent(): + num_devices = get_num_devices() + mesh = DeviceMesh("cuda", list(range(num_devices))) - def fn(x): - return dtensor_convert_element_type_prim(x, dtypes.bfloat16) + dim_size = 16 - tmodel = thunder.jit(fn, executors=executors_map[executor].executors_list()) - actual = tmodel(in_dtensor) - expected = in_dtensor.to(torch.bfloat16) + w_dtensor = distribute_tensor(torch.randn(dim_size, dim_size, requires_grad=True), mesh, [Shard(0)]) + in_dtensor = distribute_tensor(torch.randn(dim_size, dim_size, requires_grad=True), mesh, [Shard(0)]) - torch.testing.assert_close(actual, expected) + def fn(x, w): + return torch.mul(x, w) - g_o = distribute_tensor(torch.ones(dim_size, dim_size), mesh, [Shard(0)]) - expected_g = torch.autograd.grad( - expected, - (in_dtensor,), - g_o, - ) - actual_g = torch.autograd.grad(actual, (in_dtensor,), g_o) + tmodel = thunder.jit(fn, executors=thunder.get_always_executors()) + actual = tmodel(in_dtensor, w_dtensor) + g_o = distribute_tensor(torch.ones(dim_size, dim_size), mesh, [Shard(1)]) - torch.testing.assert_close(actual_g, expected_g) + with pytest.raises(RuntimeError, match="has changed for cotangent between tracing and runtime"): + torch.autograd.grad(actual, (in_dtensor, w_dtensor), g_o) - @common_utils.parametrize("executor", tuple(executors_map.keys())) - def test_dtensor_broadcast_in_dim(self, executor): - from thunder.torch.experimental.dtensor_torch_and_prims import dtensor_broadcast_in_dim_prim - num_devices = self.world_size - mesh = DeviceMesh("cuda", list(range(num_devices))) - dim_size = 16 - in_dtensor = distribute_tensor(torch.randn(dim_size, dim_size, requires_grad=False), mesh, [Shard(0)]) +@pytest.mark.parametrize("executor", tuple(executors_map.keys())) +def test_dtensor_convert_element_type(executor): + from thunder.torch.experimental.dtensor_torch_and_prims import dtensor_convert_element_type_prim - def fn(x): - return dtensor_broadcast_in_dim_prim(x, (dim_size, dim_size), (0, 1)) + num_devices = get_num_devices() + mesh = DeviceMesh("cuda", list(range(num_devices))) - tmodel = thunder.jit(fn, executors=executors_map[executor].executors_list()) - actual = tmodel(in_dtensor) - expected = in_dtensor.broadcast_to((dim_size, dim_size)) + dim_size = 16 - torch.testing.assert_close(actual, expected) + in_dtensor = distribute_tensor(torch.randn(dim_size, dim_size, requires_grad=True), mesh, [Shard(0)]) + + def fn(x): + return dtensor_convert_element_type_prim(x, dtypes.bfloat16) + + tmodel = thunder.jit(fn, executors=executors_map[executor].executors_list()) + actual = tmodel(in_dtensor) + expected = in_dtensor.to(torch.bfloat16) + + torch.testing.assert_close(actual, expected) - @common_utils.parametrize( - "op, executor", - product(dtensor_supported_opinfos, tuple(executors_map.keys())), - lambda op, executor: op.name + "_" + executor, + g_o = distribute_tensor(torch.ones(dim_size, dim_size), mesh, [Shard(0)]) + expected_g = torch.autograd.grad( + expected, + (in_dtensor,), + g_o, ) - def test_dtensor_opinfo(self, op: OpInfo, executor): - if op.name in skip_opinfos: - raise unittest.SkipTest(f"test_dtensor_opinfo: Skipping {op.name} as it is in skip_opinfos") - - # NOTE: This test only tests for dtype=torch.float32 and requires_grad=True - # not for all dtype which are supported by the operation. - num_devices = self.world_size - mesh = DeviceMesh("cuda", list(range(num_devices))) - - thunder_op = thunder.jit(op.op, executors=executors_map[executor].executors_list(), nv_enable_linear=True) - torch_op = op.torch_reference - - tested_sample_count = 0 - - for sample in op.sample_inputs("cpu", dtypes.float32, requires_grad=op.supports_grad): - # DTensorConverter converts inputs tensors to DTensor and creates DTensor - # with possible placements based on the input shapes. - # See - https://github.com/pytorch/pytorch/blob/eaa5d9d3d3dc642832b269b184f0c3ab8c990274/torch/testing/_internal/distributed/_tensor/common_dtensor.py#L521 - dtensor_converter = DTensorConverter(mesh, sample.args, sample.kwargs) - for dtensor_args, dtensor_kwargs in dtensor_converter: - if not dtensor_converter.successful(): - continue - - # Computes PyTorch result - try: - torch_result = torch_op(*dtensor_args, **dtensor_kwargs) - except Exception: - # Unsupported input passed to `torch_op`, we expect an exception from `thunder_op` as well. - with pytest.raises(Exception): - thunder_op(*dtensor_args, **dtensor_kwargs) - continue - - thunder_result = thunder_op(*dtensor_args, **dtensor_kwargs) - torch.testing.assert_close(thunder_result, torch_result) - - trace = thunder.last_traces(thunder_op)[0] - assert any("dtensor" in bsym.sym.name for bsym in trace.bound_symbols) - - if op.supports_grad: - torch_flats, _ = tree_flatten((dtensor_args, dtensor_kwargs)) - torch_result = filter_differentiable_outputs(torch_result) - if torch_result == []: - raise RuntimeError("test_dtensor_opinfo: Expected atleast 1 differentiable output.") - - grads = [] - assert isinstance(torch_result, torch.Tensor) or isinstance(torch_result, Sequence), ( - "test_dtensor_opinfo:Expected a single torch tensor or a sequence of torch tensors" - ) - if isinstance(torch_result, Sequence): - for x in torch_result: - assert isinstance(x, torch.Tensor), ( - "test_dtensor_opinfo: Expected a single torch tensor or a sequence of torch tensors" - ) - if is_output_differentiable(x): - grads.append(torch.ones_like(x)) - else: - if is_output_differentiable(torch_result): - grads = [torch.ones_like(torch_result)] - - torch_tensors_requiring_grad = tuple( - f for f in torch_flats if isinstance(f, torch.Tensor) and f.requires_grad - ) - torch_grad_result = torch.autograd.grad(torch_result, torch_tensors_requiring_grad, grads) - - thunder_result = filter_differentiable_outputs(thunder_result) - thunder_grad_result = torch.autograd.grad(thunder_result, torch_tensors_requiring_grad, grads) - torch.testing.assert_close(thunder_grad_result, torch_grad_result) - - # Increment tested sample count - tested_sample_count += 1 - - assert tested_sample_count > 0, f"test_dtensor_opinfo:No samples tested for {op.name} with {executor} executor" - - -common_utils.instantiate_parametrized_tests(DTensorTest) - -if __name__ == "__main__": - common_utils.run_tests() + actual_g = torch.autograd.grad(actual, (in_dtensor,), g_o) + + torch.testing.assert_close(actual_g, expected_g) + + +@pytest.mark.parametrize("executor", tuple(executors_map.keys())) +def test_dtensor_broadcast_in_dim(executor): + from thunder.torch.experimental.dtensor_torch_and_prims import dtensor_broadcast_in_dim_prim + + num_devices = get_num_devices() + mesh = DeviceMesh("cuda", list(range(num_devices))) + dim_size = 16 + in_dtensor = distribute_tensor(torch.randn(dim_size, dim_size, requires_grad=False), mesh, [Shard(0)]) + + def fn(x): + return dtensor_broadcast_in_dim_prim(x, (dim_size, dim_size), (0, 1)) + + tmodel = thunder.jit(fn, executors=executors_map[executor].executors_list()) + actual = tmodel(in_dtensor) + expected = in_dtensor.broadcast_to((dim_size, dim_size)) + + torch.testing.assert_close(actual, expected) + + +@pytest.mark.parametrize("op", dtensor_supported_opinfos) +@pytest.mark.parametrize("executor", tuple(executors_map.keys())) +def test_dtensor_opinfo(op: OpInfo, executor): + if op.name in skip_opinfos: + raise pytest.skip(f"test_dtensor_opinfo: Skipping {op.name} as it is in skip_opinfos") + + # NOTE: This test only tests for dtype=torch.float32 and requires_grad=True + # not for all dtype which are supported by the operation. + num_devices = get_num_devices() + mesh = DeviceMesh("cuda", list(range(num_devices))) + + thunder_op = thunder.jit(op.op, executors=executors_map[executor].executors_list(), nv_enable_linear=True) + torch_op = op.torch_reference + + tested_sample_count = 0 + + for sample in op.sample_inputs("cpu", dtypes.float32, requires_grad=op.supports_grad): + # DTensorConverter converts inputs tensors to DTensor and creates DTensor + # with possible placements based on the input shapes. + # See - https://github.com/pytorch/pytorch/blob/eaa5d9d3d3dc642832b269b184f0c3ab8c990274/torch/testing/_internal/distributed/_tensor/common_dtensor.py#L521 + dtensor_converter = DTensorConverter(mesh, sample.args, sample.kwargs) + for dtensor_args, dtensor_kwargs in dtensor_converter: + if not dtensor_converter.successful(): + continue + + # Computes PyTorch result + try: + torch_result = torch_op(*dtensor_args, **dtensor_kwargs) + except Exception: + # Unsupported input passed to `torch_op`, we expect an exception from `thunder_op` as well. + with pytest.raises(Exception): + thunder_op(*dtensor_args, **dtensor_kwargs) + continue + + thunder_result = thunder_op(*dtensor_args, **dtensor_kwargs) + torch.testing.assert_close(thunder_result, torch_result) + + trace = thunder.last_traces(thunder_op)[0] + assert any("dtensor" in bsym.sym.name for bsym in trace.bound_symbols) + + if op.supports_grad: + torch_flats, _ = tree_flatten((dtensor_args, dtensor_kwargs)) + torch_result = filter_differentiable_outputs(torch_result) + if torch_result == []: + raise RuntimeError("test_dtensor_opinfo: Expected atleast 1 differentiable output.") + + grads = [] + assert isinstance(torch_result, torch.Tensor) or isinstance(torch_result, Sequence), ( + "test_dtensor_opinfo:Expected a single torch tensor or a sequence of torch tensors" + ) + if isinstance(torch_result, Sequence): + for x in torch_result: + assert isinstance(x, torch.Tensor), ( + "test_dtensor_opinfo: Expected a single torch tensor or a sequence of torch tensors" + ) + if is_output_differentiable(x): + grads.append(torch.ones_like(x)) + else: + if is_output_differentiable(torch_result): + grads = [torch.ones_like(torch_result)] + + torch_tensors_requiring_grad = tuple( + f for f in torch_flats if isinstance(f, torch.Tensor) and f.requires_grad + ) + torch_grad_result = torch.autograd.grad(torch_result, torch_tensors_requiring_grad, grads) + + thunder_result = filter_differentiable_outputs(thunder_result) + thunder_grad_result = torch.autograd.grad(thunder_result, torch_tensors_requiring_grad, grads) + torch.testing.assert_close(thunder_grad_result, torch_grad_result) + + # Increment tested sample count + tested_sample_count += 1 + + assert tested_sample_count > 0, f"test_dtensor_opinfo:No samples tested for {op.name} with {executor} executor"