Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
185 changes: 171 additions & 14 deletions docs/user-guide/parameter-tables.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,11 @@
"metadata": {},
"outputs": [],
"source": [
"graph = pipeline.reduce(func=lambda *result: sum(result), name='merged').get('merged')\n",
"def merge(*data):\n",
" return sum(data)\n",
"\n",
"\n",
"graph = pipeline.reduce(func=merge, name='merged').get('merged')\n",
"graph.visualize()"
]
},
Expand Down Expand Up @@ -294,8 +298,6 @@
"source": [
"## Grouping intermediate results based on secondary parameters\n",
"\n",
"**Cyclebane and Sciline do not support `groupby` yet, this is work in progress so this example is not functional yet.**\n",
"\n",
"This chapter illustrates how to implement *groupby* operations with Sciline.\n",
"\n",
"Continuing from the examples for *map* and *reduce*, we can introduce a secondary parameter in the table, such as the material of the sample:"
Expand All @@ -307,13 +309,11 @@
"metadata": {},
"outputs": [],
"source": [
"Material = NewType('Material', str)\n",
"\n",
"run_ids = [102, 103, 104, 105]\n",
"sample = ['diamond', 'graphite', 'graphite', 'graphite']\n",
"filenames = [f'file{i}.txt' for i in run_ids]\n",
"param_table = pd.DataFrame(\n",
" {Filename: filenames, Material: sample}, index=run_ids\n",
" {Filename: filenames, 'Material': sample}, index=run_ids\n",
").rename_axis(index='run_id')\n",
"param_table"
]
Expand All @@ -322,15 +322,172 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Future releases of Sciline will support a `groupby` operation, roughly as follows:\n",
"We then group the results by sample material,\n",
"and reduce the data using the same merge function as before.\n",
"\n",
"```python\n",
"pipeline = base.map(param_table).groupby(Material).reduce(func=merge)\n",
"```\n",
"The end goal being to obtain two end results; one for each material."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"grouped = (\n",
" base.map(param_table)\n",
" .groupby('Material')\n",
" .reduce(key=Result, func=merge, name='merged')\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"graph = grouped.get(sciline.get_mapped_node_names(grouped, 'merged'))\n",
"graph.visualize()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"sciline.compute_mapped(grouped, 'merged')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Grouping early in the graph\n",
"\n",
"Sometimes, it is also desirable to apply grouping earlier in the pipeline graph.\n",
"In this example, we wish to combine the raw data before cleaning and computing the result.\n",
"\n",
"We can then compute the merged result, grouped by the value of `Material`.\n",
"Note how the initial steps of the computation depend on the `run_id` index name, while later steps depend on `Material`, a new index name defined by the `groupby` operation.\n",
"The files for each run ID have been grouped by their material and then merged."
"The function that merges the raw data needs to merge both data and metadata: "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Define function to merge RawData\n",
"def merge_raw(*das):\n",
" out = {'data': [], 'meta': {}}\n",
" for da in das:\n",
" out['data'].extend(da['data'])\n",
" for k, v in da['meta'].items():\n",
" if k not in out['meta']:\n",
" out['meta'][k] = []\n",
" out['meta'][k].append(v)\n",
" return out"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The grouped graph is now just taking the part that leads to `RawData` (via `base[RawData]`) and dropping everything downstream."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"grouped = (\n",
" base[RawData]\n",
" .map(param_table)\n",
" .groupby('Material')\n",
" .reduce(key=RawData, func=merge_raw, name='merged')\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Visualizing the graph shows two `merged` nodes; one for each material:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"grouped.visualize(sciline.get_mapped_node_names(grouped, 'merged'))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The task is now to reattach this grouped graph to the lower part of our pipeline.\n",
"Since the base graph has a single `RawData`, we first need to map it to the two possible materials that are left after grouping."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"new = base.copy()\n",
"new[RawData] = None\n",
"\n",
"# Get the list of materials left after grouping\n",
"unique_materials = sciline.get_mapped_node_names(grouped, 'merged').index\n",
"\n",
"mapped = new.map(\n",
" pd.DataFrame(\n",
" {RawData: [None] * len(unique_materials), 'Material': unique_materials}\n",
" ).set_index('Material')\n",
")\n",
"\n",
"mapped.visualize(sciline.get_mapped_node_names(mapped, Result))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can now attach the top part of the graph to the bottom one:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"mapped[RawData] = grouped['merged']"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"mapped.visualize(sciline.get_mapped_node_names(mapped, Result))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"sciline.compute_mapped(mapped, Result)"
]
},
{
Expand Down Expand Up @@ -519,7 +676,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
"version": "3.12.7"
}
},
"nbformat": 4,
Expand Down
81 changes: 80 additions & 1 deletion src/sciline/data_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@
from __future__ import annotations

