@@ -37,7 +37,12 @@ def clone_symm_mem_tensor(tensor: torch.Tensor) -> torch.Tensor:
3737        device = tensor .device ,
3838    )
3939    assert  dist .group .WORLD  is  not None 
40-     symm_mem .rendezvous (symm_mem_tensor , dist .group .WORLD .group_name )
40+     try :
41+         symm_mem .rendezvous (symm_mem_tensor , dist .group .WORLD .group_name )
42+     except  RuntimeError  as  e :
43+         raise  RuntimeError (
44+             f"Failed to rendezvous tensor symmetric memory tensor of shape { tensor .shape }  
45+         ) from  e 
4146    symm_mem_tensor .copy_ (tensor )
4247    return  symm_mem_tensor 
4348
@@ -56,7 +61,7 @@ def clone_inputs(args: tuple[object]) -> tuple[object]:
5661
5762@dataclass  
5863class  ExperimentConfig :
59-     shape : tuple [int ]
64+     shape : tuple [int , ... ]
6065    dtype : torch .dtype 
6166    backends : list [str ]
6267    device : torch .device  |  None  =  None 
@@ -96,7 +101,7 @@ class BenchmarkOperator:
96101\ 
97102\ 
98103\ 
99- 
104+  <op> 
100105""" 
101106
102107    experiments : list [Experiment ]
@@ -131,6 +136,12 @@ def parse_args(self) -> argparse.Namespace:
131136            description = f"Run benchmark for { self .__name__ }   +  self .help_str 
132137        )
133138
139+         parser .add_argument (
140+             "op" ,
141+             type = str ,
142+             help = "Operator to benchmark. " ,
143+         )
144+ 
134145        parser .add_argument (
135146            "--backend" ,
136147            type = str ,
@@ -153,6 +164,8 @@ def parse_args(self) -> argparse.Namespace:
153164        self .args  =  parser .parse_args ()
154165        self .args .dtype  =  getattr (torch , self .args .dtype )
155166
167+         assert  self .args .op  ==  self .op_name 
168+ 
156169        return  self .args 
157170
158171    def  __init__ (self ) ->  None :
@@ -168,7 +181,6 @@ def __init__(self) -> None:
168181
169182        self .device  =  torch .device (f"cuda:{ self .local_rank }  )
170183        torch .cuda .set_device (self .device )
171-         dist .init_process_group ("nccl" )
172184        torch .manual_seed (42  +  self .local_rank )
173185
174186        self .experiments  =  []
@@ -292,35 +304,43 @@ def get_results(self, metric: str = "speedup") -> defaultdict | None:
292304
293305    def  run_experiment (self , config : ExperimentConfig ) ->  dict [str , float ]:
294306        if  self .baseline  not  in config .backends :
295-             backends  =  config .backends . append ( self .baseline ) 
307+             backends  =  [ * config .backends ,  self .baseline ] 
296308        else :
297309            backends  =  config .backends 
298310
299311        gloden_inp  =  self .gen_inputs (config )
300-         inputs  =  {backend : clone_inputs (gloden_inp ) for  backend  in  backends }  # pyright: ignore[reportOptionalIterable] 
301312
302313        gloden_fn  =  self .fn_dict [self .baseline ]
303314        assert  gloden_fn  is  not None 
304315
316+         inp_og  =  clone_inputs (gloden_inp )
305317        gloden_o  =  gloden_fn (* gloden_inp )
306318
307319        results  =  {}
308-         for  backend  in  backends :   # pyright: ignore[reportOptionalIterable] 
320+         for  backend  in  backends :
309321            fn  =  self .fn_dict [backend ]
310322            if  fn  is  None :
311323                results [backend ] =  float ("nan" )
312324                continue 
313-             inp  =  inputs [ backend ] 
325+             inp  =  clone_inputs ( inp_og ) 
314326            target_fn  =  functools .partial (fn , * inp )
315327            try :
316328                test_o  =  target_fn ()
317329            except  RuntimeError :
318330                results [backend ] =  float ("nan" )
319331                continue 
332+             except  AssertionError :
333+                 results [backend ] =  float ("nan" )
334+                 continue 
320335            torch .testing .assert_close (test_o , gloden_o , atol = 1e-1 , rtol = 1e-1 )
321336
322337            results [backend ] =  benchmark_distributed (
323338                target_fn , profile_ranks = [self .MASTER_RANK ]
324339            )
340+             del  test_o 
341+             del  inp 
342+ 
343+         del  gloden_inp 
344+         del  gloden_o 
325345
326346        return  results 
0 commit comments