|
8 | 8 | import ray |
9 | 9 | import ray.data |
10 | 10 | from ray.data import DataContext |
| 11 | +from ray.data.block import Block |
| 12 | +from ray.data.datasource.filename_provider import FilenameProvider |
11 | 13 |
|
12 | 14 | from graphgen.bases import Config, Node |
13 | 15 | from graphgen.common import init_llm, init_storage |
14 | 16 | from graphgen.utils import logger |
15 | 17 |
|
16 | 18 |
|
| 19 | +class NodeFilenameProvider(FilenameProvider): |
| 20 | + def __init__(self, node_id: str): |
| 21 | + self.node_id = node_id |
| 22 | + |
| 23 | + def get_filename_for_block( |
| 24 | + self, block: Block, write_uuid: str, task_index: int, block_index: int |
| 25 | + ) -> str: |
| 26 | + # format: {node_id}_{write_uuid}_{task_index:06}_{block_index:06}.jsonl |
| 27 | + return f"{self.node_id}_{write_uuid}_{task_index:06d}_{block_index:06d}.jsonl" |
| 28 | + |
| 29 | + def get_filename_for_row( |
| 30 | + self, |
| 31 | + row: Dict[str, Any], |
| 32 | + write_uuid: str, |
| 33 | + task_index: int, |
| 34 | + block_index: int, |
| 35 | + row_index: int, |
| 36 | + ) -> str: |
| 37 | + raise NotImplementedError( |
| 38 | + f"Row-based filenames are not supported by write_json. " |
| 39 | + f"Node: {self.node_id}, write_uuid: {write_uuid}" |
| 40 | + ) |
| 41 | + |
| 42 | + |
17 | 43 | class Engine: |
18 | 44 | def __init__( |
19 | 45 | self, config: Dict[str, Any], functions: Dict[str, Callable], **ray_init_kwargs |
@@ -263,13 +289,32 @@ def func_wrapper(row_or_batch: Dict[str, Any]) -> Dict[str, Any]: |
263 | 289 | f"Unsupported node type {node.type} for node {node.id}" |
264 | 290 | ) |
265 | 291 |
|
266 | | - def execute(self, initial_ds: ray.data.Dataset) -> Dict[str, ray.data.Dataset]: |
| 292 | + def execute( |
| 293 | + self, initial_ds: ray.data.Dataset, output_dir: str |
| 294 | + ) -> Dict[str, ray.data.Dataset]: |
267 | 295 | sorted_nodes = self._topo_sort(self.config.nodes) |
268 | 296 |
|
269 | 297 | for node in sorted_nodes: |
| 298 | + logger.info("Executing node %s of type %s", node.id, node.type) |
270 | 299 | self._execute_node(node, initial_ds) |
271 | 300 | if getattr(node, "save_output", False): |
272 | | - self.datasets[node.id] = self.datasets[node.id].materialize() |
| 301 | + node_output_path = os.path.join(output_dir, f"{node.id}") |
| 302 | + os.makedirs(node_output_path, exist_ok=True) |
| 303 | + logger.info("Saving output of node %s to %s", node.id, node_output_path) |
| 304 | + |
| 305 | + ds = self.datasets[node.id] |
| 306 | + ds.write_json( |
| 307 | + node_output_path, |
| 308 | + filename_provider=NodeFilenameProvider(node.id), |
| 309 | + pandas_json_args_fn=lambda: { |
| 310 | + "orient": "records", |
| 311 | + "lines": True, |
| 312 | + "force_ascii": False, |
| 313 | + }, |
| 314 | + ) |
| 315 | + logger.info("Node %s output saved to %s", node.id, node_output_path) |
| 316 | + |
| 317 | + # ray will lazy read the dataset |
| 318 | + self.datasets[node.id] = ray.data.read_json(node_output_path) |
273 | 319 |
|
274 | | - output_nodes = [n for n in sorted_nodes if getattr(n, "save_output", False)] |
275 | | - return {node.id: self.datasets[node.id] for node in output_nodes} |
| 320 | + return self.datasets |
0 commit comments