Skip to content

Conversation

@emersodb
Copy link
Collaborator

PR Type

Refactor and Testing

Short Description

Clickup Ticket(s): https://app.clickup.com/t/868ew1hny

Primarily, this PR aims to address the todos in the above ticket:

# TODO: figure out if there is a way of getting rid of the cast
# TODO consider moving all the functions below into the Dataset class

This includes a fairly sizeable refactor of the various transformations in the dataset file as part of addressing the second TODO. The dataset file was fairly overloaded with a bunch of different functionality that could be split out into various other modules, which I have done.

The main modules introduces are dataset_utils.py and dataset_transformations.py which handle general utilities and transformation functionality that is used to process and create dataset objects for training TabDDPM models among other things.

Tests Added

This PR also adds a bunch of tests for various components that I touched.

…ly sizeable refactor of the various transformations in the dataset file and adding a number of tests.
return {k: normalizer.transform(v) for k, v in x.items()}, normalizer


def process_nans_in_numerical_features(dataset: Dataset, policy: NumericalNaNPolicy | None) -> Dataset:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This guy is sticking around in the dataset, because it relies on the definition of Dataset and I didn't want to jump through circular definition hoops.

@emersodb emersodb marked this pull request as ready for review October 31, 2025 19:16
@coderabbitai
Copy link

coderabbitai bot commented Oct 31, 2025

📝 Walkthrough

Walkthrough

