Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
70 commits
Select commit Hold shift + click to select a range
8b7a781
Data processing and shadow models
fatemetkl Aug 21, 2025
71178c0
Merged main
fatemetkl Aug 21, 2025
5e4e3d9
mypy fixes
fatemetkl Aug 21, 2025
a47fe95
Merged main
fatemetkl Sep 16, 2025
5ab4ade
Removed an old code
fatemetkl Sep 16, 2025
6166728
A working example
fatemetkl Sep 23, 2025
dbea030
Merged main
fatemetkl Sep 23, 2025
e5a3263
small fixes
fatemetkl Sep 23, 2025
3ce5786
Merged main
fatemetkl Sep 25, 2025
d7977b3
Improved the shadow training pipeline
fatemetkl Sep 26, 2025
7f7d120
Small fixes
fatemetkl Sep 26, 2025
5ac0172
Merge branch 'main' into ft/shadow_models
fatemetkl Sep 26, 2025
822e0a3
small change
fatemetkl Sep 29, 2025
8e410b6
Merge branch 'main' into ft/shadow_models
fatemetkl Sep 29, 2025
c2a5115
Small improvements
fatemetkl Sep 30, 2025
260f1cb
Added tests, fixed mypy and ruff errors
fatemetkl Oct 2, 2025
85c43c3
Merged main into branch, and addressed conflicts, refactors
fatemetkl Oct 3, 2025
3de5e07
Fixed mypy errors
fatemetkl Oct 3, 2025
2772609
Small fixes
fatemetkl Oct 3, 2025
0bd7849
Added more to the docstrings and comments
fatemetkl Oct 3, 2025
3d95d9f
Merged main, addressed conflicts and some fixes
fatemetkl Oct 9, 2025
1afc8e4
Small fix
fatemetkl Oct 9, 2025
fc2a772
Addressed Marcelo's comments part 1
fatemetkl Oct 9, 2025
65234a7
Sync shadow model data with blendingplusplus
fatemetkl Oct 10, 2025
3f22c21
Unified attack's load_multi_table, and the one currently in our codebase
fatemetkl Oct 10, 2025
a0b5122
Addressed Marcelo's comments part 2
fatemetkl Oct 10, 2025
945658e
Fixed segmentation fault error due to dependencies by seperating impo…
fatemetkl Oct 14, 2025
b70e818
Initial RMIA structure
sarakodeiri Oct 14, 2025
c5d9a0a
Merge branch 'main' into ft/shadow_models
fatemetkl Oct 14, 2025
85873c5
Seperated metaclassifier and shadow pipeline scripts to fix segmentat…
fatemetkl Oct 14, 2025
f2de6dc
Manually add .pkl files
sarakodeiri Oct 14, 2025
ef7c6ee
Small fixes
fatemetkl Oct 14, 2025
812a54e
Directory naming fix
fatemetkl Oct 14, 2025
1151622
Final set of Marcelo's comments
fatemetkl Oct 15, 2025
fc5a6aa
Implement most of RMIA
sarakodeiri Oct 15, 2025
8e67448
Move all RMIA code
sarakodeiri Oct 15, 2025
2211a00
Merge branch 'ft/shadow_models' into sk/rmia
sarakodeiri Oct 15, 2025
cc49d04
Finalize RMIA and cleanup
sarakodeiri Oct 16, 2025
7c6981a
Merge branch 'main' into sk/rmia
sarakodeiri Oct 16, 2025
7b8f2f5
Resolve conflicts
sarakodeiri Oct 16, 2025
79c26b0
Small fix
sarakodeiri Oct 16, 2025
a17bfd4
Add tests
sarakodeiri Oct 16, 2025
91df09d
Ruff fix
sarakodeiri Oct 16, 2025
895dda0
Merge branch 'main' into sk/rmia
sarakodeiri Oct 16, 2025
9798a0f
Merge branch 'main' into sk/rmia
emersodb Oct 17, 2025
0c35030
First draft of target model training
sarakodeiri Oct 17, 2025
ca49fdc
Merge branch 'sk/rmia' into target_model
sarakodeiri Oct 17, 2025
fa0fb01
Address David's comments + minor changes
sarakodeiri Oct 22, 2025
972f4ee
Merge branch 'main' into sk/rmia
sarakodeiri Oct 22, 2025
ee1fb7a
Merge branch 'sk/rmia' into target_model
sarakodeiri Oct 22, 2025
8fa101c
Merge branch 'main' into sk/rmia
emersodb Oct 22, 2025
507a452
Merge branch 'main' into sk/rmia
emersodb Oct 22, 2025
77fa27e
Finalize proper target model train
sarakodeiri Oct 22, 2025
396b9c6
Merge branch 'target_model' into sk/rmia
sarakodeiri Oct 22, 2025
8f7a6dc
Minor ruff fix
sarakodeiri Oct 22, 2025
6cb82f0
Second round of David's comments
sarakodeiri Oct 24, 2025
c51af3c
Minor fixes
sarakodeiri Oct 24, 2025
0f8c851
Merge branch 'main' into sk/rmia
sarakodeiri Oct 24, 2025
90b09fe
Merge branch 'main' into sk/rmia
lotif Oct 27, 2025
b00d22b
Merge branch 'main' into sk/rmia
sarakodeiri Oct 29, 2025
a9a72b4
Addressed Fatemeh's comments
sarakodeiri Nov 5, 2025
a4f542f
Added some tests
sarakodeiri Nov 5, 2025
9706404
Finalized tests + Minor fixes
sarakodeiri Nov 5, 2025
0d1cb64
mypy fixes
sarakodeiri Nov 5, 2025
4e406fe
Merge branch 'main' into sk/rmia
sarakodeiri Nov 5, 2025
0bf4fef
Fix in test
sarakodeiri Nov 6, 2025
7f6b5b2
Changed backticks to double backticks
sarakodeiri Nov 7, 2025
88a6d4b
Merge branch 'main' into sk/rmia
sarakodeiri Nov 7, 2025
6fc757a
Docstring fix
sarakodeiri Nov 7, 2025
5cec444
Merge branch 'sk/rmia' of https://github.com/VectorInstitute/midst-to…
sarakodeiri Nov 7, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ wheels/
# Synthcity backups
**/workspace/*.bkp

# Dataset files
# Data files
examples/**/data/