import itertools
from collections.abc import Callable, Generator, Iterable, Mapping
from collections.abc import Callable, Generator, Hashable, Iterable, Mapping
from types import NoneType
from typing import TYPE_CHECKING, Any, TypeVar, get_args

import cyclebane as cb
import networkx as nx
from cyclebane.graph import GroupbyGraph as CbGroupbyGraph
from cyclebane.node_values import IndexName, IndexValue

from ._provider import ArgSpec, Provider, ToProvider, _bind_free_typevars
Expand Down Expand Up @@ -207,6 +208,25 @@ def map(self: T, node_values: dict[Key, Any]) -> T:
"""
return self._from_cyclebane(self._cbgraph.map(node_values))

def groupby(self: T, node: Key) -> T:
"""Group the graph by a specific node.

Parameters
----------
node:
Node key to group by.

Returns
-------
:
A new graph that groups mapped nodes by the given key. This graph is not
meant to be executed directly, but to be further processed via
:meth:`GroupbyGraph.reduce`.
"""
return GroupbyGraph(
graph=self._cbgraph.groupby(node), graph_maker=self._from_cyclebane
)

def reduce(self: T, *, func: Callable[..., Any], **kwargs: Any) -> T:
"""Reduce the outputs of a mapped graph into a single value and provider.

Expand Down Expand Up @@ -250,6 +270,65 @@ def visualize_data_graph(self, **kwargs: Any) -> graphviz.Digraph:
return dot


class GroupbyGraph:
"""
A graph that has been grouped by a specific index.
This is a specialized graph that is used to represent the result of a groupby
operation. It allows for operations on the grouped data,
such as aggregation or summarization.
"""

def __init__(
self, graph: CbGroupbyGraph, graph_maker: Callable[..., DataGraph]
) -> None:
self._cbgraph = graph
# We forward the constructor so this can be used by both DataGraph and Pipeline
self._graph_maker = graph_maker

def reduce(
self,
*,
func: Callable[..., Any],
key: None | Hashable = None,
name: None | Hashable = None,
attrs: None | dict[str, Any] = None,
) -> DataGraph:
"""Reduce the grouped node in the graph group by group, so that it results in a
single value and provider per group.

Parameters
----------
func:
Function that takes the values to reduce and returns a single value.
key:
The name of the source node to reduce. This is the original name prior to
mapping. If not given, tries to find a unique sink node.
See :meth:`cyclebane.Graph.reduce`.
name:
The name of the new node. If not given, a unique name is generated.
See :meth:`cyclebane.Graph.reduce`.
attrs:
Attributes to set on the new node(s). See :meth:`cyclebane.Graph.reduce`.

Returns
-------
:
A new :class:`DataGraph` with reduced grouped nodes.
"""
cbattrs = {'reduce': func}
if attrs is not None:
if "func" in attrs:
raise ValueError(
"The 'func' attribute cannot be set via 'attrs'. "
"Use the 'func' argument instead."
)
cbattrs.update(attrs)

return self._graph_maker(
self._cbgraph.reduce(attrs=cbattrs, key=key, name=name)
)


_no_value = object()


Expand Down
Loading
Loading