This PR refactors the clavaddpm dataset handling architecture by extracting and modularizing data preprocessing logic into three new utility modules (dataset_transformations.py, dataset_utils.py, metrics.py). The main dataset.py module is restructured to use these utilities and introduces a caching mechanism for transformed datasets. The make_dataset_from_df function transitions from module-level to a Dataset classmethod. The parameter data_split_ratios is renamed to data_split_percentages across multiple call sites in fine_tuning.py and train.py. Tests are updated to reflect API changes, and two new test modules are added for the new utilities.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45–75 minutes

  • src/midst_toolkit/models/clavaddpm/dataset_transformations.py: Six new transformation functions with distinct data processing logic (normalize, drop_rows_according_to_mask, process_nans_in_categorical_features, collapse_rare_categories, encode_categorical_features, transform_targets); each handles different aspects of categorical/numerical/target transformations with specific validation and error paths.
  • src/midst_toolkit/models/clavaddpm/dataset_utils.py: Five new utility functions including feature encoding/merging and category size computation; encode_and_merge_features contains conditional branching for categorical/numerical feature handling and label encoder management.
  • src/midst_toolkit/models/clavaddpm/dataset.py: Major architectural changes introducing caching logic (setup_cache_path, get_cached_dataset) and substantial refactoring of make_dataset_from_df; migration from in-file implementations to modular imports requires careful verification of equivalent behavior.
  • src/midst_toolkit/models/clavaddpm/metrics.py: New module with task-specific metric calculations; get_predicted_labels_and_probs and calculate_metrics contain conditional logic for different task types and prediction types.
  • API migration patterns: Multiple files (clavaddpm_fine_tuning.py, train.py) undergo correlated changes from module-level function calls to Dataset classmethod calls, plus parameter renaming from data_split_ratios to data_split_percentages; consistency of these changes across all call sites requires verification.
  • Test coverage: New test modules for dataset_transformations.py and dataset_utils.py need review to ensure comprehensive coverage of the new utility functions and edge cases.

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 66.67% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Title Check ✅ Passed The PR title "Dataset.py Refactors, Simplifications, and Testing" directly and clearly captures the primary changes in the changeset. The title encompasses the three main aspects of the PR: (1) refactoring dataset-related code, (2) simplifying the codebase by modularizing functionality into dedicated files (dataset_utils.py and dataset_transformations.py), and (3) adding comprehensive test coverage. The title is concise, specific, and avoids vague terminology—a reviewer scanning the history would immediately understand the core nature of these changes.
Description Check ✅ Passed The pull request description follows the required template structure with all key sections present. The PR Type section identifies the work as "Refactor and Testing" (a reasonable interpretation of the template's suggested options). The Short Description is comprehensive, providing the ClickUp ticket link, explaining the TODOs being addressed, describing the refactoring approach, and naming the new modules introduced. The Tests Added section confirms tests were added for the modified components. While the Tests Added section is relatively brief and could detail specific test files or coverage more explicitly, the description provides sufficient context to understand the PR's objectives and scope.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch dbe/more_trainer_todos

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 5104f85 and 17944bb.

📒 Files selected for processing (9)
  • src/midst_toolkit/attacks/ensemble/clavaddpm_fine_tuning.py (3 hunks)
  • src/midst_toolkit/models/clavaddpm/dataset.py (9 hunks)
  • src/midst_toolkit/models/clavaddpm/dataset_transformations.py (1 hunks)
  • src/midst_toolkit/models/clavaddpm/dataset_utils.py (1 hunks)
  • src/midst_toolkit/models/clavaddpm/metrics.py (1 hunks)
  • src/midst_toolkit/models/clavaddpm/train.py (3 hunks)
  • tests/unit/models/clavaddpm/test_dataset.py (2 hunks)
  • tests/unit/models/clavaddpm/test_dataset_transformations.py (1 hunks)
  • tests/unit/models/clavaddpm/test_dataset_utils.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (9)
tests/unit/models/clavaddpm/test_dataset_utils.py (4)
src/midst_toolkit/common/enumerations.py (1)
  • DataSplit (24-27)
src/midst_toolkit/common/random.py (2)
  • set_all_random_seeds (11-55)
  • unset_all_random_seeds (58-67)
src/midst_toolkit/models/clavaddpm/dataset_utils.py (3)
  • encode_and_merge_features (86-171)
  • get_categorical_and_numerical_column_names (56-83)
  • get_category_sizes (41-53)
src/midst_toolkit/models/clavaddpm/enumerations.py (1)
  • IsTargetConditioned (26-51)
src/midst_toolkit/models/clavaddpm/metrics.py (2)
src/midst_toolkit/common/enumerations.py (2)
  • PredictionType (19-21)
  • TaskType (4-16)
src/midst_toolkit/models/clavaddpm/dataset.py (1)
  • calculate_metrics (231-258)
tests/unit/models/clavaddpm/test_dataset_transformations.py (4)
src/midst_toolkit/common/enumerations.py (1)
  • TaskType (4-16)
src/midst_toolkit/common/random.py (2)
  • set_all_random_seeds (11-55)
  • unset_all_random_seeds (58-67)
src/midst_toolkit/models/clavaddpm/dataset_transformations.py (4)
  • collapse_rare_categories (136-172)
  • encode_categorical_features (175-251)
  • process_nans_in_categorical_features (97-133)
  • transform_targets (254-290)
src/midst_toolkit/models/clavaddpm/enumerations.py (3)
  • CategoricalEncoding (83-88)
  • CategoricalNaNPolicy (77-80)
  • TargetPolicy (91-94)
src/midst_toolkit/attacks/ensemble/clavaddpm_fine_tuning.py (1)
src/midst_toolkit/models/clavaddpm/dataset.py (3)
  • Dataset (68-381)
  • Transformations (52-64)
  • make_dataset_from_df (261-381)
src/midst_toolkit/models/clavaddpm/dataset.py (4)
src/midst_toolkit/common/enumerations.py (2)
  • DataSplit (24-27)
  • TaskType (4-16)
src/midst_toolkit/models/clavaddpm/dataset_transformations.py (5)
  • collapse_rare_categories (136-172)
  • drop_rows_according_to_mask (72-94)
  • normalize (34-69)
  • process_nans_in_categorical_features (97-133)
  • transform_targets (254-290)
src/midst_toolkit/models/clavaddpm/dataset_utils.py (5)
  • dump_pickle (29-38)
  • encode_and_merge_features (86-171)
  • get_categorical_and_numerical_column_names (56-83)
  • get_category_sizes (41-53)
  • load_pickle (15-26)
src/midst_toolkit/models/clavaddpm/enumerations.py (1)
  • IsTargetConditioned (26-51)
src/midst_toolkit/models/clavaddpm/dataset_utils.py (3)
src/midst_toolkit/common/enumerations.py (1)
  • DataSplit (24-27)
src/midst_toolkit/models/clavaddpm/enumerations.py (1)
  • IsTargetConditioned (26-51)
src/midst_toolkit/models/clavaddpm/dataset.py (1)
  • get_category_sizes (219-229)
src/midst_toolkit/models/clavaddpm/train.py (1)
src/midst_toolkit/models/clavaddpm/dataset.py (3)
  • Dataset (68-381)
  • Transformations (52-64)
  • make_dataset_from_df (261-381)
tests/unit/models/clavaddpm/test_dataset.py (5)
src/midst_toolkit/common/enumerations.py (1)
  • TaskType (4-16)
src/midst_toolkit/common/random.py (2)
  • set_all_random_seeds (11-55)
  • unset_all_random_seeds (58-67)
src/midst_toolkit/models/clavaddpm/dataset.py (5)
  • Dataset (68-381)
  • Transformations (52-64)
  • get_cached_dataset (405-426)
  • setup_cache_path (384-402)
  • _load_datasets (102-122)
src/midst_toolkit/models/clavaddpm/dataset_utils.py (1)
  • dump_pickle (29-38)
src/midst_toolkit/models/clavaddpm/enumerations.py (3)
  • CategoricalEncoding (83-88)
  • Normalization (62-67)
  • NumericalNaNPolicy (70-74)
src/midst_toolkit/models/clavaddpm/dataset_transformations.py (2)
src/midst_toolkit/common/enumerations.py (2)
  • DataSplit (24-27)
  • TaskType (4-16)
src/midst_toolkit/models/clavaddpm/enumerations.py (4)
  • CategoricalEncoding (83-88)
  • CategoricalNaNPolicy (77-80)
  • Normalization (62-67)
  • TargetPolicy (91-94)
🪛 Ruff (0.14.2)
src/midst_toolkit/models/clavaddpm/metrics.py

61-61: Avoid specifying long messages outside the exception class

(TRY003)

src/midst_toolkit/models/clavaddpm/dataset.py

400-400: Probable use of insecure hash functions in hashlib: md5

(S324)


426-426: Avoid specifying long messages outside the exception class

(TRY003)

src/midst_toolkit/models/clavaddpm/dataset_utils.py

26-26: pickle and modules that wrap it can be unsafe when used to deserialize untrusted data, possible security issue

(S301)

src/midst_toolkit/models/clavaddpm/dataset_transformations.py

66-66: Avoid specifying long messages outside the exception class

(TRY003)


86-86: Avoid specifying long messages outside the exception class

(TRY003)


92-92: Avoid specifying long messages outside the exception class

(TRY003)


130-130: Avoid specifying long messages outside the exception class

(TRY003)


247-247: Avoid specifying long messages outside the exception class

(TRY003)


288-288: Avoid specifying long messages outside the exception class

(TRY003)

🔇 Additional comments (2)
tests/unit/models/clavaddpm/test_dataset_transformations.py (1)

1-174: No issues found

The new tests thoroughly exercise the transformation utilities, and I have no change requests here.

src/midst_toolkit/models/clavaddpm/dataset_transformations.py (1)

1-290: No issues found

The shared transformation utilities look solid, and I have no requested changes here.

@coderabbitai
Copy link

coderabbitai bot commented Oct 31, 2025

📝 Walkthrough

Walkthrough

This pull request refactors the CLAVADDPM dataset handling architecture. It converts the standalone make_dataset_from_df function into a static method on the Dataset class, renames the data_split_ratios parameter to data_split_percentages across call sites, and introduces caching infrastructure for dataset transformations. Three new modules are added: dataset_transformations.py (providing normalization, row dropping, NaN handling, rare category collapsing, categorical encoding, and target transformation utilities), dataset_utils.py (providing pickle I/O, feature analysis, and encoding/merging functions), and metrics.py (providing RMSE, label/probability conversion, and task-specific metric calculation). Existing dataset.py is extended with setup_cache_path and get_cached_dataset helpers. All import statements and invocations in train.py and clavaddpm_fine_tuning.py are updated to reflect the API changes. Test modules are updated to reflect the new public exports and added with coverage for the new functionality.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

  • src/midst_toolkit/models/clavaddpm/dataset_transformations.py: Six transformation functions with complex branching logic across multiple encoding types (ORDINAL, ONE_HOT, COUNTER), normalization strategies (STANDARD, MINMAX, QUANTILE), and conditional transformations (NaN handling, rare category collapsing). Each encoding path has distinct error handling and return type logic requiring separate reasoning.
  • src/midst_toolkit/models/clavaddpm/dataset.py: Substantial additions including the new static method make_dataset_from_df with DataFrame validation and split generation, plus caching infrastructure (setup_cache_path, get_cached_dataset) with hash-based filename resolution. Return type annotation changes require verification across the codebase.
  • src/midst_toolkit/models/clavaddpm/dataset_utils.py: Five utility functions with diverse responsibilities (pickle I/O, feature category analysis, categorical/numerical column extraction, encoding/merging) each requiring separate validation logic and edge case handling.
  • API surface changes: The relocation of make_dataset_from_df from a module-level function to a class method, combined with parameter renaming, requires verification across multiple call sites (train.py, clavaddpm_fine_tuning.py).
  • Test coverage: New test modules introduce assertions on exact numeric outputs for transformations and encodings, requiring careful verification of expected values and behavior across different configuration combinations.

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 66.67% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Title Check ✅ Passed The title "Dataset.py Refactors, Simplifications, and Testing" is clearly related to the main changes in this pull request. The changeset involves significant refactoring of the dataset module by extracting functionality into separate modules (dataset_utils.py and dataset_transformations.py), simplifying the original dataset.py file, and adding comprehensive tests. The title accurately captures these three key aspects of the work and would give a teammate scanning the history a good sense that this involves reorganizing and improving dataset-related code.
Description Check ✅ Passed The pull request description follows the required template structure with all major sections present: PR Type, Short Description with a ClickUp ticket link, and Tests Added. The description clearly explains the motivation (addressing specific TODOs about removing casts and moving functions), the main refactoring work (splitting dataset.py into dataset_utils.py and dataset_transformations.py), and confirms tests were added. The PR Type value "Refactor and Testing" deviates slightly from the specified options [Feature | Fix | Documentation | Other], and the Tests Added section is brief, but the overall description adequately conveys the intent and scope of the changes.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch dbe/more_trainer_todos

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 5104f85 and 17944bb.

📒 Files selected for processing (9)
  • src/midst_toolkit/attacks/ensemble/clavaddpm_fine_tuning.py (3 hunks)
  • src/midst_toolkit/models/clavaddpm/dataset.py (9 hunks)
  • src/midst_toolkit/models/clavaddpm/dataset_transformations.py (1 hunks)
  • src/midst_toolkit/models/clavaddpm/dataset_utils.py (1 hunks)
  • src/midst_toolkit/models/clavaddpm/metrics.py (1 hunks)
  • src/midst_toolkit/models/clavaddpm/train.py (3 hunks)
  • tests/unit/models/clavaddpm/test_dataset.py (2 hunks)
  • tests/unit/models/clavaddpm/test_dataset_transformations.py (1 hunks)
  • tests/unit/models/clavaddpm/test_dataset_utils.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (9)
tests/unit/models/clavaddpm/test_dataset_transformations.py (4)
src/midst_toolkit/common/enumerations.py (1)
  • TaskType (4-16)
src/midst_toolkit/common/random.py (2)
  • set_all_random_seeds (11-55)
  • unset_all_random_seeds (58-67)
src/midst_toolkit/models/clavaddpm/dataset_transformations.py (4)
  • collapse_rare_categories (136-172)
  • encode_categorical_features (175-251)
  • process_nans_in_categorical_features (97-133)
  • transform_targets (254-290)
src/midst_toolkit/models/clavaddpm/enumerations.py (3)
  • CategoricalEncoding (83-88)
  • CategoricalNaNPolicy (77-80)
  • TargetPolicy (91-94)
src/midst_toolkit/models/clavaddpm/metrics.py (2)
src/midst_toolkit/common/enumerations.py (2)
  • PredictionType (19-21)
  • TaskType (4-16)
src/midst_toolkit/models/clavaddpm/dataset.py (1)
  • calculate_metrics (231-258)
src/midst_toolkit/models/clavaddpm/train.py (1)
src/midst_toolkit/models/clavaddpm/dataset.py (3)
  • Dataset (68-381)
  • Transformations (52-64)
  • make_dataset_from_df (261-381)
src/midst_toolkit/models/clavaddpm/dataset_utils.py (3)
src/midst_toolkit/common/enumerations.py (1)
  • DataSplit (24-27)
src/midst_toolkit/models/clavaddpm/enumerations.py (1)
  • IsTargetConditioned (26-51)
src/midst_toolkit/models/clavaddpm/dataset.py (1)
  • get_category_sizes (219-229)
tests/unit/models/clavaddpm/test_dataset.py (4)
src/midst_toolkit/common/enumerations.py (1)
  • TaskType (4-16)
src/midst_toolkit/models/clavaddpm/dataset.py (5)
  • Dataset (68-381)
  • Transformations (52-64)
  • get_cached_dataset (405-426)
  • setup_cache_path (384-402)
  • _load_datasets (102-122)
src/midst_toolkit/models/clavaddpm/dataset_utils.py (1)
  • dump_pickle (29-38)
src/midst_toolkit/models/clavaddpm/enumerations.py (3)
  • CategoricalEncoding (83-88)
  • Normalization (62-67)
  • NumericalNaNPolicy (70-74)
src/midst_toolkit/models/clavaddpm/dataset.py (2)
src/midst_toolkit/models/clavaddpm/dataset_transformations.py (5)
  • collapse_rare_categories (136-172)
  • drop_rows_according_to_mask (72-94)
  • normalize (34-69)
  • process_nans_in_categorical_features (97-133)
  • transform_targets (254-290)
src/midst_toolkit/models/clavaddpm/dataset_utils.py (5)
  • dump_pickle (29-38)
  • encode_and_merge_features (86-171)
  • get_categorical_and_numerical_column_names (56-83)
  • get_category_sizes (41-53)
  • load_pickle (15-26)
tests/unit/models/clavaddpm/test_dataset_utils.py (4)
src/midst_toolkit/common/enumerations.py (1)
  • DataSplit (24-27)
src/midst_toolkit/common/random.py (2)
  • set_all_random_seeds (11-55)
  • unset_all_random_seeds (58-67)
src/midst_toolkit/models/clavaddpm/dataset_utils.py (3)
  • encode_and_merge_features (86-171)
  • get_categorical_and_numerical_column_names (56-83)
  • get_category_sizes (41-53)
src/midst_toolkit/models/clavaddpm/enumerations.py (1)
  • IsTargetConditioned (26-51)
src/midst_toolkit/models/clavaddpm/dataset_transformations.py (2)
src/midst_toolkit/common/enumerations.py (2)
  • DataSplit (24-27)
  • TaskType (4-16)
src/midst_toolkit/models/clavaddpm/enumerations.py (4)
  • CategoricalEncoding (83-88)
  • CategoricalNaNPolicy (77-80)
  • Normalization (62-67)
  • TargetPolicy (91-94)
src/midst_toolkit/attacks/ensemble/clavaddpm_fine_tuning.py (1)
src/midst_toolkit/models/clavaddpm/dataset.py (3)
  • Dataset (68-381)
  • Transformations (52-64)
  • make_dataset_from_df (261-381)
🪛 Ruff (0.14.2)
src/midst_toolkit/models/clavaddpm/metrics.py

61-61: Avoid specifying long messages outside the exception class

(TRY003)

src/midst_toolkit/models/clavaddpm/dataset_utils.py

26-26: pickle and modules that wrap it can be unsafe when used to deserialize untrusted data, possible security issue

(S301)

src/midst_toolkit/models/clavaddpm/dataset.py

400-400: Probable use of insecure hash functions in hashlib: md5

(S324)


426-426: Avoid specifying long messages outside the exception class

(TRY003)

src/midst_toolkit/models/clavaddpm/dataset_transformations.py

66-66: Avoid specifying long messages outside the exception class

(TRY003)


86-86: Avoid specifying long messages outside the exception class

(TRY003)


92-92: Avoid specifying long messages outside the exception class

(TRY003)


130-130: Avoid specifying long messages outside the exception class

(TRY003)


247-247: Avoid specifying long messages outside the exception class

(TRY003)


288-288: Avoid specifying long messages outside the exception class

(TRY003)

Copy link
Collaborator

@bzamanlooy bzamanlooy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly questions I have and a few suggestions. I will continue with the rest of the PR next week.

Copy link
Collaborator

@bzamanlooy bzamanlooy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just added one or two very minor comments.

@emersodb emersodb requested a review from bzamanlooy November 4, 2025 19:06
Copy link
Collaborator

@fatemetkl fatemetkl left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great overall! Just a few minor comments.

info: dict[str, Any],
data_split_percentages: list[float] | None = None,
noise_scale: float = 0,
data_split_random_state: int = 42,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Super minor: data_split_random_state: int | None = None,

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a tricky one... Because it's already been like that in the code, I don't know the implications of removing it. Now that we have examples and such, I believe we can test if removing this won't break anything?

if task_type == TaskType.BINCLASS:
result["roc_auc"] = roc_auc_score(y_true, probs)
return result
column_orders = numerical_column_names + categorical_column_names
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we sure get_categorical_and_numerical_column_names returns numerical_column_names and categorical_column_names in the correct order? You’ve probably checked this already, but I was wondering if the order is enforced somewhere, or if we rely on the user to provide the column names in the right order in the table domain JSON.

Optional, default is None. If not None, will check if the transformations exist in the cache directory.
If they do, will returned the cached transformed dataset. If not, will transform the dataset and cache it.
cache_dir: The directory to cache the transformed dataset. Optional, default is None. If not None, will check
if the transformations and dataset exist in the cache directory. If they do, will returned the cached
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will return

Copy link
Collaborator

@lotif lotif left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great refactor, thanks for doing it! Approving with some minor comments.

dataset = np.load(directory / f"{dataset_name}_{split}.npy", allow_pickle=True)
assert isinstance(dataset, np.ndarray), "Dataset must be of type Numpy Array"
datasets[split] = dataset
return datasets
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpicking, but this block could benefit from some line breaks. I'd add one after the raise and another one before the return.

info: dict[str, Any],
data_split_percentages: list[float] | None = None,
noise_scale: float = 0,
data_split_random_state: int = 42,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a tricky one... Because it's already been like that in the code, I don't know the implications of removing it. Now that we have examples and such, I believe we can test if removing this won't break anything?

return metrics

@staticmethod
def make_dataset_from_df(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpicking: the dataset in the name becomes redundant once this is turned into a Dataset function. It could be renamed to make_from_df so calling it becomes a bit cleaner: Dataset.make_from_df.

if transformations == cache_transformations:
log(INFO, f"Using cached features: {cache_path}")
return transformed_dataset
raise RuntimeError(f"Hash collision for {cache_path}")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpicking: add a line break before this line.

if row_mask.ndim != 1 or row_mask.shape[0] != data.shape[0]:
raise ValueError(f"Mask for split '{split_name}' has shape {row_mask.shape}; expected ({data.shape[0]},)")
filtered_data_split[split_name] = data[row_mask]
return filtered_data_split
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This whole function could also benefit from some line breaks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants