diff --git a/torchrec/distributed/model_parallel.py b/torchrec/distributed/model_parallel.py index ccca503e5..cd45433c8 100644 --- a/torchrec/distributed/model_parallel.py +++ b/torchrec/distributed/model_parallel.py @@ -385,7 +385,13 @@ def copy( return copy_dmp def _init_dmp(self, module: nn.Module) -> nn.Module: - return self._shard_modules_impl(module) + if torch._utils_internal.justknobs_check( + "pytorch/torchrec:enable_module_id_cache_for_dmp_shard_modules" + ): + module_id_cache: Dict[int, ShardedModule] = {} + else: + module_id_cache = None + return self._shard_modules_impl(module, module_id_cache=module_id_cache) def _init_delta_tracker( self, delta_tracker_config: DeltaTrackerConfig, module: nn.Module @@ -435,28 +441,37 @@ def _shard_modules_impl( self, module: nn.Module, path: str = "", + module_id_cache: Optional[Dict[str, ShardedModule]] = None, ) -> nn.Module: # pre-sharded module if isinstance(module, ShardedModule): return module + if module_id_cache is not None: + module_id = id(module) + if module_id in module_id_cache: + return module_id_cache[module_id] + # shardable module module_sharding_plan = self._plan.get_plan_for_module(path) if module_sharding_plan: sharder_key = type(module) - module = self._sharder_map[sharder_key].shard( + sharded_module = self._sharder_map[sharder_key].shard( module, module_sharding_plan, self._env, self.device, path, ) - return module + if module_id_cache is not None: + module_id_cache[module_id] = sharded_module + return sharded_module for name, child in module.named_children(): child = self._shard_modules_impl( child, path + "." + name if path else name, + module_id_cache, ) setattr(module, name, child) @@ -1001,12 +1016,18 @@ def _shard_modules_impl( self, module: nn.Module, path: str = "", + module_id_cache: Optional[Dict[int, ShardedModule]] = None, ) -> nn.Module: # pre-sharded module if isinstance(module, ShardedModule): return module + if module_id_cache is not None: + module_id = id(module) + if module_id in module_id_cache: + return module_id_cache[module_id] + # shardable module module_sharding_plan = self._plan.get_plan_for_module(path) if module_sharding_plan: @@ -1025,19 +1046,22 @@ def _shard_modules_impl( ) break - module = self._sharder_map[sharder_key].shard( + sharded_module = self._sharder_map[sharder_key].shard( module, module_sharding_plan, env, self.device, path, ) - return module + if module_id_cache is not None: + module_id_cache[module_id] = sharded_module + return sharded_module for name, child in module.named_children(): child = self._shard_modules_impl( child, path + "." + name if path else name, + module_id_cache, ) setattr(module, name, child) diff --git a/torchrec/distributed/tests/test_model_parallel_nccl_single_rank.py b/torchrec/distributed/tests/test_model_parallel_nccl_single_rank.py index 0ea359f89..1c607e487 100644 --- a/torchrec/distributed/tests/test_model_parallel_nccl_single_rank.py +++ b/torchrec/distributed/tests/test_model_parallel_nccl_single_rank.py @@ -7,15 +7,218 @@ # pyre-strict +from unittest.mock import patch + +import torch +import torch.nn as nn +from torchrec.distributed.model_parallel import DistributedModelParallel from torchrec.distributed.test_utils.test_model_parallel_base import ( ModelParallelSparseOnlyBase, ModelParallelStateDictBase, ) +from torchrec.distributed.types import ShardedModule +from torchrec.modules.embedding_configs import EmbeddingBagConfig, EmbeddingConfig +from torchrec.modules.embedding_modules import ( + EmbeddingBagCollection, + EmbeddingCollection, +) +from torchrec.sparse.jagged_tensor import KeyedJaggedTensor class ModelParallelStateDictTestNccl(ModelParallelStateDictBase): pass +class SparseArch(nn.Module): + def __init__( + self, + ebc: EmbeddingBagCollection, + ec: EmbeddingCollection, + ) -> None: + super().__init__() + self.ebc = ebc + self.ec = ec + + def forward(self, features: KeyedJaggedTensor) -> tuple[torch.Tensor, torch.Tensor]: + ebc_out = self.ebc(features) + ec_out = self.ec(features) + return ebc_out.values(), ec_out.values() + + +# Create a model with two sparse architectures sharing the same modules +class TwoSparseArchModel(nn.Module): + def __init__( + self, + sparse1: SparseArch, + sparse2: SparseArch, + ) -> None: + super().__init__() + # Both architectures share the same EBC and EC instances + self.sparse1 = sparse1 + self.sparse2 = sparse2 + + def forward( + self, features: KeyedJaggedTensor + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + ebc1_out, ec1_out = self.sparse1(features) + ebc2_out, ec2_out = self.sparse2(features) + + return ebc1_out, ec1_out, ebc2_out, ec2_out + + class ModelParallelSparseOnlyTestNccl(ModelParallelSparseOnlyBase): - pass + def test_shared_sparse_module_in_multiple_parents(self) -> None: + """ + Test that the module ID cache correctly handles the same sparse module + being used in multiple parent modules. This tests the caching behavior + when a single EmbeddingBagCollection and EmbeddingCollection are shared + across two different parent sparse architectures. + """ + + # Setup: Create shared embedding modules that will be reused + ebc = EmbeddingBagCollection( + device=torch.device("meta"), + tables=[ + EmbeddingBagConfig( + name="ebc_table", + embedding_dim=64, + num_embeddings=100, + feature_names=["ebc_feature"], + ), + ], + ) + ec = EmbeddingCollection( + device=torch.device("meta"), + tables=[ + EmbeddingConfig( + name="ec_table", + embedding_dim=32, + num_embeddings=50, + feature_names=["ec_feature"], + ), + ], + ) + + # Create the model with shared modules + sparse1 = SparseArch(ebc, ec) + sparse2 = SparseArch(ebc, ec) + model = TwoSparseArchModel(sparse1, sparse2) + + # Execute: Shard the model with DistributedModelParallel + dmp = DistributedModelParallel(model, device=self.device) + + # Assert: Verify that the shared modules are properly handled + self.assertIsNotNone(dmp.module) + + # Verify that the same module instances are reused (cached behavior) + wrapped_module = dmp.module + self.assertIs( + wrapped_module.sparse1.ebc, + wrapped_module.sparse2.ebc, + "ebc1 and ebc2 should be the same sharded instance", + ) + self.assertIs( + wrapped_module.sparse1.ec, + wrapped_module.sparse2.ec, + "ec1 and ec2 should be the same sharded instance", + ) + self.assertIsInstance( + wrapped_module.sparse1.ebc, + ShardedModule, + "ebc1 should be sharded", + ) + self.assertIsInstance( + wrapped_module.sparse1.ec, + ShardedModule, + "ec1 should be sharded", + ) + + def test_shared_sparse_module_in_multiple_parents_negative(self) -> None: + """ + Test that when module ID caching is disabled (module_id_cache=None), + the same module instance gets sharded multiple times, resulting in + different sharded instances. This validates the behavior without caching. + """ + + def mock_init_dmp( + self_dmp: DistributedModelParallel, module: nn.Module + ) -> nn.Module: + """Override _init_dmp to always set module_id_cache to None""" + # Call _shard_modules_impl with module_id_cache=None (caching disabled) + return self_dmp._shard_modules_impl(module, module_id_cache=None) + + # Setup: Create shared embedding modules that will be reused + ebc = EmbeddingBagCollection( + device=torch.device("meta"), + tables=[ + EmbeddingBagConfig( + name="ebc_table", + embedding_dim=64, + num_embeddings=100, + feature_names=["ebc_feature"], + ), + ], + ) + ec = EmbeddingCollection( + device=torch.device("meta"), + tables=[ + EmbeddingConfig( + name="ec_table", + embedding_dim=32, + num_embeddings=50, + feature_names=["ec_feature"], + ), + ], + ) + + # Create the model with shared modules + sparse1 = SparseArch(ebc, ec) + sparse2 = SparseArch(ebc, ec) + model = TwoSparseArchModel(sparse1, sparse2) + + # Execute: Mock _init_dmp to disable caching, then shard the model + with patch.object( + DistributedModelParallel, + "_init_dmp", + mock_init_dmp, + ): + dmp = DistributedModelParallel(model, device=self.device) + + # Assert: Verify that modules are NOT cached (different instances) + self.assertIsNotNone(dmp.module) + wrapped_module = dmp.module + + # Without caching, the same module should be sharded twice, + # resulting in different sharded instances + self.assertIsNot( + wrapped_module.sparse1.ebc, + wrapped_module.sparse2.ebc, + "Without caching, ebc1 and ebc2 should be different sharded instances", + ) + self.assertIsNot( + wrapped_module.sparse1.ec, + wrapped_module.sparse2.ec, + "Without caching, ec1 and ec2 should be different sharded instances", + ) + + # Both should still be properly sharded, just not cached + self.assertIsInstance( + wrapped_module.sparse1.ebc, + ShardedModule, + "ebc1 should be sharded", + ) + self.assertIsInstance( + wrapped_module.sparse1.ec, + ShardedModule, + "ec1 should be sharded", + ) + self.assertIsInstance( + wrapped_module.sparse2.ebc, + ShardedModule, + "ebc2 should be sharded", + ) + self.assertIsInstance( + wrapped_module.sparse2.ec, + ShardedModule, + "ec2 should be sharded", + )