diff --git a/docs/user-guide/parameter-tables.ipynb b/docs/user-guide/parameter-tables.ipynb index 08ca020a..dde7e028 100644 --- a/docs/user-guide/parameter-tables.ipynb +++ b/docs/user-guide/parameter-tables.ipynb @@ -229,7 +229,11 @@ "metadata": {}, "outputs": [], "source": [ - "graph = pipeline.reduce(func=lambda *result: sum(result), name='merged').get('merged')\n", + "def merge(*data):\n", + " return sum(data)\n", + "\n", + "\n", + "graph = pipeline.reduce(func=merge, name='merged').get('merged')\n", "graph.visualize()" ] }, @@ -294,8 +298,6 @@ "source": [ "## Grouping intermediate results based on secondary parameters\n", "\n", - "**Cyclebane and Sciline do not support `groupby` yet, this is work in progress so this example is not functional yet.**\n", - "\n", "This chapter illustrates how to implement *groupby* operations with Sciline.\n", "\n", "Continuing from the examples for *map* and *reduce*, we can introduce a secondary parameter in the table, such as the material of the sample:" @@ -307,13 +309,11 @@ "metadata": {}, "outputs": [], "source": [ - "Material = NewType('Material', str)\n", - "\n", "run_ids = [102, 103, 104, 105]\n", "sample = ['diamond', 'graphite', 'graphite', 'graphite']\n", "filenames = [f'file{i}.txt' for i in run_ids]\n", "param_table = pd.DataFrame(\n", - " {Filename: filenames, Material: sample}, index=run_ids\n", + " {Filename: filenames, 'Material': sample}, index=run_ids\n", ").rename_axis(index='run_id')\n", "param_table" ] @@ -322,15 +322,172 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Future releases of Sciline will support a `groupby` operation, roughly as follows:\n", + "We then group the results by sample material,\n", + "and reduce the data using the same merge function as before.\n", "\n", - "```python\n", - "pipeline = base.map(param_table).groupby(Material).reduce(func=merge)\n", - "```\n", + "The end goal being to obtain two end results; one for each material." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "grouped = (\n", + " base.map(param_table)\n", + " .groupby('Material')\n", + " .reduce(key=Result, func=merge, name='merged')\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "graph = grouped.get(sciline.get_mapped_node_names(grouped, 'merged'))\n", + "graph.visualize()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sciline.compute_mapped(grouped, 'merged')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Grouping early in the graph\n", + "\n", + "Sometimes, it is also desirable to apply grouping earlier in the pipeline graph.\n", + "In this example, we wish to combine the raw data before cleaning and computing the result.\n", "\n", - "We can then compute the merged result, grouped by the value of `Material`.\n", - "Note how the initial steps of the computation depend on the `run_id` index name, while later steps depend on `Material`, a new index name defined by the `groupby` operation.\n", - "The files for each run ID have been grouped by their material and then merged." + "The function that merges the raw data needs to merge both data and metadata: " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Define function to merge RawData\n", + "def merge_raw(*das):\n", + " out = {'data': [], 'meta': {}}\n", + " for da in das:\n", + " out['data'].extend(da['data'])\n", + " for k, v in da['meta'].items():\n", + " if k not in out['meta']:\n", + " out['meta'][k] = []\n", + " out['meta'][k].append(v)\n", + " return out" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The grouped graph is now just taking the part that leads to `RawData` (via `base[RawData]`) and dropping everything downstream." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "grouped = (\n", + " base[RawData]\n", + " .map(param_table)\n", + " .groupby('Material')\n", + " .reduce(key=RawData, func=merge_raw, name='merged')\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Visualizing the graph shows two `merged` nodes; one for each material:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "grouped.visualize(sciline.get_mapped_node_names(grouped, 'merged'))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The task is now to reattach this grouped graph to the lower part of our pipeline.\n", + "Since the base graph has a single `RawData`, we first need to map it to the two possible materials that are left after grouping." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "new = base.copy()\n", + "new[RawData] = None\n", + "\n", + "# Get the list of materials left after grouping\n", + "unique_materials = sciline.get_mapped_node_names(grouped, 'merged').index\n", + "\n", + "mapped = new.map(\n", + " pd.DataFrame(\n", + " {RawData: [None] * len(unique_materials), 'Material': unique_materials}\n", + " ).set_index('Material')\n", + ")\n", + "\n", + "mapped.visualize(sciline.get_mapped_node_names(mapped, Result))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can now attach the top part of the graph to the bottom one:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "mapped[RawData] = grouped['merged']" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "mapped.visualize(sciline.get_mapped_node_names(mapped, Result))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sciline.compute_mapped(mapped, Result)" ] }, { @@ -519,7 +676,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.12" + "version": "3.12.7" } }, "nbformat": 4, diff --git a/src/sciline/data_graph.py b/src/sciline/data_graph.py index 401a07ce..cf2cb78c 100644 --- a/src/sciline/data_graph.py +++ b/src/sciline/data_graph.py @@ -4,12 +4,13 @@ from __future__ import annotations import itertools -from collections.abc import Callable, Generator, Iterable, Mapping +from collections.abc import Callable, Generator, Hashable, Iterable, Mapping from types import NoneType from typing import TYPE_CHECKING, Any, TypeVar, get_args import cyclebane as cb import networkx as nx +from cyclebane.graph import GroupbyGraph as CbGroupbyGraph from cyclebane.node_values import IndexName, IndexValue from ._provider import ArgSpec, Provider, ToProvider, _bind_free_typevars @@ -207,6 +208,25 @@ def map(self: T, node_values: dict[Key, Any]) -> T: """ return self._from_cyclebane(self._cbgraph.map(node_values)) + def groupby(self: T, node: Key) -> T: + """Group the graph by a specific node. + + Parameters + ---------- + node: + Node key to group by. + + Returns + ------- + : + A new graph that groups mapped nodes by the given key. This graph is not + meant to be executed directly, but to be further processed via + :meth:`GroupbyGraph.reduce`. + """ + return GroupbyGraph( + graph=self._cbgraph.groupby(node), graph_maker=self._from_cyclebane + ) + def reduce(self: T, *, func: Callable[..., Any], **kwargs: Any) -> T: """Reduce the outputs of a mapped graph into a single value and provider. @@ -250,6 +270,65 @@ def visualize_data_graph(self, **kwargs: Any) -> graphviz.Digraph: return dot +class GroupbyGraph: + """ + A graph that has been grouped by a specific index. + This is a specialized graph that is used to represent the result of a groupby + operation. It allows for operations on the grouped data, + such as aggregation or summarization. + """ + + def __init__( + self, graph: CbGroupbyGraph, graph_maker: Callable[..., DataGraph] + ) -> None: + self._cbgraph = graph + # We forward the constructor so this can be used by both DataGraph and Pipeline + self._graph_maker = graph_maker + + def reduce( + self, + *, + func: Callable[..., Any], + key: None | Hashable = None, + name: None | Hashable = None, + attrs: None | dict[str, Any] = None, + ) -> DataGraph: + """Reduce the grouped node in the graph group by group, so that it results in a + single value and provider per group. + + Parameters + ---------- + func: + Function that takes the values to reduce and returns a single value. + key: + The name of the source node to reduce. This is the original name prior to + mapping. If not given, tries to find a unique sink node. + See :meth:`cyclebane.Graph.reduce`. + name: + The name of the new node. If not given, a unique name is generated. + See :meth:`cyclebane.Graph.reduce`. + attrs: + Attributes to set on the new node(s). See :meth:`cyclebane.Graph.reduce`. + + Returns + ------- + : + A new :class:`DataGraph` with reduced grouped nodes. + """ + cbattrs = {'reduce': func} + if attrs is not None: + if "func" in attrs: + raise ValueError( + "The 'func' attribute cannot be set via 'attrs'. " + "Use the 'func' argument instead." + ) + cbattrs.update(attrs) + + return self._graph_maker( + self._cbgraph.reduce(attrs=cbattrs, key=key, name=name) + ) + + _no_value = object() diff --git a/tests/groupby_test.py b/tests/groupby_test.py new file mode 100644 index 00000000..fa86f017 --- /dev/null +++ b/tests/groupby_test.py @@ -0,0 +1,146 @@ +from typing import NewType + +import numpy as np +import pandas as pd + +import sciline as sl + +_fake_filesytem = { + 'file102.txt': [1, 2, float('nan'), 3], + 'file103.txt': [4, 5, 6, 7], + 'file104.txt': [8, 9, 10, 11, 12], + 'file105.txt': [13, 14, 15], +} + +# 1. Define domain types + +Filename = NewType('Filename', str) +RawData = NewType('RawData', dict) +CleanedData = NewType('CleanedData', list) +ScaleFactor = NewType('ScaleFactor', float) +Result = NewType('Result', float) +Material = NewType('Material', str) + + +# 2. Define providers + + +def load(filename: Filename) -> RawData: + """Load the data from the filename.""" + + data = _fake_filesytem[filename] + return RawData({'data': data, 'meta': {'filename': filename}}) + + +def clean(raw_data: RawData) -> CleanedData: + """Clean the data, removing NaNs.""" + import math + + return CleanedData([x for x in raw_data['data'] if not math.isnan(x)]) + + +def process(data: CleanedData, param: ScaleFactor) -> Result: + """Process the data, multiplying the sum by the scale factor.""" + return Result(sum(data) * param) + + +def merge(*data): + return sum(data) + + +def test_groupby_material_at_result(): + # Create pipeline + providers = [load, clean, process] + params = {ScaleFactor: 2.0} + base = sl.Pipeline(providers, params=params) + + # Make parameter table + run_ids = [102, 103, 104, 105] + sample = ['diamond', 'graphite', 'graphite', 'graphite'] + filenames = [f'file{i}.txt' for i in run_ids] + param_table = pd.DataFrame( + {Filename: filenames, Material: sample}, index=run_ids + ).rename_axis(index='run_id') + + # Group by Material and merge Result + grouped = ( + base.map(param_table) + .groupby(Material) + .reduce(key=Result, func=merge, name="merged") + ) + + result = sl.compute_mapped(grouped, "merged") + assert result['diamond'] == 12.0 + assert result['graphite'] == 228.0 + + +def test_groupby_material_at_rawdata(): + # Create pipeline + providers = [load, clean, process] + params = {ScaleFactor: 2.0} + base = sl.Pipeline(providers, params=params) + + # Make parameter table + run_ids = [102, 103, 104, 105] + sample = ['diamond', 'graphite', 'graphite', 'graphite'] + filenames = [f'file{i}.txt' for i in run_ids] + param_table = pd.DataFrame( + {Filename: filenames, 'Material': sample}, index=run_ids + ).rename_axis(index='run_id') + + # Define function to merge RawData + def merge_raw(*das): + out = {"data": [], "meta": {}} + for da in das: + out["data"].extend(da["data"]) + for k, v in da["meta"].items(): + if k not in out["meta"]: + out["meta"][k] = [] + out["meta"][k].append(v) + return out + + # Group by Material and merge RawData + grouped = ( + base[RawData] + .map(param_table) + .groupby('Material') + .reduce(key=RawData, func=merge_raw, name='merged') + ) + + # Join back to base pipeline + new = base.copy() + new[RawData] = None + + mapped = new.map( + # Need dummy RawData column to allow re-attaching + pd.DataFrame({RawData: [1, 2], 'Material': ['diamond', 'graphite']}).set_index( + 'Material' + ) + ) + + # Attach the grouped merged data to the lower part of the pipeline + mapped[RawData] = grouped['merged'] + + clean_data = sl.compute_mapped(mapped, CleanedData) + assert np.array_equal(clean_data['diamond'], [1, 2, 3]) + assert np.array_equal( + clean_data['graphite'], [4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] + ) + + result = sl.compute_mapped(mapped, Result) + assert result['diamond'] == 12.0 + assert result['graphite'] == 228.0 + + raw_data = sl.compute_mapped(mapped, RawData, index_names=['Material']) + assert np.array_equal( + raw_data['diamond']['data'], [1, 2, float('nan'), 3], equal_nan=True + ) + assert np.array_equal( + raw_data['graphite']['data'], + [4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], + equal_nan=True, + ) + assert raw_data['diamond']['meta'] == {'filename': ['file102.txt']} + assert raw_data['graphite']['meta'] == { + 'filename': ['file103.txt', 'file104.txt', 'file105.txt'] + }