|
3 | 3 |
|
4 | 4 | from __future__ import annotations |
5 | 5 |
|
6 | | -from dataclasses import dataclass |
7 | 6 | from typing import Any, Dict, List, Tuple, Union, cast |
8 | 7 |
|
9 | 8 | from llama_index.core.agent.workflow import ( |
|
15 | 14 | from llama_index.core.tools import AsyncBaseTool, BaseTool |
16 | 15 | from pyvis.network import Network |
17 | 16 | from workflows import Workflow |
18 | | -from workflows.decorators import StepConfig |
19 | 17 | from workflows.events import ( |
20 | 18 | Event, |
21 | | - HumanResponseEvent, |
22 | | - InputRequiredEvent, |
23 | 19 | StartEvent, |
24 | 20 | StopEvent, |
25 | 21 | ) |
26 | 22 | from workflows.handler import WorkflowHandler |
| 23 | +from workflows.representation_utils import ( |
| 24 | + DrawWorkflowEdge, |
| 25 | + DrawWorkflowGraph, |
| 26 | + DrawWorkflowNode, |
| 27 | + _truncate_label, |
| 28 | +) |
| 29 | +from workflows.representation_utils import ( |
| 30 | + extract_workflow_structure as _extract_workflow_structure, |
| 31 | +) |
27 | 32 | from workflows.runtime.types.results import AddCollectedEvent, StepWorkerResult |
28 | 33 | from workflows.runtime.types.ticks import TickAddEvent, TickStepResult, WorkflowTick |
29 | 34 |
|
30 | 35 |
|
31 | | -@dataclass |
32 | | -class DrawWorkflowNode: |
33 | | - """Represents a node in the workflow graph.""" |
34 | | - |
35 | | - id: str |
36 | | - label: str |
37 | | - node_type: str # 'step', 'event', 'external' |
38 | | - title: str | None = None |
39 | | - event_type: type | None = None # Store the actual event type for styling decisions |
40 | | - |
41 | | - |
42 | | -@dataclass |
43 | | -class DrawWorkflowEdge: |
44 | | - """Represents an edge in the workflow graph.""" |
45 | | - |
46 | | - source: str |
47 | | - target: str |
48 | | - |
49 | | - |
50 | | -@dataclass |
51 | | -class DrawWorkflowGraph: |
52 | | - """Intermediate representation of workflow structure.""" |
53 | | - |
54 | | - nodes: List[DrawWorkflowNode] |
55 | | - edges: List[DrawWorkflowEdge] |
56 | | - |
57 | | - |
58 | | -def _truncate_label(label: str, max_length: int) -> str: |
59 | | - """Helper to truncate long labels.""" |
60 | | - return label if len(label) <= max_length else f"{label[: max_length - 1]}*" |
61 | | - |
62 | | - |
63 | | -def _extract_workflow_structure( |
64 | | - workflow: Workflow, max_label_length: int | None = None |
65 | | -) -> DrawWorkflowGraph: |
66 | | - """Extract workflow structure into an intermediate representation.""" |
67 | | - # Get workflow steps |
68 | | - steps = workflow._get_steps() |
69 | | - nodes = [] |
70 | | - edges = [] |
71 | | - added_nodes = set() # Track added node IDs to avoid duplicates |
72 | | - |
73 | | - step_config: StepConfig | None = None |
74 | | - |
75 | | - # Only one kind of `StopEvent` is allowed in a `Workflow`. |
76 | | - # Assuming that `Workflow` is validated before drawing, it's enough to find the first one. |
77 | | - current_stop_event = None |
78 | | - for step_name, step_func in steps.items(): |
79 | | - step_config = step_func._step_config |
80 | | - |
81 | | - for return_type in step_config.return_types: |
82 | | - if issubclass(return_type, StopEvent): |
83 | | - current_stop_event = return_type |
84 | | - break |
85 | | - |
86 | | - if current_stop_event: |
87 | | - break |
88 | | - |
89 | | - # First pass: Add all nodes |
90 | | - for step_name, step_func in steps.items(): |
91 | | - step_config = step_func._step_config |
92 | | - # Add step node |
93 | | - step_label = ( |
94 | | - _truncate_label(step_name, max_label_length) |
95 | | - if max_label_length |
96 | | - else step_name |
97 | | - ) |
98 | | - step_title = ( |
99 | | - step_name |
100 | | - if max_label_length and len(step_name) > max_label_length |
101 | | - else None |
102 | | - ) |
103 | | - |
104 | | - if step_name not in added_nodes: |
105 | | - nodes.append( |
106 | | - DrawWorkflowNode( |
107 | | - id=step_name, |
108 | | - label=step_label, |
109 | | - node_type="step", |
110 | | - title=step_title, |
111 | | - ) |
112 | | - ) |
113 | | - added_nodes.add(step_name) |
114 | | - |
115 | | - # Add event nodes for accepted events |
116 | | - for event_type in step_config.accepted_events: |
117 | | - if event_type == StopEvent and event_type != current_stop_event: |
118 | | - continue |
119 | | - |
120 | | - event_label = ( |
121 | | - _truncate_label(event_type.__name__, max_label_length) |
122 | | - if max_label_length |
123 | | - else event_type.__name__ |
124 | | - ) |
125 | | - event_title = ( |
126 | | - event_type.__name__ |
127 | | - if max_label_length and len(event_type.__name__) > max_label_length |
128 | | - else None |
129 | | - ) |
130 | | - |
131 | | - if event_type.__name__ not in added_nodes: |
132 | | - nodes.append( |
133 | | - DrawWorkflowNode( |
134 | | - id=event_type.__name__, |
135 | | - label=event_label, |
136 | | - node_type="event", |
137 | | - title=event_title, |
138 | | - event_type=event_type, |
139 | | - ) |
140 | | - ) |
141 | | - added_nodes.add(event_type.__name__) |
142 | | - |
143 | | - # Add event nodes for return types |
144 | | - for return_type in step_config.return_types: |
145 | | - if return_type is type(None): |
146 | | - continue |
147 | | - |
148 | | - return_label = ( |
149 | | - _truncate_label(return_type.__name__, max_label_length) |
150 | | - if max_label_length |
151 | | - else return_type.__name__ |
152 | | - ) |
153 | | - return_title = ( |
154 | | - return_type.__name__ |
155 | | - if max_label_length and len(return_type.__name__) > max_label_length |
156 | | - else None |
157 | | - ) |
158 | | - |
159 | | - if return_type.__name__ not in added_nodes: |
160 | | - nodes.append( |
161 | | - DrawWorkflowNode( |
162 | | - id=return_type.__name__, |
163 | | - label=return_label, |
164 | | - node_type="event", |
165 | | - title=return_title, |
166 | | - event_type=return_type, |
167 | | - ) |
168 | | - ) |
169 | | - added_nodes.add(return_type.__name__) |
170 | | - |
171 | | - # Add external_step node when InputRequiredEvent is found |
172 | | - if ( |
173 | | - issubclass(return_type, InputRequiredEvent) |
174 | | - and "external_step" not in added_nodes |
175 | | - ): |
176 | | - nodes.append( |
177 | | - DrawWorkflowNode( |
178 | | - id="external_step", |
179 | | - label="external_step", |
180 | | - node_type="external", |
181 | | - ) |
182 | | - ) |
183 | | - added_nodes.add("external_step") |
184 | | - |
185 | | - # Second pass: Add edges |
186 | | - for step_name, step_func in steps.items(): |
187 | | - step_config = step_func._step_config |
188 | | - # Edges from steps to return types |
189 | | - for return_type in step_config.return_types: |
190 | | - if return_type is not type(None): |
191 | | - edges.append(DrawWorkflowEdge(step_name, return_type.__name__)) |
192 | | - |
193 | | - if issubclass(return_type, InputRequiredEvent): |
194 | | - edges.append(DrawWorkflowEdge(return_type.__name__, "external_step")) |
195 | | - |
196 | | - # Edges from events to steps |
197 | | - for event_type in step_config.accepted_events: |
198 | | - edges.append(DrawWorkflowEdge(event_type.__name__, step_name)) |
199 | | - |
200 | | - if issubclass(event_type, HumanResponseEvent): |
201 | | - edges.append(DrawWorkflowEdge("external_step", event_type.__name__)) |
202 | | - |
203 | | - return DrawWorkflowGraph(nodes=nodes, edges=edges) |
204 | | - |
205 | | - |
206 | 36 | def _get_node_color(node: DrawWorkflowNode) -> str: |
207 | 37 | """Determine color for a node based on its type and event_type.""" |
208 | 38 | if node.node_type == "step": |
|
0 commit comments