From 61e5f6e83daa9c76558d5a67dcfbf14b074f3eaa Mon Sep 17 00:00:00 2001 From: Muad Abd El Hay Date: Wed, 27 Aug 2025 12:52:42 +0200 Subject: [PATCH 1/4] Add EKS (Ensemble Kalman Smoother) loader for movement MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Implemented from_eks_file() function to load EKS CSV files - Added support for ensemble statistics (median, variance, posterior variance) - Updated from_file() to support "EKS" as source_software - Added comprehensive unit tests for EKS loader functionality - EKS files follow DeepLabCut CSV format but include additional ensemble columns 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude --- movement/io/load_poses.py | 170 ++++++++++++++++++++- tests/test_unit/test_io/test_load_poses.py | 65 ++++++++ 2 files changed, 231 insertions(+), 4 deletions(-) mode change 100644 => 100755 movement/io/load_poses.py mode change 100644 => 100755 tests/test_unit/test_io/test_load_poses.py diff --git a/movement/io/load_poses.py b/movement/io/load_poses.py old mode 100644 new mode 100755 index 226620a59..abf8c0745 --- a/movement/io/load_poses.py +++ b/movement/io/load_poses.py @@ -101,6 +101,7 @@ def from_file( "LightningPose", "Anipose", "NWB", + "EKS", ], fps: float | None = None, **kwargs, @@ -112,11 +113,11 @@ def from_file( file_path : pathlib.Path or str Path to the file containing predicted poses. The file format must be among those supported by the ``from_dlc_file()``, - ``from_slp_file()`` or ``from_lp_file()`` functions. One of these - these functions will be called internally, based on + ``from_slp_file()``, ``from_lp_file()``, ``from_eks_file()`` functions. + One of these functions will be called internally, based on the value of ``source_software``. source_software : {"DeepLabCut", "SLEAP", "LightningPose", "Anipose", \ - "NWB"} + "NWB", "EKS"} The source software of the file. fps : float, optional The number of frames per second in the video. If None (default), @@ -141,6 +142,7 @@ def from_file( movement.io.load_poses.from_lp_file movement.io.load_poses.from_anipose_file movement.io.load_poses.from_nwb_file + movement.io.load_poses.from_eks_file Examples -------- @@ -166,6 +168,8 @@ def from_file( "metadata in the file." ) return from_nwb_file(file_path, **kwargs) + elif source_software == "EKS": + return from_eks_file(file_path, fps) else: raise logger.error( ValueError(f"Unsupported source software: {source_software}") @@ -376,9 +380,123 @@ def from_dlc_file( ) +def from_eks_file( + file_path: Path | str, fps: float | None = None +) -> xr.Dataset: + """Create a ``movement`` poses dataset from an EKS (Ensemble Kalman Smoother) file. + + Parameters + ---------- + file_path : pathlib.Path or str + Path to the EKS .csv file containing the predicted poses and + ensemble statistics. + fps : float, optional + The number of frames per second in the video. If None (default), + the ``time`` coordinates will be in frame numbers. + + Returns + ------- + xarray.Dataset + ``movement`` dataset containing the pose tracks, confidence scores, + ensemble statistics, and associated metadata. + + Notes + ----- + EKS files have a similar structure to DeepLabCut CSV files but include + additional columns for ensemble statistics: + - x_ens_median, y_ens_median: Median of the ensemble predictions + - x_ens_var, y_ens_var: Variance of the ensemble predictions + - x_posterior_var, y_posterior_var: Posterior variance from the smoother + + See Also + -------- + movement.io.load_poses.from_dlc_file + + Examples + -------- + >>> from movement.io import load_poses + >>> ds = load_poses.from_eks_file("path/to/file_eks.csv", fps=30) + + """ + # Read the CSV file using similar logic to DeepLabCut + file = ValidFile(file_path, expected_suffix=[".csv"]) + + # Parse the CSV to get the DataFrame with multi-index columns + df = _df_from_eks_csv(file.path) + + # Extract individual and keypoint names + if "individuals" in df.columns.names: + individual_names = ( + df.columns.get_level_values("individuals").unique().to_list() + ) + else: + individual_names = ["individual_0"] + keypoint_names = ( + df.columns.get_level_values("bodyparts").unique().to_list() + ) + + # Get unique coord types + coord_types = df.columns.get_level_values("coords").unique().to_list() + + # Separate main tracking data (x, y, likelihood) from ensemble stats + main_coords = ["x", "y", "likelihood"] + ensemble_coords = [c for c in coord_types if c not in main_coords] + + # Extract main tracking data - need to select each coord separately and concatenate + # because xs doesn't support list keys + x_df = df.xs("x", level="coords", axis=1) + y_df = df.xs("y", level="coords", axis=1) + likelihood_df = df.xs("likelihood", level="coords", axis=1) + + # Stack the coordinates back together in the right order + # Shape will be (n_frames, n_individuals * n_keypoints) for each + x_data = x_df.to_numpy().flatten().reshape(-1, len(individual_names), len(keypoint_names)) + y_data = y_df.to_numpy().flatten().reshape(-1, len(individual_names), len(keypoint_names)) + likelihood_data = likelihood_df.to_numpy().flatten().reshape(-1, len(individual_names), len(keypoint_names)) + + # Stack to create (n_frames, 3, n_keypoints, n_individuals) + # where the second axis contains x, y, likelihood + tracks_with_scores = np.stack([x_data, y_data, likelihood_data], axis=1) + # Transpose to match expected order (n_frames, 3, n_keypoints, n_individuals) + tracks_with_scores = tracks_with_scores.transpose(0, 1, 3, 2) + + # Create the base dataset + ds = from_numpy( + position_array=tracks_with_scores[:, :-1, :, :], + confidence_array=tracks_with_scores[:, -1, :, :], + individual_names=individual_names, + keypoint_names=keypoint_names, + fps=fps, + source_software="EnsembleKalmanSmoother", + ) + + # Add ensemble statistics as additional data variables + for coord in ensemble_coords: + coord_df = df.xs(coord, level="coords", axis=1) + coord_data = ( + coord_df.to_numpy() + .reshape((-1, len(individual_names), len(keypoint_names))) + .transpose(0, 2, 1) + ) + # Create xarray DataArray with proper dimensions + da = xr.DataArray( + coord_data, + dims=["time", "keypoints", "individuals"], + coords={ + "time": ds.coords["time"], + "keypoints": ds.coords["keypoints"], + "individuals": ds.coords["individuals"], + }, + name=coord, + ) + ds[coord] = da + + return ds + + def from_multiview_files( file_path_dict: dict[str, Path | str], - source_software: Literal["DeepLabCut", "SLEAP", "LightningPose"], + source_software: Literal["DeepLabCut", "SLEAP", "LightningPose", "EKS"], fps: float | None = None, ) -> xr.Dataset: """Load and merge pose tracking data from multiple views (cameras). @@ -657,6 +775,50 @@ def _df_from_dlc_csv(file_path: Path) -> pd.DataFrame: return df +def _df_from_eks_csv(file_path: Path) -> pd.DataFrame: + """Create an EKS-style DataFrame from a .csv file. + + EKS CSV files have a similar structure to DeepLabCut CSV files but + with additional columns for ensemble statistics. + + Parameters + ---------- + file_path : pathlib.Path + Path to the EKS-style .csv file containing pose tracks and ensemble stats. + + Returns + ------- + pandas.DataFrame + EKS-style DataFrame with multi-index columns. + + """ + # EKS CSV has the same header structure as DeepLabCut CSV + possible_level_names = ["scorer", "individuals", "bodyparts", "coords"] + with open(file_path) as f: + # if line starts with a possible level name, split it into a list + # of strings, and add it to the list of header lines + header_lines = [ + line.strip().split(",") + for line in f.readlines() + if line.split(",")[0] in possible_level_names + ] + # Form multi-index column names from the header lines + level_names = [line[0] for line in header_lines] + column_tuples = list( + zip(*[line[1:] for line in header_lines], strict=False) + ) + columns = pd.MultiIndex.from_tuples(column_tuples, names=level_names) + # Import the EKS poses as a DataFrame + df = pd.read_csv( + file_path, + skiprows=len(header_lines), + index_col=0, + names=np.array(columns), + ) + df.columns.rename(level_names, inplace=True) + return df + + def _df_from_dlc_h5(file_path: Path) -> pd.DataFrame: """Create a DeepLabCut-style DataFrame from a .h5 file. diff --git a/tests/test_unit/test_io/test_load_poses.py b/tests/test_unit/test_io/test_load_poses.py old mode 100644 new mode 100755 index 2293300ee..f309a1829 --- a/tests/test_unit/test_io/test_load_poses.py +++ b/tests/test_unit/test_io/test_load_poses.py @@ -314,3 +314,68 @@ def test_load_from_nwb_file(input_type, kwargs, request): if input_type == "nwb_file": expected_attrs["source_file"] = nwb_file assert ds_from_file_path.attrs == expected_attrs + + +def test_load_from_eks_file(): + """Test that loading pose tracks from an EKS CSV file + returns a proper Dataset with ensemble statistics. + """ + # Use the provided example EKS file if available, + # otherwise skip the test + try: + from pathlib import Path + file_path = Path("sub-460_strain-B6_2024-11-12T12_30_00_eks.csv") + if not file_path.exists(): + pytest.skip("EKS example file not found") + + # Load the EKS file + ds = load_poses.from_eks_file(file_path, fps=30) + + # Check that it's a valid dataset with the expected structure + assert isinstance(ds, xr.Dataset) + assert "position" in ds.data_vars + assert "confidence" in ds.data_vars + + # Check ensemble statistics are present + ensemble_vars = ['x_ens_median', 'y_ens_median', 'x_ens_var', + 'y_ens_var', 'x_posterior_var', 'y_posterior_var'] + for var in ensemble_vars: + assert var in ds.data_vars + + # Check dimensions + assert "time" in ds.dims + assert "individuals" in ds.dims + assert "keypoints" in ds.dims + assert "space" in ds.dims + + # Check basic attributes + assert ds.attrs["source_software"] == "EnsembleKalmanSmoother" + assert ds.attrs["fps"] == 30 + + # Check shapes are consistent + n_time, n_space, n_keypoints, n_individuals = ds.position.shape + assert ds.confidence.shape == (n_time, n_keypoints, n_individuals) + for var in ensemble_vars: + assert ds[var].shape == (n_time, n_keypoints, n_individuals) + + except ImportError: + pytest.skip("Required dependencies for EKS loading not available") + + +def test_load_from_file_eks(): + """Test that loading EKS files via from_file() works correctly.""" + try: + from pathlib import Path + file_path = Path("sub-460_strain-B6_2024-11-12T12_30_00_eks.csv") + if not file_path.exists(): + pytest.skip("EKS example file not found") + + # Test loading via from_file + ds = load_poses.from_file(file_path, source_software="EKS", fps=30) + + # Should be identical to from_eks_file + ds_direct = load_poses.from_eks_file(file_path, fps=30) + xr.testing.assert_identical(ds, ds_direct) + + except ImportError: + pytest.skip("Required dependencies for EKS loading not available") From 626b9fb796ab4680042f03ce36a0cd26df960c70 Mon Sep 17 00:00:00 2001 From: Muad Abd El Hay Date: Thu, 28 Aug 2025 10:00:03 +0200 Subject: [PATCH 2/4] Improve EKS loader to properly structure ensemble statistics MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Restructure ensemble statistics (ens_median, ens_var, posterior_var) to have (time, space, keypoints, individuals) dimensions - Stack x and y components into space dimension for consistency with position data - This ensures all data variables have matching dimension structure 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude --- movement/io/load_poses.py | 126 ++++++++++++++++----- tests/test_unit/test_io/test_load_poses.py | 32 ++++-- 2 files changed, 115 insertions(+), 43 deletions(-) mode change 100755 => 100644 movement/io/load_poses.py mode change 100755 => 100644 tests/test_unit/test_io/test_load_poses.py diff --git a/movement/io/load_poses.py b/movement/io/load_poses.py old mode 100755 new mode 100644 index abf8c0745..2382e30fa --- a/movement/io/load_poses.py +++ b/movement/io/load_poses.py @@ -383,7 +383,7 @@ def from_dlc_file( def from_eks_file( file_path: Path | str, fps: float | None = None ) -> xr.Dataset: - """Create a ``movement`` poses dataset from an EKS (Ensemble Kalman Smoother) file. + """Create a ``movement`` poses dataset from an EKS file. Parameters ---------- @@ -405,7 +405,7 @@ def from_eks_file( EKS files have a similar structure to DeepLabCut CSV files but include additional columns for ensemble statistics: - x_ens_median, y_ens_median: Median of the ensemble predictions - - x_ens_var, y_ens_var: Variance of the ensemble predictions + - x_ens_var, y_ens_var: Variance of the ensemble predictions - x_posterior_var, y_posterior_var: Posterior variance from the smoother See Also @@ -420,10 +420,10 @@ def from_eks_file( """ # Read the CSV file using similar logic to DeepLabCut file = ValidFile(file_path, expected_suffix=[".csv"]) - + # Parse the CSV to get the DataFrame with multi-index columns df = _df_from_eks_csv(file.path) - + # Extract individual and keypoint names if "individuals" in df.columns.names: individual_names = ( @@ -434,32 +434,45 @@ def from_eks_file( keypoint_names = ( df.columns.get_level_values("bodyparts").unique().to_list() ) - + # Get unique coord types coord_types = df.columns.get_level_values("coords").unique().to_list() - + # Separate main tracking data (x, y, likelihood) from ensemble stats main_coords = ["x", "y", "likelihood"] ensemble_coords = [c for c in coord_types if c not in main_coords] - - # Extract main tracking data - need to select each coord separately and concatenate + + # Extract main tracking data - need to select each coord separately # because xs doesn't support list keys x_df = df.xs("x", level="coords", axis=1) y_df = df.xs("y", level="coords", axis=1) likelihood_df = df.xs("likelihood", level="coords", axis=1) - + # Stack the coordinates back together in the right order # Shape will be (n_frames, n_individuals * n_keypoints) for each - x_data = x_df.to_numpy().flatten().reshape(-1, len(individual_names), len(keypoint_names)) - y_data = y_df.to_numpy().flatten().reshape(-1, len(individual_names), len(keypoint_names)) - likelihood_data = likelihood_df.to_numpy().flatten().reshape(-1, len(individual_names), len(keypoint_names)) - + x_data = ( + x_df.to_numpy() + .flatten() + .reshape(-1, len(individual_names), len(keypoint_names)) + ) + y_data = ( + y_df.to_numpy() + .flatten() + .reshape(-1, len(individual_names), len(keypoint_names)) + ) + likelihood_data = ( + likelihood_df.to_numpy() + .flatten() + .reshape(-1, len(individual_names), len(keypoint_names)) + ) + # Stack to create (n_frames, 3, n_keypoints, n_individuals) # where the second axis contains x, y, likelihood tracks_with_scores = np.stack([x_data, y_data, likelihood_data], axis=1) - # Transpose to match expected order (n_frames, 3, n_keypoints, n_individuals) + # Transpose to match expected order + # (n_frames, 3, n_keypoints, n_individuals) tracks_with_scores = tracks_with_scores.transpose(0, 1, 3, 2) - + # Create the base dataset ds = from_numpy( position_array=tracks_with_scores[:, :-1, :, :], @@ -469,28 +482,79 @@ def from_eks_file( fps=fps, source_software="EnsembleKalmanSmoother", ) - - # Add ensemble statistics as additional data variables + + # Group ensemble statistics by their type + # (ens_median, ens_var, posterior_var) + # Each will have x and y components as the space dimension + ensemble_stats: dict[str, dict[str, np.ndarray]] = {} + for coord in ensemble_coords: + # Parse coordinate name to extract stat type and spatial component + # e.g., "x_ens_median" -> ("ens_median", "x") + if coord.startswith("x_"): + stat_name = coord[2:] # Remove "x_" prefix + spatial_component = "x" + elif coord.startswith("y_"): + stat_name = coord[2:] # Remove "y_" prefix + spatial_component = "y" + else: + # If it doesn't follow expected pattern, handle as before + coord_df = df.xs(coord, level="coords", axis=1) + coord_data = ( + coord_df.to_numpy() + .reshape((-1, len(individual_names), len(keypoint_names))) + .transpose(0, 2, 1) + ) + da = xr.DataArray( + coord_data, + dims=["time", "keypoints", "individuals"], + coords={ + "time": ds.coords["time"], + "keypoints": ds.coords["keypoints"], + "individuals": ds.coords["individuals"], + }, + name=coord, + ) + ds[coord] = da + continue + + # Initialize the stat dict if needed + if stat_name not in ensemble_stats: + ensemble_stats[stat_name] = {} + + # Extract the data for this coordinate coord_df = df.xs(coord, level="coords", axis=1) coord_data = ( coord_df.to_numpy() .reshape((-1, len(individual_names), len(keypoint_names))) .transpose(0, 2, 1) ) - # Create xarray DataArray with proper dimensions - da = xr.DataArray( - coord_data, - dims=["time", "keypoints", "individuals"], - coords={ - "time": ds.coords["time"], - "keypoints": ds.coords["keypoints"], - "individuals": ds.coords["individuals"], - }, - name=coord, - ) - ds[coord] = da - + + # Store the data indexed by spatial component + ensemble_stats[stat_name][spatial_component] = coord_data + + # Create DataArrays with (time, space, keypoints, individuals) dims + for stat_name, spatial_data in ensemble_stats.items(): + if "x" in spatial_data and "y" in spatial_data: + # Stack x and y into space dimension + stat_array = np.stack( + [spatial_data["x"], spatial_data["y"]], axis=1 + ) # Results in (time, space=2, keypoints, individuals) + + # Create xarray DataArray with proper dimensions + da = xr.DataArray( + stat_array, + dims=["time", "space", "keypoints", "individuals"], + coords={ + "time": ds.coords["time"], + "space": ["x", "y"], + "keypoints": ds.coords["keypoints"], + "individuals": ds.coords["individuals"], + }, + name=stat_name, + ) + ds[stat_name] = da + return ds @@ -784,7 +848,7 @@ def _df_from_eks_csv(file_path: Path) -> pd.DataFrame: Parameters ---------- file_path : pathlib.Path - Path to the EKS-style .csv file containing pose tracks and ensemble stats. + Path to the EKS-style .csv file with pose tracks and ensemble stats. Returns ------- diff --git a/tests/test_unit/test_io/test_load_poses.py b/tests/test_unit/test_io/test_load_poses.py old mode 100755 new mode 100644 index f309a1829..04b13d337 --- a/tests/test_unit/test_io/test_load_poses.py +++ b/tests/test_unit/test_io/test_load_poses.py @@ -324,40 +324,47 @@ def test_load_from_eks_file(): # otherwise skip the test try: from pathlib import Path + file_path = Path("sub-460_strain-B6_2024-11-12T12_30_00_eks.csv") if not file_path.exists(): pytest.skip("EKS example file not found") - + # Load the EKS file ds = load_poses.from_eks_file(file_path, fps=30) - + # Check that it's a valid dataset with the expected structure assert isinstance(ds, xr.Dataset) assert "position" in ds.data_vars assert "confidence" in ds.data_vars - + # Check ensemble statistics are present - ensemble_vars = ['x_ens_median', 'y_ens_median', 'x_ens_var', - 'y_ens_var', 'x_posterior_var', 'y_posterior_var'] + ensemble_vars = [ + "x_ens_median", + "y_ens_median", + "x_ens_var", + "y_ens_var", + "x_posterior_var", + "y_posterior_var", + ] for var in ensemble_vars: assert var in ds.data_vars - + # Check dimensions assert "time" in ds.dims assert "individuals" in ds.dims assert "keypoints" in ds.dims assert "space" in ds.dims - + # Check basic attributes assert ds.attrs["source_software"] == "EnsembleKalmanSmoother" assert ds.attrs["fps"] == 30 - + # Check shapes are consistent n_time, n_space, n_keypoints, n_individuals = ds.position.shape assert ds.confidence.shape == (n_time, n_keypoints, n_individuals) for var in ensemble_vars: assert ds[var].shape == (n_time, n_keypoints, n_individuals) - + except ImportError: pytest.skip("Required dependencies for EKS loading not available") @@ -366,16 +373,17 @@ def test_load_from_file_eks(): """Test that loading EKS files via from_file() works correctly.""" try: from pathlib import Path + file_path = Path("sub-460_strain-B6_2024-11-12T12_30_00_eks.csv") if not file_path.exists(): pytest.skip("EKS example file not found") - + # Test loading via from_file ds = load_poses.from_file(file_path, source_software="EKS", fps=30) - + # Should be identical to from_eks_file ds_direct = load_poses.from_eks_file(file_path, fps=30) xr.testing.assert_identical(ds, ds_direct) - + except ImportError: pytest.skip("Required dependencies for EKS loading not available") From cf889f9396c0a9f5ef4d1ce559a8c6718b68319a Mon Sep 17 00:00:00 2001 From: Muad Abd El Hay Date: Thu, 28 Aug 2025 10:45:06 +0200 Subject: [PATCH 3/4] Update EKS test file to use renamed example file MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Change test filename from sub-460_strain-B6_2024-11-12T12_30_00_eks.csv to eks_output.csv - Update ensemble variable checks to match new structure (ens_median, ens_var, posterior_var) - Fix shape assertions for ensemble variables to include space dimension 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude --- tests/test_unit/test_io/test_load_poses.py | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/tests/test_unit/test_io/test_load_poses.py b/tests/test_unit/test_io/test_load_poses.py index 04b13d337..7ac916613 100644 --- a/tests/test_unit/test_io/test_load_poses.py +++ b/tests/test_unit/test_io/test_load_poses.py @@ -325,7 +325,7 @@ def test_load_from_eks_file(): try: from pathlib import Path - file_path = Path("sub-460_strain-B6_2024-11-12T12_30_00_eks.csv") + file_path = Path("eks_output.csv") if not file_path.exists(): pytest.skip("EKS example file not found") @@ -338,14 +338,7 @@ def test_load_from_eks_file(): assert "confidence" in ds.data_vars # Check ensemble statistics are present - ensemble_vars = [ - "x_ens_median", - "y_ens_median", - "x_ens_var", - "y_ens_var", - "x_posterior_var", - "y_posterior_var", - ] + ensemble_vars = ["ens_median", "ens_var", "posterior_var"] for var in ensemble_vars: assert var in ds.data_vars @@ -363,7 +356,12 @@ def test_load_from_eks_file(): n_time, n_space, n_keypoints, n_individuals = ds.position.shape assert ds.confidence.shape == (n_time, n_keypoints, n_individuals) for var in ensemble_vars: - assert ds[var].shape == (n_time, n_keypoints, n_individuals) + assert ds[var].shape == ( + n_time, + n_space, + n_keypoints, + n_individuals, + ) except ImportError: pytest.skip("Required dependencies for EKS loading not available") @@ -374,7 +372,7 @@ def test_load_from_file_eks(): try: from pathlib import Path - file_path = Path("sub-460_strain-B6_2024-11-12T12_30_00_eks.csv") + file_path = Path("eks_output.csv") if not file_path.exists(): pytest.skip("EKS example file not found") From 44480db24a330722fff7a306ffabcfbc086e7043 Mon Sep 17 00:00:00 2001 From: Muad Abd El Hay Date: Fri, 29 Aug 2025 15:08:49 +0200 Subject: [PATCH 4/4] Update EKS loader to use 'EKS' as source_software and sample data MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Change source_software attribute from 'EnsembleKalmanSmoother' to 'EKS' - Update tests to use sample data files from GIN repository - Use EKS_IBL-paw_multicam_right.predictions.csv from DATA_PATHS - Remove dependency on local eks_output.csv file 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude --- movement/io/load_poses.py | 2 +- tests/test_unit/test_io/test_load_poses.py | 24 ++++++++-------------- 2 files changed, 10 insertions(+), 16 deletions(-) diff --git a/movement/io/load_poses.py b/movement/io/load_poses.py index 2382e30fa..2905abee2 100644 --- a/movement/io/load_poses.py +++ b/movement/io/load_poses.py @@ -480,7 +480,7 @@ def from_eks_file( individual_names=individual_names, keypoint_names=keypoint_names, fps=fps, - source_software="EnsembleKalmanSmoother", + source_software="EKS", ) # Group ensemble statistics by their type diff --git a/tests/test_unit/test_io/test_load_poses.py b/tests/test_unit/test_io/test_load_poses.py index 7ac916613..735227e1b 100644 --- a/tests/test_unit/test_io/test_load_poses.py +++ b/tests/test_unit/test_io/test_load_poses.py @@ -320,15 +320,11 @@ def test_load_from_eks_file(): """Test that loading pose tracks from an EKS CSV file returns a proper Dataset with ensemble statistics. """ - # Use the provided example EKS file if available, - # otherwise skip the test - try: - from pathlib import Path - - file_path = Path("eks_output.csv") - if not file_path.exists(): - pytest.skip("EKS example file not found") + file_path = DATA_PATHS.get("EKS_IBL-paw_multicam_right.predictions.csv") + if file_path is None: + pytest.skip("EKS example file not found") + try: # Load the EKS file ds = load_poses.from_eks_file(file_path, fps=30) @@ -349,7 +345,7 @@ def test_load_from_eks_file(): assert "space" in ds.dims # Check basic attributes - assert ds.attrs["source_software"] == "EnsembleKalmanSmoother" + assert ds.attrs["source_software"] == "EKS" assert ds.attrs["fps"] == 30 # Check shapes are consistent @@ -369,13 +365,11 @@ def test_load_from_eks_file(): def test_load_from_file_eks(): """Test that loading EKS files via from_file() works correctly.""" - try: - from pathlib import Path - - file_path = Path("eks_output.csv") - if not file_path.exists(): - pytest.skip("EKS example file not found") + file_path = DATA_PATHS.get("EKS_IBL-paw_multicam_right.predictions.csv") + if file_path is None: + pytest.skip("EKS example file not found") + try: # Test loading via from_file ds = load_poses.from_file(file_path, source_software="EKS", fps=30)