Skip to content
2 changes: 1 addition & 1 deletion .github/workflows/integration_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -64,4 +64,4 @@ jobs:

- name: Install dependencies and check code
run: |
uv run pytest -m "integration_test" --log-cli-level=WARNING
uv run pytest -m "integration_test" --log-cli-level=WARNING -s -vv
12 changes: 7 additions & 5 deletions src/midst_toolkit/attacks/ensemble/clavaddpm_fine_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,9 @@ def fine_tune_model(
if len(category_sizes) == 0 or transformations.categorical_encoding == CategoricalEncoding.ONE_HOT:
category_sizes = np.array([0])

num_numerical_features = dataset.x_num[DataSplit.TRAIN.value].shape[1] if dataset.x_num is not None else 0
num_numerical_features = (
dataset.numerical_features[DataSplit.TRAIN.value].shape[1] if dataset.numerical_features is not None else 0
)

train_loader = prepare_fast_dataloader(dataset, split=DataSplit.TRAIN, batch_size=batch_size)

Expand All @@ -110,7 +112,7 @@ def fine_tune_model(
trainer = ClavaDDPMTrainer(
diffusion,
train_loader,
lr=lr,
learning_rate=lr,
weight_decay=weight_decay,
steps=steps,
device=str(device),
Expand Down Expand Up @@ -193,11 +195,11 @@ def fine_tune_classifier(
if len(category_sizes) == 0 or transformations.categorical_encoding == CategoricalEncoding.ONE_HOT:
category_sizes = np.array([0])

if dataset.x_num is None:
if dataset.numerical_features is None:
log(WARNING, "dataset.x_num is None. num_numerical_features will be set to 0")
num_numerical_features = 0
else:
num_numerical_features = dataset.x_num[DataSplit.TRAIN.value].shape[1]
num_numerical_features = dataset.numerical_features[DataSplit.TRAIN.value].shape[1]

if model_params.is_target_conditioned == IsTargetConditioned.CONCAT:
num_numerical_features -= 1
Expand Down Expand Up @@ -279,7 +281,7 @@ def child_fine_tuning(
child_info = get_table_info(child_df_with_cluster, child_domain_dict, target_col)
child_model_params = ModelParameters(
diffusion_parameters=DiffusionParameters(
d_layers=diffusion_config["d_layers"],
layers_dimensions=diffusion_config["d_layers"],
dropout=diffusion_config["dropout"],
),
)
Expand Down
22 changes: 12 additions & 10 deletions src/midst_toolkit/models/clavaddpm/data_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,23 +537,25 @@ def prepare_fast_dataloader(
Returns:
A generator of batches of data from the dataset.
"""
if dataset.x_cat is not None:
if dataset.x_num is not None:
concatenated_features = np.concatenate([dataset.x_num[split.value], dataset.x_cat[split.value]], axis=1)
x = torch.from_numpy(concatenated_features).float()
if dataset.categorical_features is not None:
if dataset.numerical_features is not None:
concatenated_features = np.concatenate(
[dataset.numerical_features[split.value], dataset.categorical_features[split.value]], axis=1
)
features = torch.from_numpy(concatenated_features).float()
else:
x = torch.from_numpy(dataset.x_cat[split.value]).float()
features = torch.from_numpy(dataset.categorical_features[split.value]).float()
else:
assert dataset.x_num is not None
x = torch.from_numpy(dataset.x_num[split.value]).float()
assert dataset.numerical_features is not None
features = torch.from_numpy(dataset.numerical_features[split.value]).float()

if target_type == TargetType.FLOAT:
y = torch.from_numpy(dataset.y[split.value]).float()
target = torch.from_numpy(dataset.target[split.value]).float()
elif target_type == TargetType.LONG:
y = torch.from_numpy(dataset.y[split.value]).long()
target = torch.from_numpy(dataset.target[split.value]).long()
else:
raise ValueError(f"Unsupported target type: {target_type}")

dataloader = FastTensorDataLoader([x, y], batch_size=batch_size, shuffle=(split == DataSplit.TRAIN))
dataloader = FastTensorDataLoader([features, target], batch_size=batch_size, shuffle=(split == DataSplit.TRAIN))
while True:
yield from dataloader
Loading