# Trained metaclassifiers
Expand Down
25 changes: 19 additions & 6 deletions examples/ensemble_attack/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,13 @@ data_paths:
attack_results_path: ${base_example_dir}/attack_results # Path where the attack results will be stored

model_paths:
shadow_models_path: ${base_example_dir}/shadow_models # Path where the shadow models are stored
metaclassifier_model_path: ${base_example_dir}/trained_models # Path where the trained metaclassifier model will be saved

# Pipeline control
pipeline:
run_data_processing: false # Set this to false if you have already saved the processed data
run_shadow_model_training: true
run_metaclassifier_training: false
run_shadow_model_training: false # Set this to false if shadow models are already trained and saved
run_metaclassifier_training: true


# Dataset specific information used for processing in this example
Expand Down Expand Up @@ -54,7 +53,18 @@ shadow_training:
tabddpm_training_config_path: ${base_example_dir}/data_configs/trans.json
# Model training artifacts are saved under shadow_models_data_path/workspace_name/exp_name
# Also, training configs for each shadow model are created under shadow_models_data_path.
shadow_models_output_path: ${base_data_dir}/shadow_models_data
shadow_models_output_path: ${base_data_dir}/shadow_models_and_data
target_model_output_path: ${base_data_dir}/target_model_and_data
final_shadow_models_path: [
"${shadow_training.shadow_models_output_path}/initial_model_rmia_1/shadow_workspace/pre_trained_model/rmia_shadows.pkl",
"${shadow_training.shadow_models_output_path}/initial_model_rmia_2/shadow_workspace/pre_trained_model/rmia_shadows.pkl",
"${shadow_training.shadow_models_output_path}/shadow_model_rmia_third_set/shadow_workspace/trained_model/rmia_shadows_third_set.pkl",
] # Paths to final shadow models used for metaclassifier training (relative to shadow_models_output_path)
# These paths are a result of running the shadow model training pipeline, specifically the
# train_three_sets_of_shadow_models in shadow_model_training.py
# Each .pkl file contains the training data, trained model and training results for all shadow models in a list.
final_target_model_path: ${shadow_training.target_model_output_path}/target_model/shadow_workspace/trained_target_model/target_model.pkl
# Path to final target model (relative to target_model_output_path)
fine_tuning_config:
fine_tune_diffusion_iterations: 2
fine_tune_classifier_iterations: 2
Expand All @@ -66,10 +76,13 @@ metaclassifier:
# Data types json file is used for xgboost model training.
data_types_file_path: ${base_example_dir}/data_configs/data_types.json
model_type: "xgb"
use_gpu: true
# Model training parameters
num_optuna_trials: 10 # Original code: 100
num_kfolds: 5
use_gpu: false
# Temporary. Might remove having an epoch parameter.
epochs: 1


