From fc6d0a1a35bb0d3989144d366466799ebb446852 Mon Sep 17 00:00:00 2001 From: David Kelley Date: Tue, 15 Apr 2025 18:06:08 -0700 Subject: [PATCH] spec in groups --- src/baskerville/scripts/hound_eval_spec.py | 284 ++++++++++----------- 1 file changed, 133 insertions(+), 151 deletions(-) diff --git a/src/baskerville/scripts/hound_eval_spec.py b/src/baskerville/scripts/hound_eval_spec.py index 1956387..3eda754 100755 --- a/src/baskerville/scripts/hound_eval_spec.py +++ b/src/baskerville/scripts/hound_eval_spec.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ========================================================================= -from optparse import OptionParser +import argparse import gc import json import os @@ -38,150 +38,134 @@ def main(): - usage = "usage: %prog [options] " - parser = OptionParser(usage) - parser.add_option( - "-c", - dest="class_min", - default=10, - type="int", - help="Minimum target class size to consider [Default: %default]", + parser = argparse.ArgumentParser(description="Evaluate a trained model.") + parser.add_argument( + "--f16", + default=False, + action="store_true", + help="use mixed precision for inference [Default: %(default)s]", ) - parser.add_option( + parser.add_argument( "--head", dest="head_i", default=0, - type="int", - help="Parameters head to test [Default: %default]", + type=int, + help="Parameters head to evaluate [Default: %(default)s]", ) - parser.add_option( + parser.add_argument( + "-m", + "--group_min", + default=20, + type=int, + help="Minimum target group size to consider [Default: %(default)s]", + ) + parser.add_argument( "-o", - dest="out_dir", - default="test_out", - help="Output directory for test statistics [Default: %default]", + "--out_dir", + default="spec_out", + help="Output directory for evaluation statistics [Default: %(default)s]", ) - parser.add_option( - "--rc", - dest="rc", + parser.add_argument( + "--rank", default=False, action="store_true", - help="Average the fwd and rc predictions [Default: %default]", + help="Compute Spearman rank correlation [Default: %(default)s]", ) - parser.add_option( - "-s", - "--step", - dest="step", - default=1, - type="int", - help="Step across positions [Default: %default]", - ) - parser.add_option( - "--f16", - dest="f16", + parser.add_argument( + "--rc", default=False, action="store_true", - help="use mixed precision for inference", + help="Average the fwd and rc predictions [Default: %(default)s]", ) - parser.add_option( + parser.add_argument( "--save", - dest="save", default=False, action="store_true", - help="Save targets and predictions numpy arrays [Default: %default]", + help="Save targets and predictions numpy arrays [Default: %(default)s]", ) - parser.add_option( + parser.add_argument( "--shifts", - dest="shifts", default="0", - help="Ensemble prediction shifts [Default: %default]", + help="Ensemble prediction shifts [Default: %(default)s]", ) - parser.add_option( + parser.add_argument( + "--split", + default="test", + help="Dataset split label for eg TFR pattern [Default: %(default)s]", + ) + parser.add_argument( + "--step", + default=1, + type=int, + help="Step across positions [Default: %(default)s]", + ) + parser.add_argument( "-t", - dest="targets_file", + "--targets_file", default=None, - type="str", help="File specifying target indexes and labels in table format", ) - parser.add_option( - "--target_classes", - dest="target_classes", + parser.add_argument( + "--target_groups", default=None, - type="str", - help="comma separated string of target classes", - ) - parser.add_option( - "--split", - dest="split_label", - default="test", - help="Dataset split label for eg TFR pattern [Default: %default]", + type=str, + help="Comma separated string of target groups", ) - parser.add_option( + parser.add_argument( "--tfr", - dest="tfr_pattern", default=None, - help="TFR pattern string appended to data_dir/tfrecords for subsetting [Default: %default]", + help="Subsetting TFR pattern appended to data_dir/tfrecords [Default: %(default)s]", ) - parser.add_option( - "-v", - dest="high_var_pct", + parser.add_argument( + "--var_pct", default=1.0, - type="float", + type=float, help="Highly variable site proportion to take [Default: %default]", ) - (options, args) = parser.parse_args() + parser.add_argument("params_file", help="JSON file with model parameters") + parser.add_argument("model_file", help="Trained model file.") + parser.add_argument("data_dir", help="Train/valid/test data directory") + args = parser.parse_args() - if len(args) != 3: - parser.error("Must provide parameters, model, and test data HDF5") - else: - params_file = args[0] - model_file = args[1] - data_dir = args[2] - - if not os.path.isdir(options.out_dir): - os.mkdir(options.out_dir) + if not os.path.isdir(args.out_dir): + os.mkdir(args.out_dir) # parse shifts to integers - options.shifts = [int(shift) for shift in options.shifts.split(",")] + args.shifts = [int(shift) for shift in args.shifts.split(",")] ####################################################### # targets # read table - if options.targets_file is None: - options.targets_file = "%s/targets.txt" % data_dir - targets_df = pd.read_csv(options.targets_file, index_col=0, sep="\t") + if args.targets_file is None: + args.targets_file = f"{args.data_dir}/targets.txt" + targets_df = pd.read_csv(args.targets_file, index_col=0, sep="\t") num_targets = targets_df.shape[0] - # classify - target_classes = [] - - if options.target_classes is None: + # set target groups + if "group" not in targets_df.columns: + targets_group = [] for ti in range(num_targets): description = targets_df.iloc[ti].description if description.find(":") == -1: - tc = "*" + tg = "*" else: desc_split = description.split(":") if desc_split[0] == "CHIP": - tc = "/".join(desc_split[:2]) + tg = "/".join(desc_split[:2]) else: - tc = desc_split[0] - target_classes.append(tc) - targets_df["class"] = target_classes - target_classes = sorted(set(target_classes)) - else: - targets_df["class"] = targets_df["description"].str.replace( - ":.*", "", regex=True - ) - target_classes = options.target_classes.split(",") + tg = desc_split[0] + targets_group.append(tg) + targets_df["group"] = targets_group - print(target_classes) + if args.target_groups is None: + args.target_groups = sorted(set(targets_df.group)) ####################################################### - # model + # setup # read parameters - with open(params_file) as params_open: + with open(args.params_file) as params_open: params = json.load(params_open) params_model = params["model"] params_train = params["train"] @@ -192,31 +176,28 @@ def main(): # construct eval data eval_data = dataset.SeqDataset( - data_dir, - split_label=options.split_label, + args.data_dir, + split_label=args.split, batch_size=params_train["batch_size"], mode="eval", - tfr_pattern=options.tfr_pattern, + tfr_pattern=args.tfr, ) # initialize model - ################### - # mixed precision # - ################### - if options.f16: + if args.f16: mixed_precision.set_global_policy("mixed_float16") # set global policy seqnn_model = seqnn.SeqNN(params_model) # create model - seqnn_model.restore(model_file, options.head_i) + seqnn_model.restore(args.model_file, args.head_i) seqnn_model.append_activation() # add additional activation to cast float16 output to float32 else: # initialize model seqnn_model = seqnn.SeqNN(params_model) - seqnn_model.restore(model_file, options.head_i) + seqnn_model.restore(args.model_file, args.head_i) seqnn_model.build_slice(targets_df.index) - if options.step > 1: - seqnn_model.step(options.step) - seqnn_model.build_ensemble(options.rc, options.shifts) + if args.step > 1: + seqnn_model.step(args.step) + seqnn_model.build_ensemble(args.rc, args.shifts) ####################################################### # targets/predictions @@ -227,7 +208,6 @@ def main(): eval_preds = [] eval_targets = [] - si = 0 for x, y in tqdm(eval_data.dataset): # predict yh = seqnn_model(x) @@ -235,8 +215,8 @@ def main(): y = y.numpy().astype("float16") y = y[:, :, np.array(targets_df.index)] - if options.step > 1: - step_i = np.arange(0, eval_data.target_length, options.step) + if args.step > 1: + step_i = np.arange(0, eval_data.target_length, args.step) y = y[:, step_i, :] eval_targets.append(y) @@ -249,61 +229,62 @@ def main(): print("targets", eval_targets.shape) ####################################################### - # process classes + # process groups targets_spec = np.zeros(num_targets) - for tc in target_classes: - class_mask = np.array(targets_df["class"] == tc) - class_df = targets_df[class_mask] - num_targets_class = class_mask.sum() - print("%-15s %4d" % (tc, num_targets_class), flush=True) + for tg in args.target_groups: + group_mask = np.array(targets_df.group == tg) + group_df = targets_df[group_mask] + num_targets_group = group_mask.sum() + print("%-15s %4d" % (tg, num_targets_group), flush=True) - if num_targets_class < options.class_min: - targets_spec[class_mask] = np.nan + if num_targets_group < args.group_min: + targets_spec[group_mask] = np.nan else: - # slice class - eval_preds_class = eval_preds[:, :, class_mask] - eval_preds_class = eval_preds_class.reshape((-1, num_targets_class)) - eval_preds_class = eval_preds_class.astype("float32") - eval_targets_class = eval_targets[:, :, class_mask] - eval_targets_class = eval_targets_class.reshape((-1, num_targets_class)) - eval_targets_class = eval_targets_class.astype("float32") + # slice group + eval_preds_group = eval_preds[:, :, group_mask] + eval_preds_group = eval_preds_group.reshape((-1, num_targets_group)) + eval_targets_group = eval_targets[:, :, group_mask] + eval_targets_group = eval_targets_group.reshape((-1, num_targets_group)) # fix stranded stranded = False - if "strand_pair" in class_df.columns: - stranded = (class_df.strand_pair != class_df.index).all() + if "strand_pair" in group_df.columns: + stranded = (group_df.strand_pair != group_df.index).all() if stranded: # reshape to concat +/-, assuming they're adjacent - num_targets_class //= 2 - eval_preds_class = np.reshape(eval_preds_class, (-1, num_targets_class)) - eval_targets_class = np.reshape( - eval_targets_class, (-1, num_targets_class) + num_targets_group //= 2 + eval_preds_group = np.reshape(eval_preds_group, (-1, num_targets_group)) + eval_targets_group = np.reshape( + eval_targets_group, (-1, num_targets_group) ) + # quantile normalize + t0 = time.time() + print(" Quantile normalize...", flush=True, end="") + eval_preds_norm = quantile_normalize(eval_preds_group, ncpus=2) + del eval_preds_group + eval_targets_norm = quantile_normalize(eval_targets_group, ncpus=2) + del eval_targets_group + print("DONE in %ds" % (time.time() - t0)) + + # upcast + eval_preds_norm = eval_preds_norm.astype(np.float32) + eval_targets_norm = eval_targets_norm.astype(np.float32) + # highly variable filter - if options.high_var_pct < 1: + if args.var_pct < 1: t0 = time.time() print(" Highly variable position filter...", flush=True, end="") - eval_targets_var = eval_targets_class.var(axis=1) - high_var_t = np.percentile( - eval_targets_var, 100 * (1 - options.high_var_pct) - ) + eval_targets_var = eval_targets_group.var(axis=1) + high_var_t = np.percentile(eval_targets_var, 100 * (1 - args.var_pct)) high_var_mask = eval_targets_var >= high_var_t + eval_preds_norm = eval_preds_norm[high_var_mask] + eval_targets_norm = eval_targets_norm[high_var_mask] print("DONE in %ds" % (time.time() - t0)) - eval_preds_class = eval_preds_class[high_var_mask] - eval_targets_class = eval_targets_class[high_var_mask] - - # quantile normalize - t0 = time.time() - print(" Quantile normalize...", flush=True, end="") - eval_preds_norm = quantile_normalize(eval_preds_class, ncpus=2) - eval_targets_norm = quantile_normalize(eval_targets_class, ncpus=2) - print("DONE in %ds" % (time.time() - t0)) - # mean normalize eval_preds_norm -= eval_preds_norm.mean(axis=-1, keepdims=True) eval_targets_norm -= eval_targets_norm.mean(axis=-1, keepdims=True) @@ -311,25 +292,27 @@ def main(): # compute correlations t0 = time.time() print(" Compute correlations...", flush=True, end="") - pearsonr_class = np.zeros(num_targets_class) - for ti in range(num_targets_class): + pearsonr_group = np.zeros(num_targets_group) + for ti in range(num_targets_group): eval_preds_norm_ti = eval_preds_norm[:, ti] eval_targets_norm_ti = eval_targets_norm[:, ti] - pearsonr_class[ti] = pearsonr(eval_preds_norm_ti, eval_targets_norm_ti)[ + pearsonr_group[ti] = pearsonr(eval_preds_norm_ti, eval_targets_norm_ti)[ 0 ] print("DONE in %ds" % (time.time() - t0)) if stranded: - pearsonr_class = np.repeat(pearsonr_class, 2) + pearsonr_group = np.repeat(pearsonr_group, 2) # save - targets_spec[class_mask] = pearsonr_class + targets_spec[group_mask] = pearsonr_group # print - print(" PearsonR %.4f" % pearsonr_class[ti], flush=True) + print(" PearsonR %.4f" % pearsonr_group[ti], flush=True) # clean + del eval_preds_norm + del eval_targets_norm gc.collect() # write target-level statistics @@ -341,9 +324,8 @@ def main(): "description": targets_df.description, } ) - targets_acc_df.to_csv( - "%s/acc.txt" % options.out_dir, sep="\t", index=False, float_format="%.5f" - ) + acc_file = f"{args.out_dir}/acc.txt" + targets_acc_df.to_csv(acc_file, sep="\t", index=False, float_format="%.5f") ################################################################################