Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 159 additions & 5 deletions city2graph/metapath.py
Original file line number Diff line number Diff line change
Expand Up @@ -757,9 +757,34 @@ def _materialize_metapath(
right_on=f"src_{idx}",
how="inner",
copy=False,
suffixes=("", "_right"),
)
joined["path_nodes"] = [
left + right[1:]
for left, right in zip(
joined["path_nodes"],
joined["path_nodes_right"],
strict=False,
)
]
joined["path_edges"] = [
left + right
for left, right in zip(
joined["path_edges"],
joined["path_edges_right"],
strict=False,
)
]
# Drop intermediate join columns to save memory
joined = joined.drop(columns=[f"dst_{idx - 1}", f"src_{idx}"], errors="ignore")
joined = joined.drop(
columns=[
f"dst_{idx - 1}",
f"src_{idx}",
"path_nodes_right",
"path_edges_right",
],
errors="ignore",
)

if joined.empty:
return (
Expand All @@ -778,6 +803,7 @@ def _materialize_metapath(
aggregation=aggregation,
start_index_name=start_index_name,
end_index_name=end_index_name,
directed=directed,
)

if aggregated.empty:
Expand Down Expand Up @@ -872,6 +898,14 @@ def _build_hop_frame(
data = {
src_col: edge_gdf.index.get_level_values(src_level).to_numpy(),
dst_col: edge_gdf.index.get_level_values(dst_level).to_numpy(),
"path_nodes": list(
zip(
edge_gdf.index.get_level_values(src_level).to_numpy(),
edge_gdf.index.get_level_values(dst_level).to_numpy(),
strict=False,
)
),
"path_edges": [(index_value,) for index_value in edge_gdf.index.to_list()],
}

if edge_attrs:
Expand All @@ -893,6 +927,7 @@ def _aggregate_paths(
aggregation: _EdgeAttrAggregation,
start_index_name: str,
end_index_name: str,
directed: bool,
) -> pd.DataFrame:
"""
Group joined paths into terminal node pairs with aggregated weights.
Expand All @@ -913,6 +948,8 @@ def _aggregate_paths(
Name of the start node index.
end_index_name : str
Name of the end node index.
directed : bool
Whether the metapath should preserve edge orientation.

Returns
-------
Expand All @@ -926,12 +963,25 @@ def _aggregate_paths(
agg_map: dict[str, str | Callable[[pd.Series], float]] = {"weight": "sum"}

# Base workload with path count (weight=1 for each path)
workload_data = {
path_nodes = cast("list[tuple[object, ...]]", combined["path_nodes"].to_list())
workload_data: dict[str, object] = {
"src": combined[src_col].to_numpy(),
"dst": combined[dst_col].to_numpy(),
"weight": np.ones(len(combined), dtype=float),
}

if not directed:
canonical_paths = [_canonicalize_undirected_sequence(path) for path in path_nodes]
canonical_edge_paths = [
_canonicalize_undirected_sequence(
tuple(_canonicalize_undirected_edge_id(edge_id) for edge_id in edge_path)
)
for edge_path in cast("list[tuple[object, ...]]", combined["path_edges"].to_list())
]
workload_data["src"] = [path[0] for path in canonical_paths]
workload_data["dst"] = [path[-1] for path in canonical_paths]
workload_data["path_signature"] = canonical_edge_paths

if edge_attrs:
for attr in edge_attrs:
# Collect columns for this attribute across all steps
Expand All @@ -948,6 +998,9 @@ def _aggregate_paths(

workload = pd.DataFrame(workload_data)

if not directed:
workload = workload.drop_duplicates(subset=["path_signature"], keep="first")

# Group by terminal nodes and aggregate (e.g. sum of weights = number of paths)
aggregated = workload.groupby(["src", "dst"], sort=False).agg(agg_map)

Expand All @@ -958,6 +1011,96 @@ def _aggregate_paths(
return aggregated


def _stable_value_key(value: object) -> tuple[str, str]:
"""
Return a deterministic sort key for heterogeneous identifiers.

This avoids direct comparisons between incomparable Python objects while
preserving a stable ordering for canonicalization.

Parameters
----------
value : object
Value to convert into a stable ordering key.

Returns
-------
tuple[str, str]
Key composed of the type name and repr string.
"""
return (type(value).__name__, repr(value))


def _canonicalize_undirected_pair(value_a: object, value_b: object) -> tuple[object, object]:
"""
Return a deterministic ordering for an undirected pair.

The pair is ordered via :func:`_stable_value_key` so mirrored edges collapse
to the same representation.

Parameters
----------
value_a : object
First value in the pair.
value_b : object
Second value in the pair.

Returns
-------
tuple[object, object]
Canonically ordered pair.
"""
key_a = _stable_value_key(value_a)
key_b = _stable_value_key(value_b)
return (value_a, value_b) if key_a <= key_b else (value_b, value_a)


def _canonicalize_undirected_edge_id(edge_id: object) -> object:
"""
Canonicalize an edge identifier for undirected path comparison.

For tuple-like edge identifiers, only the terminal node pair is reordered;
any additional index levels are preserved as-is.

Parameters
----------
edge_id : object
Edge identifier, typically a tuple derived from a MultiIndex row.

Returns
-------
object
Edge identifier with its terminal node pair canonically ordered.
"""
if isinstance(edge_id, tuple) and len(edge_id) >= 2:
edge_u, edge_v = _canonicalize_undirected_pair(edge_id[0], edge_id[1])
return (edge_u, edge_v, *edge_id[2:])
return edge_id


def _canonicalize_undirected_sequence(values: tuple[object, ...]) -> tuple[object, ...]:
"""
Canonicalize a path-like sequence against its reversal.

The lexicographically smaller orientation under :func:`_stable_value_key`
is retained so forward and reversed traversals share one signature.

Parameters
----------
values : tuple[object, ...]
Sequence to compare with its reversed orientation.

Returns
-------
tuple[object, ...]
Deterministically oriented sequence.
"""
reversed_values = values[::-1]
forward_key = tuple(_stable_value_key(value) for value in values)
reverse_key = tuple(_stable_value_key(value) for value in reversed_values)
return values if forward_key <= reverse_key else reversed_values


def _empty_metapath_frame(
edge_attrs: list[str] | None,
start_index_name: str,
Expand Down Expand Up @@ -1545,6 +1688,7 @@ def _extract_metapath_edges(
"""
new_edges_data = []
endpoint_indices_arr = np.array(endpoint_indices)
seen_pairs: set[tuple[object, object]] = set()

for i, start_idx in enumerate(endpoint_indices):
dists = dist_matrix[i]
Expand Down Expand Up @@ -1581,10 +1725,20 @@ def _extract_metapath_edges(
start_orig = graph.nodes[start_node].get("_original_index", start_node)
end_orig = graph.nodes[end_node].get("_original_index", end_node)

orig_edge_index = (start_orig, end_orig)
if not directed:
start_key = (type(start_orig).__name__, repr(start_orig))
end_key = (type(end_orig).__name__, repr(end_orig))

if not directed and start_orig > end_orig: # type: ignore[operator]
orig_edge_index = (end_orig, start_orig)
if start_key > end_key:
start_node, end_node = end_node, start_node
start_orig, end_orig = end_orig, start_orig

canonical_pair = (start_orig, end_orig)
if canonical_pair in seen_pairs:
continue
seen_pairs.add(canonical_pair)

orig_edge_index = (start_orig, end_orig)

new_edges_data.append(
{
Expand Down
135 changes: 135 additions & 0 deletions tests/test_metapath.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,76 @@ def test_add_metapaths_edge_direction_and_lookup(
directed=True,
)

def test_add_metapaths_undirected_deduplicates_mirrored_paths(self) -> None:
"""Undirected metapaths should not materialize mirrored terminal pairs twice."""
buildings = gpd.GeoDataFrame(
{
"geometry": [Point(0, 0), Point(10, 0)],
"node_type": "building",
},
index=[1, 2],
crs="EPSG:4326",
)
streets = gpd.GeoDataFrame(
{
"geometry": [Point(0, 1), Point(10, 1)],
"node_type": "street",
},
index=[101, 102],
crs="EPSG:4326",
)
nodes = {"building": buildings, "street": streets}
edges = {
("building", "access", "street"): gpd.GeoDataFrame(
{
"edge_type": "access",
"geometry": [
LineString([(0, 0), (0, 1)]),
LineString([(10, 0), (10, 1)]),
],
},
index=pd.MultiIndex.from_tuples([(1, 101), (2, 102)]),
crs="EPSG:4326",
),
("street", "access", "building"): gpd.GeoDataFrame(
{
"edge_type": "access",
"geometry": [
LineString([(0, 1), (0, 0)]),
LineString([(10, 1), (10, 0)]),
],
},
index=pd.MultiIndex.from_tuples([(101, 1), (102, 2)]),
crs="EPSG:4326",
),
("street", "road", "street"): gpd.GeoDataFrame(
{
"edge_type": "road",
"geometry": [
LineString([(0, 1), (10, 1)]),
LineString([(10, 1), (0, 1)]),
],
},
index=pd.MultiIndex.from_tuples([(101, 102), (102, 101)]),
crs="EPSG:4326",
),
}
sequence = [
("building", "access", "street"),
("street", "road", "street"),
("street", "access", "building"),
]
relation = ("building", "metapath_0", "building")

_, undirected_edges = add_metapaths((nodes, edges), sequence=sequence, directed=False)
undirected_pairs = set(undirected_edges[relation].index.tolist())
assert undirected_pairs == {(1, 2)}
assert undirected_edges[relation].loc[(1, 2), "weight"] == 1

_, directed_edges = add_metapaths((nodes, edges), sequence=sequence, directed=True)
directed_pairs = set(directed_edges[relation].index.tolist())
assert directed_pairs == {(1, 2), (2, 1)}

@pytest.mark.parametrize(
"join_case", ["empty_hop", "disjoint", "nan_sources", "index_normalization"]
)
Expand Down Expand Up @@ -556,6 +626,32 @@ def test_add_metapaths_by_weight_basic(self, sample_weight_graph_data: WeightGra
assert isinstance(geom, LineString)
assert list(geom.coords) == [(0.0, 0.0), (10.0, 0.0)]

def test_add_metapaths_by_weight_undirected_deduplicates_pairs(
self,
sample_weight_graph_data: WeightGraphData,
) -> None:
"""Undirected weighted metapaths should emit each endpoint pair once."""
nodes_dict, edges_dict = sample_weight_graph_data

_, edges_out = add_metapaths_by_weight(
(nodes_dict, edges_dict),
endpoint_type="building",
weight="weight",
threshold=15.0,
directed=False,
)

relation = ("building", "connected_within_0.0_15.0", "building")
assert relation in edges_out

new_edges = edges_out[relation]
pairs = set(new_edges.index.tolist())
assert pairs == {(1, 2), (2, 3)}
assert (2, 1) not in pairs
assert (3, 2) not in pairs
assert new_edges.loc[(1, 2), "weight"] == pytest.approx(12.0)
assert new_edges.loc[(2, 3), "weight"] == pytest.approx(12.0)

def test_add_metapaths_by_weight_threshold_controls(
self,
sample_weight_graph_data: WeightGraphData,
Expand Down Expand Up @@ -649,6 +745,45 @@ def test_add_metapaths_by_weight_networkx_io(
for _, _, data in nx_roundtrip.edges(data=True)
)

def test_add_metapaths_by_weight_networkx_multigraph_undirected_deduplicates(
self,
sample_weight_graph_data: WeightGraphData,
) -> None:
"""Undirected MultiGraph output should not include mirrored metapath duplicates."""
nodes_dict, edges_dict = sample_weight_graph_data
nx_multigraph = gdf_to_nx(nodes=nodes_dict, edges=edges_dict, multigraph=True)

nx_result = add_metapaths_by_weight(
nx_multigraph,
endpoint_type="building",
weight="weight",
threshold=15.0,
directed=False,
as_nx=True,
multigraph=True,
)

assert isinstance(nx_result, nx.MultiGraph)
metapath_edges = [
(u, v, data)
for u, v, _, data in nx_result.edges(data=True, keys=True)
if data.get("edge_type") == ("building", "connected_within_0.0_15.0", "building")
]
endpoint_pairs = {
tuple(
sorted(
(
nx_result.nodes[u].get("_original_index", u),
nx_result.nodes[v].get("_original_index", v),
)
)
)
for u, v, _ in metapath_edges
}

assert endpoint_pairs == {(1, 2), (2, 3)}
assert len(metapath_edges) == 2

def test_add_metapaths_by_weight_missing_endpoint(
self,
sample_weight_graph_data: WeightGraphData,
Expand Down