# General settings
random_seed: 42
random_seed: 42 # Set to null for no seed, or an integer for a fixed seed
3 changes: 2 additions & 1 deletion examples/ensemble_attack/data_configs/data_types.json
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
{
"numerical": ["trans_date", "amount", "balance", "account"],
"categorical": ["trans_type", "operation", "k_symbol", "bank"],
"variable_to_predict": "trans_type"
"variable_to_predict": "trans_type",
"id_column_name": "trans_id"
}
19 changes: 17 additions & 2 deletions examples/ensemble_attack/run_attack.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,25 @@ def main(config: DictConfig) -> None:
# TODO: Investigate the source of error.
if config.pipeline.run_shadow_model_training:
shadow_pipeline = importlib.import_module("examples.ensemble_attack.run_shadow_model_training")
shadow_pipeline.run_shadow_model_training(config)
attack_data_paths = shadow_pipeline.run_shadow_model_training(config)
attack_data_paths = [Path(path) for path in attack_data_paths]

target_data_path = shadow_pipeline.run_target_model_training(config)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This tightly couples shadow model training with target model training. Based on our discussion, the target model is the one we're attacking right? Theoretically, this model may already exist and we just want to attack it. I.e. we may not always want to or be able to train it? Again, I may still be misunderstanding our vocabulary here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From my understanding, the target model is the model we're attacking in the simulated setting, and is a shadow model. The main difference between the target model and the other shadow models is the data it's being trained on. I might be wrong, but I think the target model is being trained on the entire "real data" while the other shadow models are being trained/fine-tuned on different combinations and subsets of the population data. The docstring on train_three_sets_of_shadow_models in midst_toolkit.attacks.ensemble.rmia.shadow_model_training explains it in more detail.
Our vocabulary and the original implementation's vocabulary isn't the most simple and I'm still not sure if I've done the right thing here. I want to keep the PR open a bit longer to get more eyes on it and the concepts.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure sounds good. Like I said, I'm not as deeply integrated in the vocabulary. So this may be perfectly reasonable. Just wanted to ask the question.

target_data_path = Path(target_data_path)

if config.pipeline.run_metaclassifier_training:
if not config.pipeline.run_shadow_model_training:
# If shadow model training is skipped, we need to provide the previous shadow model and target model paths.

shadow_data_paths = [Path(path) for path in config.shadow_training.final_shadow_models_path]

target_data_path = Path(config.shadow_training.final_target_model_path)

assert len(shadow_data_paths) == 3, "The attack_data_paths list must contain exactly three elements."
assert target_data_path is not None, "The target_data_path must be provided for metaclassifier training."

meta_pipeline = importlib.import_module("examples.ensemble_attack.run_metaclassifier_training")
meta_pipeline.run_metaclassifier_training(config)
meta_pipeline.run_metaclassifier_training(config, shadow_data_paths, target_data_path)


