Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
234 changes: 230 additions & 4 deletions movement/io/load_poses.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def from_file(
"LightningPose",
"Anipose",
"NWB",
"EKS",
],
fps: float | None = None,
**kwargs,
Expand All @@ -112,11 +113,11 @@ def from_file(
file_path : pathlib.Path or str
Path to the file containing predicted poses. The file format must
be among those supported by the ``from_dlc_file()``,
``from_slp_file()`` or ``from_lp_file()`` functions. One of these
these functions will be called internally, based on
``from_slp_file()``, ``from_lp_file()``, ``from_eks_file()`` functions.
One of these functions will be called internally, based on
Comment on lines 115 to +117
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I realised that it's starting to become a bit unwieldy to list all the functions here (we've already neglected to add some), so let's take this chance to instead point users to the "See Also" list, which is up-to-date.

Suggested change
be among those supported by the ``from_dlc_file()``,
``from_slp_file()`` or ``from_lp_file()`` functions. One of these
these functions will be called internally, based on
``from_slp_file()``, ``from_lp_file()``, ``from_eks_file()`` functions.
One of these functions will be called internally, based on
be among those supported by the software-specific loading functions
that are listed under "See Also".

the value of ``source_software``.
source_software : {"DeepLabCut", "SLEAP", "LightningPose", "Anipose", \
"NWB"}
"NWB", "EKS"}
The source software of the file.
fps : float, optional
The number of frames per second in the video. If None (default),
Expand All @@ -141,6 +142,7 @@ def from_file(
movement.io.load_poses.from_lp_file
movement.io.load_poses.from_anipose_file
movement.io.load_poses.from_nwb_file
movement.io.load_poses.from_eks_file

Examples
--------
Expand All @@ -166,6 +168,8 @@ def from_file(
"metadata in the file."
)
return from_nwb_file(file_path, **kwargs)
elif source_software == "EKS":
return from_eks_file(file_path, fps)
else:
raise logger.error(
ValueError(f"Unsupported source software: {source_software}")
Expand Down Expand Up @@ -376,9 +380,187 @@ def from_dlc_file(
)


def from_eks_file(
file_path: Path | str, fps: float | None = None
) -> xr.Dataset:
"""Create a ``movement`` poses dataset from an EKS file.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe under this first like we could spell out that EKS stands for Ensemble Kalman Smoother and point to their GitHub repository (you can do this with a reference, as is done for the from_nwb_file() function).


Parameters
----------
file_path : pathlib.Path or str
Path to the EKS .csv file containing the predicted poses and
ensemble statistics.
fps : float, optional
The number of frames per second in the video. If None (default),
the ``time`` coordinates will be in frame numbers.

Returns
-------
xarray.Dataset
``movement`` dataset containing the pose tracks, confidence scores,
ensemble statistics, and associated metadata.

Notes
-----
EKS files have a similar structure to DeepLabCut CSV files but include
additional columns for ensemble statistics:
Comment on lines +405 to +406
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To match the spelling of .csv earlier in the docstring (and in other docstrings).

Suggested change
EKS files have a similar structure to DeepLabCut CSV files but include
additional columns for ensemble statistics:
EKS files have a similar structure to DeepLabCut .csv files but include
additional columns for ensemble statistics:

- x_ens_median, y_ens_median: Median of the ensemble predictions
- x_ens_var, y_ens_var: Variance of the ensemble predictions
- x_posterior_var, y_posterior_var: Posterior variance from the smoother
Comment on lines +407 to +409
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you want the bullet point list to render properly in Sphinx .rst you have to leave a blank line. Moreover, and this is more of a subjective stylistic choice, I would make the column names monospace

Suggested change
- x_ens_median, y_ens_median: Median of the ensemble predictions
- x_ens_var, y_ens_var: Variance of the ensemble predictions
- x_posterior_var, y_posterior_var: Posterior variance from the smoother
- ``x_ens_median``, ``y_ens_median``: Median of the ensemble predictions
- ``x_ens_var``, ``y_ens_var``: Variance of the ensemble predictions
- ``x_posterior_var``, ``y_posterior_var``: Posterior variance from
the kalman smoother


See Also
--------
movement.io.load_poses.from_dlc_file

Examples
--------
>>> from movement.io import load_poses
>>> ds = load_poses.from_eks_file("path/to/file_eks.csv", fps=30)

"""
# Read the CSV file using similar logic to DeepLabCut
file = ValidFile(file_path, expected_suffix=[".csv"])

# Parse the CSV to get the DataFrame with multi-index columns
df = _df_from_eks_csv(file.path)

# Extract individual and keypoint names
if "individuals" in df.columns.names:
individual_names = (
df.columns.get_level_values("individuals").unique().to_list()
)
else:
individual_names = ["individual_0"]
keypoint_names = (
df.columns.get_level_values("bodyparts").unique().to_list()
)

# Get unique coord types
coord_types = df.columns.get_level_values("coords").unique().to_list()

# Separate main tracking data (x, y, likelihood) from ensemble stats
main_coords = ["x", "y", "likelihood"]
ensemble_coords = [c for c in coord_types if c not in main_coords]

# Extract main tracking data - need to select each coord separately
# because xs doesn't support list keys
x_df = df.xs("x", level="coords", axis=1)
y_df = df.xs("y", level="coords", axis=1)
likelihood_df = df.xs("likelihood", level="coords", axis=1)

# Stack the coordinates back together in the right order
# Shape will be (n_frames, n_individuals * n_keypoints) for each
x_data = (
x_df.to_numpy()
.flatten()
.reshape(-1, len(individual_names), len(keypoint_names))
)
y_data = (
y_df.to_numpy()
.flatten()
.reshape(-1, len(individual_names), len(keypoint_names))
)
likelihood_data = (
likelihood_df.to_numpy()
.flatten()
.reshape(-1, len(individual_names), len(keypoint_names))
)

# Stack to create (n_frames, 3, n_keypoints, n_individuals)
# where the second axis contains x, y, likelihood
tracks_with_scores = np.stack([x_data, y_data, likelihood_data], axis=1)
# Transpose to match expected order
# (n_frames, 3, n_keypoints, n_individuals)
tracks_with_scores = tracks_with_scores.transpose(0, 1, 3, 2)

# Create the base dataset
ds = from_numpy(
position_array=tracks_with_scores[:, :-1, :, :],
confidence_array=tracks_with_scores[:, -1, :, :],
individual_names=individual_names,
keypoint_names=keypoint_names,
fps=fps,
source_software="EKS",
)
Comment on lines +442 to +484
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You have already nicely separated the coordinates that are shared with DLC from the ensemble coords.
This means we have a great opportunity to reduce code here, you can simply use the existing from_dlc_style_df function to do most of the work for you. This is how I would do it.

Suggested change
main_coords = ["x", "y", "likelihood"]
ensemble_coords = [c for c in coord_types if c not in main_coords]
# Extract main tracking data - need to select each coord separately
# because xs doesn't support list keys
x_df = df.xs("x", level="coords", axis=1)
y_df = df.xs("y", level="coords", axis=1)
likelihood_df = df.xs("likelihood", level="coords", axis=1)
# Stack the coordinates back together in the right order
# Shape will be (n_frames, n_individuals * n_keypoints) for each
x_data = (
x_df.to_numpy()
.flatten()
.reshape(-1, len(individual_names), len(keypoint_names))
)
y_data = (
y_df.to_numpy()
.flatten()
.reshape(-1, len(individual_names), len(keypoint_names))
)
likelihood_data = (
likelihood_df.to_numpy()
.flatten()
.reshape(-1, len(individual_names), len(keypoint_names))
)
# Stack to create (n_frames, 3, n_keypoints, n_individuals)
# where the second axis contains x, y, likelihood
tracks_with_scores = np.stack([x_data, y_data, likelihood_data], axis=1)
# Transpose to match expected order
# (n_frames, 3, n_keypoints, n_individuals)
tracks_with_scores = tracks_with_scores.transpose(0, 1, 3, 2)
# Create the base dataset
ds = from_numpy(
position_array=tracks_with_scores[:, :-1, :, :],
confidence_array=tracks_with_scores[:, -1, :, :],
individual_names=individual_names,
keypoint_names=keypoint_names,
fps=fps,
source_software="EKS",
)
dlc_coords = ["x", "y", "likelihood"]
ensemble_coords = [c for c in coord_types if c not in dlc_coords]
# Filter the DataFrame to keep only coords shared with DLC
df_main = df.loc[:, df.columns.get_level_values("coords").isin(dlc_coords)]
# Now we can use the from_dlc_style_df function
ds = from_dlc_style_df(df_main, fps=fps, source_software="EKS")

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Btw, if you take this suggestion of mine, together with the next suggestions, there will no longer be a need for creating lists for individual_names and keypoint_names.


# Group ensemble statistics by their type
# (ens_median, ens_var, posterior_var)
# Each will have x and y components as the space dimension
ensemble_stats: dict[str, dict[str, np.ndarray]] = {}

for coord in ensemble_coords:
# Parse coordinate name to extract stat type and spatial component
# e.g., "x_ens_median" -> ("ens_median", "x")
if coord.startswith("x_"):
stat_name = coord[2:] # Remove "x_" prefix
spatial_component = "x"
elif coord.startswith("y_"):
stat_name = coord[2:] # Remove "y_" prefix
spatial_component = "y"
else:
# If it doesn't follow expected pattern, handle as before
coord_df = df.xs(coord, level="coords", axis=1)
coord_data = (
coord_df.to_numpy()
.reshape((-1, len(individual_names), len(keypoint_names)))
.transpose(0, 2, 1)
)
da = xr.DataArray(
coord_data,
dims=["time", "keypoints", "individuals"],
coords={
"time": ds.coords["time"],
"keypoints": ds.coords["keypoints"],
"individuals": ds.coords["individuals"],
},
name=coord,
)
ds[coord] = da
continue
Comment on lines +501 to +519
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the purpose of loading these extra coordinates? Do we really expect the csv files to contain extra columns other than x, y, likelihood and the 6 ensemble stats? I would probably prefer to skip any extra columns, with a warning:

Suggested change
# If it doesn't follow expected pattern, handle as before
coord_df = df.xs(coord, level="coords", axis=1)
coord_data = (
coord_df.to_numpy()
.reshape((-1, len(individual_names), len(keypoint_names)))
.transpose(0, 2, 1)
)
da = xr.DataArray(
coord_data,
dims=["time", "keypoints", "individuals"],
coords={
"time": ds.coords["time"],
"keypoints": ds.coords["keypoints"],
"individuals": ds.coords["individuals"],
},
name=coord,
)
ds[coord] = da
continue
# If it doesn't follow expected pattern, pass and issue a warning
logger.warning(
f"Unexpected coordinate {coord} was skipped."
"Expected coordinates are: 'x', 'y', 'likelihood', "
"'x_ens_median', 'y_ens_median', 'x_ens_var', 'y_ens_var', "
"'x_posterior_var', 'y_posterior_var'."
)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or if you prefer, we can check for those unexpected coordinates much earlier in the code, i.e. where you separate coord_types

dlc_coords = ["x", "y", "likelihood"]
# We can explicitly name ensemble coords, because we expect them
ensemble_coords = [
    "x_ens_median", "y_ens_median",
    "x_ens_var", "y_ens_var",
    "x_posterior_var", "y_posterior_var"
]
unexpected_coords = [
    c for c in coord_types if c not in dlc_coords + ensemble_coords
]
if unexpected_coords:
    logger.warning(
        f"Unexpected coordinates found in the DataFrame: {unexpected_coords}. "
        "These will be skipped."
    )


# Initialize the stat dict if needed
if stat_name not in ensemble_stats:
ensemble_stats[stat_name] = {}

# Extract the data for this coordinate
coord_df = df.xs(coord, level="coords", axis=1)
coord_data = (
coord_df.to_numpy()
.reshape((-1, len(individual_names), len(keypoint_names)))
.transpose(0, 2, 1)
)

# Store the data indexed by spatial component
ensemble_stats[stat_name][spatial_component] = coord_data

# Create DataArrays with (time, space, keypoints, individuals) dims
for stat_name, spatial_data in ensemble_stats.items():
if "x" in spatial_data and "y" in spatial_data:
# Stack x and y into space dimension
stat_array = np.stack(
[spatial_data["x"], spatial_data["y"]], axis=1
) # Results in (time, space=2, keypoints, individuals)

# Create xarray DataArray with proper dimensions
da = xr.DataArray(
stat_array,
dims=["time", "space", "keypoints", "individuals"],
coords={
"time": ds.coords["time"],
"space": ["x", "y"],
"keypoints": ds.coords["keypoints"],
"individuals": ds.coords["individuals"],
},
name=stat_name,
)
ds[stat_name] = da
Comment on lines +486 to +556
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe there is a more compact way of achieving what you are doing in lines 486 - 556:

    # Add ensemble statistics to the dataset
    for stat in ["ens_median", "ens_var", "posterior_var"]:
        # (time, keypoints, individuals) for each of x and y
        stat_arrays_per_spatial_component = [
            (
                df.xs(f"{c}_{stat}", level="coords", axis=1)
                .to_numpy()
                .reshape((-1, ds.sizes["individuals"], ds.sizes["keypoints"]))
                .transpose(0, 2, 1)
            )
            for c in ["x", "y"]
        ]
        # Stack the arrays into (time, space=2, keypoints, individuals)
        stat_array = np.stack(stat_arrays_per_spatial_component, axis=1)
        # Create a DataArray for the ensemble statistic
        ds[stat] = xr.DataArray(
            stat_array,
            dims=["time", "space", "keypoints", "individuals"],
            coords={
                "time": ds.coords["time"],
                "space": ["x", "y"],
                "keypoints": ds.coords["keypoints"],
                "individuals": ds.coords["individuals"],
            },
        )

Your tests still pass with my version of the code btw, so I think it's equivalent.


return ds


def from_multiview_files(
file_path_dict: dict[str, Path | str],
source_software: Literal["DeepLabCut", "SLEAP", "LightningPose"],
source_software: Literal["DeepLabCut", "SLEAP", "LightningPose", "EKS"],
fps: float | None = None,
) -> xr.Dataset:
"""Load and merge pose tracking data from multiple views (cameras).
Expand Down Expand Up @@ -657,6 +839,50 @@ def _df_from_dlc_csv(file_path: Path) -> pd.DataFrame:
return df


def _df_from_eks_csv(file_path: Path) -> pd.DataFrame:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As far as I can tell this private function is unnecessary, because it's almost identical to the existing _df_from_dlc_csv function. The only difference is the use of the ValidDeepLabCutCSV validator you are not using here. But I would say that you can simply use _df_from_dlc_csv as is. The ValidDeepLabCutCSV validator won't hinder you, as it only checks that the expected multi-index levels exist. As these are the same between DLC and EKS, it's actually better to use tha validator, so you are completely fine using _df_from_dlc_csv and getting rid of _df_from_eks_csv

"""Create an EKS-style DataFrame from a .csv file.

EKS CSV files have a similar structure to DeepLabCut CSV files but
with additional columns for ensemble statistics.

Parameters
----------
file_path : pathlib.Path
Path to the EKS-style .csv file with pose tracks and ensemble stats.

Returns
-------
pandas.DataFrame
EKS-style DataFrame with multi-index columns.

"""
# EKS CSV has the same header structure as DeepLabCut CSV
possible_level_names = ["scorer", "individuals", "bodyparts", "coords"]
with open(file_path) as f:
# if line starts with a possible level name, split it into a list
# of strings, and add it to the list of header lines
header_lines = [
line.strip().split(",")
for line in f.readlines()
if line.split(",")[0] in possible_level_names
]
# Form multi-index column names from the header lines
level_names = [line[0] for line in header_lines]
column_tuples = list(
zip(*[line[1:] for line in header_lines], strict=False)
)
columns = pd.MultiIndex.from_tuples(column_tuples, names=level_names)
# Import the EKS poses as a DataFrame
df = pd.read_csv(
file_path,
skiprows=len(header_lines),
index_col=0,
names=np.array(columns),
)
df.columns.rename(level_names, inplace=True)
return df


def _df_from_dlc_h5(file_path: Path) -> pd.DataFrame:
"""Create a DeepLabCut-style DataFrame from a .h5 file.

Expand Down
65 changes: 65 additions & 0 deletions tests/test_unit/test_io/test_load_poses.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,3 +314,68 @@ def test_load_from_nwb_file(input_type, kwargs, request):
if input_type == "nwb_file":
expected_attrs["source_file"] = nwb_file
assert ds_from_file_path.attrs == expected_attrs


def test_load_from_eks_file():
"""Test that loading pose tracks from an EKS CSV file
returns a proper Dataset with ensemble statistics.
"""
file_path = DATA_PATHS.get("EKS_IBL-paw_multicam_right.predictions.csv")
if file_path is None:
pytest.skip("EKS example file not found")
Comment on lines +324 to +325
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We no longer need these two lines.


try:
# Load the EKS file
ds = load_poses.from_eks_file(file_path, fps=30)

# Check that it's a valid dataset with the expected structure
assert isinstance(ds, xr.Dataset)
assert "position" in ds.data_vars
assert "confidence" in ds.data_vars

# Check ensemble statistics are present
ensemble_vars = ["ens_median", "ens_var", "posterior_var"]
for var in ensemble_vars:
assert var in ds.data_vars

# Check dimensions
assert "time" in ds.dims
assert "individuals" in ds.dims
assert "keypoints" in ds.dims
assert "space" in ds.dims

# Check basic attributes
assert ds.attrs["source_software"] == "EKS"
assert ds.attrs["fps"] == 30

# Check shapes are consistent
n_time, n_space, n_keypoints, n_individuals = ds.position.shape
assert ds.confidence.shape == (n_time, n_keypoints, n_individuals)
for var in ensemble_vars:
assert ds[var].shape == (
n_time,
n_space,
n_keypoints,
n_individuals,
)

except ImportError:
pytest.skip("Required dependencies for EKS loading not available")
Comment on lines +362 to +363
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would also skip the try-except block here. If something is missing, the test will fail and let us know that way.



def test_load_from_file_eks():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could make this test much more succint by re-using one of our helper fixtures:

def test_load_from_eks_file(helpers):
    """Test that loading pose tracks from an EKS CSV file
    returns a proper Dataset with ensemble statistics.
    """
    file_path = DATA_PATHS.get("EKS_IBL-paw_multicam_right.predictions.csv")
    # Load the EKS file
    ds = load_poses.from_eks_file(file_path, fps=30)

    expected_values_eks = {
        "vars_dims": {
            "position": 4,
            "confidence": 3,
            "ens_median": 4,
            "ens_var": 4,
            "posterior_var": 4
        },
        "dim_names": ValidPosesDataset.DIM_NAMES,
        "source_software": "EKS",
        "fps": 30,
    }

    helpers.assert_valid_dataset(ds, expected_values_eks)

The helpers.assert_valid_dataset() will do the exact same check you were doing anyway

"""Test that loading EKS files via from_file() works correctly."""
file_path = DATA_PATHS.get("EKS_IBL-paw_multicam_right.predictions.csv")
if file_path is None:
pytest.skip("EKS example file not found")

try:
# Test loading via from_file
ds = load_poses.from_file(file_path, source_software="EKS", fps=30)

# Should be identical to from_eks_file
ds_direct = load_poses.from_eks_file(file_path, fps=30)
xr.testing.assert_identical(ds, ds_direct)

except ImportError:
pytest.skip("Required dependencies for EKS loading not available")
Comment on lines +366 to +381
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of adding this extra test, I recommend adding the EKS option to the existing test_from_file_delegates_correctly test in the same file. That's the one we use for ensuring that from_file() delegates to the correct software-specific loader.

Loading