diff --git a/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_inference_embedding.h b/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_inference_embedding.h index f723a14e2b..b2ed9723d8 100644 --- a/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_inference_embedding.h +++ b/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_inference_embedding.h @@ -441,15 +441,18 @@ class DramKVInferenceEmbedding { // Random starting cursor based on map size for good // entropy - size_t random_start = - folly::Random::rand32(wlmap->size()); + size_t map_size = wlmap->size(); + size_t random_start = folly::Random::rand32(map_size); // Try to find a used block starting from random // position weight_type* block = nullptr; for (int attempts = 0; attempts < 16; ++attempts) { + // Use modulo to prevent overflow beyond map size + size_t block_index = + (random_start + attempts) % map_size; block = pool->template get_block( - random_start + attempts); + block_index); if (block != nullptr) { // Block is used (not null) row_storage_data_ptr = diff --git a/fbgemm_gpu/test/tbe/dram_kv/dram_kv_inference_test.py b/fbgemm_gpu/test/tbe/dram_kv/dram_kv_inference_test.py index 796d453681..6ee31f91d4 100644 --- a/fbgemm_gpu/test/tbe/dram_kv/dram_kv_inference_test.py +++ b/fbgemm_gpu/test/tbe/dram_kv/dram_kv_inference_test.py @@ -385,3 +385,61 @@ def test_zero_cache_miss_initialization_with_embedding_cache_mode(self) -> None: f"All zero cache miss results should be identical across calls, " f"but results[0] != results[{i}]", ) + + def test_random_block_index_no_overflow(self) -> None: + """Test that random block selection doesn't overflow when random_start + attempts exceeds map size.""" + num_shards = 4 + uniform_init_lower: float = -0.01 + uniform_init_upper: float = 0.01 + + # Setup: Create DRAM KV inference cache with small map size + kv_embedding_cache = torch.classes.fbgemm.DramKVEmbeddingInferenceWrapper( + num_shards, + uniform_init_lower, + uniform_init_upper, + False, # disable_random_init=False to enable random block selection + ) + kv_embedding_cache.init( + [(16, 4, SparseType.FP16.as_int())], + 16, + 4, + torch.tensor([0, 20], dtype=torch.int64), + ) + + # Populate cache with a small number of entries (e.g., 10) + # This creates a scenario where random_start could be close to map size + # and random_start + 16 attempts would exceed the map size + setup_indices = torch.arange(0, 10, dtype=torch.int64) + setup_weights = torch.randint(1, 255, (10, 16), dtype=torch.uint8) + kv_embedding_cache.set_embeddings(setup_indices, setup_weights) + + # Execute: Request many cache misses to trigger random block selection + # This exercises the random block selection code path multiple times + # with different random starting points + miss_indices = torch.tensor( + [100, 101, 102, 103, 104, 105, 106, 107, 108, 109], dtype=torch.int64 + ) + + # Run multiple times to increase probability of hitting edge cases + # where random_start is near the end of the map + for iteration in range(50): + try: + result = kv_embedding_cache.get_embeddings(miss_indices) + + # Assert: Verify the operation completes without crash or exception + # The result should have correct dimensions + self.assertEqual(result.size(0), 10) + self.assertEqual(result.size(1), 16) + + # All values should be non-zero (randomized from existing blocks) + self.assertTrue( + torch.any(result != 0), + f"Expected non-zero randomized data in iteration {iteration}", + ) + except Exception as e: + self.fail( + f"Random block selection failed in iteration {iteration} " + f"with small map size (10 entries). This likely indicates an " + f"overflow issue when random_start + attempts exceeds map size. " + f"Error: {e}" + )