diff --git a/README.md b/README.md index fee7ff6..82efa06 100644 --- a/README.md +++ b/README.md @@ -5,7 +5,7 @@ # Py-CTCMetrics A python implementation of the metrics used in the paper -[CHOTA: A Higher Order Accuracy Metric for Cell Tracking](https://arxiv.org/abs/2408.11571) by +[CHOTA: A Higher Order Accuracy Metric for Cell Tracking](https://link.springer.com/chapter/10.1007/978-3-031-91721-9_8) by *Timo Kaiser et al.*. The code is designed to evaluate tracking results in the format of the [Cell-Tracking-Challenge](https://celltrackingchallenge.net/) but can also be used @@ -156,23 +156,23 @@ Per default, all given metrics are evaluated. You can also select the metrics you are interested in to avoid the calculation of metrics that are not in your interest. Additional arguments to select a subset of specific metrics are: -| Argument | Description | -| --- |-----------------------------------------------------------------| -| --valid | Check if the result has valid format | -| --det | The DET detection metric | -| --seg | The SEG segmentation metric | -| --tra | The TRA tracking metric | -| --lnk | The LNK linking metric | -| --ct | The CT (complete tracks) metric | -| --tf | The TF (track fraction) metric | -| --bc | The BC(i) (branching correctness) metric | -| --cca | The CCA (cell cycle accuracy) metric | -| --mota | The MOTA (Multiple Object Tracking Accuracy) metric | -| --hota | The HOTA (Higher Order Tracking Accuracy) metric | -| --idf1 | The IDF1 (ID F1) metric | -| --chota | The CHOTA (Cell-specific Higher Order Tracking Accuracy) metric | -| --mtml | The MT (Mostly Tracked) and ML (Mostly Lost) metrics | -| --faf | The FAF (False Alarm per Frame) metric | +| Argument | Description | +|----------|--------------------------------------------------------------------------------------| +| --valid | Check if the result has valid format | +| --det | The DET detection metric | +| --seg | The SEG segmentation metric | +| --tra | The TRA tracking metric | +| --lnk | The LNK linking metric | +| --ct | The CT (complete tracks) metric | +| --tf | The TF (track fraction) metric | +| --bc i | The BC(i) (branching correctness) metric. Set i >= 3 to an calculate BC(0) to BC(i) | +| --cca | The CCA (cell cycle accuracy) metric | +| --mota | The MOTA (Multiple Object Tracking Accuracy) metric | +| --hota | The HOTA (Higher Order Tracking Accuracy) metric | +| --idf1 | The IDF1 (ID F1) metric | +| --chota | The CHOTA (Cell-specific Higher Order Tracking Accuracy) metric | +| --mtml | The MT (Mostly Tracked) and ML (Mostly Lost) metrics | +| --faf | The FAF (False Alarm per Frame) metric | --- To use the evaluation protocol in your python code, the code can be imported diff --git a/ctc_metrics/metrics/biological/bc.py b/ctc_metrics/metrics/biological/bc.py index 93bf46b..0f988d9 100644 --- a/ctc_metrics/metrics/biological/bc.py +++ b/ctc_metrics/metrics/biological/bc.py @@ -53,9 +53,12 @@ def is_matching( mapped_comp: list, ref_children: np.ndarray, comp_children: np.ndarray, - tr: int, - tc: int -): + t_parent_end_ref: int, + t_parent_end_comp: int, + t_child_start_ref: list, + t_child_start_comp: list, + max_i: int, +): # pylint: disable=too-many-arguments,too-complex """ Checks if the reference and the computed track match. @@ -66,31 +69,61 @@ def is_matching( mapped_comp: The matched labels of the result masks. ref_children: The children ids of the reference track. comp_children: The children ids of the computed track. - tr: The frame of the reference track end. - tc: The frame of the computed track end. - + t_parent_end_ref: The frame of the reference track end. + t_parent_end_comp: The frame of the computed track end. + t_child_start_ref: The frame of the reference track start. + t_child_start_comp: The frame of the computed track start. + max_i: The maximal time gap between ends of the reference and + computed mother tracks, and beginnings of daughter tracks. Returns: True if the reference and the computed track match, False otherwise. """ # Check if the number of children is the same if len(ref_children) != len(comp_children): return False - # Compare parents - t1, t2 = min(tr, tc), max(tr, tc) - mr, mc = mapped_ref[t1], mapped_comp[t1] + # Compare parents, for temporal distance and then for spatial overlap + if abs(t_parent_end_ref - t_parent_end_comp) > max_i: + return False + t_last_common = min(t_parent_end_ref, t_parent_end_comp) + mr, mc = mapped_ref[t_last_common], mapped_comp[t_last_common] if np.sum(mc == id_comp) < 1 or np.sum(mr == id_ref) != 1: return False ind = np.argwhere(mr == id_ref).squeeze() if mc[ind] != id_comp: return False - # Compare children - mr, mc = np.asarray(mapped_ref[t2 + 1]), np.asarray(mapped_comp[t2 + 1]) - if not np.all(np.isin(comp_children, mc)): + # Check if the parent match is unique, i.e. the computed track is only assigned to one gt + if mc.count(mc[ind]) != 1: return False - if not np.all(np.isin(mr[np.isin(mc, comp_children)], ref_children)): + # Compare children + # Iterate over all GT ids and check if the first detection is matched to the correct reference children + # See discussion here https://github.com/CellTrackingChallenge/py-ctcmetrics/issues/22 + matched_children = [] + for i, t_ref in zip(ref_children, t_child_start_ref): + for j, t_comp in zip(comp_children, t_child_start_comp): + # Check if start frames of the daughters are close enough <= i_max + temporal_error = abs(t_ref - t_comp) + if temporal_error > max_i: + continue + # Verify if children are overlapping spatially + t_max = max(t_ref, t_comp) + # Check if the daughter match is unique, i.e. the computed track is only assigned to one gt + if mapped_comp[t_max].count(j) != 1: + continue + if i in mapped_ref[t_max] and j in mapped_comp[t_max]: + ind = mapped_ref[t_max].index(i) + if mapped_comp[t_max][ind] == j: + # There is a match! + if j not in matched_children: + matched_children.append(j) + break + + + if len(matched_children) != len(ref_children): return False + return True + def raw_division_metrics( comp_tracks: np.ndarray, ref_tracks: np.ndarray, @@ -100,7 +133,7 @@ def raw_division_metrics( ): """ Computes number of true positives, false positives, and false negatives for divisions. - + Args: comp_tracks: The result tracks. A (n,4) numpy ndarray with columns: - label @@ -122,7 +155,7 @@ def raw_division_metrics( matched labels of the result masks in the respective frame. The elements are in the same order as the corresponding elements in mapped_ref. - i: The maximal allowed error in frames. + i: The maximal allowed temporal error (offset) in frames. Returns: Tuple of true positives, false positives, and false negatives. @@ -136,36 +169,51 @@ def raw_division_metrics( ends_with_split_comp = get_ids_that_ends_with_split(comp_tracks) t_comp = np.asarray([comp_tracks[comp_tracks[:, 0] == comp][0, 2] for comp in ends_with_split_comp]) - - # If there are no divisions in the reference + + # If there are no divisions in the reference if len(ends_with_split_ref) == 0: return (0, len(ends_with_split_comp), 0) - + # If there are no divisions in the computed result if len(ends_with_split_comp) == 0: return (0, 0, len(ends_with_split_ref)) - + # Find all matches between reference and computed branching events (mitosis) matches = [] - for comp, tc in zip(ends_with_split_comp, t_comp): + for comp, t_parent_end_start in zip(ends_with_split_comp, t_comp): # Find potential matches - pot_matches = np.abs(t_ref - tc) <= i + pot_matches = np.abs(t_ref - t_parent_end_start) <= i if len(pot_matches) == 0: continue comp_children = comp_tracks[comp_tracks[:, 3] == comp][:, 0] + t_child_start_comp = [] + for j in comp_children: + t = comp_tracks[comp_tracks[:, 0] == j][0, 1] + t_child_start_comp.append(t) # Evaluate potential matches - for ref, tr in zip( + for ref, t_parent_end_ref in zip( ends_with_split_ref[pot_matches], t_ref[pot_matches] ): ref_children = ref_tracks[ref_tracks[:, 3] == ref][:, 0] + t_child_start_ref = [] + for j in ref_children: + t = ref_tracks[ref_tracks[:, 0] == j][0, 1] + t_child_start_ref.append(t) if is_matching( - comp, ref, mapped_ref, mapped_comp, ref_children, - comp_children, tr, tc + comp, ref, + mapped_ref, mapped_comp, + ref_children, comp_children, + t_parent_end_ref, t_parent_end_start, + t_child_start_ref, + t_child_start_comp, + i ): matches.append((ref, comp)) + return (len(matches), len(ends_with_split_comp) - len(matches), len(ends_with_split_ref) - len(matches)) + def bc( tp: int, fp: int, @@ -184,5 +232,9 @@ def bc( Returns: The branching correctness metric. """ + # Return None if no split is existing in the reference data + if (tp + fn) == 0: + return None + # Calculate BC(i) return calculate_f1_score(tp, fp, fn) diff --git a/ctc_metrics/scripts/evaluate.py b/ctc_metrics/scripts/evaluate.py index da650f0..33c1e76 100644 --- a/ctc_metrics/scripts/evaluate.py +++ b/ctc_metrics/scripts/evaluate.py @@ -1,3 +1,4 @@ +import warnings import argparse from os.path import join, basename from multiprocessing import Pool, cpu_count @@ -128,7 +129,7 @@ def calculate_metrics( segm: dict, metrics: list = None, is_valid: bool = None, -): # pylint: disable=too-complex +): # pylint: disable=too-complex,too-many-branches """ Calculate metrics for given data. @@ -171,6 +172,17 @@ def calculate_metrics( traj["labels_comp_merged"] = new_labels traj["mapped_comp_merged"] = new_mapped + # Check if a manual i was defined for BC(i) + max_i_for_bci = 3 + for m in metrics: + if m.startswith("BC("): + try: + max_i_for_bci = max(max_i_for_bci, int(m[3:-1])) + if "BC" not in metrics: + metrics.append("BC") + except ValueError: + warnings.warn(f"{m} is not a valid metric identifier!.") + # Prepare intermediate results graph_operations = {} if "DET" in metrics or "TRA" in metrics: @@ -225,7 +237,7 @@ def calculate_metrics( traj["labels_ref"], traj["mapped_ref"], traj["mapped_comp"]) if "BC" in metrics: - for i in range(4): + for i in range(max_i_for_bci+1): tp, fp, fn = raw_division_metrics(comp_tracks, ref_tracks, traj["mapped_ref"], traj["mapped_comp"], i=i) @@ -243,13 +255,13 @@ def calculate_metrics( if "CT" in metrics and "BC" in metrics and \ "CCA" in metrics and "TF" in metrics: - for i in range(4): + for i in range(max_i_for_bci+1): results[f"BIO({i})"] = bio( results["CT"], results["TF"], results[f"BC({i})"], results["CCA"]) if "BIO" in results and "LNK" in results: - for i in range(4): + for i in range(max_i_for_bci+1): results[f"OP_CLB({i})"] = op_clb( results["LNK"], results[f"BIO({i})"]) @@ -365,7 +377,7 @@ def parse_args(): parser.add_argument('--tra', action="store_true") parser.add_argument('--ct', action="store_true") parser.add_argument('--tf', action="store_true") - parser.add_argument('--bc', action="store_true") + parser.add_argument('--bc', type=int, default=0) parser.add_argument('--cca', action="store_true") parser.add_argument('--mota', action="store_true") parser.add_argument('--hota', action="store_true") @@ -391,7 +403,7 @@ def main(): ("TRA", args.tra), ("CT", args.ct), ("TF", args.tf), - ("BC", args.bc), + (f"BC({args.bc})", args.bc), ("CCA", args.cca), ("MOTA", args.mota), ("HOTA", args.hota), diff --git a/ctc_metrics/utils/handle_results.py b/ctc_metrics/utils/handle_results.py index c8dcd54..4bfd9a0 100644 --- a/ctc_metrics/utils/handle_results.py +++ b/ctc_metrics/utils/handle_results.py @@ -8,23 +8,22 @@ def print_results(results: dict): results: A dictionary containing the results. """ - def print_line(metrics: dict): + def print_block(metrics: dict): """ Prints a line of the table. Args: metrics: A list containing the arguments for the line. """ - - print(*[f"{k}: {'N/A' if v is None else float(v):.5},\t" for k, v - in metrics.items()]) + for k, v in metrics.items(): + print(f"{k}: {'N/A' if v is None else float(v):.5}") if isinstance(results, dict): - print_line(results) + print_block(results) elif isinstance(results, list): for res in results: print(res[0], end=":\t\t") - print_line(res[1]) + print_block(res[1]) def store_results( diff --git a/setup.py b/setup.py index 692c302..219d0dd 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ setup( name="py-ctcmetrics", - version="1.3.2", + version="1.3.3", packages=find_packages(), install_requires=[ "numpy",