diff --git a/movement/io/load_poses.py b/movement/io/load_poses.py index 226620a59..2905abee2 100644 --- a/movement/io/load_poses.py +++ b/movement/io/load_poses.py @@ -101,6 +101,7 @@ def from_file( "LightningPose", "Anipose", "NWB", + "EKS", ], fps: float | None = None, **kwargs, @@ -112,11 +113,11 @@ def from_file( file_path : pathlib.Path or str Path to the file containing predicted poses. The file format must be among those supported by the ``from_dlc_file()``, - ``from_slp_file()`` or ``from_lp_file()`` functions. One of these - these functions will be called internally, based on + ``from_slp_file()``, ``from_lp_file()``, ``from_eks_file()`` functions. + One of these functions will be called internally, based on the value of ``source_software``. source_software : {"DeepLabCut", "SLEAP", "LightningPose", "Anipose", \ - "NWB"} + "NWB", "EKS"} The source software of the file. fps : float, optional The number of frames per second in the video. If None (default), @@ -141,6 +142,7 @@ def from_file( movement.io.load_poses.from_lp_file movement.io.load_poses.from_anipose_file movement.io.load_poses.from_nwb_file + movement.io.load_poses.from_eks_file Examples -------- @@ -166,6 +168,8 @@ def from_file( "metadata in the file." ) return from_nwb_file(file_path, **kwargs) + elif source_software == "EKS": + return from_eks_file(file_path, fps) else: raise logger.error( ValueError(f"Unsupported source software: {source_software}") @@ -376,9 +380,187 @@ def from_dlc_file( ) +def from_eks_file( + file_path: Path | str, fps: float | None = None +) -> xr.Dataset: + """Create a ``movement`` poses dataset from an EKS file. + + Parameters + ---------- + file_path : pathlib.Path or str + Path to the EKS .csv file containing the predicted poses and + ensemble statistics. + fps : float, optional + The number of frames per second in the video. If None (default), + the ``time`` coordinates will be in frame numbers. + + Returns + ------- + xarray.Dataset + ``movement`` dataset containing the pose tracks, confidence scores, + ensemble statistics, and associated metadata. + + Notes + ----- + EKS files have a similar structure to DeepLabCut CSV files but include + additional columns for ensemble statistics: + - x_ens_median, y_ens_median: Median of the ensemble predictions + - x_ens_var, y_ens_var: Variance of the ensemble predictions + - x_posterior_var, y_posterior_var: Posterior variance from the smoother + + See Also + -------- + movement.io.load_poses.from_dlc_file + + Examples + -------- + >>> from movement.io import load_poses + >>> ds = load_poses.from_eks_file("path/to/file_eks.csv", fps=30) + + """ + # Read the CSV file using similar logic to DeepLabCut + file = ValidFile(file_path, expected_suffix=[".csv"]) + + # Parse the CSV to get the DataFrame with multi-index columns + df = _df_from_eks_csv(file.path) + + # Extract individual and keypoint names + if "individuals" in df.columns.names: + individual_names = ( + df.columns.get_level_values("individuals").unique().to_list() + ) + else: + individual_names = ["individual_0"] + keypoint_names = ( + df.columns.get_level_values("bodyparts").unique().to_list() + ) + + # Get unique coord types + coord_types = df.columns.get_level_values("coords").unique().to_list() + + # Separate main tracking data (x, y, likelihood) from ensemble stats + main_coords = ["x", "y", "likelihood"] + ensemble_coords = [c for c in coord_types if c not in main_coords] + + # Extract main tracking data - need to select each coord separately + # because xs doesn't support list keys + x_df = df.xs("x", level="coords", axis=1) + y_df = df.xs("y", level="coords", axis=1) + likelihood_df = df.xs("likelihood", level="coords", axis=1) + + # Stack the coordinates back together in the right order + # Shape will be (n_frames, n_individuals * n_keypoints) for each + x_data = ( + x_df.to_numpy() + .flatten() + .reshape(-1, len(individual_names), len(keypoint_names)) + ) + y_data = ( + y_df.to_numpy() + .flatten() + .reshape(-1, len(individual_names), len(keypoint_names)) + ) + likelihood_data = ( + likelihood_df.to_numpy() + .flatten() + .reshape(-1, len(individual_names), len(keypoint_names)) + ) + + # Stack to create (n_frames, 3, n_keypoints, n_individuals) + # where the second axis contains x, y, likelihood + tracks_with_scores = np.stack([x_data, y_data, likelihood_data], axis=1) + # Transpose to match expected order + # (n_frames, 3, n_keypoints, n_individuals) + tracks_with_scores = tracks_with_scores.transpose(0, 1, 3, 2) + + # Create the base dataset + ds = from_numpy( + position_array=tracks_with_scores[:, :-1, :, :], + confidence_array=tracks_with_scores[:, -1, :, :], + individual_names=individual_names, + keypoint_names=keypoint_names, + fps=fps, + source_software="EKS", + ) + + # Group ensemble statistics by their type + # (ens_median, ens_var, posterior_var) + # Each will have x and y components as the space dimension + ensemble_stats: dict[str, dict[str, np.ndarray]] = {} + + for coord in ensemble_coords: + # Parse coordinate name to extract stat type and spatial component + # e.g., "x_ens_median" -> ("ens_median", "x") + if coord.startswith("x_"): + stat_name = coord[2:] # Remove "x_" prefix + spatial_component = "x" + elif coord.startswith("y_"): + stat_name = coord[2:] # Remove "y_" prefix + spatial_component = "y" + else: + # If it doesn't follow expected pattern, handle as before + coord_df = df.xs(coord, level="coords", axis=1) + coord_data = ( + coord_df.to_numpy() + .reshape((-1, len(individual_names), len(keypoint_names))) + .transpose(0, 2, 1) + ) + da = xr.DataArray( + coord_data, + dims=["time", "keypoints", "individuals"], + coords={ + "time": ds.coords["time"], + "keypoints": ds.coords["keypoints"], + "individuals": ds.coords["individuals"], + }, + name=coord, + ) + ds[coord] = da + continue + + # Initialize the stat dict if needed + if stat_name not in ensemble_stats: + ensemble_stats[stat_name] = {} + + # Extract the data for this coordinate + coord_df = df.xs(coord, level="coords", axis=1) + coord_data = ( + coord_df.to_numpy() + .reshape((-1, len(individual_names), len(keypoint_names))) + .transpose(0, 2, 1) + ) + + # Store the data indexed by spatial component + ensemble_stats[stat_name][spatial_component] = coord_data + + # Create DataArrays with (time, space, keypoints, individuals) dims + for stat_name, spatial_data in ensemble_stats.items(): + if "x" in spatial_data and "y" in spatial_data: + # Stack x and y into space dimension + stat_array = np.stack( + [spatial_data["x"], spatial_data["y"]], axis=1 + ) # Results in (time, space=2, keypoints, individuals) + + # Create xarray DataArray with proper dimensions + da = xr.DataArray( + stat_array, + dims=["time", "space", "keypoints", "individuals"], + coords={ + "time": ds.coords["time"], + "space": ["x", "y"], + "keypoints": ds.coords["keypoints"], + "individuals": ds.coords["individuals"], + }, + name=stat_name, + ) + ds[stat_name] = da + + return ds + + def from_multiview_files( file_path_dict: dict[str, Path | str], - source_software: Literal["DeepLabCut", "SLEAP", "LightningPose"], + source_software: Literal["DeepLabCut", "SLEAP", "LightningPose", "EKS"], fps: float | None = None, ) -> xr.Dataset: """Load and merge pose tracking data from multiple views (cameras). @@ -657,6 +839,50 @@ def _df_from_dlc_csv(file_path: Path) -> pd.DataFrame: return df +def _df_from_eks_csv(file_path: Path) -> pd.DataFrame: + """Create an EKS-style DataFrame from a .csv file. + + EKS CSV files have a similar structure to DeepLabCut CSV files but + with additional columns for ensemble statistics. + + Parameters + ---------- + file_path : pathlib.Path + Path to the EKS-style .csv file with pose tracks and ensemble stats. + + Returns + ------- + pandas.DataFrame + EKS-style DataFrame with multi-index columns. + + """ + # EKS CSV has the same header structure as DeepLabCut CSV + possible_level_names = ["scorer", "individuals", "bodyparts", "coords"] + with open(file_path) as f: + # if line starts with a possible level name, split it into a list + # of strings, and add it to the list of header lines + header_lines = [ + line.strip().split(",") + for line in f.readlines() + if line.split(",")[0] in possible_level_names + ] + # Form multi-index column names from the header lines + level_names = [line[0] for line in header_lines] + column_tuples = list( + zip(*[line[1:] for line in header_lines], strict=False) + ) + columns = pd.MultiIndex.from_tuples(column_tuples, names=level_names) + # Import the EKS poses as a DataFrame + df = pd.read_csv( + file_path, + skiprows=len(header_lines), + index_col=0, + names=np.array(columns), + ) + df.columns.rename(level_names, inplace=True) + return df + + def _df_from_dlc_h5(file_path: Path) -> pd.DataFrame: """Create a DeepLabCut-style DataFrame from a .h5 file. diff --git a/tests/test_unit/test_io/test_load_poses.py b/tests/test_unit/test_io/test_load_poses.py index 2293300ee..735227e1b 100644 --- a/tests/test_unit/test_io/test_load_poses.py +++ b/tests/test_unit/test_io/test_load_poses.py @@ -314,3 +314,68 @@ def test_load_from_nwb_file(input_type, kwargs, request): if input_type == "nwb_file": expected_attrs["source_file"] = nwb_file assert ds_from_file_path.attrs == expected_attrs + + +def test_load_from_eks_file(): + """Test that loading pose tracks from an EKS CSV file + returns a proper Dataset with ensemble statistics. + """ + file_path = DATA_PATHS.get("EKS_IBL-paw_multicam_right.predictions.csv") + if file_path is None: + pytest.skip("EKS example file not found") + + try: + # Load the EKS file + ds = load_poses.from_eks_file(file_path, fps=30) + + # Check that it's a valid dataset with the expected structure + assert isinstance(ds, xr.Dataset) + assert "position" in ds.data_vars + assert "confidence" in ds.data_vars + + # Check ensemble statistics are present + ensemble_vars = ["ens_median", "ens_var", "posterior_var"] + for var in ensemble_vars: + assert var in ds.data_vars + + # Check dimensions + assert "time" in ds.dims + assert "individuals" in ds.dims + assert "keypoints" in ds.dims + assert "space" in ds.dims + + # Check basic attributes + assert ds.attrs["source_software"] == "EKS" + assert ds.attrs["fps"] == 30 + + # Check shapes are consistent + n_time, n_space, n_keypoints, n_individuals = ds.position.shape + assert ds.confidence.shape == (n_time, n_keypoints, n_individuals) + for var in ensemble_vars: + assert ds[var].shape == ( + n_time, + n_space, + n_keypoints, + n_individuals, + ) + + except ImportError: + pytest.skip("Required dependencies for EKS loading not available") + + +def test_load_from_file_eks(): + """Test that loading EKS files via from_file() works correctly.""" + file_path = DATA_PATHS.get("EKS_IBL-paw_multicam_right.predictions.csv") + if file_path is None: + pytest.skip("EKS example file not found") + + try: + # Test loading via from_file + ds = load_poses.from_file(file_path, source_software="EKS", fps=30) + + # Should be identical to from_eks_file + ds_direct = load_poses.from_eks_file(file_path, fps=30) + xr.testing.assert_identical(ds, ds_direct) + + except ImportError: + pytest.skip("Required dependencies for EKS loading not available")