diff --git a/fbgemm_gpu/fbgemm_gpu/tbe/bench/bench_runs.py b/fbgemm_gpu/fbgemm_gpu/tbe/bench/bench_runs.py index 1243f14db4..84fa06bdd8 100644 --- a/fbgemm_gpu/fbgemm_gpu/tbe/bench/bench_runs.py +++ b/fbgemm_gpu/fbgemm_gpu/tbe/bench/bench_runs.py @@ -153,6 +153,12 @@ def benchmark_cpu_requests_mp( float: The average runtime per iteration in seconds. """ + import os + strategy = os.environ.get('PYTORCH_SHARE_STRATEGY') + current_strategy = torch.multiprocessing.get_sharing_strategy() + if strategy is not None and current_strategy != strategy: + torch.multiprocessing.set_sharing_strategy(strategy) + cpu_bm_barrier.create_barrier(num_copies) worker_pool = torch.multiprocessing.Pool(num_copies)