diff --git a/package/CHANGELOG b/package/CHANGELOG index 5d15333ae67..723289c14e8 100644 --- a/package/CHANGELOG +++ b/package/CHANGELOG @@ -14,14 +14,18 @@ The rules for this file: ------------------------------------------------------------------------------- -??/??/?? IAlibay, orbeckst - +??/??/?? IAlibay, orbeckst, yuxuanzhuang * 2.10.0 Fixes Enhancements + * Added _run_slicer to make it possible to retrieve information + of the full trajectory slice in AnalysisBase. (Issue #4891 PR #4892) + * Added _run_frame_index to to keep track of frame iteration + number for the full trajectory slice in `single_frame` in AnalysisBase + (Issue #4891 PR #4892) Changes @@ -64,6 +68,7 @@ Changes * Codebase is now formatted with black (version `24`) (PR #4886) + 11/11/24 IAlibay, HeetVekariya, marinegor, lilyminium, RMeli, ljwoods2, aditya292002, pstaerk, PicoCentauri, BFedder, tyler.je.reddy, SampurnaM, leonwehrhan, kainszs, orionarcher, diff --git a/package/MDAnalysis/analysis/base.py b/package/MDAnalysis/analysis/base.py index 675c6d6967b..2d8bc405ed6 100644 --- a/package/MDAnalysis/analysis/base.py +++ b/package/MDAnalysis/analysis/base.py @@ -152,7 +152,8 @@ def _get_aggregator(self): import logging import warnings from functools import partial -from typing import Iterable, Union +from typing import Iterable, Union, Optional, List +from dataclasses import dataclass import numpy as np from .. import coordinates @@ -170,6 +171,29 @@ def _get_aggregator(self): logger = logging.getLogger(__name__) +@dataclass(frozen=True) +class RunConfig: + """Stores user-provided arguments for `run()`.""" + start: Optional[int] = None + stop: Optional[int] = None + step: Optional[int] = None + frames: Optional[np.ndarray] = None + backend: Optional[Union[str, BackendBase]] = None + n_workers: Optional[int] = None + n_parts: Optional[int] = None + unsupported_backend: bool = False + + +@dataclass +class RunState: + """Stores runtime-generated attributes that can be used + during the analysis.""" + slicer: Optional[Union[slice, np.ndarray]] = None + n_frames: Optional[int] = None + computation_groups: Optional[List[np.ndarray]] = None + frame_index: Optional[int] = None + + class AnalysisBase(object): r"""Base class for defining multi-frame analysis @@ -441,8 +465,14 @@ def _setup_frames( each of the workers and gets executed twice: one time in :meth:`_setup_frames` for the whole trajectory, second time in :meth:`_compute` for each of the computation groups. + + .. versionchanged:: 2.9.0 + Add `self._run_slicer` attribute to store the slicer for the + whole trajectory being analyzed. """ slicer = self._define_run_frames(trajectory, start, stop, step, frames) + self.run_state.slicer = slicer + self.run_state.n_frames = len(trajectory[slicer]) self._prepare_sliced_trajectory(slicer) def _single_frame(self): @@ -452,6 +482,10 @@ def _single_frame(self): Attributes accessible during your calculations: - ``self._frame_index``: index of the frame in results array + Note that this is not the same as the frame number in the trajectory + - ``self._run_frame_index``: index of the frame in the trajectory + This is useful for parallel runs, where you can't rely on the + `self._frame_index`. - ``self._ts`` -- Timestep instance - ``self._sliced_trajectory`` -- trajectory that you're iterating over - ``self.results`` -- :class:`MDAnalysis.analysis.results.Results` instance @@ -537,6 +571,7 @@ def _compute( ) ): self._frame_index = idx # accessed later by subclasses + self.run_state.frame_index = indexed_frames[idx, 0] self._ts = ts self.frames[idx] = ts.frame self.times[idx] = ts.time @@ -778,7 +813,7 @@ def run( By default, performs calculations in a serial fashion. Otherwise, user can choose a backend: ``str`` is matched to a builtin backend (one of ``serial``, ``multiprocessing`` and - ``dask``), or a :class:`MDAnalysis.analysis.results.BackendBase` + ``dask``), or a :class:`MDAnalysis.analysis.backends.BackendBase` subclass. .. versionadded:: 2.8.0 @@ -854,6 +889,17 @@ def run( f"{executor.n_workers=} is greater than {n_parts=}" ) ) + + self._run_config = RunConfig( + start=start, + stop=stop, + step=step, + frames=frames, + backend=backend, + n_workers=n_workers, + n_parts=n_parts, + unsupported_backend=unsupported_backend, + ) # start preparing the run worker_func = partial( @@ -871,6 +917,7 @@ def run( computation_groups = self._setup_computation_groups( start=start, stop=stop, step=step, frames=frames, n_parts=n_parts ) + self.run_state.computation_groups = computation_groups # get all results from workers in other processes. # we need `AnalysisBase` classes @@ -902,6 +949,34 @@ def _get_aggregator(self) -> ResultsGroup: .. versionadded:: 2.8.0 """ return ResultsGroup(lookup=None) + + @property + def run_config(self) -> RunConfig: + """Stores user-provided arguments for `run()`. + It includes `start`, `stop`, `step`, `frames`, `backend`, `n_workers`, + `n_parts` and `unsupported_backend` attributes. + """ + return self._run_config + + @property + def run_state(self) -> RunState: + """ + Stores runtime-generated attributes that can be used + during the analysis. + + It includes `slicer`, `n_frames`, `computation_groups` and `frame_index` + attributes. + + The `slicer`, `n_frames`, `frame_index` attributes are used to store the + information for the whole trajectory being analyzed. + They are different from e.g. `self.n_frames` which is used to store the + information for the current computation group being analyzed. + """ + # lazy initialization for the OldAPIAnalysis as it doesn't have the same + # `__init__` function as the current AnalysisBase. + if not hasattr(self, "_run_state"): + self._run_state = RunState() + return self._run_state class AnalysisFromFunction(AnalysisBase): diff --git a/package/doc/sphinx/source/documentation_pages/analysis/parallelization.rst b/package/doc/sphinx/source/documentation_pages/analysis/parallelization.rst index 3070614b5a3..6271a780103 100644 --- a/package/doc/sphinx/source/documentation_pages/analysis/parallelization.rst +++ b/package/doc/sphinx/source/documentation_pages/analysis/parallelization.rst @@ -118,22 +118,29 @@ For MDAnalysis developers From a developer point of view, there are a few methods that are important in order to understand how parallelization is implemented: -#. :meth:`MDAnalysis.analysis.base.AnalysisBase._define_run_frames` +#. :meth:`MDAnalysis.analysis.base.AnalysisBase._setup_frames` #. :meth:`MDAnalysis.analysis.base.AnalysisBase._prepare_sliced_trajectory` #. :meth:`MDAnalysis.analysis.base.AnalysisBase._configure_backend` #. :meth:`MDAnalysis.analysis.base.AnalysisBase._setup_computation_groups` #. :meth:`MDAnalysis.analysis.base.AnalysisBase._compute` #. :meth:`MDAnalysis.analysis.base.AnalysisBase._get_aggregator` -The first two methods share the functionality of :meth:`_setup_frames`. -:meth:`_define_run_frames` is run once during analysis, as it checks that input -parameters `start`, `stop`, `step` or `frames` are consistent with the given -trajectory and prepares the ``slicer`` object that defines the iteration -pattern through the trajectory. :meth:`_prepare_sliced_trajectory` assigns to +:meth:`_setup_frames` is run once during analysis :attr:`run()`, as it checks that input +parameters :attr:`start`, :attr:`stop`, :attr:`step` or :attr:`frames` are consistent with the given +trajectory and prepares the :attr:`slicer` object that defines the iteration +pattern through the trajectory with :meth:`_define_run_frames`. +The attribute :attr:`self._run_slicer` is assigned based on the `slicer`. +Users can later access the full sliced trajectory being analyzed via +:attr:`self._trajectory[self._run_slicer]`. + +:meth:`_prepare_sliced_trajectory` assigns to the :attr:`self._sliced_trajectory` attribute, computes the number of frames in it, and fills the :attr:`self.frames` and :attr:`self.times` arrays. In case the computation will be later split between other processes, this method will -be called again on each of the computation groups. +be called again on each of the computation groups. In parallel analysis, +:attr:`self._sliced_trajectory` represents a split of the original sliced +trajectory, and :attr:`self.n_frames` is the number of frames in each split +computation group (not the total number of frames in the sliced trajectory). The method :meth:`_configure_backend` performs basic health checks for a given analysis class -- namely, it compares a given backend (if it's a :class:`str` @@ -155,7 +162,14 @@ analysis get initialized with the :meth:`_prepare` method. Then the function iterates over :attr:`self._sliced_trajectory`, assigning :attr:`self._frame_index` and :attr:`self._ts` as frame index (within a computation group) and timestamp, and also setting respective -:attr:`self.frames` and :attr:`self.times` array values. +:attr:`self.frames` and :attr:`self.times` array values. Additionally, +:attr:`self._run_frame_index` is assigned the run frame index +within the full sliced trajectory (:attr:`self._trajectory[self._run_slicer]`) +that is being analyzed. +This run frame index is particularly useful for analyses requiring it, such as +:class:`MDAnalysis.analysis.diffusionmap.DistanceMatrix` that needs to know the +frame index in the trajectory sliced that is being analyzed. +See :ref:`retrieving-correct-frame-index` for more details. After :meth:`_compute` has finished, the main analysis instance calls the :meth:`_get_aggregator` method, which merges the :attr:`self.results` @@ -357,6 +371,82 @@ In this way, you will override the check for supported backends. with a supported backend. When reporting *always mention if you used* ``unsupported_backend=True``. +.. _retrieving-correct-frame-index: +Retrieving correct frame index in parallel analysis +=================================================== + +To retrieve the correct frame index during parallel analysis, use the +:attr:`self._run_frame_index` attribute. This attribute represents the correct +frame index within the full sliced trajectory +(:attr:`self._trajectory[self._run_slicer]`). + +For an example illustrating when to use :attr:`_frame_index` versus +:attr:`_run_frame_index` and :attr:`self._run_slicer`, +see the following code snippet: + +.. code-block:: python + + from MDAnalysis.analysis.base import AnalysisBase + from MDAnalysis.analysis.results import ResultsGroup + + class MyAnalysis(AnalysisBase): + _analysis_algorithm_is_parallelizable = True + + @classmethod + def get_supported_backends(cls): + """Define the supported backends for the analysis.""" + return ('serial', 'multiprocessing', 'dask') + + def _prepare(self): + """Initialize result attributes and compute frame count.""" + self.results.frame_index = [] + self.results.run_frame_index = [] + self.results.n_frames = [] + self.results.run_n_frames = [] + self.run_n_frames = len(self._trajectory[self._run_slicer]) + + def _single_frame(self): + """Process a single frame during the analysis.""" + frame_index = self._frame_index + run_frame_index = self._run_frame_index + + # Append results for the current frame + self.results.frame_index.append(frame_index) + self.results.run_frame_index.append(run_frame_index) + self.results.n_frames.append(self.n_frames) + self.results.run_n_frames.append(self.run_n_frames) + + def _get_aggregator(self): + """Return an aggregator to combine results from multiple workers.""" + return ResultsGroup( + lookup={ + 'frame_index': ResultsGroup.flatten_sequence, + 'run_frame_index': ResultsGroup.flatten_sequence, + 'n_frames': ResultsGroup.flatten_sequence, + 'run_n_frames': ResultsGroup.flatten_sequence, + } + ) + + # Example usage: serial analysis + ana = MyAnalysis(u.trajectory) + ana.run(step=2) + print(ana.results) + # Output: + # {'frame_index': [0, 1, 2, 3, 4], + # 'run_frame_index': [0, 1, 2, 3, 4], + # 'n_frames': [5, 5, 5, 5, 5], + # 'run_n_frames': [5, 5, 5, 5, 5]} + + # Example usage: parallel analysis + ana = MyAnalysis(u.trajectory) + ana.run(step=2, backend='dask', n_workers=2) + print(ana.results) + # Output: + # {'frame_index': [0, 1, 2, 0, 1], + # 'run_frame_index': [0, 1, 2, 3, 4], + # 'n_frames': [3, 3, 3, 2, 2], + # 'run_n_frames': [5, 5, 5, 5, 5]} + .. rubric:: References .. footbibliography:: diff --git a/testsuite/MDAnalysisTests/analysis/test_base.py b/testsuite/MDAnalysisTests/analysis/test_base.py index e369c4c6021..042a13ca44b 100644 --- a/testsuite/MDAnalysisTests/analysis/test_base.py +++ b/testsuite/MDAnalysisTests/analysis/test_base.py @@ -48,16 +48,36 @@ def __init__(self, reader, **kwargs): def _prepare(self): self.results.found_frames = [] + self.results.frame_index = [] + self.results.run_frame_index = [] + self.results.n_frames = [] + self.results.run_n_frames = [] + + # self.n_frames is defined elsewhere + self.run_n_frames = len(self._trajectory[self.run_state.slicer]) def _single_frame(self): + frame_index = self._frame_index + run_frame_index = self.run_state.frame_index + self.results.found_frames.append(self._ts.frame) + self.results.frame_index.append(frame_index) + self.results.run_frame_index.append(run_frame_index) + self.results.n_frames.append(self.n_frames) + self.results.run_n_frames.append(self.run_n_frames) def _conclude(self): self.found_frames = list(self.results.found_frames) def _get_aggregator(self): return base.ResultsGroup( - {"found_frames": base.ResultsGroup.ndarray_hstack} + { + "found_frames": base.ResultsGroup.ndarray_hstack, + "frame_index": base.ResultsGroup.ndarray_hstack, + "run_frame_index": base.ResultsGroup.ndarray_hstack, + "n_frames": base.ResultsGroup.ndarray_hstack, + "run_n_frames": base.ResultsGroup.ndarray_hstack, + } ) @@ -450,12 +470,17 @@ def test_frames_times(client_FrameAnalysis): start=1, stop=8, step=2, **client_FrameAnalysis ) frames = np.array([1, 3, 5, 7]) - assert an.n_frames == len(frames) + n_frames = len(frames) + frame_indices = np.arange(n_frames) + + assert an.n_frames == n_frames assert_equal(an.found_frames, frames) assert_equal(an.frames, frames, err_msg=FRAMES_ERR) assert_allclose( an.times, frames * 100, rtol=0, atol=1.5e-4, err_msg=TIMES_ERR ) + assert_equal(an.results.run_frame_index, frame_indices) + assert_equal(an.results.run_n_frames, [n_frames] * n_frames) def test_verbose(u):