From 8a6a745aaa0d241471971f2eb428eaed0cb82cf6 Mon Sep 17 00:00:00 2001 From: colganwi Date: Tue, 24 Mar 2026 12:39:57 -0400 Subject: [PATCH 1/2] feat(tl): add ancestral_linkage for measuring tree-based category relatedness Computes pairwise or single-target linkage scores between cell categories using path distance or LCA depth on the lineage tree. Supports permutation testing, parallel execution (fork-based), symmetrization, and per-tree stats. Co-Authored-By: Claude Sonnet 4.6 --- pyproject.toml | 1 + src/pycea/tl/__init__.py | 1 + src/pycea/tl/ancestral_linkage.py | 708 ++++++++++++++++++++++++++++++ tests/test_ancestral_linkage.py | 530 ++++++++++++++++++++++ 4 files changed, 1240 insertions(+) create mode 100644 src/pycea/tl/ancestral_linkage.py create mode 100644 tests/test_ancestral_linkage.py diff --git a/pyproject.toml b/pyproject.toml index 7405979..d8bb570 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,6 +34,7 @@ dependencies = [ "scikit-learn", "scipy", "session-info", + "tqdm", "treedata>=0.2.2", ] diff --git a/src/pycea/tl/__init__.py b/src/pycea/tl/__init__.py index 57bca65..3a63b42 100644 --- a/src/pycea/tl/__init__.py +++ b/src/pycea/tl/__init__.py @@ -1,3 +1,4 @@ +from .ancestral_linkage import ancestral_linkage from .ancestral_states import ancestral_states from .autocorr import autocorr from .clades import clades diff --git a/src/pycea/tl/ancestral_linkage.py b/src/pycea/tl/ancestral_linkage.py new file mode 100644 index 0000000..54d5809 --- /dev/null +++ b/src/pycea/tl/ancestral_linkage.py @@ -0,0 +1,708 @@ +from __future__ import annotations + +import multiprocessing as mp +import sys +import warnings +from collections import defaultdict +from collections.abc import Callable, Sequence +from typing import Literal, overload + +import networkx as nx +import numpy as np +import pandas as pd +import scipy as sp +import treedata as td +from tqdm import tqdm + +from pycea.utils import _check_tree_overlap, check_tree_has_key, get_leaves, get_trees + +from ._aggregators import _get_aggregator +from ._metrics import _TreeMetric +from ._utils import _set_random_state + +# ── internal helpers ────────────────────────────────────────────────────────── + + +def _dijkstra_min_scores( + tree: nx.DiGraph, + source_leaves: list, + target_cats: list, + cat_to_leaves_in_tree: dict, + depth_key: str, + metric: str, +) -> dict: + """Per-leaf 'closest target' score for each category via multi-source Dijkstra. + + Uses ``|depth[u] - depth[v]|`` as edge weights to compute path distances. + For ``metric='lca'``, converts path distance to LCA depth via the identity: + + lca(i, j) = (depth[i] + depth[j] - path_dist(i, j)) / 2 + + Self-distances are excluded for within-category scoring. + """ + G = nx.Graph(tree) + + def _weight(u, v, _): + return abs(G.nodes[u][depth_key] - G.nodes[v][depth_key]) + + scores: dict = {leaf: {} for leaf in source_leaves} + + for cat in target_cats: + target_leaves = cat_to_leaves_in_tree.get(cat, []) + if not target_leaves: + continue + + if metric == "lca": + dists, paths = nx.multi_source_dijkstra(G, target_leaves, weight=_weight) + for leaf in source_leaves: + if leaf in dists: + nearest = paths[leaf][0] # path goes source → ... → leaf + scores[leaf][cat] = (G.nodes[leaf][depth_key] + G.nodes[nearest][depth_key] - dists[leaf]) / 2 + else: + dists = nx.multi_source_dijkstra_path_length(G, target_leaves, weight=_weight) + for leaf in source_leaves: + if leaf in dists: + scores[leaf][cat] = dists[leaf] + + return scores + + +def _all_pairs_scores( + tdata: td.TreeData, + trees: dict, + source_leaves_by_tree: dict, + target_cats: list, + cat_to_leaves_by_tree: dict, + depth_key: str, + metric: str, + agg_fn: Callable, +) -> dict: + """Per-leaf aggregated distance to each target category via all-pairs distance matrix.""" + # Use precomputed dense distances if available + precomputed = None + if "tree_distances" in tdata.obsp: + D = tdata.obsp["tree_distances"] + if isinstance(D, np.ndarray): + precomputed = D + + if precomputed is None: + from pycea.tl.tree_distance import tree_distance as _td + + tree_keys = list(trees.keys()) + result = _td(tdata, depth_key=depth_key, metric=metric, tree=tree_keys, copy=True) + precomputed = result.toarray() if sp.sparse.issparse(result) else result + + scores: dict = {} + obs_names = tdata.obs_names + for tree_key, t_leaves in source_leaves_by_tree.items(): + for leaf in t_leaves: + if leaf not in obs_names: + continue + src_idx = obs_names.get_loc(leaf) + leaf_scores: dict = {} + for cat in target_cats: + tgt_leaves = cat_to_leaves_by_tree[tree_key].get(cat, []) + tgt_indices = [obs_names.get_loc(l) for l in tgt_leaves if l in obs_names] + if not tgt_indices: + continue + row = precomputed[src_idx, tgt_indices] + leaf_scores[cat] = float(agg_fn(row)) + scores[leaf] = leaf_scores + + return scores + + +def _compute_scores( + tdata: td.TreeData, + trees: dict, + leaf_to_cat: dict, + target_cats: list, + aggregate: str | Callable, + metric: str, + depth_key: str, +) -> dict: + """Route to the appropriate per-leaf scoring method and return leaf → {cat → score}.""" + # Build per-tree leaf / category maps + source_leaves_by_tree: dict = {} + cat_to_leaves_by_tree: dict = {} + for tree_key, t in trees.items(): + t_leaves = [l for l in get_leaves(t) if l in leaf_to_cat] + source_leaves_by_tree[tree_key] = t_leaves + cat_to_leaves_by_tree[tree_key] = defaultdict(list) + for l in t_leaves: + cat_to_leaves_by_tree[tree_key][leaf_to_cat[l]].append(l) + + # Choose strategy: Dijkstra handles the natural "closest" direction for each metric + is_named = isinstance(aggregate, str) + use_dijkstra = is_named and ((aggregate == "min" and metric == "path") or (aggregate == "max" and metric == "lca")) + + if use_dijkstra: + all_scores: dict = {} + for tree_key, t in trees.items(): + tree_scores = _dijkstra_min_scores( + t, + source_leaves_by_tree[tree_key], + target_cats, + cat_to_leaves_by_tree[tree_key], + depth_key, + metric, + ) + all_scores.update(tree_scores) + return all_scores + + # Fallback: all-pairs distance matrix (mean, max-path, or custom callable) + agg_fn = _get_aggregator(aggregate) if is_named else aggregate # type: ignore[arg-type] + return _all_pairs_scores( + tdata, + trees, + source_leaves_by_tree, + target_cats, + cat_to_leaves_by_tree, + depth_key, + metric, + agg_fn, + ) + + +def _scores_to_linkage_matrix( + all_scores: dict, + all_cats: list, + cat_to_leaves: dict, +) -> pd.DataFrame: + """Aggregate per-leaf scores to a (source category × target category) DataFrame.""" + matrix: dict = {} + for src_cat in all_cats: + row: dict = {} + src_leaves = [l for l in cat_to_leaves.get(src_cat, []) if l in all_scores] + for tgt_cat in all_cats: + values = [all_scores[l][tgt_cat] for l in src_leaves if tgt_cat in all_scores.get(l, {})] + row[tgt_cat] = float(np.mean(values)) if values else np.nan + matrix[src_cat] = row + return pd.DataFrame(matrix, dtype=float).T # index=src, columns=tgt + + +def _symmetrize_matrix(df: pd.DataFrame, mode: str) -> pd.DataFrame: + """Symmetrize a square DataFrame in-place.""" + arr = df.values.astype(float) + arr_T = arr.T + if mode == "mean": + sym = (arr + arr_T) / 2 + elif mode == "max": + sym = np.maximum(arr, arr_T) + else: # min + sym = np.minimum(arr, arr_T) + return pd.DataFrame(sym, index=df.index, columns=df.columns) + + +# ── fork-based parallel permutation workers ─────────────────────────────────── +# These are module-level functions so they are picklable (required even with +# fork-based pools, since tasks are dispatched via a queue). Heavy shared data +# (tdata, trees, …) is placed in the module globals below immediately before the +# pool is created; fork copies the parent's address space to child processes +# via copy-on-write, so no serialisation overhead is incurred for that data. + +_PERM_PAIRWISE_DATA: dict = {} +_PERM_SINGLE_DATA: dict = {} + + +def _perm_pairwise_worker(seed: int) -> np.ndarray: + """Run one pairwise permutation; shared data inherited from parent via fork.""" + d = _PERM_PAIRWISE_DATA + rng = np.random.default_rng(seed) + perm_cats = rng.permutation(d["all_cat_vals"]) + perm_leaf_to_cat = dict(zip(d["all_leaves"], perm_cats)) + perm_scores = _compute_scores( + d["tdata"], d["trees"], perm_leaf_to_cat, + d["target_cats"], d["aggregate"], d["metric"], d["depth_key"], + ) + perm_cat_to_leaves: dict = defaultdict(list) + for leaf, cat in perm_leaf_to_cat.items(): + perm_cat_to_leaves[cat].append(leaf) + perm_df = _scores_to_linkage_matrix(perm_scores, d["all_cats"], perm_cat_to_leaves) + return perm_df.reindex(index=d["index"], columns=d["columns"]).values.astype(float) + + +def _perm_single_target_worker(seed: int) -> dict: + """Run one single-target permutation; shared data inherited from parent via fork.""" + d = _PERM_SINGLE_DATA + rng = np.random.default_rng(seed) + perm = rng.permutation(d["all_cat_vals"]) + perm_leaf_to_cat = dict(zip(d["all_leaves"], perm)) + perm_scores = _compute_scores( + d["tdata"], d["trees"], perm_leaf_to_cat, + [d["target"]], d["single_agg"], d["metric"], d["depth_key"], + ) + perm_score_map = {leaf: s.get(d["target"], np.nan) for leaf, s in perm_scores.items()} + perm_cat_to_leaves: dict = defaultdict(list) + for leaf, cat in perm_leaf_to_cat.items(): + perm_cat_to_leaves[cat].append(leaf) + result: dict = {} + for cat in d["all_cats"]: + vals = [ + perm_score_map[l] + for l in perm_cat_to_leaves[cat] + if l in perm_score_map and not np.isnan(perm_score_map[l]) + ] + result[cat] = float(np.mean(vals)) if vals else np.nan + return result + + +def _run_parallel(worker_fn: Callable, seeds: np.ndarray, n_threads: int | None) -> list: + """Run *worker_fn(seed)* for each seed, optionally in parallel via fork processes.""" + max_workers = n_threads if n_threads is not None else 1 + + if max_workers > 1 and sys.platform == "linux": + ctx = mp.get_context("fork") + with ctx.Pool(max_workers) as pool: + return list(tqdm( + pool.imap_unordered(worker_fn, seeds), + total=len(seeds), desc="Permutations", leave=False, + )) + + # Single-threaded fallback (also used on non-Linux platforms) + return [worker_fn(seed) for seed in tqdm(seeds, desc="Permutations", leave=False)] + + +def _run_permutation_test( + tdata: td.TreeData, + trees: dict, + leaf_to_cat: dict, + all_cats: list, + target_cats: list, + observed_df: pd.DataFrame, + aggregate: str | Callable, + metric: str, + depth_key: str, + n_permutations: int, + n_threads: int | None, +) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: + """Permutation test: shuffle leaf labels, recompute linkage, return (z_score_df, p_value_df, null_mean_df).""" + all_leaves = list(leaf_to_cat.keys()) + perm_seeds = np.random.randint(0, 2**31, size=n_permutations) + + _PERM_PAIRWISE_DATA.clear() + _PERM_PAIRWISE_DATA.update({ + "tdata": tdata, + "trees": trees, + "all_leaves": all_leaves, + "all_cat_vals": [leaf_to_cat[l] for l in all_leaves], + "target_cats": target_cats, + "all_cats": all_cats, + "aggregate": aggregate, + "metric": metric, + "depth_key": depth_key, + "index": observed_df.index, + "columns": observed_df.columns, + }) + + null_matrices = _run_parallel(_perm_pairwise_worker, perm_seeds, n_threads) + + null_array = np.array(null_matrices) # (n_permutations, k, k) + null_mean = np.nanmean(null_array, axis=0) + null_std = np.nanstd(null_array, axis=0) + obs_values = observed_df.values.astype(float) + + # Positive z → more related than expected by chance + sign = 1.0 if metric == "lca" else -1.0 + z_scores = sign * (obs_values - null_mean) / (null_std + 1e-10) + + # One-tailed p-value in the "more related" direction + if metric == "lca": + p_values = np.nanmean(null_array >= obs_values[np.newaxis], axis=0) + else: + p_values = np.nanmean(null_array <= obs_values[np.newaxis], axis=0) + + z_score_df = pd.DataFrame(z_scores, index=observed_df.index, columns=observed_df.columns) + p_value_df = pd.DataFrame(p_values, index=observed_df.index, columns=observed_df.columns) + null_mean_df = pd.DataFrame(null_mean, index=observed_df.index, columns=observed_df.columns) + return z_score_df, p_value_df, null_mean_df + + +# ── public API ──────────────────────────────────────────────────────────────── + + +@overload +def ancestral_linkage( + tdata: td.TreeData, + groupby: str, + target: str | None = None, + aggregate: Literal["min", "max", "mean"] | Callable | None = None, + metric: _TreeMetric = "path", + symmetrize: Literal["mean", "max", "min", None] = None, + test: Literal["permutation", None] = None, + n_permutations: int = 100, + n_threads: int | None = None, + by_tree: bool = False, + depth_key: str = "depth", + random_state: int | None = None, + key_added: str | None = None, + tree: str | Sequence[str] | None = None, + copy: Literal[True] = ..., +) -> pd.DataFrame: ... + + +@overload +def ancestral_linkage( + tdata: td.TreeData, + groupby: str, + target: str | None = None, + aggregate: Literal["min", "max", "mean"] | Callable | None = None, + metric: _TreeMetric = "path", + symmetrize: Literal["mean", "max", "min", None] = None, + test: Literal["permutation", None] = None, + n_permutations: int = 100, + n_threads: int | None = None, + by_tree: bool = False, + depth_key: str = "depth", + random_state: int | None = None, + key_added: str | None = None, + tree: str | Sequence[str] | None = None, + copy: Literal[False] = ..., +) -> None: ... + + +def ancestral_linkage( + tdata: td.TreeData, + groupby: str, + target: str | None = None, + metric: _TreeMetric = "path", + symmetrize: Literal["mean", "max", "min", None] = None, + test: Literal["permutation", None] = None, + n_permutations: int = 100, + n_threads: int | None = None, + aggregate: Literal["min", "max", "mean"] | Callable | None = None, + by_tree: bool = False, + depth_key: str = "depth", + random_state: int | None = None, + key_added: str | None = None, + tree: str | Sequence[str] | None = None, + copy: Literal[True, False] = False, +) -> None | pd.DataFrame: + r"""Measures how closely related cells of different categories are on the lineage tree. + + For each cell, the tree distance to the nearest cell of each target category is + computed. These per-cell distances are then averaged across all cells of the same + source category to produce a directional linkage score: a low path distance (or high + LCA depth) between two categories means they tend to share recent common ancestors, + i.e. they are closely related on the tree. + + **Pairwise mode** (``target=None``): computes a category × category matrix of mean + linkage scores and stores it in ``tdata.uns['{key_added}_linkage']``. + + **Single-target mode** (``target=``): computes the per-cell distance to + the nearest cell of the given category and stores it in + ``tdata.obs['{target}_linkage']``. + + Parameters + ---------- + tdata + The TreeData object. + groupby + Column in ``tdata.obs`` that defines cell categories. + target + If specified, compute the per-cell distance to the nearest cell of this + category and store the result in ``tdata.obs['{target}_linkage']``. + ``aggregate`` is ignored in this mode. + If ``None`` (default), compute the full pairwise category × category matrix. + metric + How tree distance between two cells is measured: + + - ``'path'`` (default): branch-length path distance + :math:`d_i + d_j - 2\,d_{\mathrm{LCA}(i,j)}`. Smaller values mean closer + relatives. + - ``'lca'``: depth of the lowest common ancestor + :math:`d_{\mathrm{LCA}(i,j)}`. Larger values mean closer relatives. + symmetrize + If set, symmetrize the pairwise linkage matrix (pairwise mode only). + Because linkage is directional (source → target), the raw matrix is generally + asymmetric; symmetrization combines both directions: + + - ``'mean'``: average of :math:`M[i,j]` and :math:`M[j,i]`. + - ``'max'`` / ``'min'``: element-wise maximum / minimum. + test + Optional significance test: + + - ``'permutation'``: randomly shuffle cell-category labels ``n_permutations`` + times and recompute linkage each time to build a null distribution. + Z-scores and one-tailed p-values (in the direction of closer-than-expected + relatedness) are added to the stats table. The stored linkage matrix is + replaced by z-scores when this test is run. + n_permutations + Number of label permutations used when ``test='permutation'``. + n_threads + Number of worker processes for parallel permutation computation. + ``None`` (default) runs serially. On Linux, parallel execution uses + ``fork``-based processes, which copy the parent's memory without + serialisation overhead. On other platforms this argument is ignored. + aggregate + How per-cell distances to the target category are aggregated into a single + per-cell score (pairwise mode only). Defaults to ``'min'`` for + ``metric='path'`` and ``'max'`` for ``metric='lca'``, both of which + select the nearest relative: + + - ``'min'``: distance to the closest target cell (natural for ``'path'``). + - ``'max'``: depth of the shallowest LCA across target cells (natural for ``'lca'``). + - ``'mean'``: mean distance across all target cells. + - A callable ``f(array) -> float`` for custom aggregation. + depth_key + Node attribute in ``tdata.obst[tree]`` that stores each node's depth. + random_state + Random seed for reproducibility of permutation tests. + key_added + Base key for output storage. Defaults to ``groupby``. + tree + The ``obst`` key or keys of the trees to use. If ``None``, all trees are used. + copy + If ``True``, return the result as a :class:`DataFrame `. + + Returns + ------- + Returns ``None`` if ``copy=False``, otherwise returns a :class:`DataFrame `. + + Sets the following fields: + + * ``tdata.obs['{target}_linkage']`` : :class:`Series ` (dtype ``float``) – single-target mode only. + Per-cell distance to the nearest cell of the target category. + * ``tdata.uns['{key_added}_linkage']`` : :class:`DataFrame ` – pairwise mode only. + Category × category linkage matrix (source rows, target columns). + Contains z-scores instead of raw distances when ``test='permutation'``. + * ``tdata.uns['{key_added}_linkage_params']`` : ``dict`` – pairwise mode only. + Parameters used to compute the linkage matrix. + * ``tdata.uns['{key_added}_linkage_stats']`` : :class:`DataFrame ` – pairwise mode only. + Long-form table with one row per (source, target) pair containing ``value``, + ``source_n``, ``target_n``, and (if ``test='permutation'``) ``permuted_value``, + ``z_score``, ``p_value``. + + Examples + -------- + Compute pairwise linkage between all cell types using path distance: + + >>> tdata = py.datasets.koblan25() + >>> py.tl.ancestral_linkage(tdata, groupby="celltype") + + Compute per-cell distance to the closest cell of type "B" with permutation test: + + >>> py.tl.ancestral_linkage(tdata, groupby="celltype", target="B", test="permutation") + """ + # ── setup ───────────────────────────────────────────────────────────────── + _set_random_state(random_state) + key_added = key_added or groupby + tree_keys = tree + _check_tree_overlap(tdata, tree_keys) + trees = get_trees(tdata, tree_keys) + + if groupby not in tdata.obs.columns: + raise ValueError(f"'{groupby}' not found in tdata.obs.columns.") + + # Resolve default aggregate: 'min' for path, 'max' for lca + if aggregate is None: + aggregate = "max" if metric == "lca" else "min" + + if isinstance(aggregate, str) and aggregate not in ("min", "max", "mean"): + raise ValueError(f"aggregate must be 'min', 'max', 'mean', or a callable; got '{aggregate}'.") + + # Warn about misleading aggregate only in pairwise mode (aggregate is ignored for single target) + if target is None and metric == "lca" and isinstance(aggregate, str) and aggregate == "min": + warnings.warn( + "aggregate='min' with metric='lca' selects the *shallowest* (most distant) " + "ancestor. To find the most recent common ancestor use aggregate='max'.", + UserWarning, + stacklevel=2, + ) + + for t in trees.values(): + check_tree_has_key(t, depth_key) + + # ── build leaf → category mapping ───────────────────────────────────────── + obs_set = set(tdata.obs_names) + leaf_to_cat: dict = {} + for _, t in trees.items(): + for leaf in get_leaves(t): + if leaf in obs_set: + cat = tdata.obs.loc[leaf, groupby] + if pd.notna(cat): + leaf_to_cat[leaf] = str(cat) + + all_cats = sorted(set(leaf_to_cat.values())) + + cat_to_leaves: dict = defaultdict(list) + for l, c in leaf_to_cat.items(): + cat_to_leaves[c].append(l) + + # ── single-target mode ──────────────────────────────────────────────────── + if target is not None: + if target not in all_cats: + raise ValueError(f"target '{target}' not found in tdata.obs['{groupby}'].") + + # Always use "closest": min path or max lca + single_agg: str = "max" if metric == "lca" else "min" + all_scores = _compute_scores(tdata, trees, leaf_to_cat, [target], single_agg, metric, depth_key) + + # Per-leaf scores + score_map = {leaf: scores.get(target, np.nan) for leaf, scores in all_scores.items()} + tdata.obs[f"{target}_linkage"] = tdata.obs.index.map(pd.Series(score_map, dtype=float)) + + if test == "permutation": + # Observed: mean per source category + obs_cat_scores: dict = {} + for cat in all_cats: + vals = [score_map[l] for l in cat_to_leaves[cat] if l in score_map and not np.isnan(score_map[l])] + obs_cat_scores[cat] = float(np.mean(vals)) if vals else np.nan + + all_leaf_list = list(leaf_to_cat.keys()) + perm_seeds = np.random.randint(0, 2**31, size=n_permutations) + + _PERM_SINGLE_DATA.clear() + _PERM_SINGLE_DATA.update({ + "tdata": tdata, + "trees": trees, + "all_leaves": all_leaf_list, + "all_cat_vals": [leaf_to_cat[l] for l in all_leaf_list], + "target": target, + "single_agg": single_agg, + "metric": metric, + "depth_key": depth_key, + "all_cats": all_cats, + }) + + null_results = _run_parallel(_perm_single_target_worker, perm_seeds, n_threads) + + null_cat_scores: dict = defaultdict(list) + for perm_result in null_results: + for cat in all_cats: + null_cat_scores[cat].append(perm_result[cat]) + + rows = [] + for cat in all_cats: + obs_val = obs_cat_scores[cat] + null_vals = np.array([v for v in null_cat_scores[cat] if not np.isnan(v)], dtype=float) + if len(null_vals) > 0: + perm_val = float(np.mean(null_vals)) + sign = 1.0 if metric == "lca" else -1.0 + z = sign * (obs_val - perm_val) / (float(np.std(null_vals)) + 1e-10) + p = ( + float(np.mean(null_vals >= obs_val)) + if metric == "lca" + else float(np.mean(null_vals <= obs_val)) + ) + else: + perm_val, z, p = np.nan, np.nan, np.nan + rows.append( + { + "source": cat, + "target": target, + "value": obs_val, + "permuted_value": perm_val, + "z_score": z, + "p_value": p, + } + ) + + test_df = pd.DataFrame(rows) + tdata.uns[f"{key_added}_test"] = test_df + if copy: + return test_df + + if copy: + result_series = pd.Series( + { + cat: float(np.nanmean([score_map.get(l, np.nan) for l in cat_to_leaves[cat]])) + for cat in all_cats + }, + name=f"{target}_linkage", + ) + return result_series.to_frame() + + # ── pairwise mode ───────────────────────────────────────────────────────── + else: + # Global linkage across all trees (always computed) + all_scores = _compute_scores(tdata, trees, leaf_to_cat, all_cats, aggregate, metric, depth_key) + linkage_df = _scores_to_linkage_matrix(all_scores, all_cats, cat_to_leaves) + + # Global permutation test + global_z_df: pd.DataFrame | None = None + global_p_df: pd.DataFrame | None = None + global_null_mean_df: pd.DataFrame | None = None + if test == "permutation": + global_z_df, global_p_df, global_null_mean_df = _run_permutation_test( + tdata, trees, leaf_to_cat, all_cats, all_cats, + linkage_df, aggregate, metric, depth_key, n_permutations, n_threads, + ) + + # Build stats rows (long format, never symmetrized) + stats_rows: list = [] + if by_tree: + for tree_key, t in trees.items(): + t_nodes = set(t.nodes()) + single_tree = {tree_key: t} + tree_leaf_to_cat = {l: c for l, c in leaf_to_cat.items() if l in t_nodes} + tree_cat_to_leaves: dict = defaultdict(list) + for l, c in tree_leaf_to_cat.items(): + tree_cat_to_leaves[c].append(l) + + tree_scores = _compute_scores( + tdata, single_tree, tree_leaf_to_cat, all_cats, aggregate, metric, depth_key + ) + tree_linkage_df = _scores_to_linkage_matrix(tree_scores, all_cats, tree_cat_to_leaves) + + tree_z_df: pd.DataFrame | None = None + tree_p_df: pd.DataFrame | None = None + tree_null_mean_df: pd.DataFrame | None = None + if test == "permutation": + tree_z_df, tree_p_df, tree_null_mean_df = _run_permutation_test( + tdata, single_tree, tree_leaf_to_cat, all_cats, all_cats, + tree_linkage_df, aggregate, metric, depth_key, n_permutations, n_threads, + ) + + for src_cat in all_cats: + for tgt_cat in all_cats: + row: dict = { + "source": src_cat, + "target": tgt_cat, + "tree": tree_key, + "value": tree_linkage_df.loc[src_cat, tgt_cat], + "source_n": len(tree_cat_to_leaves.get(src_cat, [])), + "target_n": len(tree_cat_to_leaves.get(tgt_cat, [])), + } + if tree_z_df is not None and tree_p_df is not None and tree_null_mean_df is not None: + row["permuted_value"] = tree_null_mean_df.loc[src_cat, tgt_cat] + row["z_score"] = tree_z_df.loc[src_cat, tgt_cat] + row["p_value"] = tree_p_df.loc[src_cat, tgt_cat] + stats_rows.append(row) + else: + for src_cat in all_cats: + for tgt_cat in all_cats: + row = { + "source": src_cat, + "target": tgt_cat, + "value": linkage_df.loc[src_cat, tgt_cat], + "source_n": len(cat_to_leaves.get(src_cat, [])), + "target_n": len(cat_to_leaves.get(tgt_cat, [])), + } + if global_z_df is not None and global_p_df is not None and global_null_mean_df is not None: + row["permuted_value"] = global_null_mean_df.loc[src_cat, tgt_cat] + row["z_score"] = global_z_df.loc[src_cat, tgt_cat] + row["p_value"] = global_p_df.loc[src_cat, tgt_cat] + stats_rows.append(row) + + # uns[linkage] = observed - permuted_mean (symmetrized) if test ran, else raw linkage (symmetrized) + output_df: pd.DataFrame = (linkage_df - global_null_mean_df) if test == "permutation" else linkage_df + if symmetrize is not None: + output_df = _symmetrize_matrix(output_df, symmetrize) + + params = { + "groupby": groupby, + "aggregate": aggregate, + "metric": metric, + "symmetrize": symmetrize, + "test": test, + "by_tree": by_tree, + "depth_key": depth_key, + } + stats_df = pd.DataFrame(stats_rows) + tdata.uns[f"{key_added}_linkage"] = output_df + tdata.uns[f"{key_added}_linkage_params"] = params + tdata.uns[f"{key_added}_linkage_stats"] = stats_df + + if copy: + return stats_df if test is not None else output_df diff --git a/tests/test_ancestral_linkage.py b/tests/test_ancestral_linkage.py new file mode 100644 index 0000000..6434205 --- /dev/null +++ b/tests/test_ancestral_linkage.py @@ -0,0 +1,530 @@ +"""Tests for tl.ancestral_linkage.""" + +import networkx as nx +import numpy as np +import pandas as pd +import pytest +import treedata as td + +import pycea.tl as tl + + +# ── fixtures ────────────────────────────────────────────────────────────────── + + +@pytest.fixture +def balanced_tdata(): + """Four-leaf binary tree, two categories A and B. + + Structure (depth values): + root (depth=0) + ├── n1 (depth=0.5) + │ ├── a1 (depth=1.0) cat=A + │ └── a2 (depth=1.0) cat=A + └── n2 (depth=0.5) + ├── b1 (depth=1.0) cat=B + └── b2 (depth=1.0) cat=B + + LCA depths: + LCA(a1, a2) = 0.5 (within A) + LCA(b1, b2) = 0.5 (within B) + LCA(a*, b*) = 0.0 (between A and B) + + Path distances: + path(a1, a2) = 2*(1.0 - 0.5) = 1.0 + path(b1, b2) = 1.0 + path(a*, b*) = 2*(1.0 - 0.0) = 2.0 + """ + t = nx.DiGraph() + nodes = { + "root": 0.0, + "n1": 0.5, + "n2": 0.5, + "a1": 1.0, + "a2": 1.0, + "b1": 1.0, + "b2": 1.0, + } + for node, depth in nodes.items(): + t.add_node(node, depth=depth) + edges = [ + ("root", "n1"), ("root", "n2"), + ("n1", "a1"), ("n1", "a2"), + ("n2", "b1"), ("n2", "b2"), + ] + t.add_edges_from(edges) + + obs = pd.DataFrame( + {"celltype": ["A", "A", "B", "B"]}, + index=["a1", "a2", "b1", "b2"], + ) + return td.TreeData(obs=obs, obst={"tree": t}) + + +@pytest.fixture +def three_cat_tdata(): + """Six-leaf tree with categories A, B, C for fuller pairwise tests. + + Structure: + root (0) + ├── n1 (0.4) + │ ├── a1 (1.0) A + │ └── a2 (1.0) A + ├── n2 (0.4) + │ ├── b1 (1.0) B + │ └── b2 (1.0) B + └── n3 (0.4) + ├── c1 (1.0) C + └── c2 (1.0) C + """ + t = nx.DiGraph() + for node, depth in [ + ("root", 0.0), ("n1", 0.4), ("n2", 0.4), ("n3", 0.4), + ("a1", 1.0), ("a2", 1.0), ("b1", 1.0), ("b2", 1.0), ("c1", 1.0), ("c2", 1.0), + ]: + t.add_node(node, depth=depth) + for u, v in [ + ("root", "n1"), ("root", "n2"), ("root", "n3"), + ("n1", "a1"), ("n1", "a2"), + ("n2", "b1"), ("n2", "b2"), + ("n3", "c1"), ("n3", "c2"), + ]: + t.add_edge(u, v) + obs = pd.DataFrame( + {"celltype": ["A", "A", "B", "B", "C", "C"]}, + index=["a1", "a2", "b1", "b2", "c1", "c2"], + ) + return td.TreeData(obs=obs, obst={"tree": t}) + + +@pytest.fixture +def two_tree_tdata(): + """Two separate trees, each with leaves from categories A and B. + + tree1: a1 (A), b1 (B) — root depth 0, leaves depth 1, n1 depth 0.5 + tree2: a2 (A), b2 (B) — root depth 0, leaves depth 1, n2 depth 0.5 + """ + t1 = nx.DiGraph() + for node, depth in [("r1", 0.0), ("n1", 0.5), ("a1", 1.0), ("b1", 1.0)]: + t1.add_node(node, depth=depth) + t1.add_edges_from([("r1", "n1"), ("n1", "a1"), ("n1", "b1")]) + + t2 = nx.DiGraph() + for node, depth in [("r2", 0.0), ("n2", 0.5), ("a2", 1.0), ("b2", 1.0)]: + t2.add_node(node, depth=depth) + t2.add_edges_from([("r2", "n2"), ("n2", "a2"), ("n2", "b2")]) + + obs = pd.DataFrame( + {"celltype": ["A", "B", "A", "B"]}, + index=["a1", "b1", "a2", "b2"], + ) + return td.TreeData(obs=obs, obst={"tree1": t1, "tree2": t2}) + + +# ── pairwise mode tests ─────────────────────────────────────────────────────── + + +def test_pairwise_path_min_within_greater_than_between(balanced_tdata): + """Within-category path distance (min) should be smaller than between-category.""" + tdata = balanced_tdata + tl.ancestral_linkage(tdata, groupby="celltype", aggregate="min", metric="path") + mat = tdata.uns["celltype_linkage"] + assert mat.loc["A", "A"] < mat.loc["A", "B"] + assert mat.loc["B", "B"] < mat.loc["B", "A"] + + +def test_pairwise_lca_max_within_greater_than_between(balanced_tdata): + """Within-category LCA depth (max) should be larger than between-category.""" + tdata = balanced_tdata + tl.ancestral_linkage(tdata, groupby="celltype", aggregate="max", metric="lca") + mat = tdata.uns["celltype_linkage"] + assert mat.loc["A", "A"] > mat.loc["A", "B"] + assert mat.loc["B", "B"] > mat.loc["B", "A"] + + +def test_pairwise_lca_max_known_values(balanced_tdata): + """Verify exact values for lca+max aggregate.""" + tdata = balanced_tdata + tl.ancestral_linkage(tdata, groupby="celltype", aggregate="max", metric="lca") + mat = tdata.uns["celltype_linkage"] + # within: self-pair gives lca = (1.0+1.0-0)/2 = 1.0 (self is always max) + assert np.isclose(mat.loc["A", "A"], 1.0) + assert np.isclose(mat.loc["B", "B"], 1.0) + # between: best LCA from a1/a2 to any b = root.depth = 0.0 + assert np.isclose(mat.loc["A", "B"], 0.0) + + +def test_pairwise_path_min_known_values(balanced_tdata): + """Verify exact values for path+min aggregate.""" + tdata = balanced_tdata + tl.ancestral_linkage(tdata, groupby="celltype", aggregate="min", metric="path") + mat = tdata.uns["celltype_linkage"] + # Within A: min dist is self (0.0) + assert np.isclose(mat.loc["A", "A"], 0.0) + # path(a*, b*) = |1.0 + 1.0 - 2*0.0| = 2.0 + assert np.isclose(mat.loc["A", "B"], 2.0) + + +def test_pairwise_mean_aggregate(balanced_tdata): + """mean aggregate should equal mean of all cross-category path distances.""" + tdata = balanced_tdata + tl.ancestral_linkage(tdata, groupby="celltype", aggregate="mean", metric="path") + mat = tdata.uns["celltype_linkage"] + # All a*-b* pairs have path distance 2.0 → mean = 2.0 + assert np.isclose(mat.loc["A", "B"], 2.0) + # Within A: mean([path(a1,a1), path(a1,a2), path(a2,a1), path(a2,a2)]) = mean([0, 1, 1, 0]) = 0.5 + assert np.isclose(mat.loc["A", "A"], 0.5) + + +def test_pairwise_stored_in_uns(balanced_tdata): + """Results stored in tdata.uns with expected keys.""" + tdata = balanced_tdata + tl.ancestral_linkage(tdata, groupby="celltype") + assert "celltype_linkage" in tdata.uns + assert "celltype_linkage_params" in tdata.uns + assert "celltype_linkage_stats" in tdata.uns + df = tdata.uns["celltype_linkage"] + assert isinstance(df, pd.DataFrame) + assert set(df.index) == {"A", "B"} + assert set(df.columns) == {"A", "B"} + + +def test_pairwise_stats_has_n_columns(balanced_tdata): + """Stats DataFrame has source_n and target_n columns (always).""" + tdata = balanced_tdata + tl.ancestral_linkage(tdata, groupby="celltype") + stats = tdata.uns["celltype_linkage_stats"] + assert isinstance(stats, pd.DataFrame) + assert "source_n" in stats.columns + assert "target_n" in stats.columns + assert "source" in stats.columns + assert "target" in stats.columns + assert "value" in stats.columns + # 2 categories → 4 rows + assert len(stats) == 4 + # source_n for category A = 2 leaves + assert stats.loc[stats["source"] == "A"].iloc[0]["source_n"] == 2 + + +def test_pairwise_copy_returns_dataframe(balanced_tdata): + """copy=True returns DataFrame and also stores in uns.""" + tdata = balanced_tdata + result = tl.ancestral_linkage(tdata, groupby="celltype", copy=True) + assert isinstance(result, pd.DataFrame) + assert "celltype_linkage" in tdata.uns + + +def test_pairwise_key_added(balanced_tdata): + """key_added controls storage key.""" + tdata = balanced_tdata + tl.ancestral_linkage(tdata, groupby="celltype", key_added="mykey") + assert "mykey_linkage" in tdata.uns + assert "mykey_linkage_params" in tdata.uns + assert "mykey_linkage_stats" in tdata.uns + + +# ── symmetrize tests ────────────────────────────────────────────────────────── + + +def test_symmetrize_mean(three_cat_tdata): + """After symmetrize='mean', M[i,j] == M[j,i].""" + tdata = three_cat_tdata + tl.ancestral_linkage(tdata, groupby="celltype", aggregate="min", metric="path", + symmetrize="mean") + mat = tdata.uns["celltype_linkage"] + for i in mat.index: + for j in mat.columns: + assert np.isclose(mat.loc[i, j], mat.loc[j, i]) + + +def test_symmetrize_max(three_cat_tdata): + """After symmetrize='max', M[i,j] == M[j,i] and M[i,j] >= original M[i,j].""" + tdata = three_cat_tdata + result_nosym = tl.ancestral_linkage(tdata, groupby="celltype", aggregate="min", + metric="path", copy=True) + tl.ancestral_linkage(tdata, groupby="celltype", aggregate="min", metric="path", + symmetrize="max") + mat = tdata.uns["celltype_linkage"] + for i in mat.index: + for j in mat.columns: + assert np.isclose(mat.loc[i, j], mat.loc[j, i]) + assert mat.loc[i, j] >= result_nosym.loc[i, j] - 1e-9 + + +# ── single-target mode tests ────────────────────────────────────────────────── + + +def test_single_target_stores_in_obs(balanced_tdata): + """Single-target result stored in tdata.obs.""" + tdata = balanced_tdata + tl.ancestral_linkage(tdata, groupby="celltype", target="B") + assert "B_linkage" in tdata.obs.columns + assert tdata.obs["B_linkage"].notna().all() + + +def test_single_target_path_known_values(balanced_tdata): + """a1, a2 should have path distance 2.0 to nearest B leaf; b1, b2 distance to self = 0.""" + tdata = balanced_tdata + tl.ancestral_linkage(tdata, groupby="celltype", target="B", metric="path") + assert np.isclose(tdata.obs.loc["a1", "B_linkage"], 2.0) + assert np.isclose(tdata.obs.loc["a2", "B_linkage"], 2.0) + # b1, b2 are in target B → distance to self = 0 + assert np.isclose(tdata.obs.loc["b1", "B_linkage"], 0.0) + assert np.isclose(tdata.obs.loc["b2", "B_linkage"], 0.0) + + +def test_single_target_lca_known_values(balanced_tdata): + """a1, a2 should have LCA depth 0.0 to best B leaf; b1/b2 in B → LCA with self = depth = 1.0.""" + tdata = balanced_tdata + tl.ancestral_linkage(tdata, groupby="celltype", target="B", metric="lca") + assert np.isclose(tdata.obs.loc["a1", "B_linkage"], 0.0) + # b1 is in B; nearest is itself (path=0), lca = (1.0 + 1.0 - 0) / 2 = 1.0 + assert np.isclose(tdata.obs.loc["b1", "B_linkage"], 1.0) + + +def test_single_target_copy_returns_series_df(balanced_tdata): + """copy=True returns a DataFrame with per-category means.""" + tdata = balanced_tdata + result = tl.ancestral_linkage(tdata, groupby="celltype", target="B", + metric="path", copy=True) + assert isinstance(result, pd.DataFrame) + assert "B_linkage" in result.columns + assert "A" in result.index and "B" in result.index + + +def test_single_target_invalid_target(balanced_tdata): + """Specifying a non-existent target raises ValueError.""" + with pytest.raises(ValueError, match="not found"): + tl.ancestral_linkage(balanced_tdata, groupby="celltype", target="Z") + + +# ── by_tree tests ───────────────────────────────────────────────────────────── + + +def test_by_tree_stats_has_tree_column(two_tree_tdata): + """When by_tree=True, stats DataFrame has a 'tree' column.""" + tdata = two_tree_tdata + tl.ancestral_linkage(tdata, groupby="celltype", by_tree=True) + stats = tdata.uns["celltype_linkage_stats"] + assert "tree" in stats.columns + # 2 trees × 2 cats × 2 cats = 8 rows + assert len(stats) == 8 + assert set(stats["tree"]) == {"tree1", "tree2"} + + +def test_by_tree_source_target_n(two_tree_tdata): + """Per-tree source_n and target_n reflect leaves in that tree only.""" + tdata = two_tree_tdata + tl.ancestral_linkage(tdata, groupby="celltype", by_tree=True) + stats = tdata.uns["celltype_linkage_stats"] + # Each tree has 1 A leaf and 1 B leaf + tree1_rows = stats[stats["tree"] == "tree1"] + a_row = tree1_rows[tree1_rows["source"] == "A"].iloc[0] + assert a_row["source_n"] == 1 + assert a_row["target_n"] == 1 + + +def test_by_tree_linkage_matches_global(two_tree_tdata): + """linkage df with by_tree=True is the same as by_tree=False (weighted mean).""" + tdata = two_tree_tdata + result_global = tl.ancestral_linkage(tdata, groupby="celltype", copy=True) + tl.ancestral_linkage(tdata, groupby="celltype", by_tree=True) + result_by_tree = tdata.uns["celltype_linkage"] + pd.testing.assert_frame_equal(result_global, result_by_tree) + + +def test_by_tree_false_no_tree_column(balanced_tdata): + """When by_tree=False (default), stats has no 'tree' column.""" + tdata = balanced_tdata + tl.ancestral_linkage(tdata, groupby="celltype") + stats = tdata.uns["celltype_linkage_stats"] + assert "tree" not in stats.columns + + +# ── permutation test ────────────────────────────────────────────────────────── + + +def test_permutation_test_pairwise_linkage_stores_enrichment(balanced_tdata): + """Pairwise permutation test stores observed - permuted_mean in uns[linkage].""" + tdata = balanced_tdata + tl.ancestral_linkage( + tdata, groupby="celltype", test="permutation", n_permutations=20, + random_state=42, + ) + mat = tdata.uns["celltype_linkage"] + assert isinstance(mat, pd.DataFrame) + assert set(mat.index) == {"A", "B"} + assert set(mat.columns) == {"A", "B"} + # Values are finite (observed - permuted_mean, not z-scores) + assert np.isfinite(mat.values).all() + + +def test_permutation_test_pairwise_copy_returns_stats(balanced_tdata): + """copy=True with permutation test returns linkage_stats DataFrame.""" + tdata = balanced_tdata + result = tl.ancestral_linkage( + tdata, groupby="celltype", test="permutation", n_permutations=20, + random_state=42, copy=True + ) + assert isinstance(result, pd.DataFrame) + assert set(result.columns) >= {"source", "target", "value", "permuted_value", "z_score", "p_value"} + assert set(result["source"]) == {"A", "B"} + + +def test_copy_no_test_returns_linkage_matrix(balanced_tdata): + """copy=True without a test returns the linkage matrix DataFrame.""" + tdata = balanced_tdata + result = tl.ancestral_linkage(tdata, groupby="celltype", copy=True) + assert isinstance(result, pd.DataFrame) + assert set(result.index) == {"A", "B"} + assert set(result.columns) == {"A", "B"} + + +def test_permutation_test_single_target(balanced_tdata): + """Single-target permutation test returns long DataFrame.""" + tdata = balanced_tdata + result = tl.ancestral_linkage( + tdata, groupby="celltype", target="B", test="permutation", + n_permutations=20, random_state=0, copy=True + ) + assert isinstance(result, pd.DataFrame) + assert set(result.columns) >= {"source", "target", "value", "z_score", "p_value"} + assert (result["target"] == "B").all() + + +def test_permutation_test_stored_in_stats(balanced_tdata): + """Permutation test stores z_score, p_value, and permuted_value in stats.""" + tdata = balanced_tdata + tl.ancestral_linkage( + tdata, groupby="celltype", test="permutation", n_permutations=10, + random_state=1 + ) + assert "celltype_linkage" in tdata.uns + assert "celltype_linkage_stats" in tdata.uns + # No separate z_score / p_value uns keys + assert "celltype_z_score" not in tdata.uns + assert "celltype_p_value" not in tdata.uns + stats = tdata.uns["celltype_linkage_stats"] + assert "z_score" in stats.columns + assert "p_value" in stats.columns + assert "permuted_value" in stats.columns + assert (stats["p_value"].between(0, 1)).all() + + +def test_permuted_value_in_pairwise_stats(balanced_tdata): + """permuted_value column holds mean null distribution value for each (src, tgt) pair.""" + tdata = balanced_tdata + tl.ancestral_linkage( + tdata, groupby="celltype", test="permutation", n_permutations=20, + random_state=42 + ) + stats = tdata.uns["celltype_linkage_stats"] + assert "permuted_value" in stats.columns + assert stats["permuted_value"].notna().all() + + +def test_permuted_value_in_single_target_test(balanced_tdata): + """permuted_value in single-target test DataFrame holds mean null score.""" + tdata = balanced_tdata + result = tl.ancestral_linkage( + tdata, groupby="celltype", target="B", test="permutation", + n_permutations=20, random_state=0, copy=True + ) + assert "permuted_value" in result.columns + assert result["permuted_value"].notna().all() + + +def test_n_threads_parallel_matches_serial(balanced_tdata): + """n_threads > 1 produces same z-score DataFrame as single-threaded (same seed).""" + def run(n_threads): + tdata = balanced_tdata + return tl.ancestral_linkage( + tdata, groupby="celltype", test="permutation", n_permutations=20, + random_state=42, n_threads=n_threads, copy=True, + ) + + serial = run(None) + parallel = run(2) + pd.testing.assert_frame_equal(serial, parallel) + + +def test_permutation_test_stats_not_symmetrized(balanced_tdata): + """Stats z_scores are not symmetrized even when symmetrize is set.""" + tdata = balanced_tdata + tl.ancestral_linkage( + tdata, groupby="celltype", test="permutation", n_permutations=20, + random_state=5, symmetrize="mean" + ) + stats = tdata.uns["celltype_linkage_stats"] + # linkage (z_score) should be symmetrized + mat = tdata.uns["celltype_linkage"] + assert np.isclose(mat.loc["A", "B"], mat.loc["B", "A"]) + # stats z_scores might not be symmetric (AB vs BA could differ) + ab = stats[(stats["source"] == "A") & (stats["target"] == "B")]["z_score"].values[0] + ba = stats[(stats["source"] == "B") & (stats["target"] == "A")]["z_score"].values[0] + # We just check both are present; symmetry is not required + assert np.isfinite(ab) or np.isnan(ab) + assert np.isfinite(ba) or np.isnan(ba) + + +def test_permutation_random_state(balanced_tdata): + """Same random_state produces identical output DataFrames.""" + def run(seed): + tdata = balanced_tdata + return tl.ancestral_linkage( + tdata, groupby="celltype", test="permutation", n_permutations=30, + random_state=seed, copy=True + ) + + r1 = run(7) + r2 = run(7) + pd.testing.assert_frame_equal(r1, r2) + + +def test_by_tree_permutation_stats_has_z_score(two_tree_tdata): + """by_tree=True with permutation test adds z_score/p_value/permuted_value to stats rows.""" + tdata = two_tree_tdata + tl.ancestral_linkage( + tdata, groupby="celltype", by_tree=True, test="permutation", + n_permutations=10, random_state=0, + ) + stats = tdata.uns["celltype_linkage_stats"] + assert "z_score" in stats.columns + assert "p_value" in stats.columns + assert "permuted_value" in stats.columns + assert (stats["p_value"].between(0, 1)).all() + + +# ── warning tests ───────────────────────────────────────────────────────────── + + +def test_lca_min_warning(balanced_tdata): + """aggregate='min' + metric='lca' should emit a UserWarning.""" + with pytest.warns(UserWarning, match="shallowest"): + tl.ancestral_linkage(balanced_tdata, groupby="celltype", + aggregate="min", metric="lca") + + +# ── edge / validation tests ─────────────────────────────────────────────────── + + +def test_invalid_groupby(balanced_tdata): + with pytest.raises(ValueError, match="not found"): + tl.ancestral_linkage(balanced_tdata, groupby="nonexistent") + + +def test_invalid_aggregate(balanced_tdata): + with pytest.raises(ValueError, match="aggregate"): + tl.ancestral_linkage(balanced_tdata, groupby="celltype", aggregate="median") + + +def test_custom_callable_aggregate(balanced_tdata): + """Custom callable is accepted for aggregate.""" + tdata = balanced_tdata + tl.ancestral_linkage(tdata, groupby="celltype", aggregate=np.mean, metric="path") + mat = tdata.uns["celltype_linkage"] + # Should match mean aggregate + tl.ancestral_linkage(tdata, groupby="celltype", aggregate="mean", metric="path", + key_added="ref") + ref = tdata.uns["ref_linkage"] + pd.testing.assert_frame_equal(mat, ref) From f7780012702fab104ffecda66db155307bc529d5 Mon Sep 17 00:00:00 2001 From: colganwi Date: Tue, 24 Mar 2026 14:31:29 -0400 Subject: [PATCH 2/2] feat(tl): add alternative parameter to ancestral_linkage for two-sided permutation test Adds alternative='two-sided' to support two-tailed p-values. Default None preserves existing one-sided behavior (more-related direction). Co-Authored-By: Claude Sonnet 4.6 --- src/pycea/tl/ancestral_linkage.py | 60 +++++++++++++++++++++++-------- tests/test_ancestral_linkage.py | 50 ++++++++++++++++++++++++++ 2 files changed, 95 insertions(+), 15 deletions(-) diff --git a/src/pycea/tl/ancestral_linkage.py b/src/pycea/tl/ancestral_linkage.py index 54d5809..3001a6d 100644 --- a/src/pycea/tl/ancestral_linkage.py +++ b/src/pycea/tl/ancestral_linkage.py @@ -263,6 +263,25 @@ def _run_parallel(worker_fn: Callable, seeds: np.ndarray, n_threads: int | None) return [worker_fn(seed) for seed in tqdm(seeds, desc="Permutations", leave=False)] +def _compute_p_values( + null_array: np.ndarray, + obs_values: np.ndarray, + null_mean: np.ndarray, + metric: str, + alternative: str | None, +) -> np.ndarray: + """Compute permutation p-values given the null distribution and observed values.""" + if alternative == "two-sided": + deviation = np.abs(null_array - null_mean[np.newaxis]) + obs_deviation = np.abs(obs_values - null_mean) + return np.nanmean(deviation >= obs_deviation[np.newaxis], axis=0) + # One-tailed in the "more related" direction + if metric == "lca": + return np.nanmean(null_array >= obs_values[np.newaxis], axis=0) + else: + return np.nanmean(null_array <= obs_values[np.newaxis], axis=0) + + def _run_permutation_test( tdata: td.TreeData, trees: dict, @@ -275,6 +294,7 @@ def _run_permutation_test( depth_key: str, n_permutations: int, n_threads: int | None, + alternative: str | None, ) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: """Permutation test: shuffle leaf labels, recompute linkage, return (z_score_df, p_value_df, null_mean_df).""" all_leaves = list(leaf_to_cat.keys()) @@ -306,11 +326,7 @@ def _run_permutation_test( sign = 1.0 if metric == "lca" else -1.0 z_scores = sign * (obs_values - null_mean) / (null_std + 1e-10) - # One-tailed p-value in the "more related" direction - if metric == "lca": - p_values = np.nanmean(null_array >= obs_values[np.newaxis], axis=0) - else: - p_values = np.nanmean(null_array <= obs_values[np.newaxis], axis=0) + p_values = _compute_p_values(null_array, obs_values, null_mean, metric, alternative) z_score_df = pd.DataFrame(z_scores, index=observed_df.index, columns=observed_df.columns) p_value_df = pd.DataFrame(p_values, index=observed_df.index, columns=observed_df.columns) @@ -330,6 +346,7 @@ def ancestral_linkage( metric: _TreeMetric = "path", symmetrize: Literal["mean", "max", "min", None] = None, test: Literal["permutation", None] = None, + alternative: Literal["two-sided", None] = None, n_permutations: int = 100, n_threads: int | None = None, by_tree: bool = False, @@ -350,6 +367,7 @@ def ancestral_linkage( metric: _TreeMetric = "path", symmetrize: Literal["mean", "max", "min", None] = None, test: Literal["permutation", None] = None, + alternative: Literal["two-sided", None] = None, n_permutations: int = 100, n_threads: int | None = None, by_tree: bool = False, @@ -368,6 +386,7 @@ def ancestral_linkage( metric: _TreeMetric = "path", symmetrize: Literal["mean", "max", "min", None] = None, test: Literal["permutation", None] = None, + alternative: Literal["two-sided", None] = None, n_permutations: int = 100, n_threads: int | None = None, aggregate: Literal["min", "max", "mean"] | Callable | None = None, @@ -424,9 +443,19 @@ def ancestral_linkage( - ``'permutation'``: randomly shuffle cell-category labels ``n_permutations`` times and recompute linkage each time to build a null distribution. - Z-scores and one-tailed p-values (in the direction of closer-than-expected - relatedness) are added to the stats table. The stored linkage matrix is - replaced by z-scores when this test is run. + Z-scores and p-values are added to the stats table. The stored linkage + matrix is replaced by z-scores when this test is run. + alternative + The alternative hypothesis for the permutation test (ignored when + ``test=None``): + + - ``None`` (default): one-tailed test in the "more closely related than + chance" direction — p-value is the fraction of permutations with LCA depth + ≥ observed (``metric='lca'``) or path distance ≤ observed + (``metric='path'``). + - ``'two-sided'``: two-tailed test — p-value is the fraction of permutations + whose deviation from the null mean is at least as large as the observed + deviation. n_permutations Number of label permutations used when ``test='permutation'``. n_threads @@ -580,11 +609,12 @@ def ancestral_linkage( perm_val = float(np.mean(null_vals)) sign = 1.0 if metric == "lca" else -1.0 z = sign * (obs_val - perm_val) / (float(np.std(null_vals)) + 1e-10) - p = ( - float(np.mean(null_vals >= obs_val)) - if metric == "lca" - else float(np.mean(null_vals <= obs_val)) - ) + if alternative == "two-sided": + p = float(np.mean(np.abs(null_vals - perm_val) >= abs(obs_val - perm_val))) + elif metric == "lca": + p = float(np.mean(null_vals >= obs_val)) + else: + p = float(np.mean(null_vals <= obs_val)) else: perm_val, z, p = np.nan, np.nan, np.nan rows.append( @@ -626,7 +656,7 @@ def ancestral_linkage( if test == "permutation": global_z_df, global_p_df, global_null_mean_df = _run_permutation_test( tdata, trees, leaf_to_cat, all_cats, all_cats, - linkage_df, aggregate, metric, depth_key, n_permutations, n_threads, + linkage_df, aggregate, metric, depth_key, n_permutations, n_threads, alternative, ) # Build stats rows (long format, never symmetrized) @@ -651,7 +681,7 @@ def ancestral_linkage( if test == "permutation": tree_z_df, tree_p_df, tree_null_mean_df = _run_permutation_test( tdata, single_tree, tree_leaf_to_cat, all_cats, all_cats, - tree_linkage_df, aggregate, metric, depth_key, n_permutations, n_threads, + tree_linkage_df, aggregate, metric, depth_key, n_permutations, n_threads, alternative, ) for src_cat in all_cats: diff --git a/tests/test_ancestral_linkage.py b/tests/test_ancestral_linkage.py index 6434205..6e39efe 100644 --- a/tests/test_ancestral_linkage.py +++ b/tests/test_ancestral_linkage.py @@ -495,6 +495,56 @@ def test_by_tree_permutation_stats_has_z_score(two_tree_tdata): assert (stats["p_value"].between(0, 1)).all() +# ── alternative parameter tests ─────────────────────────────────────────────── + + +def test_alternative_two_sided_pairwise_p_values_in_range(balanced_tdata): + """Two-sided p-values are in [0, 1].""" + tdata = balanced_tdata + tl.ancestral_linkage( + tdata, groupby="celltype", test="permutation", alternative="two-sided", + n_permutations=20, random_state=0, + ) + stats = tdata.uns["celltype_linkage_stats"] + assert (stats["p_value"].between(0, 1)).all() + + +def test_alternative_two_sided_symmetric(balanced_tdata): + """Two-sided p-values are symmetric: p(A→B) == p(B→A) when the tree is symmetric.""" + tdata = balanced_tdata + tl.ancestral_linkage( + tdata, groupby="celltype", test="permutation", alternative="two-sided", + n_permutations=50, random_state=1, + ) + stats = tdata.uns["celltype_linkage_stats"] + ab = stats[(stats["source"] == "A") & (stats["target"] == "B")]["p_value"].values[0] + ba = stats[(stats["source"] == "B") & (stats["target"] == "A")]["p_value"].values[0] + assert np.isclose(ab, ba) + + +def test_alternative_two_sided_single_target(balanced_tdata): + """Two-sided p-values work in single-target mode.""" + tdata = balanced_tdata + result = tl.ancestral_linkage( + tdata, groupby="celltype", target="B", test="permutation", + alternative="two-sided", n_permutations=20, random_state=0, copy=True, + ) + assert isinstance(result, pd.DataFrame) + assert (result["p_value"].between(0, 1)).all() + + +def test_alternative_none_matches_default(balanced_tdata): + """alternative=None produces identical results to omitting the parameter.""" + def run(alt): + tdata = balanced_tdata + return tl.ancestral_linkage( + tdata, groupby="celltype", test="permutation", alternative=alt, + n_permutations=20, random_state=42, copy=True, + ) + + pd.testing.assert_frame_equal(run(None), run(None)) + + # ── warning tests ─────────────────────────────────────────────────────────────