Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 18 additions & 18 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

# Py-CTCMetrics
A python implementation of the metrics used in the paper
[CHOTA: A Higher Order Accuracy Metric for Cell Tracking](https://arxiv.org/abs/2408.11571) by
[CHOTA: A Higher Order Accuracy Metric for Cell Tracking](https://link.springer.com/chapter/10.1007/978-3-031-91721-9_8) by
*Timo Kaiser et al.*. The code is
designed to evaluate tracking results in the format of the
[Cell-Tracking-Challenge](https://celltrackingchallenge.net/) but can also be used
Expand Down Expand Up @@ -156,23 +156,23 @@ Per default, all given metrics are evaluated. You can also select the metrics
you are interested in to avoid the calculation of metrics that are not in your
interest. Additional arguments to select a subset of specific metrics are:

| Argument | Description |
| --- |-----------------------------------------------------------------|
| --valid | Check if the result has valid format |
| --det | The DET detection metric |
| --seg | The SEG segmentation metric |
| --tra | The TRA tracking metric |
| --lnk | The LNK linking metric |
| --ct | The CT (complete tracks) metric |
| --tf | The TF (track fraction) metric |
| --bc | The BC(i) (branching correctness) metric |
| --cca | The CCA (cell cycle accuracy) metric |
| --mota | The MOTA (Multiple Object Tracking Accuracy) metric |
| --hota | The HOTA (Higher Order Tracking Accuracy) metric |
| --idf1 | The IDF1 (ID F1) metric |
| --chota | The CHOTA (Cell-specific Higher Order Tracking Accuracy) metric |
| --mtml | The MT (Mostly Tracked) and ML (Mostly Lost) metrics |
| --faf | The FAF (False Alarm per Frame) metric |
| Argument | Description |
|----------|--------------------------------------------------------------------------------------|
| --valid | Check if the result has valid format |
| --det | The DET detection metric |
| --seg | The SEG segmentation metric |
| --tra | The TRA tracking metric |
| --lnk | The LNK linking metric |
| --ct | The CT (complete tracks) metric |
| --tf | The TF (track fraction) metric |
| --bc i | The BC(i) (branching correctness) metric. Set i >= 3 to an calculate BC(0) to BC(i) |
| --cca | The CCA (cell cycle accuracy) metric |
| --mota | The MOTA (Multiple Object Tracking Accuracy) metric |
| --hota | The HOTA (Higher Order Tracking Accuracy) metric |
| --idf1 | The IDF1 (ID F1) metric |
| --chota | The CHOTA (Cell-specific Higher Order Tracking Accuracy) metric |
| --mtml | The MT (Mostly Tracked) and ML (Mostly Lost) metrics |
| --faf | The FAF (False Alarm per Frame) metric |
---

To use the evaluation protocol in your python code, the code can be imported
Expand Down
100 changes: 76 additions & 24 deletions ctc_metrics/metrics/biological/bc.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,12 @@ def is_matching(
mapped_comp: list,
ref_children: np.ndarray,
comp_children: np.ndarray,
tr: int,
tc: int
):
t_parent_end_ref: int,
t_parent_end_comp: int,
t_child_start_ref: list,
t_child_start_comp: list,
max_i: int,
): # pylint: disable=too-many-arguments,too-complex
"""
Checks if the reference and the computed track match.

Expand All @@ -66,31 +69,61 @@ def is_matching(
mapped_comp: The matched labels of the result masks.
ref_children: The children ids of the reference track.
comp_children: The children ids of the computed track.
tr: The frame of the reference track end.
tc: The frame of the computed track end.

t_parent_end_ref: The frame of the reference track end.
t_parent_end_comp: The frame of the computed track end.
t_child_start_ref: The frame of the reference track start.
t_child_start_comp: The frame of the computed track start.
max_i: The maximal time gap between ends of the reference and
computed mother tracks, and beginnings of daughter tracks.
Returns:
True if the reference and the computed track match, False otherwise.
"""
# Check if the number of children is the same
if len(ref_children) != len(comp_children):
return False
# Compare parents
t1, t2 = min(tr, tc), max(tr, tc)
mr, mc = mapped_ref[t1], mapped_comp[t1]
# Compare parents, for temporal distance and then for spatial overlap
if abs(t_parent_end_ref - t_parent_end_comp) > max_i:
return False
t_last_common = min(t_parent_end_ref, t_parent_end_comp)
mr, mc = mapped_ref[t_last_common], mapped_comp[t_last_common]
if np.sum(mc == id_comp) < 1 or np.sum(mr == id_ref) != 1:
return False
ind = np.argwhere(mr == id_ref).squeeze()
if mc[ind] != id_comp:
return False
# Compare children
mr, mc = np.asarray(mapped_ref[t2 + 1]), np.asarray(mapped_comp[t2 + 1])
if not np.all(np.isin(comp_children, mc)):
# Check if the parent match is unique, i.e. the computed track is only assigned to one gt
if mc.count(mc[ind]) != 1:
return False
if not np.all(np.isin(mr[np.isin(mc, comp_children)], ref_children)):
# Compare children
# Iterate over all GT ids and check if the first detection is matched to the correct reference children
# See discussion here https://github.com/CellTrackingChallenge/py-ctcmetrics/issues/22
matched_children = []
for i, t_ref in zip(ref_children, t_child_start_ref):
for j, t_comp in zip(comp_children, t_child_start_comp):
# Check if start frames of the daughters are close enough <= i_max
temporal_error = abs(t_ref - t_comp)
if temporal_error > max_i:
continue
# Verify if children are overlapping spatially
t_max = max(t_ref, t_comp)
# Check if the daughter match is unique, i.e. the computed track is only assigned to one gt
if mapped_comp[t_max].count(j) != 1:
continue
if i in mapped_ref[t_max] and j in mapped_comp[t_max]:
ind = mapped_ref[t_max].index(i)
if mapped_comp[t_max][ind] == j:
# There is a match!
if j not in matched_children:
matched_children.append(j)
break


if len(matched_children) != len(ref_children):
return False

return True


def raw_division_metrics(
comp_tracks: np.ndarray,
ref_tracks: np.ndarray,
Expand All @@ -100,7 +133,7 @@ def raw_division_metrics(
):
"""
Computes number of true positives, false positives, and false negatives for divisions.

Args:
comp_tracks: The result tracks. A (n,4) numpy ndarray with columns:
- label
Expand All @@ -122,7 +155,7 @@ def raw_division_metrics(
matched labels of the result masks in the respective frame. The
elements are in the same order as the corresponding elements in
mapped_ref.
i: The maximal allowed error in frames.
i: The maximal allowed temporal error (offset) in frames.

Returns:
Tuple of true positives, false positives, and false negatives.
Expand All @@ -136,36 +169,51 @@ def raw_division_metrics(
ends_with_split_comp = get_ids_that_ends_with_split(comp_tracks)
t_comp = np.asarray([comp_tracks[comp_tracks[:, 0] == comp][0, 2]
for comp in ends_with_split_comp])
# If there are no divisions in the reference

# If there are no divisions in the reference
if len(ends_with_split_ref) == 0:
return (0, len(ends_with_split_comp), 0)

# If there are no divisions in the computed result
if len(ends_with_split_comp) == 0:
return (0, 0, len(ends_with_split_ref))

# Find all matches between reference and computed branching events (mitosis)
matches = []
for comp, tc in zip(ends_with_split_comp, t_comp):
for comp, t_parent_end_start in zip(ends_with_split_comp, t_comp):
# Find potential matches
pot_matches = np.abs(t_ref - tc) <= i
pot_matches = np.abs(t_ref - t_parent_end_start) <= i
if len(pot_matches) == 0:
continue
comp_children = comp_tracks[comp_tracks[:, 3] == comp][:, 0]
t_child_start_comp = []
for j in comp_children:
t = comp_tracks[comp_tracks[:, 0] == j][0, 1]
t_child_start_comp.append(t)
# Evaluate potential matches
for ref, tr in zip(
for ref, t_parent_end_ref in zip(
ends_with_split_ref[pot_matches],
t_ref[pot_matches]
):
ref_children = ref_tracks[ref_tracks[:, 3] == ref][:, 0]
t_child_start_ref = []
for j in ref_children:
t = ref_tracks[ref_tracks[:, 0] == j][0, 1]
t_child_start_ref.append(t)
if is_matching(
comp, ref, mapped_ref, mapped_comp, ref_children,
comp_children, tr, tc
comp, ref,
mapped_ref, mapped_comp,
ref_children, comp_children,
t_parent_end_ref, t_parent_end_start,
t_child_start_ref,
t_child_start_comp,
i
):
matches.append((ref, comp))

return (len(matches), len(ends_with_split_comp) - len(matches), len(ends_with_split_ref) - len(matches))


def bc(
tp: int,
fp: int,
Expand All @@ -184,5 +232,9 @@ def bc(
Returns:
The branching correctness metric.
"""
# Return None if no split is existing in the reference data
if (tp + fn) == 0:
return None

# Calculate BC(i)
return calculate_f1_score(tp, fp, fn)
24 changes: 18 additions & 6 deletions ctc_metrics/scripts/evaluate.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import warnings
import argparse
from os.path import join, basename
from multiprocessing import Pool, cpu_count
Expand Down Expand Up @@ -128,7 +129,7 @@ def calculate_metrics(
segm: dict,
metrics: list = None,
is_valid: bool = None,
): # pylint: disable=too-complex
): # pylint: disable=too-complex,too-many-branches
"""
Calculate metrics for given data.

Expand Down Expand Up @@ -171,6 +172,17 @@ def calculate_metrics(
traj["labels_comp_merged"] = new_labels
traj["mapped_comp_merged"] = new_mapped

# Check if a manual i was defined for BC(i)
max_i_for_bci = 3
for m in metrics:
if m.startswith("BC("):
try:
max_i_for_bci = max(max_i_for_bci, int(m[3:-1]))
if "BC" not in metrics:
metrics.append("BC")
except ValueError:
warnings.warn(f"{m} is not a valid metric identifier!.")

# Prepare intermediate results
graph_operations = {}
if "DET" in metrics or "TRA" in metrics:
Expand Down Expand Up @@ -225,7 +237,7 @@ def calculate_metrics(
traj["labels_ref"], traj["mapped_ref"], traj["mapped_comp"])

if "BC" in metrics:
for i in range(4):
for i in range(max_i_for_bci+1):
tp, fp, fn = raw_division_metrics(comp_tracks, ref_tracks,
traj["mapped_ref"], traj["mapped_comp"],
i=i)
Expand All @@ -243,13 +255,13 @@ def calculate_metrics(

if "CT" in metrics and "BC" in metrics and \
"CCA" in metrics and "TF" in metrics:
for i in range(4):
for i in range(max_i_for_bci+1):
results[f"BIO({i})"] = bio(
results["CT"], results["TF"],
results[f"BC({i})"], results["CCA"])

if "BIO" in results and "LNK" in results:
for i in range(4):
for i in range(max_i_for_bci+1):
results[f"OP_CLB({i})"] = op_clb(
results["LNK"], results[f"BIO({i})"])

Expand Down Expand Up @@ -365,7 +377,7 @@ def parse_args():
parser.add_argument('--tra', action="store_true")
parser.add_argument('--ct', action="store_true")
parser.add_argument('--tf', action="store_true")
parser.add_argument('--bc', action="store_true")
parser.add_argument('--bc', type=int, default=0)
parser.add_argument('--cca', action="store_true")
parser.add_argument('--mota', action="store_true")
parser.add_argument('--hota', action="store_true")
Expand All @@ -391,7 +403,7 @@ def main():
("TRA", args.tra),
("CT", args.ct),
("TF", args.tf),
("BC", args.bc),
(f"BC({args.bc})", args.bc),
("CCA", args.cca),
("MOTA", args.mota),
("HOTA", args.hota),
Expand Down
11 changes: 5 additions & 6 deletions ctc_metrics/utils/handle_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,23 +8,22 @@ def print_results(results: dict):
results: A dictionary containing the results.
"""

def print_line(metrics: dict):
def print_block(metrics: dict):
"""
Prints a line of the table.

Args:
metrics: A list containing the arguments for the line.
"""

print(*[f"{k}: {'N/A' if v is None else float(v):.5},\t" for k, v
in metrics.items()])
for k, v in metrics.items():
print(f"{k}: {'N/A' if v is None else float(v):.5}")

if isinstance(results, dict):
print_line(results)
print_block(results)
elif isinstance(results, list):
for res in results:
print(res[0], end=":\t\t")
print_line(res[1])
print_block(res[1])


def store_results(
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

setup(
name="py-ctcmetrics",
version="1.3.2",
version="1.3.3",
packages=find_packages(),
install_requires=[
"numpy",
Expand Down
Loading