diff --git a/mplang/v1/core/async_comm.py b/mplang/v1/core/async_comm.py new file mode 100644 index 00000000..1f829107 --- /dev/null +++ b/mplang/v1/core/async_comm.py @@ -0,0 +1,359 @@ +# Copyright 2025 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import asyncio +import logging +from abc import ABC, abstractmethod +from typing import Any + +from mplang.v1.core.comm import ICommunicator +from mplang.v1.core.mask import Mask + + +class IAsyncCommunicator(ICommunicator): + """Base class for asynchronous communicators.""" + + @abstractmethod + async def async_send(self, to: int, key: str, data: Any) -> None: + """Send data to peer with the given key asynchronously""" + + @abstractmethod + async def async_recv(self, frm: int, key: str) -> Any: + """Receive data from peer with the given key asynchronously""" + + +class IAsyncCollective(ABC): + """Interface for asynchronous collective communication""" + + @abstractmethod + async def p2p(self, frm: int, to: int, data: Any) -> Any: + """Perform point-to-point communication""" + + @abstractmethod + async def gather(self, root: int, data: Any) -> list[Any]: + """Gather data from all processes to root""" + + @abstractmethod + async def gather_m(self, pmask: int, root: int, data: Any) -> list[Any]: + """Gather data from parties in pmask to root""" + + @abstractmethod + async def scatter(self, root: int, args: list[Any]) -> Any: + """Scatter data from root to all processes""" + + @abstractmethod + async def scatter_m(self, pmask: int, root: int, args: list[Any]) -> Any: + """Scatter data from root to parties in pmask""" + + @abstractmethod + async def allgather(self, arg: Any) -> list[Any]: + """Gather data from all processes to all processes""" + + @abstractmethod + async def allgather_m(self, pmask: int, arg: Any) -> list[Any]: + """Gather data from parties in pmask to all processes""" + + @abstractmethod + async def bcast(self, root: int, arg: Any) -> Any: + """Broadcast data from root to all processes""" + + @abstractmethod + async def bcast_m(self, pmask: int, root: int, arg: Any) -> Any: + """Broadcast data from root to parties in pmask""" + + +class AsyncCollectiveMixin(IAsyncCommunicator, IAsyncCollective): + """Mixin class providing default implementations of asynchronous collective communication algorithms""" + + # Note: These will be provided by mixing classes as properties + @property + def rank(self) -> int: + raise NotImplementedError + + @property + def world_size(self) -> int: + raise NotImplementedError + + def send(self, to: int, key: str, data: Any) -> None: + raise NotImplementedError + + def recv(self, frm: int, key: str) -> Any: + raise NotImplementedError + + async def async_send(self, to: int, key: str, data: Any) -> None: + raise NotImplementedError + + async def async_recv(self, frm: int, key: str) -> Any: + raise NotImplementedError + + def new_id(self) -> str: + raise NotImplementedError + + async def p2p(self, frm: int, to: int, data: Any) -> Any: + assert 0 <= frm < self.world_size + assert 0 <= to < self.world_size + + cid = self.new_id() + + send_coro = None + if self.rank == frm: + send_coro = self.async_send(to, cid, data) + + recv_coro = None + if self.rank == to: + recv_coro = self.async_recv(frm, cid) + + if send_coro and recv_coro: + _, res = await asyncio.gather(send_coro, recv_coro) + return res + elif send_coro: + await send_coro + return None + elif recv_coro: + return await recv_coro + else: + return None + + async def gather_m(self, pmask: int, root: int, data: Any) -> list[Any]: + assert 0 <= root < self.world_size + cid = self.new_id() + mask = Mask(pmask) + + # 1. Send if we are in mask + if self.rank in mask: + await self.async_send(root, cid, data) + + # 2. Recv if we are root + if self.rank == root: + # Create futures for all expected receives + futures = [] + for idx in mask: + futures.append(self.async_recv(idx, cid)) + + # Wait for all concurrently + results = await asyncio.gather(*futures) + return results + else: + return [None] * mask.num_parties() + + async def gather(self, root: int, data: Any) -> list[Any]: + pmask = Mask.all(self.world_size) + return await self.gather_m(pmask.value, root, data) + + async def scatter_m(self, pmask: int, root: int, args: list[Any]) -> Any: + logging.debug( + f"[{self.rank}]: scatter_m: pmask={pmask}, root={root}, args={args}" + ) + assert 0 <= root < self.world_size + mask = Mask(pmask) + assert len(args) == mask.num_parties(), f"{len(args)} != {mask.num_parties()}" + + cid = self.new_id() + + if self.rank == root: + # Send to all targets concurrently + send_futures = [] + for idx, arg in zip(mask, args, strict=True): + send_futures.append(self.async_send(idx, cid, arg)) + await asyncio.gather(*send_futures) + + if self.rank in mask: + data = await self.async_recv(root, cid) + else: + data = None + + return data + + async def scatter(self, root: int, args: list[Any]) -> Any: + pmask = Mask.all(self.world_size) + return await self.scatter_m(pmask.value, root, args) + + async def allgather_m(self, pmask: int, arg: Any) -> list[Any]: + logging.debug(f"allgather_m: pmask={pmask}, arg={arg}") + cid = self.new_id() + mask = Mask(pmask) + + # 1. Send to all other parties in mask + if self.rank in mask: + send_futures = [] + for idx in mask: + send_futures.append(self.async_send(idx, cid, arg)) + await asyncio.gather(*send_futures) + + # 2. Recv from all parties in mask + recv_futures = [] + for idx in mask: + recv_futures.append(self.async_recv(idx, cid)) + + res = await asyncio.gather(*recv_futures) + return res + else: + return [None] * mask.num_parties() + + async def allgather(self, arg: Any) -> list[Any]: + pmask = Mask.all(self.world_size) + return await self.allgather_m(pmask.value, arg) + + async def bcast_m(self, pmask: int, root: int, arg: Any) -> Any: + logging.debug(f"bcast_m: pmask={pmask}, root={root}, arg={arg}") + assert 0 <= root < self.world_size + mask = Mask(pmask) + cid = self.new_id() + + if self.rank == root: + send_futures = [] + for idx in mask: + send_futures.append(self.async_send(idx, cid, arg)) + await asyncio.gather(*send_futures) + + if self.rank in mask: + return await self.async_recv(root, cid) + else: + return None + + async def bcast(self, root: int, arg: Any) -> Any: + pmask = Mask.all(self.world_size) + return await self.bcast_m(pmask.value, root, arg) + + +class AsyncCommunicatorBase(IAsyncCommunicator): + """Base implementation providing message box functionality for local communication using asyncio""" + + def __init__( + self, rank: int, world_size: int, loop: asyncio.AbstractEventLoop | None = None + ): + self._rank = rank + self._world_size = world_size + # Map (frm, key) -> Future or Data + self._msgboxes: dict[tuple[int, str], Any | asyncio.Future] = {} + self._counter = 0 + self._loop = loop + + @property + def rank(self) -> int: + return self._rank + + @property + def world_size(self) -> int: + return self._world_size + + def _get_loop(self) -> asyncio.AbstractEventLoop: + if self._loop is None: + try: + self._loop = asyncio.get_running_loop() + except RuntimeError as e: + raise RuntimeError( + "AsyncCommunicatorBase must be used within an asyncio event loop or loop must be provided in init" + ) from e + return self._loop + + def new_id(self) -> str: + # Simple counter, assuming single-threaded access to this method within the loop + res = self._counter + self._counter += 1 + return str(res) + + async def async_recv(self, frm: int, key: str) -> Any: + """Wait until the key is set, returns the value""" + mkey = (frm, key) + + # Check if data is already there + if mkey in self._msgboxes: + val = self._msgboxes[mkey] + if isinstance(val, asyncio.Future): + # Already waiting? This shouldn't happen in normal logic unless multiple recvs for same key + return await val + else: + # Data arrived before recv + del self._msgboxes[mkey] + return val + + # Not there, create a future + loop = self._get_loop() + fut = loop.create_future() + self._msgboxes[mkey] = fut + try: + return await fut + finally: + if mkey in self._msgboxes and self._msgboxes[mkey] is fut: + del self._msgboxes[mkey] + + def onSent(self, frm: int, key: str, data: Any) -> None: + """Called when a key is sent to self. + + This method must be thread-safe as it might be called from network threads. + """ + loop = self._get_loop() + # Use call_soon_threadsafe to handle calls from other threads (e.g. network callbacks) + # If called from the same loop, it just schedules it for next iteration. + loop.call_soon_threadsafe(self._on_sent_internal, frm, key, data) + + def _on_sent_internal(self, frm: int, key: str, data: Any) -> None: + mkey = (frm, key) + if mkey in self._msgboxes: + val = self._msgboxes[mkey] + if isinstance(val, asyncio.Future): + if not val.done(): + val.set_result(data) + # Future is done, we can remove it from msgboxes? + # No, recv needs to await it. But recv will remove it after await. + # Wait, if we remove it here, recv might fail if it hasn't awaited yet? + # Actually, once set_result is called, the future holds the value. + # We should remove it from _msgboxes so it doesn't grow forever? + # But recv uses mkey to find the future. + # So we leave it there. recv will remove it. + else: + raise RuntimeError(f"Duplicate message for {mkey}") + else: + self._msgboxes[mkey] = data + + async def async_send(self, to: int, key: str, data: Any) -> None: + # Base implementation for local simulation: directly call peer's onSent + # In a real distributed setting, this would put data on wire. + raise NotImplementedError( + "Must be implemented by subclass or mixin with peer awareness" + ) + + def send(self, to: int, key: str, data: Any) -> None: + raise NotImplementedError( + "Synchronous send not supported in AsyncCommunicatorBase" + ) + + def recv(self, frm: int, key: str) -> Any: + raise NotImplementedError( + "Synchronous recv not supported in AsyncCommunicatorBase" + ) + + +class AsyncThreadCommunicator(AsyncCommunicatorBase, AsyncCollectiveMixin): + """Thread-based async communicator for in-memory communication (simulation)""" + + def __init__( + self, rank: int, world_size: int, loop: asyncio.AbstractEventLoop | None = None + ): + super().__init__(rank, world_size, loop) + self.peers: list[AsyncThreadCommunicator] = [] + + def set_peers(self, peers: list[AsyncThreadCommunicator]) -> None: + assert self.world_size == len(peers) + self.peers = peers + + async def async_send(self, to: int, key: str, data: Any) -> None: + assert 0 <= to < self.world_size + # In local simulation, we can directly call peer's onSent. + # Since we are all in the same process (and likely same loop for simulation), + # we can just call it. + self.peers[to].onSent(self.rank, key, data) diff --git a/mplang/v1/core/expr/ast.py b/mplang/v1/core/expr/ast.py index 39413af3..09548b99 100644 --- a/mplang/v1/core/expr/ast.py +++ b/mplang/v1/core/expr/ast.py @@ -34,7 +34,7 @@ from mplang.v1.core.tensor import TensorType if TYPE_CHECKING: - from mplang.v1.core.expr.visitor import ExprVisitor + from mplang.v1.core.expr.visitor import AsyncExprVisitor, ExprVisitor class Expr(ABC): @@ -84,6 +84,10 @@ def _compute_mptypes(self) -> list[MPType]: def accept(self, visitor: ExprVisitor) -> Any: """Accept a visitor for the visitor pattern.""" + @abstractmethod + async def accept_async(self, visitor: AsyncExprVisitor, env: dict[str, Any]) -> Any: + """Accept an async visitor with environment.""" + # ============================================================================ # Concrete Expression Classes @@ -161,6 +165,9 @@ def _compute_mptypes(self) -> list[MPType]: def accept(self, visitor: ExprVisitor) -> Any: return visitor.visit_eval(self) + async def accept_async(self, visitor: AsyncExprVisitor, env: dict[str, Any]) -> Any: + return await visitor.visit_eval(self, env) + class TupleExpr(Expr): """Expression for creating a tuple from multiple single-output expressions. @@ -204,6 +211,9 @@ def _compute_mptypes(self) -> list[MPType]: def accept(self, visitor: ExprVisitor) -> Any: return visitor.visit_tuple(self) + async def accept_async(self, visitor: AsyncExprVisitor, env: dict[str, Any]) -> Any: + return await visitor.visit_tuple(self, env) + class CondExpr(Expr): """Expression for conditional execution. @@ -240,6 +250,9 @@ def _compute_mptypes(self) -> list[MPType]: def accept(self, visitor: ExprVisitor) -> Any: return visitor.visit_cond(self) + async def accept_async(self, visitor: AsyncExprVisitor, env: dict[str, Any]) -> Any: + return await visitor.visit_cond(self, env) + class WhileExpr(Expr): """Expression for while loop.""" @@ -266,6 +279,9 @@ def _compute_mptypes(self) -> list[MPType]: def accept(self, visitor: ExprVisitor) -> Any: return visitor.visit_while(self) + async def accept_async(self, visitor: AsyncExprVisitor, env: dict[str, Any]) -> Any: + return await visitor.visit_while(self, env) + class ConvExpr(Expr): """Expression for convergence of multiple variables.""" @@ -321,6 +337,9 @@ def _compute_mptypes(self) -> list[MPType]: def accept(self, visitor: ExprVisitor) -> Any: return visitor.visit_conv(self) + async def accept_async(self, visitor: AsyncExprVisitor, env: dict[str, Any]) -> Any: + return await visitor.visit_conv(self, env) + class ShflSExpr(Expr): """Expression for static shuffle operation. @@ -403,6 +422,9 @@ def _compute_mptypes(self) -> list[MPType]: def accept(self, visitor: ExprVisitor) -> Any: return visitor.visit_shfl_s(self) + async def accept_async(self, visitor: AsyncExprVisitor, env: dict[str, Any]) -> Any: + return await visitor.visit_shfl_s(self, env) + class ShflExpr(Expr): """Expression for dynamic shuffle operation.""" @@ -427,6 +449,9 @@ def _compute_mptypes(self) -> list[MPType]: def accept(self, visitor: ExprVisitor) -> Any: return visitor.visit_shfl(self) + async def accept_async(self, visitor: AsyncExprVisitor, env: dict[str, Any]) -> Any: + return await visitor.visit_shfl(self, env) + class AccessExpr(Expr): """Expression for accessing a specific output of a multi-output expression. @@ -457,6 +482,9 @@ def _compute_mptypes(self) -> list[MPType]: def accept(self, visitor: ExprVisitor) -> Any: return visitor.visit_access(self) + async def accept_async(self, visitor: AsyncExprVisitor, env: dict[str, Any]) -> Any: + return await visitor.visit_access(self, env) + class VariableExpr(Expr): """Expression for variable reference/lookup.""" @@ -473,6 +501,9 @@ def _compute_mptypes(self) -> list[MPType]: def accept(self, visitor: ExprVisitor) -> Any: return visitor.visit_variable(self) + async def accept_async(self, visitor: AsyncExprVisitor, env: dict[str, Any]) -> Any: + return await visitor.visit_variable(self, env) + class FuncDefExpr(Expr): """Expression representing a function definition with parameters and body. @@ -522,6 +553,9 @@ def _compute_mptypes(self) -> list[MPType]: def accept(self, visitor: ExprVisitor) -> Any: return visitor.visit_func_def(self) + async def accept_async(self, visitor: AsyncExprVisitor, env: dict[str, Any]) -> Any: + return await visitor.visit_func_def(self, env) + class CallExpr(Expr): """Expression for function call.""" @@ -540,3 +574,6 @@ def _compute_mptypes(self) -> list[MPType]: def accept(self, visitor: ExprVisitor) -> Any: return visitor.visit_call(self) + + async def accept_async(self, visitor: AsyncExprVisitor, env: dict[str, Any]) -> Any: + return await visitor.visit_call(self, env) diff --git a/mplang/v1/core/expr/async_evaluator.py b/mplang/v1/core/expr/async_evaluator.py new file mode 100644 index 00000000..da1c7fbc --- /dev/null +++ b/mplang/v1/core/expr/async_evaluator.py @@ -0,0 +1,413 @@ +# Copyright 2025 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import asyncio +from concurrent.futures import Executor +from dataclasses import dataclass +from typing import Any + +from mplang.v1.core.async_comm import IAsyncCommunicator +from mplang.v1.core.expr.ast import ( + AccessExpr, + CallExpr, + CondExpr, + ConvExpr, + EvalExpr, + Expr, + FuncDefExpr, + ShflExpr, + ShflSExpr, + TupleExpr, + VariableExpr, + WhileExpr, +) +from mplang.v1.core.expr.evaluator import EvalSemantic +from mplang.v1.core.expr.walk import walk_dataflow +from mplang.v1.core.mask import Mask +from mplang.v1.core.pfunc import PFunction +from mplang.v1.kernels.context import RuntimeContext +from mplang.v1.kernels.value import Value + + +@dataclass +class AsyncEvalSemantic(EvalSemantic): + """Async version of EvalSemantic. + + Reuses pure computation logic from EvalSemantic + """ + + executor: Executor | None = None + + def __post_init__(self) -> None: + if not isinstance(self.comm, IAsyncCommunicator): + raise TypeError("AsyncEvalSemantic requires an IAsyncCommunicator instance") + + async def _exec_pfunc_async(self, pfunc: PFunction, args: list[Any]) -> list[Any]: + # Check if any args are None - if so, this rank shouldn't participate + # This prevents None values from reaching kernel validation + if any(arg is None for arg in args): + return [None] * len(pfunc.outs_info) + + if self.executor: + loop = asyncio.get_running_loop() + return await loop.run_in_executor( + self.executor, self._exec_pfunc, pfunc, args + ) + else: + return self._exec_pfunc(pfunc, args) + + async def _eval_eval_node_async( + self, expr: EvalExpr, arg_vals: list[Any] + ) -> list[Any]: + assert isinstance(expr.pfunc, PFunction) + if not self._should_run(expr.rmask, arg_vals): + return [None] * len(expr.mptypes) + return await self._exec_pfunc_async(expr.pfunc, arg_vals) + + async def _eval_shfl_s_node_async( + self, expr: ShflSExpr, src_value: Any + ) -> list[Any]: + pmask = expr.pmask + src_ranks = expr.src_ranks + dst_ranks = list(Mask(pmask)) + assert len(src_ranks) == len(dst_ranks) + cid = self.comm.new_id() + + # Prepare send and recv operations separately + send_tasks = [] + recv_futures = [] + + # Send phase + for src, dst in zip(src_ranks, dst_ranks, strict=True): + if self.comm.rank == src: + send_tasks.append(self.comm.async_send(dst, cid, src_value)) + + # Recv phase + for src, dst in zip(src_ranks, dst_ranks, strict=True): + if self.comm.rank == dst: + recv_futures.append(self.comm.async_recv(src, cid)) + + # Execute all operations concurrently to avoid deadlock + all_tasks = send_tasks + recv_futures + if all_tasks: + results = await asyncio.gather(*all_tasks) + # Return only the recv results + recv_results = results[len(send_tasks) :] + if self.comm.rank in dst_ranks: + assert len(recv_results) == 1 + return recv_results + else: + # Should not happen, but handle gracefully + return [None] + else: + # This party is neither sending nor receiving + if self.comm.rank in dst_ranks: + # Destination rank but no src_ranks match? + return [None] + else: + # Not involved in this shuffle + return [None] + + async def _eval_shfl_node_async( + self, expr: ShflExpr, data: Any, idx: Any + ) -> list[Any]: + # Async version of shuffle implementation + # allgather index via send/recv + indices = [None] * self.comm.world_size + cid = self.comm.new_id() + + # Send index to all other ranks + send_tasks = [] + for dst_rank in range(self.comm.world_size): + if dst_rank != self.comm.rank: + send_tasks.append(self.comm.async_send(dst_rank, cid, idx)) + + # Receive index from all ranks + recv_tasks = [] + for src_rank in range(self.comm.world_size): + if src_rank != self.comm.rank: + recv_tasks.append(self.comm.async_recv(src_rank, cid)) + + # Wait for all operations + if send_tasks: + await asyncio.gather(*send_tasks) + if recv_tasks: + recv_results = await asyncio.gather(*recv_tasks) + for i, src_rank in enumerate([ + r for r in range(self.comm.world_size) if r != self.comm.rank + ]): + indices[src_rank] = recv_results[i] + + # Set own index + indices[self.comm.rank] = idx + + # Process indices + indices_int: list[int | None] = [self._as_optional_int(val) for val in indices] + send_pairs: list[tuple[int, int]] = [] + for dst_idx, src_idx in enumerate(indices_int): + if src_idx is not None: + send_pairs.append((src_idx, dst_idx)) + send_pairs.sort() + + # Second phase: send data according to pairs + cid = self.comm.new_id() + received_data = None + + # Send data + data_send_tasks = [] + for src_rank, dst_rank in send_pairs: + if self.comm.rank == src_rank: + data_send_tasks.append(self.comm.async_send(dst_rank, cid, data)) + + # Receive data + data_recv_tasks = [] + for src_rank, dst_rank in send_pairs: + if self.comm.rank == dst_rank: + data_recv_tasks.append(self.comm.async_recv(src_rank, cid)) + + # Wait for data operations + if data_send_tasks: + await asyncio.gather(*data_send_tasks) + if data_recv_tasks: + recv_data = await asyncio.gather(*data_recv_tasks) + # Should receive exactly one data item + received_data = recv_data[0] + + return [received_data] + + async def _simple_allgather_async(self, value: Any) -> list[Any]: + """Async all-gather emulation using async send/recv. + + This implements an O(P^2) pairwise exchange (each rank sends its value to all + other ranks) and collects values in rank order. Suitable for small P (typical + controller / simulation sizes) and control metadata like a single bool. + + Returns a list of length world_size with entries ordered by rank. + """ + ws = self.comm.world_size + value = self._unwrap_value(value) + # Trivial fast-path + if ws == 1: + return [value] + cid = self.comm.new_id() + gathered: list[Any] = [None] * ws # type: ignore + gathered[self.comm.rank] = value + + # Create async tasks for all send and receive operations + tasks = [] + # Fan-out: send to all other ranks + for dst in range(ws): + if dst != self.comm.rank: + tasks.append(self.comm.async_send(dst, cid, value)) + # Fan-in: receive from all other ranks + for src in range(ws): + if src != self.comm.rank: + tasks.append(self.comm.async_recv(src, cid)) + + # Wait for all operations to complete + results = await asyncio.gather(*tasks) + + # Process results: first half are sends (which return None), second half are receives + recv_results = results[len(results) // 2:] + for i, src in enumerate([r for r in range(ws) if r != self.comm.rank]): + gathered[src] = recv_results[i] + + return gathered + + async def _verify_uniform_predicate_async(self, pred: Any) -> None: + """Async version of uniform predicate verification using async collective communication. + + Verifies that the predicate value is uniform across all parties by performing + an async all-gather operation and checking that all values are identical. + """ + # Use Value.to_bool() if available, otherwise unwrap and convert + if isinstance(pred, Value): + pred_bool = pred.to_bool() + else: + pred_bool = bool(self._unwrap_value(pred)) + + # Use async allgather to collect predicate values from all parties + vals = await self._simple_allgather_async(pred_bool) + + if not vals: + raise ValueError("uniform_cond: empty gather for predicate") + + first = vals[0] + for v in vals[1:]: + if v != first: + raise ValueError( + "uniform_cond: predicate is not uniform across parties" + ) + + +class AsyncIterativeEvaluator(AsyncEvalSemantic): + """Async evaluator using iterative traversal to avoid stack overflow. + + This evaluator follows the same pattern as the synchronous IterativeEvaluator: + 1. Uses local symbols dictionary instead of instance state + 2. Directly recurses via method calls (not Python call stack) + 3. Processes nodes in dependency order + """ + + def __init__( + self, + rank: int, + env: dict[str, Any], + comm: IAsyncCommunicator, + runtime: RuntimeContext, + executor: Executor, + ): + super().__init__(rank, env, comm, runtime, executor) + + async def evaluate(self, expr: Expr, env: dict[str, Any] | None = None) -> Any: + """Entry point for evaluation.""" + evaluation_env = env if env is not None else self.env + result = await self._iter_eval_graph(expr, evaluation_env) + return result + + async def _iter_eval_graph(self, root: Expr, env: dict[str, Any]) -> list[Any]: + """Main evaluation loop using iterative traversal (async version of sync pattern).""" + symbols: dict[int, list[Any]] = {} + + # Process all nodes in dependency order + for node in walk_dataflow(root, traversal="dfs_post_iter"): + if isinstance(node, VariableExpr): + if node.name not in env: + raise ValueError( + f"Variable '{node.name}' not found in evaluator environment" + ) + symbols[id(node)] = [env[node.name]] + + elif isinstance(node, TupleExpr): + vals = [self._first(symbols[id(a)]) for a in node.args] + symbols[id(node)] = vals + + elif isinstance(node, AccessExpr): + src_vals = symbols[id(node.src)] + symbols[id(node)] = [src_vals[node.index]] + + elif isinstance(node, CallExpr): + arg_vals = [self._first(symbols[id(a)]) for a in node.args] + assert isinstance(node.fn, FuncDefExpr) + sub_env = dict(zip(node.fn.params, arg_vals, strict=True)) + # Recursive method call - not Python call stack recursion! + res = await self._iter_eval_graph(node.fn.body, {**env, **sub_env}) + symbols[id(node)] = res + + elif isinstance(node, CondExpr): + pred_val = self._first(symbols[id(node.pred)]) + arg_vals = [self._first(symbols[id(a)]) for a in node.args] + + if pred_val is None: + symbols[id(node)] = [None] * len(node.mptypes) + else: + # Optional uniform verification + if node.verify_uniform: + await self._verify_uniform_predicate_async(pred_val) + + # Convert to bool + if isinstance(pred_val, Value): + pred = pred_val.to_bool() + else: + pred = bool(self._unwrap_value(pred_val)) + + if pred: + sub_env = dict(zip(node.then_fn.params, arg_vals, strict=True)) + # Recursive method call + res = await self._iter_eval_graph( + node.then_fn.body, {**env, **sub_env} + ) + symbols[id(node)] = res + else: + sub_env = dict(zip(node.else_fn.params, arg_vals, strict=True)) + # Recursive method call + res = await self._iter_eval_graph( + node.else_fn.body, {**env, **sub_env} + ) + symbols[id(node)] = res + + elif isinstance(node, WhileExpr): + state = [self._first(symbols[id(a)]) for a in node.args] + while True: + cond_env = dict(zip(node.cond_fn.params, state, strict=True)) + # Recursive method call for condition + cond_vals = await self._iter_eval_graph( + node.cond_fn.body, {**env, **cond_env} + ) + cond_val = self._check_while_predicate(cond_vals) + if not bool(cond_val): + break + + body_env = dict(zip(node.body_fn.params, state, strict=True)) + # Recursive method call for body + new_state = await self._iter_eval_graph( + node.body_fn.body, {**env, **body_env} + ) + state = self._merge_state(state, new_state) + symbols[id(node)] = state[0 : len(node.body_fn.mptypes)] + + elif isinstance(node, EvalExpr): + arg_vals = [self._first(symbols[id(a)]) for a in node.args] + symbols[id(node)] = await self._eval_eval_node_async(node, arg_vals) + + elif isinstance(node, ConvExpr): + vars_vals = [self._first(symbols[id(v)]) for v in node.vars] + # ConvExpr needs async implementation + symbols[id(node)] = await self._eval_conv_node_async(node, vars_vals) + + elif isinstance(node, ShflSExpr): + value = self._first(symbols[id(node.src_val)]) + symbols[id(node)] = await self._eval_shfl_s_node_async(node, value) + + elif isinstance(node, ShflExpr): + data = self._first(symbols[id(node.src)]) + index = self._first(symbols[id(node.index)]) + symbols[id(node)] = await self._eval_shfl_node_async(node, data, index) + + elif isinstance(node, FuncDefExpr): + # FuncDefExpr should not be directly evaluated + raise RuntimeError("FuncDefExpr should not be directly evaluated") + else: + raise NotImplementedError(f"Unsupported expression type: {type(node)}") + + return symbols[id(root)] + + @staticmethod + def _first(vals: list[Any]) -> Any: + """Get first value from list (matches sync evaluator).""" + if not isinstance(vals, list): + return vals + if len(vals) == 0: + return None + return vals[0] + + def _merge_state(self, old: list[Any], new: list[Any]) -> list[Any]: + """Merge state for while loops (matches sync evaluator).""" + assert len(new) <= len(old) + return new + old[len(new) :] + + async def _eval_conv_node_async( + self, _expr: ConvExpr, vars_vals: list[Any] + ) -> list[Any]: + """Async version of conv node evaluation.""" + # Implement the same logic as sync _eval_conv_node + assert len(vars_vals) > 0, "pconv called with empty vars list." + filtered = [v for v in vars_vals if v is not None] + if len(filtered) == 0: + return [None] + if len(filtered) == 1: + return [filtered[0]] + raise ValueError(f"pconv called with multiple vars={filtered}.") diff --git a/mplang/v1/runtime/communicator.py b/mplang/v1/runtime/communicator.py index bc51a3b6..1bad4d7e 100644 --- a/mplang/v1/runtime/communicator.py +++ b/mplang/v1/runtime/communicator.py @@ -23,6 +23,7 @@ import httpx +from mplang.v1.core.async_comm import AsyncCommunicatorBase from mplang.v1.core.comm import CommunicatorBase from mplang.v1.kernels.value import Value, decode_value, encode_value @@ -105,3 +106,85 @@ def recv(self, frm: int, key: str) -> Any: f"Received data: from_rank={frm}, to_rank={self._rank}, key={key}" ) return result + + +class AsyncHttpCommunicator(AsyncCommunicatorBase): + """Async version of HttpCommunicator.""" + + def __init__( + self, + session_name: str, + rank: int, + endpoints: list[str], + loop=None, + ): + # Validate endpoints + if not endpoints: + raise ValueError("endpoints cannot be empty") + + if not all(endpoint for endpoint in endpoints): + raise ValueError("endpoints cannot contain empty elements") + + super().__init__(rank, len(endpoints), loop) + self.session_name = session_name + # Ensure all endpoints have protocol prefix + self.endpoints = [ + endpoint + if endpoint.startswith(("http://", "https://")) + else f"http://{endpoint}" + for endpoint in endpoints + ] + logging.info( + f"AsyncHttpCommunicator initialized: session={session_name}, rank={rank}, endpoints={self.endpoints}" + ) + + async def async_send(self, to: int, key: str, data: Any) -> None: + """Sends data to a peer party by PUTing to its /comm/{key}/from/{from_rank} endpoint.""" + target_endpoint = self.endpoints[to] + url = f"{target_endpoint}/sessions/{self.session_name}/comm/{key}/from/{self._rank}" + logging.debug( + f"Async sending data: from_rank={self._rank}, to_rank={to}, key={key}, target_url={url}" + ) + + try: + # Serialize data using Value envelope. + if not isinstance(data, Value): + raise TypeError( + f"Communicator requires Value instance, got {type(data).__name__}. " + "Wrap data in TensorValue or custom Value subclass." + ) + data_bytes = encode_value(data) + data_b64 = base64.b64encode(data_bytes).decode("utf-8") + + request_data = { + "data": data_b64, + } + + async with httpx.AsyncClient(timeout=60.0) as client: + response = await client.put(url, json=request_data) + logging.debug(f"Async send response: status={response.status_code}") + if response.status_code != 200: + logging.error(f"Async send failed: {response.text}") + response.raise_for_status() + + except httpx.RequestError as e: + logging.error( + f"Async send failed with exception: from_rank={self._rank}, to_rank={to}, key={key}, error={e}" + ) + raise OSError(f"Failed to send data to rank {to}") from e + + async def async_recv(self, frm: int, key: str) -> Any: + """Wait until the key is set, returns the value.""" + logging.debug( + f"Async waiting to receive: from_rank={frm}, to_rank={self._rank}, key={key}" + ) + data_b64 = await super().async_recv(frm, key) + + data_bytes = base64.b64decode(data_b64) + # Deserialize using Value envelope + result = decode_value(data_bytes) + + logging.debug( + f"Async received data: from_rank={frm}, to_rank={self._rank}, key={key}" + ) + return result diff --git a/mplang/v1/runtime/server.py b/mplang/v1/runtime/server.py index 0062aedb..48fd5299 100644 --- a/mplang/v1/runtime/server.py +++ b/mplang/v1/runtime/server.py @@ -18,7 +18,9 @@ """ import base64 +from concurrent.futures import Executor, ThreadPoolExecutor import logging +import os import re from typing import Any @@ -55,6 +57,7 @@ # per-server global state _sessions: dict[str, Session] = {} _global_symbols: dict[str, Symbol] = {} +_executor: Executor = ThreadPoolExecutor(max_workers=os.cpu_count()) def register_session(session: Session) -> Session: # pragma: no cover - test helper @@ -271,7 +274,9 @@ def create_session(session_name: str, request: CreateSessionRequest) -> SessionR sess = _sessions[session_name] else: spec = ClusterSpec.from_dict(request.cluster_spec) - sess = create_session_from_spec(name=session_name, rank=request.rank, spec=spec) + sess = create_session_from_spec( + name=session_name, rank=request.rank, spec=spec, async_mode=True + ) _sessions[session_name] = sess return SessionResponse(name=sess.name) @@ -300,7 +305,7 @@ def delete_session(session_name: str) -> dict[str, str]: "/sessions/{session_name}/computations/{computation_id}", response_model=ComputationResponse, ) -def create_and_execute_computation( +async def create_and_execute_computation( session_name: str, computation_id: str, request: CreateComputationRequest ) -> ComputationResponse: graph_proto = mpir_pb2.GraphProto() @@ -325,12 +330,14 @@ def create_and_execute_computation( if not comp: comp = Computation(name=computation_id, expr=expr) sess.add_computation(comp) - sess.execute(comp, request.input_names, request.output_names) + await sess.async_execute( + comp, request.input_names, request.output_names, executor=_executor + ) return ComputationResponse(name=computation_id) @app.delete("/sessions/{session_name}/computations/{computation_id}") -def delete_computation(session_name: str, computation_id: str) -> dict[str, str]: +async def delete_computation(session_name: str, computation_id: str) -> dict[str, str]: """Delete a specific computation.""" sess = _sessions.get(session_name) if sess and sess.delete_computation(computation_id): diff --git a/mplang/v1/runtime/session.py b/mplang/v1/runtime/session.py index 27e3a369..45fa3247 100644 --- a/mplang/v1/runtime/session.py +++ b/mplang/v1/runtime/session.py @@ -27,6 +27,7 @@ import logging import time +from concurrent.futures import Executor from dataclasses import dataclass, field from functools import cached_property from typing import TYPE_CHECKING, Any, cast @@ -34,15 +35,17 @@ import spu.libspu as libspu +from mplang.v1.core.async_comm import IAsyncCommunicator from mplang.v1.core.cluster import ClusterSpec from mplang.v1.core.comm import ICommunicator from mplang.v1.core.expr.ast import Expr +from mplang.v1.core.expr.async_evaluator import AsyncIterativeEvaluator from mplang.v1.core.expr.evaluator import IEvaluator, create_evaluator from mplang.v1.core.mask import Mask from mplang.v1.kernels.context import RuntimeContext from mplang.v1.kernels.spu import PFunction # type: ignore from mplang.v1.kernels.value import Value -from mplang.v1.runtime.communicator import HttpCommunicator +from mplang.v1.runtime.communicator import AsyncHttpCommunicator, HttpCommunicator from mplang.v1.runtime.exceptions import ResourceNotFound from mplang.v1.runtime.link_comm import LinkCommunicator from mplang.v1.utils.spu_utils import parse_field, parse_protocol @@ -192,7 +195,6 @@ def ensure_spu_env(self) -> None: spu_addrs: list[str] = [] for r, addr in enumerate(self.cluster_spec.endpoints): if r in self.spu_mask: - # TODO(oeqqwq): addr may contain other schema like grpc:// if not addr.startswith(("http://", "https://")): addr = f"http://{addr}" parsed = urlparse(addr) @@ -281,17 +283,68 @@ def execute( ) self.add_symbol(Symbol(name=name, mptype={}, data=val)) + async def async_execute( + self, + computation: Computation, + input_names: list[str], + output_names: list[str], + executor: Executor, + ) -> None: + if not isinstance(self.communicator, IAsyncCommunicator): + raise RuntimeError("Session.async_execute requires an async communicator") + + env: dict[str, Any] = {} + for in_name in input_names: + sym = self.get_symbol(in_name) + if sym is None: + raise ResourceNotFound( + f"Input symbol '{in_name}' not found in session '{self.name}'" + ) + env[in_name] = sym.data + rt = self.ensure_runtime() + self.ensure_spu_env() + evaluator = AsyncIterativeEvaluator( + rank=self.rank, + env=env, + comm=self.communicator, + runtime=rt, + executor=executor, + ) + results = await evaluator.evaluate(computation.expr, env) + if results and len(results) != len(output_names): + raise RuntimeError( + f"Expected {len(output_names)} results, got {len(results)}" + ) + for name, val in zip(output_names, results, strict=True): + # In pure SIMP model, all nodes should have the same symbol table. + # Non-participating nodes get None values. + if val is not None and not isinstance(val, Value): + raise TypeError( + "Session executions must produce kernel Value outputs; " + f"got {type(val).__name__} for symbol '{name}'" + ) + self.add_symbol(Symbol(name=name, mptype={}, data=val)) + # --- Convenience constructor use HttpCommunicator--- -def create_session_from_spec(name: str, rank: int, spec: ClusterSpec) -> Session: +def create_session_from_spec( + name: str, rank: int, spec: ClusterSpec, async_mode: bool = False +) -> Session: if len(spec.get_devices_by_kind("SPU")) == 0: raise RuntimeError("No SPU device found in cluster_spec") # Create HttpCommunicator for the session - communicator = HttpCommunicator( - session_name=name, - rank=rank, - endpoints=spec.endpoints, - ) + if async_mode: + communicator: ICommunicator = AsyncHttpCommunicator( + session_name=name, + rank=rank, + endpoints=spec.endpoints, + ) + else: + communicator = HttpCommunicator( + session_name=name, + rank=rank, + endpoints=spec.endpoints, + ) return Session(name=name, rank=rank, cluster_spec=spec, communicator=communicator) diff --git a/mplang/v1/runtime/simulation.py b/mplang/v1/runtime/simulation.py index 56ed912e..1b212717 100644 --- a/mplang/v1/runtime/simulation.py +++ b/mplang/v1/runtime/simulation.py @@ -14,12 +14,11 @@ from __future__ import annotations -import concurrent.futures -import faulthandler +import asyncio import logging -import sys -import traceback +import os from collections.abc import Sequence +from concurrent.futures import ThreadPoolExecutor from typing import Any, cast import spu.libspu as libspu @@ -38,8 +37,12 @@ PFunction, # for spu.seed_env kernel seeding TensorLike, ) +from mplang.v1.core.async_comm import AsyncThreadCommunicator from mplang.v1.core.expr.ast import Expr -from mplang.v1.core.expr.evaluator import IEvaluator, create_evaluator +from mplang.v1.core.expr.async_evaluator import ( + AsyncIterativeEvaluator, +) +from mplang.v1.core.expr.evaluator import IEvaluator from mplang.v1.kernels.context import RuntimeContext from mplang.v1.runtime.link_comm import LinkCommunicator from mplang.v1.utils.spu_utils import parse_field, parse_protocol @@ -146,6 +149,9 @@ def __init__( self._spu_world = spu_mask.num_parties() self._spu_mask = spu_mask + # Executor for CPU-bound tasks + self._executor = ThreadPoolExecutor(max_workers=os.cpu_count()) + # Persistent per-rank RuntimeContext instances (reused across evaluates). # We no longer pre-create evaluators since each evaluate has different env bindings. # Build per-rank runtime contexts. @@ -210,90 +216,90 @@ def fetch(self, obj: MPObject) -> list[TensorLike]: raise ValueError(f"Expected SimVar, got {type(obj)}") return [v.to_numpy() if hasattr(v, "to_numpy") else v for v in obj._values] + def _ensure_spu_init(self, rank: int) -> None: + """Ensure SPU environment is initialized for the given rank.""" + runtime = self._runtimes[rank] + spu_meta = runtime.state.setdefault("_spu", {}) + if not spu_meta.get("inited", False): + link_ctx = self._spu_link_ctxs[rank] + seed_fn = PFunction( + fn_type="spu.seed_env", + ins_info=(), + outs_info=(), + config=self._spu_runtime_cfg, + world=self._spu_world, + link=link_ctx, + ) + runtime.run_kernel(seed_fn, []) # type: ignore[arg-type] + spu_meta["inited"] = True + # override def evaluate(self, expr: Expr, bindings: dict[str, MPObject]) -> Sequence[MPObject]: - # sanity check for bindings. + return asyncio.run(self._evaluate_async(expr, bindings)) + + async def _evaluate_async( + self, expr: Expr, bindings: dict[str, MPObject] + ) -> Sequence[MPObject]: + """Async evaluation entry point.""" + # 1. Setup Async Communicators + world_size = self.world_size() + async_comms = [ + AsyncThreadCommunicator(rank, world_size) for rank in range(world_size) + ] + for comm in async_comms: + comm.set_peers(async_comms) + + # 2. Prepare Environment + # Validate that all variables belong to this simulator context for name, var in bindings.items(): + if not isinstance(var, SimVar): + raise ValueError( + f"Expected SimVar for variable '{name}', got {type(var)}" + ) if var.ctx is not self: - raise ValueError(f"Variable {name} not in this context, got {var.ctx}.") + raise ValueError(f"Variable '{name}' not in this context") pts_env = [ {name: cast(SimVar, var)._values[rank] for name, var in bindings.items()} - for rank in range(self.world_size()) + for rank in range(world_size) ] - # Build per-rank evaluators with the per-party environment (runtime reused) - pts_evaluators: list[IEvaluator] = [] - for rank in range(self.world_size()): + # 3. Create Evaluators + evaluators = [] + for rank in range(world_size): runtime = self._runtimes[rank] - ev = create_evaluator( - rank, - pts_env[rank], - self._comms[rank], - runtime, - None, + # Initialize SPU if needed (same logic as sync) + self._ensure_spu_init(rank) + + ev = AsyncIterativeEvaluator( + rank=rank, + env=pts_env[rank], + comm=async_comms[rank], + runtime=runtime, + executor=self._executor, ) - # Seed SPU once per runtime (idempotent logical requirement) - # Use setdefault to both retrieve and create metadata dict in one step. - spu_meta = runtime.state.setdefault("_spu", {}) - if not spu_meta.get("inited", False): - link_ctx = self._spu_link_ctxs[rank] - seed_fn = PFunction( - fn_type="spu.seed_env", - ins_info=(), - outs_info=(), - config=self._spu_runtime_cfg, - world=self._spu_world, - link=link_ctx, - ) - ev.runtime.run_kernel(seed_fn, []) # type: ignore[arg-type] - spu_meta["inited"] = True - pts_evaluators.append(ev) - - # Collect evaluation results from all parties - pts_results: list[Any] = [] - - with concurrent.futures.ThreadPoolExecutor() as executor: - futures = [ - executor.submit(self._do_evaluate, expr, evaluator) - for evaluator in pts_evaluators - ] - - # Collect results with proper exception handling - for i, future in enumerate(futures): - try: - result = future.result(100) # 100 second timeout - pts_results.append(result) - except concurrent.futures.TimeoutError: - faulthandler.dump_traceback(file=sys.stderr, all_threads=True) - raise - except Exception as e: - print( - f"Exception in party {i}: {type(e).__name__}: {e}", - file=sys.stderr, - ) - traceback.print_exc(file=sys.stderr) - executor.shutdown(wait=False, cancel_futures=True) - raise - - # Convert results to SimVar objects - # pts_results is a list of party results, where each party result is a list of values - # We need to transpose this to get (n_outputs, n_parties) structure - assert len(pts_results) == self.world_size() - - # Ensure all parties returned the same number of outputs (matrix validation) + evaluators.append(ev) + + # 4. Run Evaluation concurrently + # We need to run all evaluators.evaluate(expr) concurrently. + tasks = [ev.evaluate(expr) for ev in evaluators] + pts_results = await asyncio.gather(*tasks) + + # Ensure results are lists if expr has single output + if expr.num_outputs == 1: + # If each evaluator already returns a list (as async evaluators do), don't wrap again + if pts_results and not isinstance(pts_results[0], list): + pts_results = [[res] for res in pts_results] + + # 5. Process Results (Transpose and Wrap) + assert len(pts_results) == world_size if pts_results and not all( len(row) == len(pts_results[0]) for row in pts_results ): raise ValueError("Inconsistent number of outputs across parties") - # Transpose: (n_parties, n_outputs) -> (n_outputs, n_parties) output_values = list(zip(*pts_results, strict=False)) - - # Get the output types from the expression output_types = expr.mptypes - - # Create SimVar objects for each output sim_vars = [] for values, mptype in zip(output_values, output_types, strict=False): sim_var = SimVar(self, mptype, list(values)) diff --git a/tests/v1/core/test_async_comm.py b/tests/v1/core/test_async_comm.py new file mode 100644 index 00000000..71fecad1 --- /dev/null +++ b/tests/v1/core/test_async_comm.py @@ -0,0 +1,102 @@ +# Copyright 2025 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio + +import pytest + +from mplang.v1.core.async_comm import AsyncThreadCommunicator + + +@pytest.mark.asyncio +async def test_async_p2p(): + world_size = 2 + comms = [AsyncThreadCommunicator(i, world_size) for i in range(world_size)] + for comm in comms: + comm.set_peers(comms) + + # P0 sends to P1 + async def p0_task(): + await comms[0].p2p(0, 1, "hello") + return "done" + + async def p1_task(): + data = await comms[1].p2p(0, 1, None) + return data + + results = await asyncio.gather(p0_task(), p1_task()) + assert results[1] == "hello" + + +@pytest.mark.asyncio +async def test_async_gather(): + world_size = 3 + comms = [AsyncThreadCommunicator(i, world_size) for i in range(world_size)] + for comm in comms: + comm.set_peers(comms) + + async def task(rank): + data = f"data-{rank}" + return await comms[rank].gather(0, data) + + results = await asyncio.gather(*[task(i) for i in range(world_size)]) + + # Rank 0 should get all data + assert results[0] == ["data-0", "data-1", "data-2"] + # Others get None list + assert results[1] == [None, None, None] + assert results[2] == [None, None, None] + + +@pytest.mark.asyncio +async def test_async_scatter(): + world_size = 3 + comms = [AsyncThreadCommunicator(i, world_size) for i in range(world_size)] + for comm in comms: + comm.set_peers(comms) + + data_to_scatter = ["d0", "d1", "d2"] + + async def task(rank): + if rank == 0: + return await comms[rank].scatter(0, data_to_scatter) + else: + return await comms[rank].scatter(0, [None] * 3) # args ignored for non-root + + results = await asyncio.gather(*[task(i) for i in range(world_size)]) + + assert results[0] == "d0" + assert results[1] == "d1" + assert results[2] == "d2" + + +@pytest.mark.asyncio +async def test_async_bcast(): + world_size = 3 + comms = [AsyncThreadCommunicator(i, world_size) for i in range(world_size)] + for comm in comms: + comm.set_peers(comms) + + async def task(rank): + if rank == 0: + return await comms[rank].bcast(0, "broadcast_data") + else: + return await comms[rank].bcast(0, None) + + results = await asyncio.gather(*[task(i) for i in range(world_size)]) + + # bcast returns the data for everyone in the mask, including the root + assert results[0] == "broadcast_data" + assert results[1] == "broadcast_data" + assert results[2] == "broadcast_data" diff --git a/tutorials/v1/device/02_simulation_and_driver.py b/tutorials/v1/device/02_simulation_and_driver.py index b04466a7..df118405 100644 --- a/tutorials/v1/device/02_simulation_and_driver.py +++ b/tutorials/v1/device/02_simulation_and_driver.py @@ -127,13 +127,13 @@ def cmd_main(): Usage: 1. Simulator (local, no setup needed): - uv run tutorials/device/02_simulation_and_driver.py sim + uv run tutorials/v1/device/02_simulation_and_driver.py sim 2. Driver (distributed): Step 1: Start cluster in separate terminal: - uv run python -m mplang.runtime.cli up -c examples/v1/conf/3pc.yaml + uv run python -m mplang.v1.runtime.cli up -c examples/v1/conf/3pc.yaml Step 2: Run computation: - uv run tutorials/device/02_simulation_and_driver.py run + uv run tutorials/v1/device/02_simulation_and_driver.py run """ cmd_main()