diff --git a/MaskHIT/configs/config_default.yaml b/MaskHIT/configs/config_default.yaml deleted file mode 100644 index c09679e..0000000 --- a/MaskHIT/configs/config_default.yaml +++ /dev/null @@ -1,81 +0,0 @@ -# NOTE: To override any values in this file, please create config_user.yaml -# (or any YAML file name of you like). -# In config_user.yaml, only define the variables you wish to update. -# Unmentioned variables will use the default values specified here. -# Please avoid directly modifying values in this file. -# -# Documentation Tags: -# (default): variables can be left with their default values -# (custom): more likely to need modification for each user -#--------------------------------------------------- - - -dataset: - ## Dataset configuration - - # Path to the folder containing meta info about svs data - # Example: ../../../SlidePreprocessing/for_vit/meta/IBD_PROJECT/svs_meta.pickle - meta_svs: !!str - - # Path to the folder containing meta info about the dataset - # Example: ../../../SlidePreprocessing/for_vit/meta/ibd_project_meta.pickle - meta_all: !!str - - # Outcome our model is trying to predict - # Example: "Dx (U=UC, C=Cr, I=Ind)" - outcome: !!str - - # classification type - outcome_type: !!str classification - - # the name of the study - study: !!str - - # type of disease; whether it is cancer or not - is_cancer: !!bool False - - # title of project/disease name - disease: !!str - - # names of classes in your dataset - classes: !!str - - # number of folds experiment uses - num_folds: !!int 5 - -patch: - ## Patch configuration - - # number of patches from each region. If 0 will sample all patches - num_patches: !!int 0 - - # magnification level at which patches were extracted at - magnification: !!int 10 - - # intensity of weight decay - wd: !!float 0.01 - -model: - ## Model configuration - - # used for uneven class distribution - weighted_loss: !!bool False - - # learning rate - lr: !!float 1e-5 - - # Dropout rate - dropout: !!float 0.2 - - # Batch size for processing slide patches - batch_size: !!int 16 - - # determines whether old logs should stay - override_logs: !!bool True - - # number of svs sampled in sample-patient mode - regions_per_svs: !!int 64 - - # TBD - sample_patient: !!bool True - diff --git a/requirement.sh b/MaskHIT/install_requirement.sh similarity index 100% rename from requirement.sh rename to MaskHIT/install_requirement.sh diff --git a/install_requirement_for_container.sh b/MaskHIT/install_requirement_for_container.sh similarity index 92% rename from install_requirement_for_container.sh rename to MaskHIT/install_requirement_for_container.sh index b6a8863..5daade9 100755 --- a/install_requirement_for_container.sh +++ b/MaskHIT/install_requirement_for_container.sh @@ -2,7 +2,7 @@ pip install openslide-python # install additional packages -pip install pandarallel pandas scikit-image scikit-learn einops tqdm lifelines pyyaml +pip install pandarallel pandas scikit-image scikit-learn einops tqdm lifelines pyyaml seaborn pip install git+https://github.com/ildoonet/pytorch-gradual-warmup-lr.git pip install opencv-python # When you see: AttributeError: module 'cv2.dnn' has no attribute 'DictValue' diff --git a/MaskHIT/.gitignore b/MaskHIT/maskhit/.gitignore similarity index 100% rename from MaskHIT/.gitignore rename to MaskHIT/maskhit/.gitignore diff --git a/MaskHIT/README.md b/MaskHIT/maskhit/README.md similarity index 100% rename from MaskHIT/README.md rename to MaskHIT/maskhit/README.md diff --git a/MaskHIT/model/archs/vit/__init__.py b/MaskHIT/maskhit/__init__.py similarity index 100% rename from MaskHIT/model/archs/vit/__init__.py rename to MaskHIT/maskhit/__init__.py diff --git a/MaskHIT/maskhit/configs/config_default.yaml b/MaskHIT/maskhit/configs/config_default.yaml new file mode 100644 index 0000000..e877000 --- /dev/null +++ b/MaskHIT/maskhit/configs/config_default.yaml @@ -0,0 +1,137 @@ +# Dataset and Model Configuration File +# This configuration file sets parameters for dataset preprocessing and model training. +# Modify values in config_user.yaml to override defaults. + +# NOTE: To override any values in this file, please create config_user.yaml +# (or any YAML file name of you like). +# In config_user.yaml, only define the variables you wish to update. +# Unmentioned variables will use the default values specified here. +# Please avoid directly modifying values in this file. +# +# Documentation Tags: +# (default): variables can be left with their default values +# (custom): more likely to need modification for each user +#--------------------------------------------------- + +### DATASET CONFIGURATION ### +dataset: + ## Dataset configuration + + # Path to the file (svs_meta.pickle) containing meta info about svs data + # File is generated by SlidePrep/MaskHIT_Prep/05_post_process.py + meta_svs: !!str + + # Path to the file (meta.pickle) containing meta info about the dataset + # File is generated by SlidePrep/MaskHIT_Prep/01_get_svs_meta.py + meta_all: !!str + + # Outcome our model is trying to predict + # Example: "Dx (U=UC, C=Cr, I=Ind)" + #TODO: Not sure what this is meant. In what format? If this is for classification, clarify. + outcome: !!str + + # classification type + # Available options: survival, classification, regression + outcome_type: !!str classification + + # the name of the study + #TODO: suggest to rename to 'study_name' for clarity. + study: !!str + + # type of disease; whether it is cancer or not + #TODO: Add doc to discuss why this is matter. If necessary, we should rename. + is_cancer: !!bool False + + # title of project/disease name + #TODO: what's the difference between this and 'study' + #This seems to be used for folder value:?? + # `meta_svs['folder'] = config.dataset.disease` + #Better rename or explain the main intent of this parameter. + disease: !!str + + # names of classes in your dataset + #TODO: need doc and example. + classes: !!str + + # Number of folds for nested cross-validation + num_folds: !!int 5 + +patch: + ## Patch configuration + + # number of patches from each region. If 0 will sample all patches + num_patches: !!int 0 + + # magnification level at which patches were extracted at + #TODO: Why manually setting this again? Or can we specify a config file from Prep to extract this info? + magnification: !!int 10 + + # intensity of weight decay + #TODO: Why this is under patch section? + wd: !!float 0.01 + +model: + ## Model configuration + + # used for uneven class distribution + weighted_loss: !!bool False + + # learning rate + lr: !!float 1e-5 + + # Dropout rate + dropout: !!float 0.2 + + # Batch size for processing slide patches + batch_size: !!int 16 + + # which fold to use after kfold cross validation + #TODO: Not sure what this means. Also suggest to change variable name to be more descriptive. + fold: !!int 0 + + # determines whether old logs should stay + override_logs: !!bool True + + # number of svs sampled in sample-patient mode + regions_per_svs: !!int 64 + + #TODO: Copied from config_ibd_train.yaml. Need to double check the default values. + # Weight Decays + wd_attn: !!float 1e-3 + wd_fuse: !!float 1e-2 # changed from 1e-3 to 1e-2 + wd_loss: !!float 1e-2 # changed from 1e-3 to 1e-2 + wd_pred: !!float 0.002 + + #TODO: Copied from config_ibd_train.yaml. Need to double check the default values. + # Learning Rates + lr_attn: !!float 1e-5 # lowered since we are using pre-trained model + lr_fuse: !!float 1e-4 + lr_loss: !!float 1e-4 + lr_pred: !!float 7e-4 + + #TODO: Copied from config_ibd_train.yaml. Need to double check the default values. + #Not sure what this is for, as in the code it chooses a measure based on outcome_type? + performance_measure: !!str f1 + + #TODO: Copied from config_ibd_train.yaml. Add doc. + accumulation_steps: !!int 1 + + #TODO: Copied from config_ibd_train.yaml. Need to double check the default values. + dropout: !!float 0.2 + + #TODO: Copied from config_ibd_train.yaml. Need to double check the default values. + #TODO: What is this for? and what's # for visualization (64) + batch_size: !!int 16 # for visualization (64) + + #TODO: Copied from config_ibd_train.yaml. Maybe: overwrite_logs + override_logs: !!bool True + + # TBD + #TODO: Missing docs. + sample_patient: !!bool True + + # Check-point path? + #TODO: Missing docs. + resume: null + + diff --git a/MaskHIT/configs/config_default_visualization.yml b/MaskHIT/maskhit/configs/config_default_visualization.yml similarity index 100% rename from MaskHIT/configs/config_default_visualization.yml rename to MaskHIT/maskhit/configs/config_default_visualization.yml diff --git a/MaskHIT/configs/config_ibd_train.yml b/MaskHIT/maskhit/configs/config_ibd_train.yml similarity index 100% rename from MaskHIT/configs/config_ibd_train.yml rename to MaskHIT/maskhit/configs/config_ibd_train.yml diff --git a/MaskHIT/configs/config_ibd_visualization.yml b/MaskHIT/maskhit/configs/config_ibd_visualization.yml similarity index 100% rename from MaskHIT/configs/config_ibd_visualization.yml rename to MaskHIT/maskhit/configs/config_ibd_visualization.yml diff --git a/MaskHIT/configs/config_tcga.yml b/MaskHIT/maskhit/configs/config_tcga.yml similarity index 100% rename from MaskHIT/configs/config_tcga.yml rename to MaskHIT/maskhit/configs/config_tcga.yml diff --git a/MaskHIT/create_attention_maps.py b/MaskHIT/maskhit/create_attention_maps.py similarity index 100% rename from MaskHIT/create_attention_maps.py rename to MaskHIT/maskhit/create_attention_maps.py diff --git a/MaskHIT/cross_validation.py b/MaskHIT/maskhit/cross_validation.py similarity index 100% rename from MaskHIT/cross_validation.py rename to MaskHIT/maskhit/cross_validation.py diff --git a/MaskHIT/model/.gitignore b/MaskHIT/maskhit/model/.gitignore similarity index 100% rename from MaskHIT/model/.gitignore rename to MaskHIT/maskhit/model/.gitignore diff --git a/MaskHIT/maskhit/model/__init__.py b/MaskHIT/maskhit/model/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/MaskHIT/model/archs/__init__.py b/MaskHIT/maskhit/model/archs/__init__.py similarity index 100% rename from MaskHIT/model/archs/__init__.py rename to MaskHIT/maskhit/model/archs/__init__.py diff --git a/MaskHIT/model/archs/agg_ap.py b/MaskHIT/maskhit/model/archs/agg_ap.py similarity index 100% rename from MaskHIT/model/archs/agg_ap.py rename to MaskHIT/maskhit/model/archs/agg_ap.py diff --git a/MaskHIT/model/archs/agg_attn.py b/MaskHIT/maskhit/model/archs/agg_attn.py similarity index 100% rename from MaskHIT/model/archs/agg_attn.py rename to MaskHIT/maskhit/model/archs/agg_attn.py diff --git a/MaskHIT/model/archs/agg_deepattnmisl.py b/MaskHIT/maskhit/model/archs/agg_deepattnmisl.py similarity index 100% rename from MaskHIT/model/archs/agg_deepattnmisl.py rename to MaskHIT/maskhit/model/archs/agg_deepattnmisl.py diff --git a/MaskHIT/model/archs/agg_mhattn.py b/MaskHIT/maskhit/model/archs/agg_mhattn.py similarity index 100% rename from MaskHIT/model/archs/agg_mhattn.py rename to MaskHIT/maskhit/model/archs/agg_mhattn.py diff --git a/MaskHIT/model/archs/agg_vit.py b/MaskHIT/maskhit/model/archs/agg_vit.py similarity index 100% rename from MaskHIT/model/archs/agg_vit.py rename to MaskHIT/maskhit/model/archs/agg_vit.py diff --git a/MaskHIT/model/archs/utils/__init__.py b/MaskHIT/maskhit/model/archs/utils/__init__.py similarity index 100% rename from MaskHIT/model/archs/utils/__init__.py rename to MaskHIT/maskhit/model/archs/utils/__init__.py diff --git a/MaskHIT/model/archs/utils/masking_generator.py b/MaskHIT/maskhit/model/archs/utils/masking_generator.py similarity index 100% rename from MaskHIT/model/archs/utils/masking_generator.py rename to MaskHIT/maskhit/model/archs/utils/masking_generator.py diff --git a/MaskHIT/maskhit/model/archs/vit/__init__.py b/MaskHIT/maskhit/model/archs/vit/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/MaskHIT/model/archs/vit/deepvit.py b/MaskHIT/maskhit/model/archs/vit/deepvit.py similarity index 100% rename from MaskHIT/model/archs/vit/deepvit.py rename to MaskHIT/maskhit/model/archs/vit/deepvit.py diff --git a/MaskHIT/model/backbone.py b/MaskHIT/maskhit/model/backbone.py similarity index 100% rename from MaskHIT/model/backbone.py rename to MaskHIT/maskhit/model/backbone.py diff --git a/MaskHIT/model/helper.py b/MaskHIT/maskhit/model/helper.py similarity index 100% rename from MaskHIT/model/helper.py rename to MaskHIT/maskhit/model/helper.py diff --git a/MaskHIT/model/models.py b/MaskHIT/maskhit/model/models.py similarity index 100% rename from MaskHIT/model/models.py rename to MaskHIT/maskhit/model/models.py diff --git a/MaskHIT/maskhit/options/__init__.py b/MaskHIT/maskhit/options/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/MaskHIT/options/base_options.py b/MaskHIT/maskhit/options/base_options.py similarity index 100% rename from MaskHIT/options/base_options.py rename to MaskHIT/maskhit/options/base_options.py diff --git a/MaskHIT/options/read_config.py b/MaskHIT/maskhit/options/read_config.py similarity index 100% rename from MaskHIT/options/read_config.py rename to MaskHIT/maskhit/options/read_config.py diff --git a/MaskHIT/options/train_options.py b/MaskHIT/maskhit/options/train_options.py similarity index 98% rename from MaskHIT/options/train_options.py rename to MaskHIT/maskhit/options/train_options.py index 2000531..fba0633 100644 --- a/MaskHIT/options/train_options.py +++ b/MaskHIT/maskhit/options/train_options.py @@ -53,6 +53,7 @@ def initialize(self): type=str, default='', help='select cancer subset, if empty then use entire dataset') + #TODO: lack of doc. so this 'fold' means 0-indexed test fold ID? self.parser.add_argument('--fold', type=int, default=0, @@ -254,10 +255,12 @@ def initialize(self): help='turn off strict mode') # patch region masking + #TODO: Avoid too similar arg names self.parser.add_argument('--prob-mask', type=float, default=0, help='mask probability in BERT') + #TODO: Need doc for "masked:original:random". What are these? self.parser.add_argument('--prop-mask', type=str, default='0, 1, 0', @@ -298,7 +301,7 @@ def initialize(self): # experiment optional options self.parser.add_argument('--checkpoints-folder', type=str, - default='checkpoints_new', + default='checkpoints', help='path to the checkpoints folder') self.parser.add_argument('--log-freq', type=int, diff --git a/MaskHIT/plot_results.py b/MaskHIT/maskhit/plot_results.py similarity index 100% rename from MaskHIT/plot_results.py rename to MaskHIT/maskhit/plot_results.py diff --git a/MaskHIT/quick_test.py b/MaskHIT/maskhit/quick_test.py similarity index 100% rename from MaskHIT/quick_test.py rename to MaskHIT/maskhit/quick_test.py diff --git a/MaskHIT/train.py b/MaskHIT/maskhit/train.py similarity index 82% rename from MaskHIT/train.py rename to MaskHIT/maskhit/train.py index 828af66..00edab7 100644 --- a/MaskHIT/train.py +++ b/MaskHIT/maskhit/train.py @@ -33,6 +33,15 @@ from options.train_options import TrainOptions from utils.config import Config +"""TODO: The issues of this script: + - inconsistency: config vs args + - non-modular scripting style + - args without doc + - passing args to HybridFitter. this is like 'import *', should be specific. + - formatting of printing messages. + +""" + # Defining a global variable to store available device global device @@ -62,12 +71,14 @@ config = Config(args.default_config_file, args.user_config_file) # string holding command-line arguments joined with spaces args.all_arguments = ' '.join(sys.argv[1:]) +#TODO: this variable above is not used. Better remove or used somewhere? assert not args.sample_all, "the argument --sample-all is deprecated, use --num-patches=0 instead" # print(f"args.cancer: {args.cancer}") if args.cancer == '.': args.cancer = "" +#TODO: Not clear why it checks if it's a cancer data? What's the real intent? # setting weight decay values if hasattr(config.model, 'wd_attn') and hasattr(config.model, 'wd_fuse') and hasattr(config.model, 'wd_loss'): @@ -112,13 +123,17 @@ # Checking to see if region-size, region-length, and grid-size are valid # These parameters control the subdivision of patches within a given region if args.region_length is not None and args.region_length > 0: - assert_message = "grid size is measured in patches and need to be a positive number no larger than the region size / patch size" + assert_message = ( + "Grid size is measured in patches and needs to be a positive number. " + "It should not exceed the region size divided by the patch size." + ) assert args.grid_size <= args.region_length and args.grid_size > 0, assert_message args.prop_mask = [int(x) for x in args.prop_mask.split(',')] args.prop_mask = [x / sum(args.prop_mask) for x in args.prop_mask] # initializing sampling and outcome arguments +#TODO: lack of docs. So guess is that this is a variable name in string for an id column? if args.sample_svs: args.id_var = 'id_svs_num' else: @@ -131,7 +146,7 @@ args.patch_spec = f"mag_{float(config.patch.magnification):.1f}-size_{args.patch_size}" - +#TODO: We should write a function to handle data partitioning. args.mode_ops = {'train': {}, 'val': {}, 'predict': {}} # initializing num_patches argument for train mode @@ -219,10 +234,12 @@ def get_resume_checkpoint(checkpoints_name, epoch_to_resume): files = glob.glob( os.path.join(args.checkpoints_folder, checkpoints_name, "*.pt")) - checkpoint_to_resume = [ - fname for fname in files - if get_checkpoint_epoch(fname) == epoch_to_resume - ][0] + checkpoint_to_resume = None + for fname in files: + if get_checkpoint_epoch(fname) == epoch_to_resume: + checkpoint_to_resume = fname + break + return checkpoint_to_resume @@ -253,21 +270,37 @@ def prepare_data(meta_split, meta_file, vars_to_include=[]): - outcome: patient outcome variable, encoded for classification models e.g. 0, 1, 2 for three classes """ + + #TODO: Temp Duct-tape solution for existing datasets. Need update. + #Ideally, the target column name should be parameterized in a config file. ids_to_add = [] for index, row in meta_split.iterrows(): - value_to_split = row['case_number'] - split_value = value_to_split.split('.')[0] - ids_to_add.append(split_value) + if 'case_number' in row: + value_to_split = row['case_number'] + split_value = value_to_split.split('.')[0] + elif 'barcode' in row: + split_value = row['barcode'] + #TODO: Temp Duct-tape solution: Formatting the ID + split_value = "-".join(split_value.split('-')[:3]) + else: + raise ValueError("Row does not contain 'case_number' or 'barcode'") + ids_to_add.append(split_value) + meta_split['id_patient'] = ids_to_add - + #TODO: This whole block should be in another script to be run before train.py if 'id_patient' not in meta_split.columns: patient_ids = [] # iterating over the meta_split dataframe for index, row in meta_split.iterrows(): + #TODO: Debug: This code is basically repeating what we have done in MaskHIT_Prep + # obtaining the paths of the files to the related slide - file_names = ast.literal_eval(row['Path']) + #TODO: Debug: temp fix but need to check with other datasets to see if literal_eval is required and why. + #file_names = ast.literal_eval(row['Path']) + file_names = str(row['path']) + patient_id = file_names[0].split('/')[5].split(' ')[0] patient_ids.append(patient_id) # adding patient id to the list meta_split['id_patient'] = patient_ids # adding column to the meta_split dataframe @@ -283,19 +316,11 @@ def prepare_data(meta_split, meta_file, vars_to_include=[]): # Selects columns from meta_file df and merges them into meta_split based on a shared 'id_patient' column # includes all the columns from meta_split and only the selected columns from meta_file - try: - meta_split = meta_split.merge(meta_file[vars_to_include], - on='id_patient', - how='inner') - except KeyError as e: - print(f"KeyError: {e}") - meta_split['id_svs'] = meta_split['id_patient'] - print(vars_to_include) - if 'id_patient' in vars_to_include: - vars_to_include.remove('id_patient') - meta_split = meta_split.merge(meta_file[vars_to_include], - on='id_svs', - how='inner') + meta_split = meta_split.merge(meta_file[vars_to_include], + on='id_patient', + how='inner') + #TODO: Need check consistency after merge. + assert meta_split.shape[0] > 0, "Merge operation failed." meta_split['id_patient_num'] = meta_split.id_patient.astype( 'category').cat.codes @@ -309,6 +334,10 @@ def prepare_data(meta_split, meta_file, vars_to_include=[]): elif config.dataset.outcome_type == 'survival': meta_split = meta_split.loc[~meta_split.status.isna() & ~meta_split.time.isna()] + elif config.dataset.outcome_type == 'regression': + meta_split = meta_split.loc[~meta_split[config.dataset.outcome].isna()] + meta_split[config.dataset.outcome] = meta_split[config.dataset.outcome].astype( + 'float') return meta_split @@ -325,7 +354,7 @@ def main(): model_name = f"{TIMESTR}-{args.fold}" # if we want to resume previous training - if len(config.model.resume): + if config.model.resume: checkpoint_to_resume = get_resume_checkpoint(config.model.resume, config.model.resume_epoch) if args.resume_train: @@ -355,6 +384,7 @@ def main(): # loading datasets meta_svs = pd.read_pickle(config.dataset.meta_svs) + #TODO: Need doc why this option exists and what it does. if args.ffpe_only: meta_svs = meta_svs.loc[meta_svs.slide_type == 'ffpe'] if args.ffpe_exclude: @@ -362,10 +392,15 @@ def main(): if config.dataset.meta_all is not None: meta_all = pd.read_pickle(config.dataset.meta_all) + #Debug: Checking input data + print("Debug: meta_all:\n", meta_all.head(2)) + + #TODO: mode=extract is not expected in the train_options. Need doc. if args.mode == 'extract': meta_train = meta_val = meta_all elif 'fold' in meta_all.columns: if meta_all.fold.nunique() == 5: + #TODO: Why expecting to have 5 folds? What about other cases? Need generalize. val_fold = (args.fold + 1) % 5 test_fold = args.fold @@ -398,6 +433,7 @@ def main(): meta_val = pd.read_pickle(args.meta_val) # select cancer subset + #TODO: lack of docs. Here, cancer='' means use all data. A better term would be 'filtering'. if args.cancer == '': pass else: @@ -406,11 +442,13 @@ def main(): meta_val = meta_val.loc[meta_val.cancer == args.cancer] + #TODO: Not sure why we do this. Should clarify variables and their intent. if config.dataset.is_cancer: meta_svs['folder'] = meta_svs['cancer'] else: meta_svs['folder'] = config.dataset.disease + #TODO: Is safe to hard-code the weights? meta_svs['sampling_weights'] = 1 vars_to_include = ['id_patient', 'folder', 'id_svs', 'sampling_weights'] if 'svs_path' in meta_svs: diff --git a/MaskHIT/maskhit/trainer/__init__.py b/MaskHIT/maskhit/trainer/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/MaskHIT/trainer/earlystopping.py b/MaskHIT/maskhit/trainer/earlystopping.py similarity index 100% rename from MaskHIT/trainer/earlystopping.py rename to MaskHIT/maskhit/trainer/earlystopping.py diff --git a/MaskHIT/trainer/fitter.py b/MaskHIT/maskhit/trainer/fitter.py similarity index 99% rename from MaskHIT/trainer/fitter.py rename to MaskHIT/maskhit/trainer/fitter.py index 25b769a..66cb6a7 100644 --- a/MaskHIT/trainer/fitter.py +++ b/MaskHIT/maskhit/trainer/fitter.py @@ -17,6 +17,8 @@ import numpy as np from utils.config import Config from pathlib import Path +#TODO: Sort imports. Also avoid os.path use +#TODO: This file is way too long. Should clean up and split the features. from maskhit.model.models import HybridModel from maskhit.trainer.losses import ContrasiveLoss @@ -429,6 +431,7 @@ def train(self, df_train=None, epoch=0, accumulation_steps=1): preds=preds.data.cpu().numpy(), targets=targets.data.cpu().numpy(), outcome_type=self.config.dataset.outcome_type, label_classes = label_classes) + perfs.update(metrics[self.metric], num_samples) # measure elapsed time @@ -659,6 +662,7 @@ def fit(self, data_dict, procedure='train'): self.device = torch.device('cpu') self.current_epoch = 1 + #TODO: Better define constant values in __init__() metrics = { 'classification': 'auc', 'survival': 'c-index', diff --git a/MaskHIT/trainer/helper.py b/MaskHIT/maskhit/trainer/helper.py similarity index 100% rename from MaskHIT/trainer/helper.py rename to MaskHIT/maskhit/trainer/helper.py diff --git a/MaskHIT/trainer/info_nce_loss.py b/MaskHIT/maskhit/trainer/info_nce_loss.py similarity index 100% rename from MaskHIT/trainer/info_nce_loss.py rename to MaskHIT/maskhit/trainer/info_nce_loss.py diff --git a/MaskHIT/trainer/logger.py b/MaskHIT/maskhit/trainer/logger.py similarity index 100% rename from MaskHIT/trainer/logger.py rename to MaskHIT/maskhit/trainer/logger.py diff --git a/MaskHIT/trainer/losses.py b/MaskHIT/maskhit/trainer/losses.py similarity index 100% rename from MaskHIT/trainer/losses.py rename to MaskHIT/maskhit/trainer/losses.py diff --git a/MaskHIT/trainer/meters.py b/MaskHIT/maskhit/trainer/meters.py similarity index 100% rename from MaskHIT/trainer/meters.py rename to MaskHIT/maskhit/trainer/meters.py diff --git a/MaskHIT/trainer/metrics.py b/MaskHIT/maskhit/trainer/metrics.py similarity index 89% rename from MaskHIT/trainer/metrics.py rename to MaskHIT/maskhit/trainer/metrics.py index ace3f6f..0674b91 100644 --- a/MaskHIT/trainer/metrics.py +++ b/MaskHIT/maskhit/trainer/metrics.py @@ -2,7 +2,16 @@ import numpy as np import torch from lifelines.utils import concordance_index -from sklearn.metrics import roc_auc_score, f1_score, confusion_matrix, roc_curve, auc +from sklearn.metrics import ( + roc_auc_score, + f1_score, + mean_absolute_error, + mean_squared_error, + r2_score, + confusion_matrix, + roc_curve, + auc +) from scipy.special import softmax import seaborn as sns import matplotlib.pyplot as plt @@ -34,6 +43,7 @@ def find_confident_instance(preds): return preds[preds.max(1).argmax()] def read_and_adjust_csv(file_name, last_max_id): + #TODO: I doubt this function is generalizable. Need review. """ Reads a CSV file and adjusts the IDs in the second column by incrementing them based on a given maximum ID. It then extracts specific columns into numpy arrays. This function is used with aggregate_predictions function. @@ -95,8 +105,12 @@ def aggregate_predictions(study_name, timestr, classes, num_files = 5): def calculate_metrics(ids, preds, targets, label_classes, outcome_type='survival', mode = ''): + #TODO: Should support all the outcome_types, or should raise error if a provided type is not available. + #Can we create a class with the logics and feed the object, instead of implementing the logic here. + #TODO: mode seems duct-tape implementation. Need review. + df = pd.DataFrame(np.concatenate([ids, preds, targets], axis=1)) + if outcome_type == 'survival': - df = pd.DataFrame(np.concatenate([ids, targets, preds], axis=1)) df.columns = ['id', 'time', 'event', 'pred'] df = df.groupby('id').mean() c = c_index(df.time, -df.pred, df.event) @@ -161,6 +175,22 @@ def calculate_metrics(ids, preds, targets, label_classes, outcome_type='survival res = {'f1': f1, 'auc': auc_score} return res + + elif outcome_type == 'regression': + df.columns = ['id', 'pred', 'target'] + grouped = df.groupby('id').mean() # Assuming averaging is the desired approach + + mae = mean_absolute_error(grouped['target'], grouped['pred']) + mse = mean_squared_error(grouped['target'], grouped['pred']) + r2 = r2_score(grouped['target'], grouped['pred']) + + res = {'MAE': mae, 'MSE': mse, 'r2': r2} + return res + + else: + raise ValueError(f"Unsupported outcome type: {outcome_type}") + + def save_multi_class_auc(probs, targets, label_classes, save_path = 'auc_plot_multi_class.png'): """ diff --git a/MaskHIT/trainer/slidedataset.py b/MaskHIT/maskhit/trainer/slidedataset.py similarity index 98% rename from MaskHIT/trainer/slidedataset.py rename to MaskHIT/maskhit/trainer/slidedataset.py index 991e9fe..666ec75 100755 --- a/MaskHIT/trainer/slidedataset.py +++ b/MaskHIT/maskhit/trainer/slidedataset.py @@ -91,6 +91,8 @@ def __len__(self): def _get_patch_meta(self, folder, fname, loc): # get all the patches for one wsi + #TODO: This path doesn't seem to be well generalized. + #Confusing as folder=args.config.disease. meta_one = pd.read_pickle( f'{self.args.data}/{folder}/{fname}/{self.args.patch_spec}/meta.pickle') meta_one['valid'].fillna(0, inplace=True) diff --git a/MaskHIT/trainer/transforms.py b/MaskHIT/maskhit/trainer/transforms.py similarity index 100% rename from MaskHIT/trainer/transforms.py rename to MaskHIT/maskhit/trainer/transforms.py diff --git a/MaskHIT/trainer/wsitilesampler.py b/MaskHIT/maskhit/trainer/wsitilesampler.py similarity index 100% rename from MaskHIT/trainer/wsitilesampler.py rename to MaskHIT/maskhit/trainer/wsitilesampler.py diff --git a/MaskHIT/maskhit/utils/__init__.py b/MaskHIT/maskhit/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/MaskHIT/utils/aggregate_predictions.py b/MaskHIT/maskhit/utils/aggregate_predictions.py similarity index 100% rename from MaskHIT/utils/aggregate_predictions.py rename to MaskHIT/maskhit/utils/aggregate_predictions.py diff --git a/MaskHIT/utils/collect_predictions.py b/MaskHIT/maskhit/utils/collect_predictions.py similarity index 100% rename from MaskHIT/utils/collect_predictions.py rename to MaskHIT/maskhit/utils/collect_predictions.py diff --git a/MaskHIT/utils/collect_validations.py b/MaskHIT/maskhit/utils/collect_validations.py similarity index 100% rename from MaskHIT/utils/collect_validations.py rename to MaskHIT/maskhit/utils/collect_validations.py diff --git a/MaskHIT/utils/config.py b/MaskHIT/maskhit/utils/config.py similarity index 100% rename from MaskHIT/utils/config.py rename to MaskHIT/maskhit/utils/config.py diff --git a/MaskHIT/utils/find_best_cases.py b/MaskHIT/maskhit/utils/find_best_cases.py similarity index 100% rename from MaskHIT/utils/find_best_cases.py rename to MaskHIT/maskhit/utils/find_best_cases.py diff --git a/MaskHIT/utils/get_region_info.py b/MaskHIT/maskhit/utils/get_region_info.py similarity index 100% rename from MaskHIT/utils/get_region_info.py rename to MaskHIT/maskhit/utils/get_region_info.py diff --git a/MaskHIT/utils/hard_grid_ibd.py b/MaskHIT/maskhit/utils/hard_grid_ibd.py similarity index 100% rename from MaskHIT/utils/hard_grid_ibd.py rename to MaskHIT/maskhit/utils/hard_grid_ibd.py diff --git a/MaskHIT/utils/prepare_visualization_data.py b/MaskHIT/maskhit/utils/prepare_visualization_data.py similarity index 100% rename from MaskHIT/utils/prepare_visualization_data.py rename to MaskHIT/maskhit/utils/prepare_visualization_data.py diff --git a/MaskHIT/utils/random_grid_search.py b/MaskHIT/maskhit/utils/random_grid_search.py similarity index 100% rename from MaskHIT/utils/random_grid_search.py rename to MaskHIT/maskhit/utils/random_grid_search.py diff --git a/MaskHIT/maskhit/vis/__init__.py b/MaskHIT/maskhit/vis/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/MaskHIT/vis/wsi.py b/MaskHIT/maskhit/vis/wsi.py similarity index 100% rename from MaskHIT/vis/wsi.py rename to MaskHIT/maskhit/vis/wsi.py diff --git a/MaskHIT/setup_maskhit.py b/MaskHIT/setup_maskhit.py new file mode 100644 index 0000000..2e1c762 --- /dev/null +++ b/MaskHIT/setup_maskhit.py @@ -0,0 +1,25 @@ +from setuptools import setup, find_packages + +# Setup script for the MaskHIT package +# $ python setup_maskhit.py install +# To install this package in development mode, use: +# $ python setup_maskhit.py develop + +setup( + name='maskhit', + version='1.0', + description='MaskHIT', + author='Shuai Jiang, Naofumi Tomita', + packages=find_packages() + # install_requires=[], #external packages as dependencies +) + +# Notes for uninstallation: +# To uninstall the MaskHIT package, use the command: +# $ pip uninstall MaskHIT + +# For earlier versions of the package named 'maskhit', use: +# $ pip uninstall maskhit + +# For unisntalling develop version, use the commands above and +# remove a directory, called `maskhit.egg-info` \ No newline at end of file diff --git a/README.md b/README.md index 93d674d..1542eb9 100644 --- a/README.md +++ b/README.md @@ -1,19 +1,55 @@ + +## Quick Setup for MaskHIT +For a faster MaskHIT setup: +* Run `requirement.sh` to install needed packages + +* Install MaskHIT with `setup_maskhit.py` + + # WSI-PLP: Whole Slide Image Analysis for Patient-level Predictions The WSI-PLP repository provides tools for analyzing Whole Slide Images (WSI) with a focus on making predictions at the patient level. It's built for handling large slide images in pathology. +This repository includes two weakly-supervised deep learning methods for digital pathology and whole slide image (WSI) analysis: -## Installation -Start using WSI-PLP by selecting a project and following the setup instructions: - -### POPPSlide +## 1. POPPSlide: Patient Outcome Prediction Pipeline using Whole Slide Images 1. Go to the `POPPSlide` folder. 2. See the README there for installation details. -### MaskHIT +POPPSlide offers a comprehensive pipeline for predicting patient outcomes (categorical, time to event, or continuous) using WSIs. + +Detailed method description: https://www.nature.com/articles/s41598-021-95948-x. + +## 2. MaskHIT: Masked Pre-Training of Transformers for Histology Image Analysis 1. Visit the `MaskHIT` folder. 2. The README explains setup and configurations. -## Quick Setup for MaskHIT -For a faster MaskHIT setup: -* Run `requirement.sh` to install needed packages +MaskHIT utilizes a masked language model-like pretext task to train transformers on WSIs without labeled data. +- **Performance**: Outperforms various multiple instance learning approaches by 3% in survival prediction and 2% in cancer subtype classification tasks, and exceeds recent transformer-based methods. +- **Validation**: Attention maps generated align with pathologist annotations, indicating accurate identification of relevant histological structures. + +For more information: +https://arxiv.org/abs/2304.07434 + +# Installation +1. **Dependencies** + +* For installing necessary dependencies, use the provided script: +`install_requirements.sh` + +* For Singularity/Docker environment, use: +`install_requirements_for_container.sh` + +2. **Package Installation** + +* Install this MaskHIT package with a setup script `python setup_maskhit.py install`. + +# Usage +- **POPPSlide** + +Navigate to the `POPPSlide` subfolder for details and instructions. + +- **MaskHIT** + +For utilizing the latest pipeline, refer to the `maskhit` subfolder. -* Install MaskHIT with `setup_maskhit.py` \ No newline at end of file +* TODO: Rename maskhit folder +* Note: The maskhit folder will be renamed for consistency. diff --git a/setup_maskhit.py b/setup_maskhit.py deleted file mode 100644 index 85b8c74..0000000 --- a/setup_maskhit.py +++ /dev/null @@ -1,12 +0,0 @@ -from setuptools import setup, find_packages - -# $ python setup_maskhit.py develop - -setup( - name='MaskHIT', - version='1.0', - description='MaskHIT', - author='Shuai Jiang, Naofumi Tomita', - packages=find_packages() - # install_requires=[], #external packages as dependencies -) \ No newline at end of file