if __name__ == "__main__":
Expand Down
87 changes: 65 additions & 22 deletions examples/ensemble_attack/run_metaclassifier_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,30 @@
from midst_toolkit.common.logger import log


def run_metaclassifier_training(config: DictConfig) -> None:
def run_metaclassifier_training(
config: DictConfig,
shadow_data_paths: list[Path],
target_data_path: Path,
) -> None:
"""
Fuction to run the metaclassifier training and evaluation.

Args:
config: Configuration object set in config.yaml.
shadow_data_paths: List of paths to the trained shadow models and all their attributes and synthetic data.
The list should contain three paths, one for each set of shadow models.
target_data_path: Path to the target model and all its attributes and synthetic data.
"""
log(INFO, "Running metaclassifier training...")

# Load the processed data splits.
df_meta_train = load_dataframe(
Path(config.data_paths.processed_attack_data_path),
"master_challenge_train.csv",
)

# y_meta_train consists of binary labels (0s and 1s) indicating whether each row in df_meta_train
# belongs to the target model's training set.
y_meta_train = np.load(
Path(config.data_paths.processed_attack_data_path) / "master_challenge_train_labels.npy",
)
Expand All @@ -35,69 +46,101 @@ def run_metaclassifier_training(config: DictConfig) -> None:
Path(config.data_paths.processed_attack_data_path) / "master_challenge_test_labels.npy",
)

