-
Notifications
You must be signed in to change notification settings - Fork 1
Dataset.py Refactors, Simplifications, and Testing #80
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
…ly sizeable refactor of the various transformations in the dataset file and adding a number of tests.
| return {k: normalizer.transform(v) for k, v in x.items()}, normalizer | ||
|
|
||
|
|
||
| def process_nans_in_numerical_features(dataset: Dataset, policy: NumericalNaNPolicy | None) -> Dataset: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This guy is sticking around in the dataset, because it relies on the definition of Dataset and I didn't want to jump through circular definition hoops.
📝 WalkthroughWalkthroughThis PR refactors the clavaddpm dataset handling architecture by extracting and modularizing data preprocessing logic into three new utility modules (dataset_transformations.py, dataset_utils.py, metrics.py). The main dataset.py module is restructured to use these utilities and introduces a caching mechanism for transformed datasets. The make_dataset_from_df function transitions from module-level to a Dataset classmethod. The parameter data_split_ratios is renamed to data_split_percentages across multiple call sites in fine_tuning.py and train.py. Tests are updated to reflect API changes, and two new test modules are added for the new utilities. Estimated code review effort🎯 4 (Complex) | ⏱️ ~45–75 minutes
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (9)
src/midst_toolkit/attacks/ensemble/clavaddpm_fine_tuning.py(3 hunks)src/midst_toolkit/models/clavaddpm/dataset.py(9 hunks)src/midst_toolkit/models/clavaddpm/dataset_transformations.py(1 hunks)src/midst_toolkit/models/clavaddpm/dataset_utils.py(1 hunks)src/midst_toolkit/models/clavaddpm/metrics.py(1 hunks)src/midst_toolkit/models/clavaddpm/train.py(3 hunks)tests/unit/models/clavaddpm/test_dataset.py(2 hunks)tests/unit/models/clavaddpm/test_dataset_transformations.py(1 hunks)tests/unit/models/clavaddpm/test_dataset_utils.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (9)
tests/unit/models/clavaddpm/test_dataset_utils.py (4)
src/midst_toolkit/common/enumerations.py (1)
DataSplit(24-27)src/midst_toolkit/common/random.py (2)
set_all_random_seeds(11-55)unset_all_random_seeds(58-67)src/midst_toolkit/models/clavaddpm/dataset_utils.py (3)
encode_and_merge_features(86-171)get_categorical_and_numerical_column_names(56-83)get_category_sizes(41-53)src/midst_toolkit/models/clavaddpm/enumerations.py (1)
IsTargetConditioned(26-51)
src/midst_toolkit/models/clavaddpm/metrics.py (2)
src/midst_toolkit/common/enumerations.py (2)
PredictionType(19-21)TaskType(4-16)src/midst_toolkit/models/clavaddpm/dataset.py (1)
calculate_metrics(231-258)
tests/unit/models/clavaddpm/test_dataset_transformations.py (4)
src/midst_toolkit/common/enumerations.py (1)
TaskType(4-16)src/midst_toolkit/common/random.py (2)
set_all_random_seeds(11-55)unset_all_random_seeds(58-67)src/midst_toolkit/models/clavaddpm/dataset_transformations.py (4)
collapse_rare_categories(136-172)encode_categorical_features(175-251)process_nans_in_categorical_features(97-133)transform_targets(254-290)src/midst_toolkit/models/clavaddpm/enumerations.py (3)
CategoricalEncoding(83-88)CategoricalNaNPolicy(77-80)TargetPolicy(91-94)
src/midst_toolkit/attacks/ensemble/clavaddpm_fine_tuning.py (1)
src/midst_toolkit/models/clavaddpm/dataset.py (3)
Dataset(68-381)Transformations(52-64)make_dataset_from_df(261-381)
src/midst_toolkit/models/clavaddpm/dataset.py (4)
src/midst_toolkit/common/enumerations.py (2)
DataSplit(24-27)TaskType(4-16)src/midst_toolkit/models/clavaddpm/dataset_transformations.py (5)
collapse_rare_categories(136-172)drop_rows_according_to_mask(72-94)normalize(34-69)process_nans_in_categorical_features(97-133)transform_targets(254-290)src/midst_toolkit/models/clavaddpm/dataset_utils.py (5)
dump_pickle(29-38)encode_and_merge_features(86-171)get_categorical_and_numerical_column_names(56-83)get_category_sizes(41-53)load_pickle(15-26)src/midst_toolkit/models/clavaddpm/enumerations.py (1)
IsTargetConditioned(26-51)
src/midst_toolkit/models/clavaddpm/dataset_utils.py (3)
src/midst_toolkit/common/enumerations.py (1)
DataSplit(24-27)src/midst_toolkit/models/clavaddpm/enumerations.py (1)
IsTargetConditioned(26-51)src/midst_toolkit/models/clavaddpm/dataset.py (1)
get_category_sizes(219-229)
src/midst_toolkit/models/clavaddpm/train.py (1)
src/midst_toolkit/models/clavaddpm/dataset.py (3)
Dataset(68-381)Transformations(52-64)make_dataset_from_df(261-381)
tests/unit/models/clavaddpm/test_dataset.py (5)
src/midst_toolkit/common/enumerations.py (1)
TaskType(4-16)src/midst_toolkit/common/random.py (2)
set_all_random_seeds(11-55)unset_all_random_seeds(58-67)src/midst_toolkit/models/clavaddpm/dataset.py (5)
Dataset(68-381)Transformations(52-64)get_cached_dataset(405-426)setup_cache_path(384-402)_load_datasets(102-122)src/midst_toolkit/models/clavaddpm/dataset_utils.py (1)
dump_pickle(29-38)src/midst_toolkit/models/clavaddpm/enumerations.py (3)
CategoricalEncoding(83-88)Normalization(62-67)NumericalNaNPolicy(70-74)
src/midst_toolkit/models/clavaddpm/dataset_transformations.py (2)
src/midst_toolkit/common/enumerations.py (2)
DataSplit(24-27)TaskType(4-16)src/midst_toolkit/models/clavaddpm/enumerations.py (4)
CategoricalEncoding(83-88)CategoricalNaNPolicy(77-80)Normalization(62-67)TargetPolicy(91-94)
🪛 Ruff (0.14.2)
src/midst_toolkit/models/clavaddpm/metrics.py
61-61: Avoid specifying long messages outside the exception class
(TRY003)
src/midst_toolkit/models/clavaddpm/dataset.py
400-400: Probable use of insecure hash functions in hashlib: md5
(S324)
426-426: Avoid specifying long messages outside the exception class
(TRY003)
src/midst_toolkit/models/clavaddpm/dataset_utils.py
26-26: pickle and modules that wrap it can be unsafe when used to deserialize untrusted data, possible security issue
(S301)
src/midst_toolkit/models/clavaddpm/dataset_transformations.py
66-66: Avoid specifying long messages outside the exception class
(TRY003)
86-86: Avoid specifying long messages outside the exception class
(TRY003)
92-92: Avoid specifying long messages outside the exception class
(TRY003)
130-130: Avoid specifying long messages outside the exception class
(TRY003)
247-247: Avoid specifying long messages outside the exception class
(TRY003)
288-288: Avoid specifying long messages outside the exception class
(TRY003)
🔇 Additional comments (2)
tests/unit/models/clavaddpm/test_dataset_transformations.py (1)
1-174: No issues foundThe new tests thoroughly exercise the transformation utilities, and I have no change requests here.
src/midst_toolkit/models/clavaddpm/dataset_transformations.py (1)
1-290: No issues foundThe shared transformation utilities look solid, and I have no requested changes here.
📝 WalkthroughWalkthroughThis pull request refactors the CLAVADDPM dataset handling architecture. It converts the standalone Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (9)
src/midst_toolkit/attacks/ensemble/clavaddpm_fine_tuning.py(3 hunks)src/midst_toolkit/models/clavaddpm/dataset.py(9 hunks)src/midst_toolkit/models/clavaddpm/dataset_transformations.py(1 hunks)src/midst_toolkit/models/clavaddpm/dataset_utils.py(1 hunks)src/midst_toolkit/models/clavaddpm/metrics.py(1 hunks)src/midst_toolkit/models/clavaddpm/train.py(3 hunks)tests/unit/models/clavaddpm/test_dataset.py(2 hunks)tests/unit/models/clavaddpm/test_dataset_transformations.py(1 hunks)tests/unit/models/clavaddpm/test_dataset_utils.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (9)
tests/unit/models/clavaddpm/test_dataset_transformations.py (4)
src/midst_toolkit/common/enumerations.py (1)
TaskType(4-16)src/midst_toolkit/common/random.py (2)
set_all_random_seeds(11-55)unset_all_random_seeds(58-67)src/midst_toolkit/models/clavaddpm/dataset_transformations.py (4)
collapse_rare_categories(136-172)encode_categorical_features(175-251)process_nans_in_categorical_features(97-133)transform_targets(254-290)src/midst_toolkit/models/clavaddpm/enumerations.py (3)
CategoricalEncoding(83-88)CategoricalNaNPolicy(77-80)TargetPolicy(91-94)
src/midst_toolkit/models/clavaddpm/metrics.py (2)
src/midst_toolkit/common/enumerations.py (2)
PredictionType(19-21)TaskType(4-16)src/midst_toolkit/models/clavaddpm/dataset.py (1)
calculate_metrics(231-258)
src/midst_toolkit/models/clavaddpm/train.py (1)
src/midst_toolkit/models/clavaddpm/dataset.py (3)
Dataset(68-381)Transformations(52-64)make_dataset_from_df(261-381)
src/midst_toolkit/models/clavaddpm/dataset_utils.py (3)
src/midst_toolkit/common/enumerations.py (1)
DataSplit(24-27)src/midst_toolkit/models/clavaddpm/enumerations.py (1)
IsTargetConditioned(26-51)src/midst_toolkit/models/clavaddpm/dataset.py (1)
get_category_sizes(219-229)
tests/unit/models/clavaddpm/test_dataset.py (4)
src/midst_toolkit/common/enumerations.py (1)
TaskType(4-16)src/midst_toolkit/models/clavaddpm/dataset.py (5)
Dataset(68-381)Transformations(52-64)get_cached_dataset(405-426)setup_cache_path(384-402)_load_datasets(102-122)src/midst_toolkit/models/clavaddpm/dataset_utils.py (1)
dump_pickle(29-38)src/midst_toolkit/models/clavaddpm/enumerations.py (3)
CategoricalEncoding(83-88)Normalization(62-67)NumericalNaNPolicy(70-74)
src/midst_toolkit/models/clavaddpm/dataset.py (2)
src/midst_toolkit/models/clavaddpm/dataset_transformations.py (5)
collapse_rare_categories(136-172)drop_rows_according_to_mask(72-94)normalize(34-69)process_nans_in_categorical_features(97-133)transform_targets(254-290)src/midst_toolkit/models/clavaddpm/dataset_utils.py (5)
dump_pickle(29-38)encode_and_merge_features(86-171)get_categorical_and_numerical_column_names(56-83)get_category_sizes(41-53)load_pickle(15-26)
tests/unit/models/clavaddpm/test_dataset_utils.py (4)
src/midst_toolkit/common/enumerations.py (1)
DataSplit(24-27)src/midst_toolkit/common/random.py (2)
set_all_random_seeds(11-55)unset_all_random_seeds(58-67)src/midst_toolkit/models/clavaddpm/dataset_utils.py (3)
encode_and_merge_features(86-171)get_categorical_and_numerical_column_names(56-83)get_category_sizes(41-53)src/midst_toolkit/models/clavaddpm/enumerations.py (1)
IsTargetConditioned(26-51)
src/midst_toolkit/models/clavaddpm/dataset_transformations.py (2)
src/midst_toolkit/common/enumerations.py (2)
DataSplit(24-27)TaskType(4-16)src/midst_toolkit/models/clavaddpm/enumerations.py (4)
CategoricalEncoding(83-88)CategoricalNaNPolicy(77-80)Normalization(62-67)TargetPolicy(91-94)
src/midst_toolkit/attacks/ensemble/clavaddpm_fine_tuning.py (1)
src/midst_toolkit/models/clavaddpm/dataset.py (3)
Dataset(68-381)Transformations(52-64)make_dataset_from_df(261-381)
🪛 Ruff (0.14.2)
src/midst_toolkit/models/clavaddpm/metrics.py
61-61: Avoid specifying long messages outside the exception class
(TRY003)
src/midst_toolkit/models/clavaddpm/dataset_utils.py
26-26: pickle and modules that wrap it can be unsafe when used to deserialize untrusted data, possible security issue
(S301)
src/midst_toolkit/models/clavaddpm/dataset.py
400-400: Probable use of insecure hash functions in hashlib: md5
(S324)
426-426: Avoid specifying long messages outside the exception class
(TRY003)
src/midst_toolkit/models/clavaddpm/dataset_transformations.py
66-66: Avoid specifying long messages outside the exception class
(TRY003)
86-86: Avoid specifying long messages outside the exception class
(TRY003)
92-92: Avoid specifying long messages outside the exception class
(TRY003)
130-130: Avoid specifying long messages outside the exception class
(TRY003)
247-247: Avoid specifying long messages outside the exception class
(TRY003)
288-288: Avoid specifying long messages outside the exception class
(TRY003)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mostly questions I have and a few suggestions. I will continue with the rest of the PR next week.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just added one or two very minor comments.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great overall! Just a few minor comments.
| info: dict[str, Any], | ||
| data_split_percentages: list[float] | None = None, | ||
| noise_scale: float = 0, | ||
| data_split_random_state: int = 42, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Super minor: data_split_random_state: int | None = None,
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a tricky one... Because it's already been like that in the code, I don't know the implications of removing it. Now that we have examples and such, I believe we can test if removing this won't break anything?
| if task_type == TaskType.BINCLASS: | ||
| result["roc_auc"] = roc_auc_score(y_true, probs) | ||
| return result | ||
| column_orders = numerical_column_names + categorical_column_names |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are we sure get_categorical_and_numerical_column_names returns numerical_column_names and categorical_column_names in the correct order? You’ve probably checked this already, but I was wondering if the order is enforced somewhere, or if we rely on the user to provide the column names in the right order in the table domain JSON.
| Optional, default is None. If not None, will check if the transformations exist in the cache directory. | ||
| If they do, will returned the cached transformed dataset. If not, will transform the dataset and cache it. | ||
| cache_dir: The directory to cache the transformed dataset. Optional, default is None. If not None, will check | ||
| if the transformations and dataset exist in the cache directory. If they do, will returned the cached |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will return
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great refactor, thanks for doing it! Approving with some minor comments.
| dataset = np.load(directory / f"{dataset_name}_{split}.npy", allow_pickle=True) | ||
| assert isinstance(dataset, np.ndarray), "Dataset must be of type Numpy Array" | ||
| datasets[split] = dataset | ||
| return datasets |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nitpicking, but this block could benefit from some line breaks. I'd add one after the raise and another one before the return.
| info: dict[str, Any], | ||
| data_split_percentages: list[float] | None = None, | ||
| noise_scale: float = 0, | ||
| data_split_random_state: int = 42, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a tricky one... Because it's already been like that in the code, I don't know the implications of removing it. Now that we have examples and such, I believe we can test if removing this won't break anything?
| return metrics | ||
|
|
||
| @staticmethod | ||
| def make_dataset_from_df( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nitpicking: the dataset in the name becomes redundant once this is turned into a Dataset function. It could be renamed to make_from_df so calling it becomes a bit cleaner: Dataset.make_from_df.
| if transformations == cache_transformations: | ||
| log(INFO, f"Using cached features: {cache_path}") | ||
| return transformed_dataset | ||
| raise RuntimeError(f"Hash collision for {cache_path}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nitpicking: add a line break before this line.
| if row_mask.ndim != 1 or row_mask.shape[0] != data.shape[0]: | ||
| raise ValueError(f"Mask for split '{split_name}' has shape {row_mask.shape}; expected ({data.shape[0]},)") | ||
| filtered_data_split[split_name] = data[row_mask] | ||
| return filtered_data_split |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This whole function could also benefit from some line breaks.
PR Type
Refactor and Testing
Short Description
Clickup Ticket(s): https://app.clickup.com/t/868ew1hny
Primarily, this PR aims to address the todos in the above ticket:
This includes a fairly sizeable refactor of the various transformations in the dataset file as part of addressing the second TODO. The dataset file was fairly overloaded with a bunch of different functionality that could be split out into various other modules, which I have done.
The main modules introduces are
dataset_utils.pyanddataset_transformations.pywhich handle general utilities and transformation functionality that is used to process and create dataset objects for training TabDDPM models among other things.Tests Added
This PR also adds a bunch of tests for various components that I touched.