Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .clustering import compute_cluster_labels, split_gaussians_into_clusters

__all__ = ["compute_cluster_labels", "split_gaussians_into_clusters"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
import logging

import cuml
import cupy as cp
import numpy as np
import torch
from fvdb import GaussianSplat3d

logger = logging.getLogger(__name__)


def compute_cluster_labels(
mask_features_output: torch.Tensor,
pca_n_components: int = 128,
umap_n_components: int = 32,
umap_n_neighbors: int = 15,
hdbscan_min_samples: int = 100,
hdbscan_min_cluster_size: int = 200,
fitting_sample_size: int = 300_000,
random_seed: int = 42,
device: str | torch.device = "cuda",
) -> tuple[torch.Tensor, torch.Tensor]:
"""Cluster per-gaussian features

To speed up clustering on typically large GaussianSplat3d models (million+ gaussians),
we perform feature reduction and clustering in a three-stage pipeline:
1. PCA: Pre-reduction of high-dimensional features to an intermediate representation
2. UMAP: Non-linear reduction to a low-dimensional manifold
3. HDBSCAN: Density-based clustering to identify groups of similar gaussians

Additionally, for scenes (>300k points), subsampling is used during fitting to improve
performance, and all points are transformed/predicted afterwards.

Args:
mask_features_output: Per-gaussian feature vectors from the segmentation
model. Shape: [N, feature_dim].
pca_n_components: Number of PCA components for initial reduction.
umap_n_components: Number of UMAP dimensions for manifold embedding.
umap_n_neighbors: UMAP neighbor count (higher = more global structure).
hdbscan_min_samples: Minimum samples for HDBSCAN core points.
hdbscan_min_cluster_size: Minimum cluster size for HDBSCAN.
fitting_sample_size: Sample size for fitting UMAP and HDBSCAN.
random_seed: Random seed for reproducibility.
device: Device to perform clustering on.
Returns:
cluster_labels: Cluster assignment for each gaussian. Shape: [N].
Label -1 indicates noise points.
cluster_probs: Membership probability for each gaussian. Shape: [N].
Higher values indicate stronger cluster membership.
"""
cp.random.seed(random_seed)
np.random.seed(random_seed)
torch.manual_seed(random_seed)
device = torch.device(device)

assert umap_n_neighbors < fitting_sample_size, "UMAP n_neighbors must be less than fitting_sample_size"

# PCA pre-reduction
n_samples, n_features = mask_features_output.shape[0], mask_features_output.shape[1]
max_pca_components = min(n_samples, n_features)
if pca_n_components > max_pca_components:
logger.warning(
"Requested pca_n_components=%d is greater than min(n_samples=%d, n_features=%d); " "clamping to %d.",
pca_n_components,
n_samples,
n_features,
max_pca_components,
)
pca_n_components = max_pca_components
logger.info(f"PCA pre-reduction ({n_features} -> {pca_n_components} dimensions)...")

pca = cuml.PCA(n_components=pca_n_components)
features_pca = pca.fit_transform(mask_features_output)
logger.info(f"PCA reduced shape: {features_pca.shape}")

# UMAP reduction
n_points = features_pca.shape[0]
reduction_sample_size = min(fitting_sample_size, n_points)

logger.info(
f"UMAP reduction ({pca_n_components} -> {umap_n_components} dimensions, fitting on {reduction_sample_size:,} / {n_points:,} points)..."
)
umap_reducer = cuml.UMAP(
n_components=umap_n_components,
n_neighbors=umap_n_neighbors,
min_dist=0.0,
metric="euclidean",
random_state=random_seed,
)

if n_points > reduction_sample_size:
# Subsample for fitting, then transform all points
sample_idx = cp.random.permutation(n_points)[:reduction_sample_size]
umap_reducer.fit(features_pca[sample_idx])
features_reduced = umap_reducer.transform(features_pca)
else:
features_reduced = umap_reducer.fit_transform(features_pca)

logger.info(f"UMAP reduced shape: {features_reduced.shape}")

# Cluster HDBSCAN
logger.info(f"Clustering with HDBSCAN (fitting on {reduction_sample_size:,} / {n_points:,} points)...")

clusterer = cuml.HDBSCAN(
min_samples=hdbscan_min_samples,
min_cluster_size=hdbscan_min_cluster_size,
prediction_data=True, # Required for approximate_predict
)

if n_points > reduction_sample_size:
hdbscan_sample_idx = cp.random.permutation(n_points)[:reduction_sample_size]
clusterer.fit(features_reduced[hdbscan_sample_idx])
# Use approximate_predict to assign labels to all points
cluster_labels_cp, cluster_probs_cp = cuml.cluster.hdbscan.approximate_predict(clusterer, features_reduced)
cluster_labels = torch.as_tensor(cluster_labels_cp, device=device)
cluster_probs = torch.as_tensor(cluster_probs_cp, device=device)
else:
clusterer.fit(features_reduced)
cluster_labels = torch.as_tensor(clusterer.labels_, device=device)
cluster_probs = torch.as_tensor(clusterer.probabilities_, device=device)

return cluster_labels, cluster_probs


def split_gaussians_into_clusters(
cluster_labels: torch.Tensor, cluster_probs: torch.Tensor, gs_model: GaussianSplat3d
) -> tuple[dict[int, GaussianSplat3d], dict[int, float], GaussianSplat3d]:
"""Split a GaussianSplat3d model into per-cluster subsets.

Groups gaussians by their cluster labels and computes coherence scores
(mean membership probability) for each cluster.

Args:
cluster_labels: Cluster assignment for each gaussian. Shape: [N].
Label -1 indicates noise points.
cluster_probs: Membership probability for each gaussian. Shape: [N].
gs_model: The GaussianSplat3d model to split.

Returns:
cluster_splats: Dictionary mapping cluster ID to GaussianSplat3d subset.
Excludes noise points (label -1).
cluster_coherence: Dictionary mapping cluster ID to mean membership
probability. Higher values indicate tighter, more confident clusters.
noise_splats: GaussianSplat3d containing all noise points (label -1).
"""
unique_labels = torch.unique(cluster_labels)
num_clusters = (unique_labels >= 0).sum().item() # Exclude noise label (-1)
logger.info(f"Found {num_clusters} clusters (+ {(cluster_labels == -1).sum().item()} noise points)")

# Split gaussians into separate GaussianSplat3d instances per cluster
# Also compute cluster coherence (mean membership probability)
cluster_splats: dict[int, GaussianSplat3d] = {}
cluster_coherence: dict[int, float] = {}
for label in unique_labels.tolist():
if label == -1:
# Optionally skip noise points, or include them as a separate "noise" cluster
continue
cluster_mask = cluster_labels == label
cluster_splats[label] = gs_model[cluster_mask]
cluster_coherence[label] = cluster_probs[cluster_mask].mean().item()
logger.info(
f" Cluster {label}: {cluster_splats[label].num_gaussians:,} gaussians, "
f"coherence: {cluster_coherence[label]:.3f}"
)

# Also store noise points
noise_mask = cluster_labels == -1
noise_splats = gs_model[noise_mask]
if noise_mask.any():
logger.info(f" Noise: {noise_splats.num_gaussians:,} gaussians")

return cluster_splats, cluster_coherence, noise_splats
163 changes: 75 additions & 88 deletions instance_segmentation/garfvdb/visualize_segmentation_clusters.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,21 @@
from dataclasses import dataclass
from typing import Annotated

import cuml
import cuml.cluster.hdbscan
import cupy as cp
import fvdb.viz as fviz
import numpy as np
import torch
import tyro
from fvdb import GaussianSplat3d
from fvdb.types import to_Mat33fBatch, to_Mat44fBatch, to_Vec2iBatch
from fvdb_reality_capture.tools import filter_splats_above_scale
from fvdb_reality_capture.tools import (
filter_splats_above_scale,
filter_splats_by_mean_percentile,
filter_splats_by_opacity_percentile,
)
from garfvdb.evaluation.clustering import (
compute_cluster_labels,
split_gaussians_into_clusters,
)
from garfvdb.training.segmentation import GaussianSplatScaleConditionedSegmentation
from garfvdb.util import load_splats_from_file
from tyro.conf import arg
Expand Down Expand Up @@ -95,6 +100,13 @@ class ViewCheckpoint:
scale: float = 0.1
"""Segmentation scale as a fraction of max scale."""

filter_high_variance: bool = True
"""Remove clusters with high spatial variance (multi-center clusters)."""

variance_threshold: float = 0.1
"""Clusters with normalized variance above this threshold are removed.
Normalized variance = variance / extent^2. Typical values: ~0.03 (tight), ~0.08 (uniform), >0.1 (scattered)."""

def execute(self) -> None:
"""Execute the viewer command."""
log_level = logging.DEBUG if self.verbose else logging.INFO
Expand All @@ -112,11 +124,11 @@ def execute(self) -> None:
raise FileNotFoundError(f"Reconstruction checkpoint {self.reconstruction_path} does not exist.")
logger.info(f"Loading Gaussian splat model from {self.reconstruction_path}")
gs_model, metadata = load_splats_from_file(self.reconstruction_path, device)
gs_model = filter_splats_above_scale(gs_model, 0.1)
logger.info(f"Loaded {gs_model.num_gaussians:,} Gaussians")

# Load the segmentation runner from checkpoint
logger.info(f"Loading segmentation checkpoint from {self.segmentation_path}")
# Filter GS model
gs_model = filter_splats_above_scale(gs_model, 0.1)
gs_model = filter_splats_by_opacity_percentile(gs_model, percentile=0.85)
gs_model = filter_splats_by_mean_percentile(gs_model, percentile=[0.96, 0.96, 0.96, 0.96, 0.98, 0.99])
runner = load_segmentation_runner_from_checkpoint(
checkpoint_path=self.segmentation_path,
gs_model=gs_model,
Expand All @@ -126,95 +138,70 @@ def execute(self) -> None:

gs_model = runner.gs_model
segmentation_model = runner.model
sfm_scene = runner.sfm_scene

# Get per-gaussian features at a given scale
scale = self.scale * float(segmentation_model.max_grouping_scale.item())

mask_features_output = segmentation_model.get_gaussian_affinity_output(scale) # [N, 256]
logger.info(f"Got mask features with shape: {mask_features_output.shape}")

# PCA pre-reduction (256 -> 128)
logger.info("PCA pre-reduction (256 -> 128 dimensions)...")
pca = cuml.PCA(n_components=128)
features_pca = pca.fit_transform(mask_features_output)
logger.info(f"PCA reduced shape: {features_pca.shape}")

# UMAP reduction (128 -> 32)
n_points = features_pca.shape[0]
reduction_sample_size = min(300_000, n_points)

logger.info(
f"UMAP reduction (128 -> 32 dimensions, fitting on {reduction_sample_size:,} / {n_points:,} points)..."
)
umap_reducer = cuml.UMAP(
n_components=32,
n_neighbors=15,
min_dist=0.0,
metric="euclidean",
random_state=42,
)

if n_points > reduction_sample_size:
# Subsample for fitting, then transform all points
sample_idx = cp.random.permutation(n_points)[:reduction_sample_size]
umap_reducer.fit(features_pca[sample_idx])
features_reduced = umap_reducer.transform(features_pca)
else:
features_reduced = umap_reducer.fit_transform(features_pca)

logger.info(f"UMAP reduced shape: {features_reduced.shape}")

# Cluster HDBSCAN
logger.info(f"Clustering with HDBSCAN (fitting on {reduction_sample_size:,} / {n_points:,} points)...")

clusterer = cuml.HDBSCAN(
min_samples=100,
min_cluster_size=200,
prediction_data=True, # Required for approximate_predict
)

if n_points > reduction_sample_size:
hdbscan_sample_idx = cp.random.permutation(n_points)[:reduction_sample_size]
clusterer.fit(features_reduced[hdbscan_sample_idx])
# Use approximate_predict to assign labels to all points
cluster_labels_cp, _ = cuml.cluster.hdbscan.approximate_predict(clusterer, features_reduced)
cluster_labels = torch.as_tensor(cluster_labels_cp, device=gs_model.means.device)
else:
clusterer.fit(features_reduced)
cluster_labels = torch.as_tensor(clusterer.labels_, device=gs_model.means.device)
unique_labels = torch.unique(cluster_labels)
num_clusters = (unique_labels >= 0).sum().item() # Exclude noise label (-1)
logger.info(f"Found {num_clusters} clusters (+ {(cluster_labels == -1).sum().item()} noise points)")

# Split gaussians into separate GaussianSplat3d instances per cluster
cluster_splats: dict[int, GaussianSplat3d] = {}
for label in unique_labels.tolist():
if label == -1:
# Optionally skip noise points, or include them as a separate "noise" cluster
continue
cluster_mask = cluster_labels == label
cluster_splats[label] = gs_model[cluster_mask]
logger.info(f" Cluster {label}: {cluster_splats[label].num_gaussians:,} gaussians")

# Also store noise points if you want them
noise_mask = cluster_labels == -1
if noise_mask.any():
noise_splats = gs_model[noise_mask]
logger.info(f" Noise: {noise_splats.num_gaussians:,} gaussians")

logger.info(f"Loaded {gs_model.num_gaussians:,} Gaussians")
logger.info(f"Restored SfmScene with {sfm_scene.num_images} images (with correct scale transforms)")
logger.info(f"Segmentation model max scale: {segmentation_model.max_grouping_scale:.4f}")

## Segmentation and Clustering
# Query per-gaussian features at a given scale
scale = self.scale * float(segmentation_model.max_grouping_scale.item())
mask_features_output = segmentation_model.get_gaussian_affinity_output(scale) # [N, 256]

# Perform clustering
cluster_labels, cluster_probs = compute_cluster_labels(mask_features_output, device=device)

# Split gaussian scene into separate GaussianSplat3d instances per cluster
cluster_splats, cluster_coherence, _ = split_gaussians_into_clusters(cluster_labels, cluster_probs, gs_model)

## Filtering
# Filter clusters by spatial variance (remove multi-center clusters)
if self.filter_high_variance:
# Compute normalized spatial variance for filtering multi-center clusters
cluster_norm_variance: dict[int, float] = {}
for label, splat in cluster_splats.items():
means = splat.means # [N, 3]
# Spatial extent: max range across any axis
extent = (means.max(dim=0).values - means.min(dim=0).values).max().item()
# Normalized variance: mean variance across axes / extent^2
# High values indicate scattered points relative to the cluster size
if extent > 1e-6:
variance = means.var(dim=0).mean().item()
cluster_norm_variance[label] = variance / (extent**2)
else:
cluster_norm_variance[label] = 0.0

# Log variance statistics to help with threshold tuning
variance_values = list(cluster_norm_variance.values())
if variance_values:
logger.info(
f"Normalized variance stats: min={min(variance_values):.4f}, "
f"max={max(variance_values):.4f}, median={np.median(variance_values):.4f}"
)

removed_variance = [
label for label in cluster_splats.keys() if cluster_norm_variance[label] > self.variance_threshold
]
for label in removed_variance:
logger.info(
f" Removing cluster {label}: normalized variance {cluster_norm_variance[label]:.4f} "
f"above threshold {self.variance_threshold:.4f}"
)
del cluster_splats[label]
del cluster_coherence[label]
if removed_variance:
logger.info(f"Removed {len(removed_variance)} spatially incoherent clusters")

logger.info(f"Remaining clusters: {len(cluster_splats)}")

## Visualization
# Initialize fvdb.viz
logger.info(f"Starting viewer server on {self.viewer_ip_address}:{self.viewer_port}")
fviz.init(ip_address=self.viewer_ip_address, port=self.viewer_port, verbose=self.verbose)
viz_scene = fviz.get_scene("GarfVDB Segmentation Viewer")

# Add the Gaussian splat models to the scene
logger.info(f"Adding {len(cluster_splats)} clusters to the scene")
for cluster_id, splat in cluster_splats.items():
# Add the Gaussian splat models to the scene, sort by coherence
sorted_clusters = sorted(cluster_splats.items(), key=lambda x: cluster_coherence[x[0]])
for cluster_id, splat in sorted_clusters:
viz_scene.add_gaussian_splat_3d(f"Cluster {cluster_id}", splat)

# Set initial camera position
Expand Down