# Synthetic data borrowed from the attack implementation repository.
# From (https://github.com/CRCHUM-CITADEL/ensemble-mia/tree/main/input/tabddpm_black_box/meta_classifier)
# TODO: Change this file path to the path where the synthetic data is stored.
df_synthetic = load_dataframe(
Path(config.data_paths.processed_attack_data_path),
"synth.csv",
# Three sets of shadow models are trained separately and their paths are provided here.

assert len(shadow_data_paths) == 3, (
"At this point of development, the shadow_data_paths list must contain exactly three elements."
)

shadow_data_collection = []

for model_path in shadow_data_paths:
assert model_path.exists(), (
f"No file found at {model_path}. Make sure the path is correct, or run shadow model training first."
)

with open(model_path, "rb") as f:
shadow_data_and_result = pickle.load(f)
shadow_data_collection.append(shadow_data_and_result)

assert target_data_path.exists(), (
f"No file found at {target_data_path}. Make sure the path is correct and that you have trained the target model."
)

with open(target_data_path, "rb") as f:
target_data_and_result = pickle.load(f)

target_synthetic = target_data_and_result["trained_results"][0].synthetic_data
assert target_synthetic is not None, "Target model pickle missing synthetic_data."
target_synthetic = target_synthetic.copy()

df_reference = load_dataframe(
Path(config.data_paths.population_path),
"population_all_with_challenge_no_id.csv",
)
# We should drop the id column from master metaclassifier train data.
if "trans_id" in df_meta_train.columns:
df_meta_train = df_meta_train.drop(columns=["trans_id", "account_id"])
if "trans_id" in df_meta_test.columns:
df_meta_test = df_meta_test.drop(columns=["trans_id", "account_id"])

# Extract trans_id from both train and test dataframes
assert "trans_id" in df_meta_train.columns, "Meta train data must have trans_id column"
train_trans_ids = df_meta_train["trans_id"]

assert "trans_id" in df_meta_test.columns, "Meta test data must have trans_id column"
test_trans_ids = df_meta_test["trans_id"]

df_meta_train = df_meta_train.drop(columns=["trans_id", "account_id"])
df_meta_test = df_meta_test.drop(columns=["trans_id", "account_id"])

# Fit the metaclassifier.
meta_classifier_enum = MetaClassifierType(config.metaclassifier.model_type)

# 1. Initialize the attacker
blending_attacker = BlendingPlusPlus(
config=config,
shadow_data_collection=shadow_data_collection,
target_data=target_data_and_result,
meta_classifier_type=meta_classifier_enum,
random_seed=config.random_seed,
)
log(
INFO,
f"{meta_classifier_enum} created with random seed {config.random_seed}, starting training...",
)

log(INFO, f"{meta_classifier_enum} created with random seed {config.random_seed}.")

# 2. Train the attacker on the meta-train set

blending_attacker.fit(
df_train=df_meta_train,
y_train=y_meta_train,
df_synthetic=df_synthetic,
df_target_synthetic=target_synthetic,
df_reference=df_reference,
id_column_data=train_trans_ids,
use_gpu=config.metaclassifier.use_gpu,
epochs=config.metaclassifier.epochs,
)

log(INFO, "Metaclassifier training finished.")

timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
# TODO: Create the directory folder if it does not exist.
model_filename = f"{timestamp}_{config.metaclassifier.model_type}_trained_metaclassifier.pkl"
with open(Path(config.model_paths.metaclassifier_model_path) / model_filename, "wb") as f:
pickle.dump(blending_attacker.trained_model, f)

log(INFO, "Metaclassifier model saved, starting evaluation...")

# Get the synthetic data provided by the challenge for evaluation
# TODO: Check if the file is the correct one.
df_synthetic_original = load_dataframe(
Path(config.data_paths.processed_attack_data_path),
"synth.csv",
)
Comment on lines +125 to +130
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

Address or remove the TODO comment about file correctness.

The TODO suggests uncertainty about whether synth.csv is the correct synthetic data for evaluation. Since training uses df_synthetic from the target model (line 72), while evaluation uses df_synthetic_original from this CSV, please verify this is the intended behavior. If different synthetic datasets are required for training vs. evaluation, the comment should clarify the rationale; otherwise, the inconsistency should be resolved.

Run the following script to understand the relationship between the target model's synthetic data and synth.csv:

Do you want me to help clarify the intended usage pattern for these two synthetic datasets, or would you like me to open an issue to track resolving this TODO?


🏁 Script executed:

#!/bin/bash
# Check for documentation or comments explaining the two synthetic datasets
rg -n "synth.csv|synthetic.*evaluation|df_synthetic_original" --type=py -A 3 -B 3

# Look for configuration or documentation about evaluation data
rg -n "processed_attack_data_path|synth" --type=yaml

Length of output: 2267


🏁 Script executed:

#!/bin/bash
# Find where df_synthetic is loaded (mentioned in line 72)
rg -n "df_synthetic\s*=" --type=py -B 5 -A 5 examples/ensemble_attack/run_metaclassifier_training.py | head -60

# Look for how df_synthetic is used in fit call
rg -n "\.fit\(" --type=py -B 5 -A 5 examples/ensemble_attack/run_metaclassifier_training.py

# Check for any documentation comments around training vs evaluation data
rg -n "training|evaluation" --type=py -B 2 -A 2 examples/ensemble_attack/run_metaclassifier_training.py | head -40

Length of output: 287



---

🏁 Script executed:

```shell
#!/bin/bash
# Find the file first
fd "run_metaclassifier_training.py" --type f

# Also check what files exist in examples directory
fd --type f --max-depth 3 examples/ 2>/dev/null | head -30

Length of output: 55



---

🏁 Script executed:

```shell
#!/bin/bash
# View lines around 72 to see df_synthetic loading
sed -n '65,85p' examples/ensemble_attack/run_metaclassifier_training.py

# View the fit call around line 100-110
sed -n '100,115p' examples/ensemble_attack/run_metaclassifier_training.py

# View the full context around lines 122-136 (synthetic loading and predict)
sed -n '118,140p' examples/ensemble_attack/run_metaclassifier_training.py

Length of output: 2085



---

🏁 Script executed:

```shell
#!/bin/bash
# Search for documentation about synth.csv
rg -n "synth.csv" --type=py --type=md -B 2 -A 2

# Look for README or documentation files in ensemble_attack
fd "README|readme|INSTRUCTIONS|instructions" examples/ensemble_attack/ --type f

# Check for test files that might explain the pattern
fd "test.*metaclassifier" --type f

Length of output: 466



---

🏁 Script executed:

```shell
#!/bin/bash
# View the README.md file
cat examples/ensemble_attack/README.md

Length of output: 5238



---

🏁 Script executed:

```shell
#!/bin/bash
# View the config.yaml to understand data paths
cat examples/ensemble_attack/config.yaml

# Search for any comments or documentation about synth.csv in config files
rg -n "synth" --type=yaml -B 2 -A 2

# Look at run_attack.py or data processing scripts to see where synth.csv comes from
fd "run_attack|data_process" examples/ensemble_attack/ --type f

# Search for synth.csv file creation or references in Python files
rg -n "synth\.csv|synth.csv" --type=py -B 3 -A 3 | head -40

Length of output: 6316



---

🏁 Script executed:

```shell
#!/bin/bash
# View run_attack.py to understand data processing and synth.csv
cat examples/ensemble_attack/run_attack.py | head -200

# Search for synth.csv creation or handling in run_attack.py
rg -n "synth" examples/ensemble_attack/run_attack.py -B 3 -A 3

Length of output: 3909



---

🏁 Script executed:

```shell
#!/bin/bash
# Find the process_split_data module
fd "process_split_data.py" --type f

# Once found, examine its implementation to see what files it creates
rg -n "synth|csv" --type=py -B 2 -A 2 | grep -A 10 -B 10 "process_split_data"

Length of output: 12865



---

🏁 Script executed:

```shell
#!/bin/bash
# Global search for synth.csv creation or references
rg -n "synth\.csv|synth_csv" --type=py

# Also search in the shadow model training pipeline
rg -n "\.csv" src/midst_toolkit/attacks/ensemble/ --type=py | grep -i synth

# Check what files might be generated or expected
fd "synth" --type f 2>/dev/null | head -20

Length of output: 254


The TODO comment identifies a genuine issue: synth.csv is not created by the data processing pipeline and has no source in the codebase.

The file is loaded from processed_attack_data_path but the process_split_data.py pipeline only generates: real_train.csv, real_val.csv, real_test.csv, and master_challenge_*.csv files. Since synth.csv is never created by the codebase, either:

  1. The evaluation should use df_synthetic from the target model (same as training) instead of loading a separate file, or
  2. synth.csv should be generated by the data processing pipeline but is missing, or
  3. synth.csv is expected to be externally provided without documentation

Clarify the intended behavior and either fix the code path or add the missing data generation step.

🤖 Prompt for AI Agents
In examples/ensemble_attack/run_metaclassifier_training.py around lines 122 to
127, the code attempts to load "synth.csv" which is not produced by the data
processing pipeline; update the code or pipeline to match the intended flow:
either (A) stop loading synth.csv and instead use the synthetic dataframe
already created/returned by the target model (replace the load call with the
in-memory df_synthetic source used during training/evaluation), or (B) add
generation of synth.csv in the data processing pipeline (process_split_data.py)
so that it is created under processed_attack_data_path and document its schema,
or (C) if synth.csv is an external input, add explicit validation and clear
documentation+config to require the external file; choose one of these options
and implement the corresponding change and tests so the file path lookup no
longer fails.


# 3. Get predictions on the test set
probabilities, pred_score = blending_attacker.predict(
df_test=df_meta_test,
df_synthetic=df_synthetic,
df_original_synthetic=df_synthetic_original,
df_reference=df_reference,
id_column_data=test_trans_ids,
y_test=y_meta_test,
)

# Save the prediction probabilities
# TODO: Create the attack results directory folder if it does not exist.
attack_results_path = Path(config.data_paths.attack_results_path)
attack_results_path.mkdir(parents=True, exist_ok=True)
np.save(
Path(config.data_paths.attack_results_path)
/ f"{timestamp}_{config.metaclassifier.model_type}_test_pred_proba.npy",
Expand Down
88 changes: 86 additions & 2 deletions examples/ensemble_attack/run_shadow_model_training.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,103 @@
import pickle
import shutil
from logging import INFO
from pathlib import Path
from typing import Any

from omegaconf import DictConfig

from midst_toolkit.attacks.ensemble.data_utils import load_dataframe
from midst_toolkit.attacks.ensemble.rmia.shadow_model_training import (
train_three_sets_of_shadow_models,
)
from midst_toolkit.attacks.ensemble.shadow_model_utils import (
save_additional_tabddpm_config,
train_tabddpm_and_synthesize,
)
from midst_toolkit.common.logger import log


def run_shadow_model_training(config: DictConfig) -> None:
def run_target_model_training(config: DictConfig) -> Path:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since training the target model follows steps very similar to shadow model training and is part of the attack design, we could consider creating a dedicated function (one suggestion for the name is train_target_model_and_synthesize) in the existing shadow_model_training.py module. This way, similar code stays together, and if we later improve the shadow model training pipeline, it will be easier to update the target model training as well. This makes sense since target model is essentially a specific type of shadow model. We should also update the README to clarify this for the readers.

"""
Function to run the target model training for RMIA attack.

Args:
config: Configuration object set in config.yaml.

Returns:
Path to the saved target model results.
"""
log(INFO, "Running target model training...")

# Load the required dataframe for target model training.
df_real_data = load_dataframe(
Path(config.data_paths.processed_attack_data_path),
"real_train.csv",
)

# TODO: Test when pipeline is complete to make sure real_data is correct.

target_model_output_path = Path(config.shadow_training.target_model_output_path)
target_training_json_config_paths = config.shadow_training.training_json_config_paths

# TODO: Add this to config or .json files
table_name = "trans"
id_column_name = "trans_id"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps not for this PR, but these could be configuration parameters yes? If not doing it here, maybe just put it as a todo?


target_folder = target_model_output_path / "target_model"

target_folder.mkdir(parents=True, exist_ok=True)
shutil.copyfile(
target_training_json_config_paths.table_domain_file_path,
target_folder / f"{table_name}_domain.json",
)
shutil.copyfile(
target_training_json_config_paths.dataset_meta_file_path,
target_folder / "dataset_meta.json",
)
configs, save_dir = save_additional_tabddpm_config(
data_dir=target_folder,
training_config_json_path=Path(target_training_json_config_paths.tabddpm_training_config_path),
final_config_json_path=target_folder / f"{table_name}.json", # Path to the new json
experiment_name="trained_target_model",
)

train_result = train_tabddpm_and_synthesize(
train_set=df_real_data,
configs=configs,
save_dir=save_dir,
synthesize=True,
)

# TODO: Check: Selected_id_lists should be of form [[]]
selected_id_lists = [df_real_data[id_column_name].tolist()]

attack_data: dict[str, Any] = {
"selected_sets": selected_id_lists,
"trained_results": [],
}

attack_data["trained_results"].append(train_result)

# Pickle dump the results
result_path = Path(save_dir, "target_model.pkl")
with open(result_path, "wb") as file:
pickle.dump(attack_data, file)

return result_path


def run_shadow_model_training(config: DictConfig) -> list[Path]:
"""
Function to run the shadow model training for RMIA attack.

Args:
config: Configuration object set in config.yaml.

Returns:
Paths to the saved shadow model results for the three sets of shadow models. For more details,
see the documentation and return value of `train_three_sets_of_shadow_models`
at src/midst_toolkit/attacks/ensemble/rmia/shadow_model_training.py.
"""
log(INFO, "Running shadow model training...")
# Load the required dataframes for shadow model training.
Expand Down Expand Up @@ -55,5 +137,7 @@ def run_shadow_model_training(config: DictConfig) -> None:
)
log(
INFO,
f"Shadow model training finished and saved at 1) {first_set_result_path}, 2) {second_set_result_path}, 3) {third_set_result_path}",
f"Shadow model training finished and saved at \n1) {first_set_result_path} \n2) {second_set_result_path} \n3) {third_set_result_path}",
)

return [first_set_result_path, second_set_result_path, third_set_result_path]
Loading