diff --git a/instance_segmentation/garfvdb/garfvdb/evaluation/clustering/__init__.py b/instance_segmentation/garfvdb/garfvdb/evaluation/clustering/__init__.py new file mode 100644 index 0000000..af24880 --- /dev/null +++ b/instance_segmentation/garfvdb/garfvdb/evaluation/clustering/__init__.py @@ -0,0 +1,3 @@ +from .clustering import compute_cluster_labels, split_gaussians_into_clusters + +__all__ = ["compute_cluster_labels", "split_gaussians_into_clusters"] diff --git a/instance_segmentation/garfvdb/garfvdb/evaluation/clustering/clustering.py b/instance_segmentation/garfvdb/garfvdb/evaluation/clustering/clustering.py new file mode 100644 index 0000000..9ed2e31 --- /dev/null +++ b/instance_segmentation/garfvdb/garfvdb/evaluation/clustering/clustering.py @@ -0,0 +1,172 @@ +import logging + +import cuml +import cupy as cp +import numpy as np +import torch +from fvdb import GaussianSplat3d + +logger = logging.getLogger(__name__) + + +def compute_cluster_labels( + mask_features_output: torch.Tensor, + pca_n_components: int = 128, + umap_n_components: int = 32, + umap_n_neighbors: int = 15, + hdbscan_min_samples: int = 100, + hdbscan_min_cluster_size: int = 200, + fitting_sample_size: int = 300_000, + random_seed: int = 42, + device: str | torch.device = "cuda", +) -> tuple[torch.Tensor, torch.Tensor]: + """Cluster per-gaussian features + + To speed up clustering on typically large GaussianSplat3d models (million+ gaussians), + we perform feature reduction and clustering in a three-stage pipeline: + 1. PCA: Pre-reduction of high-dimensional features to an intermediate representation + 2. UMAP: Non-linear reduction to a low-dimensional manifold + 3. HDBSCAN: Density-based clustering to identify groups of similar gaussians + + Additionally, for scenes (>300k points), subsampling is used during fitting to improve + performance, and all points are transformed/predicted afterwards. + + Args: + mask_features_output: Per-gaussian feature vectors from the segmentation + model. Shape: [N, feature_dim]. + pca_n_components: Number of PCA components for initial reduction. + umap_n_components: Number of UMAP dimensions for manifold embedding. + umap_n_neighbors: UMAP neighbor count (higher = more global structure). + hdbscan_min_samples: Minimum samples for HDBSCAN core points. + hdbscan_min_cluster_size: Minimum cluster size for HDBSCAN. + fitting_sample_size: Sample size for fitting UMAP and HDBSCAN. + random_seed: Random seed for reproducibility. + device: Device to perform clustering on. + Returns: + cluster_labels: Cluster assignment for each gaussian. Shape: [N]. + Label -1 indicates noise points. + cluster_probs: Membership probability for each gaussian. Shape: [N]. + Higher values indicate stronger cluster membership. + """ + cp.random.seed(random_seed) + np.random.seed(random_seed) + torch.manual_seed(random_seed) + device = torch.device(device) + + assert umap_n_neighbors < fitting_sample_size, "UMAP n_neighbors must be less than fitting_sample_size" + + # PCA pre-reduction + n_samples, n_features = mask_features_output.shape[0], mask_features_output.shape[1] + max_pca_components = min(n_samples, n_features) + if pca_n_components > max_pca_components: + logger.warning( + "Requested pca_n_components=%d is greater than min(n_samples=%d, n_features=%d); " "clamping to %d.", + pca_n_components, + n_samples, + n_features, + max_pca_components, + ) + pca_n_components = max_pca_components + logger.info(f"PCA pre-reduction ({n_features} -> {pca_n_components} dimensions)...") + + pca = cuml.PCA(n_components=pca_n_components) + features_pca = pca.fit_transform(mask_features_output) + logger.info(f"PCA reduced shape: {features_pca.shape}") + + # UMAP reduction + n_points = features_pca.shape[0] + reduction_sample_size = min(fitting_sample_size, n_points) + + logger.info( + f"UMAP reduction ({pca_n_components} -> {umap_n_components} dimensions, fitting on {reduction_sample_size:,} / {n_points:,} points)..." + ) + umap_reducer = cuml.UMAP( + n_components=umap_n_components, + n_neighbors=umap_n_neighbors, + min_dist=0.0, + metric="euclidean", + random_state=random_seed, + ) + + if n_points > reduction_sample_size: + # Subsample for fitting, then transform all points + sample_idx = cp.random.permutation(n_points)[:reduction_sample_size] + umap_reducer.fit(features_pca[sample_idx]) + features_reduced = umap_reducer.transform(features_pca) + else: + features_reduced = umap_reducer.fit_transform(features_pca) + + logger.info(f"UMAP reduced shape: {features_reduced.shape}") + + # Cluster HDBSCAN + logger.info(f"Clustering with HDBSCAN (fitting on {reduction_sample_size:,} / {n_points:,} points)...") + + clusterer = cuml.HDBSCAN( + min_samples=hdbscan_min_samples, + min_cluster_size=hdbscan_min_cluster_size, + prediction_data=True, # Required for approximate_predict + ) + + if n_points > reduction_sample_size: + hdbscan_sample_idx = cp.random.permutation(n_points)[:reduction_sample_size] + clusterer.fit(features_reduced[hdbscan_sample_idx]) + # Use approximate_predict to assign labels to all points + cluster_labels_cp, cluster_probs_cp = cuml.cluster.hdbscan.approximate_predict(clusterer, features_reduced) + cluster_labels = torch.as_tensor(cluster_labels_cp, device=device) + cluster_probs = torch.as_tensor(cluster_probs_cp, device=device) + else: + clusterer.fit(features_reduced) + cluster_labels = torch.as_tensor(clusterer.labels_, device=device) + cluster_probs = torch.as_tensor(clusterer.probabilities_, device=device) + + return cluster_labels, cluster_probs + + +def split_gaussians_into_clusters( + cluster_labels: torch.Tensor, cluster_probs: torch.Tensor, gs_model: GaussianSplat3d +) -> tuple[dict[int, GaussianSplat3d], dict[int, float], GaussianSplat3d]: + """Split a GaussianSplat3d model into per-cluster subsets. + + Groups gaussians by their cluster labels and computes coherence scores + (mean membership probability) for each cluster. + + Args: + cluster_labels: Cluster assignment for each gaussian. Shape: [N]. + Label -1 indicates noise points. + cluster_probs: Membership probability for each gaussian. Shape: [N]. + gs_model: The GaussianSplat3d model to split. + + Returns: + cluster_splats: Dictionary mapping cluster ID to GaussianSplat3d subset. + Excludes noise points (label -1). + cluster_coherence: Dictionary mapping cluster ID to mean membership + probability. Higher values indicate tighter, more confident clusters. + noise_splats: GaussianSplat3d containing all noise points (label -1). + """ + unique_labels = torch.unique(cluster_labels) + num_clusters = (unique_labels >= 0).sum().item() # Exclude noise label (-1) + logger.info(f"Found {num_clusters} clusters (+ {(cluster_labels == -1).sum().item()} noise points)") + + # Split gaussians into separate GaussianSplat3d instances per cluster + # Also compute cluster coherence (mean membership probability) + cluster_splats: dict[int, GaussianSplat3d] = {} + cluster_coherence: dict[int, float] = {} + for label in unique_labels.tolist(): + if label == -1: + # Optionally skip noise points, or include them as a separate "noise" cluster + continue + cluster_mask = cluster_labels == label + cluster_splats[label] = gs_model[cluster_mask] + cluster_coherence[label] = cluster_probs[cluster_mask].mean().item() + logger.info( + f" Cluster {label}: {cluster_splats[label].num_gaussians:,} gaussians, " + f"coherence: {cluster_coherence[label]:.3f}" + ) + + # Also store noise points + noise_mask = cluster_labels == -1 + noise_splats = gs_model[noise_mask] + if noise_mask.any(): + logger.info(f" Noise: {noise_splats.num_gaussians:,} gaussians") + + return cluster_splats, cluster_coherence, noise_splats diff --git a/instance_segmentation/garfvdb/visualize_segmentation_clusters.py b/instance_segmentation/garfvdb/visualize_segmentation_clusters.py index e1e0b7b..b097757 100644 --- a/instance_segmentation/garfvdb/visualize_segmentation_clusters.py +++ b/instance_segmentation/garfvdb/visualize_segmentation_clusters.py @@ -7,16 +7,21 @@ from dataclasses import dataclass from typing import Annotated -import cuml -import cuml.cluster.hdbscan -import cupy as cp import fvdb.viz as fviz import numpy as np import torch import tyro from fvdb import GaussianSplat3d from fvdb.types import to_Mat33fBatch, to_Mat44fBatch, to_Vec2iBatch -from fvdb_reality_capture.tools import filter_splats_above_scale +from fvdb_reality_capture.tools import ( + filter_splats_above_scale, + filter_splats_by_mean_percentile, + filter_splats_by_opacity_percentile, +) +from garfvdb.evaluation.clustering import ( + compute_cluster_labels, + split_gaussians_into_clusters, +) from garfvdb.training.segmentation import GaussianSplatScaleConditionedSegmentation from garfvdb.util import load_splats_from_file from tyro.conf import arg @@ -95,6 +100,13 @@ class ViewCheckpoint: scale: float = 0.1 """Segmentation scale as a fraction of max scale.""" + filter_high_variance: bool = True + """Remove clusters with high spatial variance (multi-center clusters).""" + + variance_threshold: float = 0.1 + """Clusters with normalized variance above this threshold are removed. + Normalized variance = variance / extent^2. Typical values: ~0.03 (tight), ~0.08 (uniform), >0.1 (scattered).""" + def execute(self) -> None: """Execute the viewer command.""" log_level = logging.DEBUG if self.verbose else logging.INFO @@ -112,11 +124,11 @@ def execute(self) -> None: raise FileNotFoundError(f"Reconstruction checkpoint {self.reconstruction_path} does not exist.") logger.info(f"Loading Gaussian splat model from {self.reconstruction_path}") gs_model, metadata = load_splats_from_file(self.reconstruction_path, device) - gs_model = filter_splats_above_scale(gs_model, 0.1) - logger.info(f"Loaded {gs_model.num_gaussians:,} Gaussians") - # Load the segmentation runner from checkpoint - logger.info(f"Loading segmentation checkpoint from {self.segmentation_path}") + # Filter GS model + gs_model = filter_splats_above_scale(gs_model, 0.1) + gs_model = filter_splats_by_opacity_percentile(gs_model, percentile=0.85) + gs_model = filter_splats_by_mean_percentile(gs_model, percentile=[0.96, 0.96, 0.96, 0.96, 0.98, 0.99]) runner = load_segmentation_runner_from_checkpoint( checkpoint_path=self.segmentation_path, gs_model=gs_model, @@ -126,95 +138,70 @@ def execute(self) -> None: gs_model = runner.gs_model segmentation_model = runner.model - sfm_scene = runner.sfm_scene - - # Get per-gaussian features at a given scale - scale = self.scale * float(segmentation_model.max_grouping_scale.item()) - - mask_features_output = segmentation_model.get_gaussian_affinity_output(scale) # [N, 256] - logger.info(f"Got mask features with shape: {mask_features_output.shape}") - - # PCA pre-reduction (256 -> 128) - logger.info("PCA pre-reduction (256 -> 128 dimensions)...") - pca = cuml.PCA(n_components=128) - features_pca = pca.fit_transform(mask_features_output) - logger.info(f"PCA reduced shape: {features_pca.shape}") - - # UMAP reduction (128 -> 32) - n_points = features_pca.shape[0] - reduction_sample_size = min(300_000, n_points) - - logger.info( - f"UMAP reduction (128 -> 32 dimensions, fitting on {reduction_sample_size:,} / {n_points:,} points)..." - ) - umap_reducer = cuml.UMAP( - n_components=32, - n_neighbors=15, - min_dist=0.0, - metric="euclidean", - random_state=42, - ) - - if n_points > reduction_sample_size: - # Subsample for fitting, then transform all points - sample_idx = cp.random.permutation(n_points)[:reduction_sample_size] - umap_reducer.fit(features_pca[sample_idx]) - features_reduced = umap_reducer.transform(features_pca) - else: - features_reduced = umap_reducer.fit_transform(features_pca) - - logger.info(f"UMAP reduced shape: {features_reduced.shape}") - - # Cluster HDBSCAN - logger.info(f"Clustering with HDBSCAN (fitting on {reduction_sample_size:,} / {n_points:,} points)...") - - clusterer = cuml.HDBSCAN( - min_samples=100, - min_cluster_size=200, - prediction_data=True, # Required for approximate_predict - ) - - if n_points > reduction_sample_size: - hdbscan_sample_idx = cp.random.permutation(n_points)[:reduction_sample_size] - clusterer.fit(features_reduced[hdbscan_sample_idx]) - # Use approximate_predict to assign labels to all points - cluster_labels_cp, _ = cuml.cluster.hdbscan.approximate_predict(clusterer, features_reduced) - cluster_labels = torch.as_tensor(cluster_labels_cp, device=gs_model.means.device) - else: - clusterer.fit(features_reduced) - cluster_labels = torch.as_tensor(clusterer.labels_, device=gs_model.means.device) - unique_labels = torch.unique(cluster_labels) - num_clusters = (unique_labels >= 0).sum().item() # Exclude noise label (-1) - logger.info(f"Found {num_clusters} clusters (+ {(cluster_labels == -1).sum().item()} noise points)") - - # Split gaussians into separate GaussianSplat3d instances per cluster - cluster_splats: dict[int, GaussianSplat3d] = {} - for label in unique_labels.tolist(): - if label == -1: - # Optionally skip noise points, or include them as a separate "noise" cluster - continue - cluster_mask = cluster_labels == label - cluster_splats[label] = gs_model[cluster_mask] - logger.info(f" Cluster {label}: {cluster_splats[label].num_gaussians:,} gaussians") - - # Also store noise points if you want them - noise_mask = cluster_labels == -1 - if noise_mask.any(): - noise_splats = gs_model[noise_mask] - logger.info(f" Noise: {noise_splats.num_gaussians:,} gaussians") logger.info(f"Loaded {gs_model.num_gaussians:,} Gaussians") - logger.info(f"Restored SfmScene with {sfm_scene.num_images} images (with correct scale transforms)") logger.info(f"Segmentation model max scale: {segmentation_model.max_grouping_scale:.4f}") + ## Segmentation and Clustering + # Query per-gaussian features at a given scale + scale = self.scale * float(segmentation_model.max_grouping_scale.item()) + mask_features_output = segmentation_model.get_gaussian_affinity_output(scale) # [N, 256] + + # Perform clustering + cluster_labels, cluster_probs = compute_cluster_labels(mask_features_output, device=device) + + # Split gaussian scene into separate GaussianSplat3d instances per cluster + cluster_splats, cluster_coherence, _ = split_gaussians_into_clusters(cluster_labels, cluster_probs, gs_model) + + ## Filtering + # Filter clusters by spatial variance (remove multi-center clusters) + if self.filter_high_variance: + # Compute normalized spatial variance for filtering multi-center clusters + cluster_norm_variance: dict[int, float] = {} + for label, splat in cluster_splats.items(): + means = splat.means # [N, 3] + # Spatial extent: max range across any axis + extent = (means.max(dim=0).values - means.min(dim=0).values).max().item() + # Normalized variance: mean variance across axes / extent^2 + # High values indicate scattered points relative to the cluster size + if extent > 1e-6: + variance = means.var(dim=0).mean().item() + cluster_norm_variance[label] = variance / (extent**2) + else: + cluster_norm_variance[label] = 0.0 + + # Log variance statistics to help with threshold tuning + variance_values = list(cluster_norm_variance.values()) + if variance_values: + logger.info( + f"Normalized variance stats: min={min(variance_values):.4f}, " + f"max={max(variance_values):.4f}, median={np.median(variance_values):.4f}" + ) + + removed_variance = [ + label for label in cluster_splats.keys() if cluster_norm_variance[label] > self.variance_threshold + ] + for label in removed_variance: + logger.info( + f" Removing cluster {label}: normalized variance {cluster_norm_variance[label]:.4f} " + f"above threshold {self.variance_threshold:.4f}" + ) + del cluster_splats[label] + del cluster_coherence[label] + if removed_variance: + logger.info(f"Removed {len(removed_variance)} spatially incoherent clusters") + + logger.info(f"Remaining clusters: {len(cluster_splats)}") + + ## Visualization # Initialize fvdb.viz logger.info(f"Starting viewer server on {self.viewer_ip_address}:{self.viewer_port}") fviz.init(ip_address=self.viewer_ip_address, port=self.viewer_port, verbose=self.verbose) viz_scene = fviz.get_scene("GarfVDB Segmentation Viewer") - # Add the Gaussian splat models to the scene - logger.info(f"Adding {len(cluster_splats)} clusters to the scene") - for cluster_id, splat in cluster_splats.items(): + # Add the Gaussian splat models to the scene, sort by coherence + sorted_clusters = sorted(cluster_splats.items(), key=lambda x: cluster_coherence[x[0]]) + for cluster_id, splat in sorted_clusters: viz_scene.add_gaussian_splat_3d(f"Cluster {cluster_id}", splat) # Set initial camera position