From 6160be1142803d21686fce4d739283d56dd57c95 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sun, 5 Oct 2025 00:57:10 +0000 Subject: [PATCH 1/4] Initial plan From 96cd400984b93470042fd27d08c5b6723b90c600 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sun, 5 Oct 2025 01:07:41 +0000 Subject: [PATCH 2/4] Fix devices_list type issue by changing devices type to Any Co-authored-by: manujosephv <10508493+manujosephv@users.noreply.github.com> --- src/pytorch_tabular/config/config.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/pytorch_tabular/config/config.py b/src/pytorch_tabular/config/config.py index 999c2c4a..c14b0b27 100644 --- a/src/pytorch_tabular/config/config.py +++ b/src/pytorch_tabular/config/config.py @@ -6,7 +6,7 @@ import os import re from dataclasses import MISSING, dataclass, field -from typing import Any, Dict, Iterable, List, Optional +from typing import Any, Dict, Iterable, List, Optional, Union from omegaconf import OmegaConf @@ -400,7 +400,7 @@ class TrainerConfig: "choices": ["cpu", "gpu", "tpu", "ipu", "mps", "auto"], }, ) - devices: Optional[int] = field( + devices: Any = field( default=-1, metadata={ "help": "Number of devices to train on. -1 uses all available devices." From ac4c5fd4837f8665bfb0ea37a8b7610212e52554 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sun, 5 Oct 2025 01:09:35 +0000 Subject: [PATCH 3/4] Add tests and update documentation for devices field Co-authored-by: manujosephv <10508493+manujosephv@users.noreply.github.com> --- src/pytorch_tabular/config/config.py | 4 +- tests/test_config.py | 77 ++++++++++++++++++++++++++++ 2 files changed, 79 insertions(+), 2 deletions(-) create mode 100644 tests/test_config.py diff --git a/src/pytorch_tabular/config/config.py b/src/pytorch_tabular/config/config.py index c14b0b27..58a72a3a 100644 --- a/src/pytorch_tabular/config/config.py +++ b/src/pytorch_tabular/config/config.py @@ -287,8 +287,8 @@ class TrainerConfig: 'cpu','gpu','tpu','ipu', 'mps', 'auto'. Defaults to 'auto'. Choices are: [`cpu`,`gpu`,`tpu`,`ipu`,'mps',`auto`]. - devices (Optional[int]): Number of devices to train on (int). -1 uses all available devices. By - default, uses all available devices (-1) + devices (Union[int, List[int]]): Number of devices to train on (int), or list of device indices. + -1 uses all available devices. By default, uses all available devices (-1) devices_list (Optional[List[int]]): List of devices to train on (list). If specified, takes precedence over `devices` argument. Defaults to None diff --git a/tests/test_config.py b/tests/test_config.py new file mode 100644 index 00000000..e1747406 --- /dev/null +++ b/tests/test_config.py @@ -0,0 +1,77 @@ +#!/usr/bin/env python +"""Tests for config classes.""" + +import pytest +from omegaconf import OmegaConf + +from pytorch_tabular.config import TrainerConfig + + +class TestTrainerConfig: + """Tests for TrainerConfig class.""" + + def test_devices_list_to_devices_conversion(self): + """Test that devices_list is properly converted to devices.""" + # Test with a list of devices + trainer_config = TrainerConfig(devices_list=[0, 1]) + assert trainer_config.devices == [0, 1] + + # Wrap with OmegaConf as done in TabularModel + config = OmegaConf.structured(trainer_config) + assert config.devices == [0, 1] + + def test_devices_list_multiple_gpus(self): + """Test devices_list with multiple GPU IDs as documented.""" + trainer_config = TrainerConfig(devices_list=[1, 2, 3, 4]) + assert trainer_config.devices == [1, 2, 3, 4] + + config = OmegaConf.structured(trainer_config) + assert config.devices == [1, 2, 3, 4] + + def test_devices_int_value(self): + """Test that devices accepts integer values.""" + trainer_config = TrainerConfig(devices=2) + assert trainer_config.devices == 2 + + config = OmegaConf.structured(trainer_config) + assert config.devices == 2 + + def test_devices_default_value(self): + """Test that devices has default value of -1.""" + trainer_config = TrainerConfig() + assert trainer_config.devices == -1 + + config = OmegaConf.structured(trainer_config) + assert config.devices == -1 + + def test_devices_list_single_device(self): + """Test devices_list with a single device.""" + trainer_config = TrainerConfig(devices_list=[0]) + assert trainer_config.devices == [0] + + config = OmegaConf.structured(trainer_config) + assert config.devices == [0] + + def test_devices_list_precedence(self): + """Test that devices_list takes precedence over devices.""" + # When both are provided, devices_list should take precedence + trainer_config = TrainerConfig(devices=2, devices_list=[0, 1]) + assert trainer_config.devices == [0, 1] + + config = OmegaConf.structured(trainer_config) + assert config.devices == [0, 1] + + def test_omegaconf_merge_compatibility(self): + """Test that config works correctly with OmegaConf.merge.""" + trainer_config = TrainerConfig(devices_list=[0, 1], max_epochs=10) + config = OmegaConf.structured(trainer_config) + + # Simulate merging as done in TabularModel + merged = OmegaConf.merge( + OmegaConf.to_container(config), + {"accelerator": "gpu"} + ) + + assert merged.devices == [0, 1] + assert merged.max_epochs == 10 + assert merged.accelerator == "gpu" From 6841b1a90e75b8e7f9bc86786bd8c443be0b6cc6 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sun, 5 Oct 2025 01:14:33 +0000 Subject: [PATCH 4/4] Apply ruff linting and formatting Co-authored-by: manujosephv <10508493+manujosephv@users.noreply.github.com> --- src/pytorch_tabular/config/config.py | 14 +++++++------- tests/test_config.py | 22 +++++++++------------- 2 files changed, 16 insertions(+), 20 deletions(-) diff --git a/src/pytorch_tabular/config/config.py b/src/pytorch_tabular/config/config.py index 58a72a3a..ce6e1d0d 100644 --- a/src/pytorch_tabular/config/config.py +++ b/src/pytorch_tabular/config/config.py @@ -6,7 +6,7 @@ import os import re from dataclasses import MISSING, dataclass, field -from typing import Any, Dict, Iterable, List, Optional, Union +from typing import Any, Dict, Iterable, List, Optional from omegaconf import OmegaConf @@ -192,9 +192,9 @@ class DataConfig: ) def __post_init__(self): - assert ( - len(self.categorical_cols) + len(self.continuous_cols) + len(self.date_columns) > 0 - ), "There should be at-least one feature defined in categorical, continuous, or date columns" + assert len(self.categorical_cols) + len(self.continuous_cols) + len(self.date_columns) > 0, ( + "There should be at-least one feature defined in categorical, continuous, or date columns" + ) _validate_choices(self) if os.name == "nt" and self.num_workers != 0: print("Windows does not support num_workers > 0. Setting num_workers to 0") @@ -255,9 +255,9 @@ class InferredConfig: def __post_init__(self): if self.embedding_dims is not None: - assert all( - (isinstance(t, Iterable) and len(t) == 2) for t in self.embedding_dims - ), "embedding_dims must be a list of tuples (cardinality, embedding_dim)" + assert all((isinstance(t, Iterable) and len(t) == 2) for t in self.embedding_dims), ( + "embedding_dims must be a list of tuples (cardinality, embedding_dim)" + ) self.embedded_cat_dim = sum([t[1] for t in self.embedding_dims]) else: self.embedded_cat_dim = 0 diff --git a/tests/test_config.py b/tests/test_config.py index e1747406..a59dab16 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,7 +1,6 @@ #!/usr/bin/env python """Tests for config classes.""" -import pytest from omegaconf import OmegaConf from pytorch_tabular.config import TrainerConfig @@ -15,7 +14,7 @@ def test_devices_list_to_devices_conversion(self): # Test with a list of devices trainer_config = TrainerConfig(devices_list=[0, 1]) assert trainer_config.devices == [0, 1] - + # Wrap with OmegaConf as done in TabularModel config = OmegaConf.structured(trainer_config) assert config.devices == [0, 1] @@ -24,7 +23,7 @@ def test_devices_list_multiple_gpus(self): """Test devices_list with multiple GPU IDs as documented.""" trainer_config = TrainerConfig(devices_list=[1, 2, 3, 4]) assert trainer_config.devices == [1, 2, 3, 4] - + config = OmegaConf.structured(trainer_config) assert config.devices == [1, 2, 3, 4] @@ -32,7 +31,7 @@ def test_devices_int_value(self): """Test that devices accepts integer values.""" trainer_config = TrainerConfig(devices=2) assert trainer_config.devices == 2 - + config = OmegaConf.structured(trainer_config) assert config.devices == 2 @@ -40,7 +39,7 @@ def test_devices_default_value(self): """Test that devices has default value of -1.""" trainer_config = TrainerConfig() assert trainer_config.devices == -1 - + config = OmegaConf.structured(trainer_config) assert config.devices == -1 @@ -48,7 +47,7 @@ def test_devices_list_single_device(self): """Test devices_list with a single device.""" trainer_config = TrainerConfig(devices_list=[0]) assert trainer_config.devices == [0] - + config = OmegaConf.structured(trainer_config) assert config.devices == [0] @@ -57,7 +56,7 @@ def test_devices_list_precedence(self): # When both are provided, devices_list should take precedence trainer_config = TrainerConfig(devices=2, devices_list=[0, 1]) assert trainer_config.devices == [0, 1] - + config = OmegaConf.structured(trainer_config) assert config.devices == [0, 1] @@ -65,13 +64,10 @@ def test_omegaconf_merge_compatibility(self): """Test that config works correctly with OmegaConf.merge.""" trainer_config = TrainerConfig(devices_list=[0, 1], max_epochs=10) config = OmegaConf.structured(trainer_config) - + # Simulate merging as done in TabularModel - merged = OmegaConf.merge( - OmegaConf.to_container(config), - {"accelerator": "gpu"} - ) - + merged = OmegaConf.merge(OmegaConf.to_container(config), {"accelerator": "gpu"}) + assert merged.devices == [0, 1] assert merged.max_epochs == 10 assert merged.accelerator == "gpu"