1111import os
1212import subprocess
1313import sys
14- from typing import Any , cast , Dict , FrozenSet , List , Optional , Sequence
14+ from typing import Any , Dict , FrozenSet , List , Optional , Sequence
1515
1616from monarch ._rust_bindings .monarch_hyperactor .channel import ChannelTransport
1717from monarch ._rust_bindings .monarch_hyperactor .config import configure
1818
1919from monarch ._src .actor .bootstrap import attach_to_workers
20- from monarch ._src .actor .host_mesh import HostMesh
2120from monarch ._src .job .job import JobState , JobTrait
2221
2322
@@ -55,6 +54,8 @@ def __init__(
5554 log_dir : Optional [str ] = None ,
5655 exclusive : bool = True ,
5756 gpus_per_node : Optional [int ] = None ,
57+ cpus_per_task : Optional [int ] = None ,
58+ mem : Optional [str ] = None ,
5859 ) -> None :
5960 """
6061 Args:
@@ -84,6 +85,8 @@ def __init__(
8485 self ._log_dir : str = log_dir if log_dir is not None else os .getcwd ()
8586 self ._exclusive = exclusive
8687 self ._gpus_per_node = gpus_per_node
88+ self ._cpus_per_task = cpus_per_task
89+ self ._mem = mem
8790 # Track the single SLURM job ID and all allocated hostnames
8891 self ._slurm_job_id : Optional [str ] = None
8992 self ._all_hostnames : List [str ] = []
@@ -128,6 +131,12 @@ def _submit_slurm_job(self, num_nodes: int) -> str:
128131 if self ._gpus_per_node is not None :
129132 sbatch_directives .append (f"#SBATCH --gpus-per-node={ self ._gpus_per_node } " )
130133
134+ if self ._cpus_per_task is not None :
135+ sbatch_directives .append (f"#SBATCH --cpus-per-task={ self ._cpus_per_task } " )
136+
137+ if self ._mem is not None :
138+ sbatch_directives .append (f"#SBATCH --mem={ self ._mem } " )
139+
131140 if self ._exclusive :
132141 sbatch_directives .append ("#SBATCH --exclusive" )
133142
@@ -297,6 +306,8 @@ def can_run(self, spec: "JobTrait") -> bool:
297306 and spec ._time_limit == self ._time_limit
298307 and spec ._partition == self ._partition
299308 and spec ._gpus_per_node == self ._gpus_per_node
309+ and spec ._cpus_per_task == self ._cpus_per_task
310+ and spec ._mem == self ._mem
300311 and self ._jobs_active ()
301312 )
302313
@@ -318,6 +329,28 @@ def _jobs_active(self) -> bool:
318329
319330 return True
320331
332+ def share_node (
333+ self , tasks_per_node : int , gpus_per_task : int , partition : str
334+ ) -> None :
335+ """
336+ Share a node with other jobs.
337+ """
338+ try :
339+ import clusterscope
340+ except ImportError :
341+ raise RuntimeError (
342+ "please install clusterscope to use share_node. `pip install clusterscope`"
343+ )
344+ self ._exclusive = False
345+
346+ slurm_args = clusterscope .job_gen_task_slurm (
347+ partition = partition ,
348+ gpus_per_task = gpus_per_task ,
349+ tasks_per_node = tasks_per_node ,
350+ )
351+ self ._cpus_per_task = slurm_args ["cpus_per_task" ]
352+ self ._mem = slurm_args ["memory" ]
353+
321354 def _kill (self ) -> None :
322355 """Cancel the SLURM job."""
323356 if self ._slurm_job_id is not None :
0 commit comments