Skip to content

Commit 91159d7

Browse files
authored
chore: move representation utils to workflows core (#224)
* chore: move representation utils around * chore: changesets * chore: use fixture in tests
1 parent 300fd05 commit 91159d7

File tree

5 files changed

+152
-183
lines changed

5 files changed

+152
-183
lines changed

.changeset/fine-deer-flow.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
---
2+
"llama-index-utils-workflow": patch
3+
"llama-index-workflows": patch
4+
---
5+
6+
Moving `_extract_workflow_structure` to its own module in workflow core

packages/llama-index-utils-workflow/src/llama_index/utils/workflow/__init__.py

Lines changed: 9 additions & 179 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44
from __future__ import annotations
55

6-
from dataclasses import dataclass
76
from typing import Any, Dict, List, Tuple, Union, cast
87

98
from llama_index.core.agent.workflow import (
@@ -15,194 +14,25 @@
1514
from llama_index.core.tools import AsyncBaseTool, BaseTool
1615
from pyvis.network import Network
1716
from workflows import Workflow
18-
from workflows.decorators import StepConfig
1917
from workflows.events import (
2018
Event,
21-
HumanResponseEvent,
22-
InputRequiredEvent,
2319
StartEvent,
2420
StopEvent,
2521
)
2622
from workflows.handler import WorkflowHandler
23+
from workflows.representation_utils import (
24+
DrawWorkflowEdge,
25+
DrawWorkflowGraph,
26+
DrawWorkflowNode,
27+
_truncate_label,
28+
)
29+
from workflows.representation_utils import (
30+
extract_workflow_structure as _extract_workflow_structure,
31+
)
2732
from workflows.runtime.types.results import AddCollectedEvent, StepWorkerResult
2833
from workflows.runtime.types.ticks import TickAddEvent, TickStepResult, WorkflowTick
2934

3035

31-
@dataclass
32-
class DrawWorkflowNode:
33-
"""Represents a node in the workflow graph."""
34-
35-
id: str
36-
label: str
37-
node_type: str # 'step', 'event', 'external'
38-
title: str | None = None
39-
event_type: type | None = None # Store the actual event type for styling decisions
40-
41-
42-
@dataclass
43-
class DrawWorkflowEdge:
44-
"""Represents an edge in the workflow graph."""
45-
46-
source: str
47-
target: str
48-
49-
50-
@dataclass
51-
class DrawWorkflowGraph:
52-
"""Intermediate representation of workflow structure."""
53-
54-
nodes: List[DrawWorkflowNode]
55-
edges: List[DrawWorkflowEdge]
56-
57-
58-
def _truncate_label(label: str, max_length: int) -> str:
59-
"""Helper to truncate long labels."""
60-
return label if len(label) <= max_length else f"{label[: max_length - 1]}*"
61-
62-
63-
def _extract_workflow_structure(
64-
workflow: Workflow, max_label_length: int | None = None
65-
) -> DrawWorkflowGraph:
66-
"""Extract workflow structure into an intermediate representation."""
67-
# Get workflow steps
68-
steps = workflow._get_steps()
69-
nodes = []
70-
edges = []
71-
added_nodes = set() # Track added node IDs to avoid duplicates
72-
73-
step_config: StepConfig | None = None
74-
75-
# Only one kind of `StopEvent` is allowed in a `Workflow`.
76-
# Assuming that `Workflow` is validated before drawing, it's enough to find the first one.
77-
current_stop_event = None
78-
for step_name, step_func in steps.items():
79-
step_config = step_func._step_config
80-
81-
for return_type in step_config.return_types:
82-
if issubclass(return_type, StopEvent):
83-
current_stop_event = return_type
84-
break
85-
86-
if current_stop_event:
87-
break
88-
89-
# First pass: Add all nodes
90-
for step_name, step_func in steps.items():
91-
step_config = step_func._step_config
92-
# Add step node
93-
step_label = (
94-
_truncate_label(step_name, max_label_length)
95-
if max_label_length
96-
else step_name
97-
)
98-
step_title = (
99-
step_name
100-
if max_label_length and len(step_name) > max_label_length
101-
else None
102-
)
103-
104-
if step_name not in added_nodes:
105-
nodes.append(
106-
DrawWorkflowNode(
107-
id=step_name,
108-
label=step_label,
109-
node_type="step",
110-
title=step_title,
111-
)
112-
)
113-
added_nodes.add(step_name)
114-
115-
# Add event nodes for accepted events
116-
for event_type in step_config.accepted_events:
117-
if event_type == StopEvent and event_type != current_stop_event:
118-
continue
119-
120-
event_label = (
121-
_truncate_label(event_type.__name__, max_label_length)
122-
if max_label_length
123-
else event_type.__name__
124-
)
125-
event_title = (
126-
event_type.__name__
127-
if max_label_length and len(event_type.__name__) > max_label_length
128-
else None
129-
)
130-
131-
if event_type.__name__ not in added_nodes:
132-
nodes.append(
133-
DrawWorkflowNode(
134-
id=event_type.__name__,
135-
label=event_label,
136-
node_type="event",
137-
title=event_title,
138-
event_type=event_type,
139-
)
140-
)
141-
added_nodes.add(event_type.__name__)
142-
143-
# Add event nodes for return types
144-
for return_type in step_config.return_types:
145-
if return_type is type(None):
146-
continue
147-
148-
return_label = (
149-
_truncate_label(return_type.__name__, max_label_length)
150-
if max_label_length
151-
else return_type.__name__
152-
)
153-
return_title = (
154-
return_type.__name__
155-
if max_label_length and len(return_type.__name__) > max_label_length
156-
else None
157-
)
158-
159-
if return_type.__name__ not in added_nodes:
160-
nodes.append(
161-
DrawWorkflowNode(
162-
id=return_type.__name__,
163-
label=return_label,
164-
node_type="event",
165-
title=return_title,
166-
event_type=return_type,
167-
)
168-
)
169-
added_nodes.add(return_type.__name__)
170-
171-
# Add external_step node when InputRequiredEvent is found
172-
if (
173-
issubclass(return_type, InputRequiredEvent)
174-
and "external_step" not in added_nodes
175-
):
176-
nodes.append(
177-
DrawWorkflowNode(
178-
id="external_step",
179-
label="external_step",
180-
node_type="external",
181-
)
182-
)
183-
added_nodes.add("external_step")
184-
185-
# Second pass: Add edges
186-
for step_name, step_func in steps.items():
187-
step_config = step_func._step_config
188-
# Edges from steps to return types
189-
for return_type in step_config.return_types:
190-
if return_type is not type(None):
191-
edges.append(DrawWorkflowEdge(step_name, return_type.__name__))
192-
193-
if issubclass(return_type, InputRequiredEvent):
194-
edges.append(DrawWorkflowEdge(return_type.__name__, "external_step"))
195-
196-
# Edges from events to steps
197-
for event_type in step_config.accepted_events:
198-
edges.append(DrawWorkflowEdge(event_type.__name__, step_name))
199-
200-
if issubclass(event_type, HumanResponseEvent):
201-
edges.append(DrawWorkflowEdge("external_step", event_type.__name__))
202-
203-
return DrawWorkflowGraph(nodes=nodes, edges=edges)
204-
205-
20636
def _get_node_color(node: DrawWorkflowNode) -> str:
20737
"""Determine color for a node based on its type and event_type."""
20838
if node.node_type == "step":

packages/llama-index-workflows/src/workflows/server/representation_utils.py renamed to packages/llama-index-workflows/src/workflows/representation_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def _truncate_label(label: str, max_length: int) -> str:
7474
return label if len(label) <= max_length else f"{label[: max_length - 1]}*"
7575

7676

77-
def _extract_workflow_structure(
77+
def extract_workflow_structure(
7878
workflow: Workflow, max_label_length: Optional[int] = None
7979
) -> DrawWorkflowGraph:
8080
"""Extract workflow structure into an intermediate representation."""

packages/llama-index-workflows/src/workflows/server/server.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
EventEnvelopeWithMetadata,
5151
EventValidationError,
5252
)
53+
from workflows.representation_utils import extract_workflow_structure
5354
from workflows.server.abstract_workflow_store import (
5455
AbstractWorkflowStore,
5556
HandlerQuery,
@@ -62,8 +63,6 @@
6263
# Protocol models are used on the client side; server responds with plain dicts
6364
from workflows.utils import _nanoid as nanoid
6465

65-
from .representation_utils import _extract_workflow_structure
66-
6766
logger = logging.getLogger()
6867

6968

@@ -623,7 +622,7 @@ async def _get_workflow_representation(self, request: Request) -> JSONResponse:
623622
"""
624623
workflow = self._extract_workflow(request)
625624
try:
626-
workflow_graph = _extract_workflow_structure(workflow.workflow)
625+
workflow_graph = extract_workflow_structure(workflow.workflow)
627626
except Exception as e:
628627
raise HTTPException(
629628
detail=f"Error while getting JSON workflow representation: {e}",
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
import pytest
2+
from workflows.events import StartEvent, StopEvent
3+
from workflows.representation_utils import (
4+
DrawWorkflowEdge,
5+
DrawWorkflowGraph,
6+
DrawWorkflowNode,
7+
extract_workflow_structure,
8+
)
9+
10+
from .conftest import DummyWorkflow, LastEvent, OneTestEvent # type: ignore[import]
11+
12+
13+
@pytest.fixture()
14+
def ground_truth_repr() -> DrawWorkflowGraph:
15+
return DrawWorkflowGraph(
16+
nodes=[
17+
DrawWorkflowNode(
18+
id="end_step",
19+
label="end_step",
20+
node_type="step",
21+
title=None,
22+
event_type=None,
23+
),
24+
DrawWorkflowNode(
25+
id="LastEvent",
26+
label="LastEvent",
27+
node_type="event",
28+
title=None,
29+
event_type=LastEvent,
30+
),
31+
DrawWorkflowNode(
32+
id="StopEvent",
33+
label="StopEvent",
34+
node_type="event",
35+
title=None,
36+
event_type=StopEvent,
37+
),
38+
DrawWorkflowNode(
39+
id="middle_step",
40+
label="middle_step",
41+
node_type="step",
42+
title=None,
43+
event_type=None,
44+
),
45+
DrawWorkflowNode(
46+
id="OneTestEvent",
47+
label="OneTestEvent",
48+
node_type="event",
49+
title=None,
50+
event_type=OneTestEvent,
51+
),
52+
DrawWorkflowNode(
53+
id="start_step",
54+
label="start_step",
55+
node_type="step",
56+
title=None,
57+
event_type=None,
58+
),
59+
DrawWorkflowNode(
60+
id="StartEvent",
61+
label="StartEvent",
62+
node_type="event",
63+
title=None,
64+
event_type=StartEvent,
65+
),
66+
],
67+
edges=[
68+
DrawWorkflowEdge(source="end_step", target="StopEvent"),
69+
DrawWorkflowEdge(source="LastEvent", target="end_step"),
70+
DrawWorkflowEdge(source="middle_step", target="LastEvent"),
71+
DrawWorkflowEdge(source="OneTestEvent", target="middle_step"),
72+
DrawWorkflowEdge(source="start_step", target="OneTestEvent"),
73+
DrawWorkflowEdge(source="StartEvent", target="start_step"),
74+
],
75+
)
76+
77+
78+
def test_extract_workflow_structure(ground_truth_repr: DrawWorkflowGraph) -> None:
79+
wf = DummyWorkflow()
80+
graph = extract_workflow_structure(workflow=wf)
81+
assert isinstance(graph, DrawWorkflowGraph)
82+
assert sorted(
83+
[node.id for node in ground_truth_repr.nodes if node.node_type == "step"]
84+
) == sorted([node.id for node in graph.nodes if node.node_type == "step"])
85+
assert sorted(
86+
[node.id for node in ground_truth_repr.nodes if node.node_type == "event"]
87+
) == sorted([node.id for node in graph.nodes if node.node_type == "event"])
88+
expected_edges = ground_truth_repr.edges
89+
for edge in expected_edges:
90+
assert edge in graph.edges
91+
92+
93+
def test_extract_workflow_structure_trim_label() -> None:
94+
wf = DummyWorkflow()
95+
graph = extract_workflow_structure(workflow=wf, max_label_length=2)
96+
assert sorted(["e*", "m*", "s*"]) == sorted(
97+
[node.label for node in graph.nodes if node.node_type == "step"]
98+
)
99+
assert sorted(["S*", "S*", "O*", "L*"]) == sorted(
100+
[node.label for node in graph.nodes if node.node_type == "event"]
101+
)
102+
103+
104+
def test_graph_to_response_model() -> None:
105+
graph = DrawWorkflowGraph(
106+
nodes=[
107+
DrawWorkflowNode(
108+
id="test", label="test", node_type="step", title=None, event_type=None
109+
),
110+
DrawWorkflowNode(
111+
id="OneTestEvent",
112+
label="OneTestEvent",
113+
node_type="event",
114+
title=None,
115+
event_type=OneTestEvent,
116+
),
117+
],
118+
edges=[DrawWorkflowEdge(source="test", target="OneTestEvent")],
119+
)
120+
res = graph.to_response_model()
121+
assert len(res.nodes) == 2
122+
assert res.nodes[0].event_type is None
123+
assert res.nodes[0].title is None
124+
assert res.nodes[0].node_type == "step"
125+
assert res.nodes[0].label == "test"
126+
assert res.nodes[0].id == "test"
127+
assert res.nodes[1].event_type == OneTestEvent.__name__
128+
assert res.nodes[1].title is None
129+
assert res.nodes[1].node_type == "event"
130+
assert res.nodes[1].label == "OneTestEvent"
131+
assert res.nodes[1].id == "OneTestEvent"
132+
assert len(res.edges) == 1
133+
assert res.edges[0].source == "test"
134+
assert res.edges[0].target == "OneTestEvent"

0 commit comments

Comments
 (0)