Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -441,15 +441,18 @@ class DramKVInferenceEmbedding {

// Random starting cursor based on map size for good
// entropy
size_t random_start =
folly::Random::rand32(wlmap->size());
size_t map_size = wlmap->size();
size_t random_start = folly::Random::rand32(map_size);

// Try to find a used block starting from random
// position
weight_type* block = nullptr;
for (int attempts = 0; attempts < 16; ++attempts) {
// Use modulo to prevent overflow beyond map size
size_t block_index =
(random_start + attempts) % map_size;
block = pool->template get_block<weight_type>(
random_start + attempts);
block_index);
if (block != nullptr) {
// Block is used (not null)
row_storage_data_ptr =
Expand Down
58 changes: 58 additions & 0 deletions fbgemm_gpu/test/tbe/dram_kv/dram_kv_inference_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,3 +385,61 @@ def test_zero_cache_miss_initialization_with_embedding_cache_mode(self) -> None:
f"All zero cache miss results should be identical across calls, "
f"but results[0] != results[{i}]",
)

def test_random_block_index_no_overflow(self) -> None:
"""Test that random block selection doesn't overflow when random_start + attempts exceeds map size."""
num_shards = 4
uniform_init_lower: float = -0.01
uniform_init_upper: float = 0.01

# Setup: Create DRAM KV inference cache with small map size
kv_embedding_cache = torch.classes.fbgemm.DramKVEmbeddingInferenceWrapper(
num_shards,
uniform_init_lower,
uniform_init_upper,
False, # disable_random_init=False to enable random block selection
)
kv_embedding_cache.init(
[(16, 4, SparseType.FP16.as_int())],
16,
4,
torch.tensor([0, 20], dtype=torch.int64),
)

# Populate cache with a small number of entries (e.g., 10)
# This creates a scenario where random_start could be close to map size
# and random_start + 16 attempts would exceed the map size
setup_indices = torch.arange(0, 10, dtype=torch.int64)
setup_weights = torch.randint(1, 255, (10, 16), dtype=torch.uint8)
kv_embedding_cache.set_embeddings(setup_indices, setup_weights)

# Execute: Request many cache misses to trigger random block selection
# This exercises the random block selection code path multiple times
# with different random starting points
miss_indices = torch.tensor(
[100, 101, 102, 103, 104, 105, 106, 107, 108, 109], dtype=torch.int64
)

# Run multiple times to increase probability of hitting edge cases
# where random_start is near the end of the map
for iteration in range(50):
try:
result = kv_embedding_cache.get_embeddings(miss_indices)

# Assert: Verify the operation completes without crash or exception
# The result should have correct dimensions
self.assertEqual(result.size(0), 10)
self.assertEqual(result.size(1), 16)

# All values should be non-zero (randomized from existing blocks)
self.assertTrue(
torch.any(result != 0),
f"Expected non-zero randomized data in iteration {iteration}",
)
except Exception as e:
self.fail(
f"Random block selection failed in iteration {iteration} "
f"with small map size (10 entries). This likely indicates an "
f"overflow issue when random_start + attempts exceeds map size. "
f"Error: {e}"
)
Loading