From 99bc254427cbcc2bc7b22ae484250a495f9e8410 Mon Sep 17 00:00:00 2001 From: Emma Lin Date: Sat, 25 Oct 2025 21:28:02 -0700 Subject: [PATCH] support prefetch pipeline (#5032) Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/2045 Fix the direct_write for prefetch pipeline Reviewed By: kausv, steven1327 Differential Revision: D85021220 --- fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py | 19 +- .../tbe/ssd/ssd_split_tbe_training_test.py | 460 +++++++++++++++++- 2 files changed, 472 insertions(+), 7 deletions(-) diff --git a/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py b/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py index f7b6df66f6..a497cf9a5b 100644 --- a/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py +++ b/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py @@ -1984,12 +1984,13 @@ def _prefetch( # noqa C901 # Store info for evicting the previous iteration's # scratch pad after the corresponding backward pass is # done - self.ssd_location_update_data.append( - ( - sp_curr_prev_map_gpu, - inserted_rows, + if self.training: + self.ssd_location_update_data.append( + ( + sp_curr_prev_map_gpu, + inserted_rows, + ) ) - ) # Ensure the previous iterations eviction is complete current_stream.wait_event(self.ssd_event_sp_evict) @@ -2173,7 +2174,7 @@ def _prefetch( # noqa C901 # Store scratch pad info for post backward eviction only for training # for eval job, no backward pass, so no need to store this info - if self.training and not self._embedding_cache_mode: + if self.training: self.ssd_scratch_pad_eviction_data.append( ( inserted_rows, @@ -4548,6 +4549,12 @@ def direct_write_embedding( if len(self.ssd_scratch_pad_eviction_data) > 0: self.ssd_scratch_pad_eviction_data.pop(0) if len(self.ssd_scratch_pad_eviction_data) > 0: + # Wait for any pending backend reads to the next scratch pad + # to complete before we write to it. Otherwise, stale backend data + # will overwrite our direct_write updates. + # The ssd_event_get marks completion of backend fetch operations. + current_stream.wait_event(self.ssd_event_get) + # if scratch pad exists, write to next batch scratch pad sp = self.ssd_scratch_pad_eviction_data[0][0] sp_idx = self.ssd_scratch_pad_eviction_data[0][1].to( diff --git a/fbgemm_gpu/test/tbe/ssd/ssd_split_tbe_training_test.py b/fbgemm_gpu/test/tbe/ssd/ssd_split_tbe_training_test.py index c3f12903fa..a2869f38d4 100644 --- a/fbgemm_gpu/test/tbe/ssd/ssd_split_tbe_training_test.py +++ b/fbgemm_gpu/test/tbe/ssd/ssd_split_tbe_training_test.py @@ -1119,7 +1119,7 @@ def test_direct_write_embedding( torch.cuda.synchronize() # Execute forward - output_ref_list, output = self.execute_ssd_forward_( + self.execute_ssd_forward_( emb, emb_ref, indices_list, @@ -1311,3 +1311,461 @@ def test_direct_write_embedding( else 1e-4 ), ) + + @given( + T=st.integers(min_value=1, max_value=3), + D=st.just(16), + B=st.just(8), + log_E=st.just(9), + L=st.just(4), + weighted=st.booleans(), + cache_set_scale=st.just(0.1), + pooling_mode=st.sampled_from([PoolingMode.SUM]), + weights_precision=st.sampled_from([SparseType.FP32, SparseType.FP16]), + output_dtype=st.sampled_from([SparseType.FP32, SparseType.FP16]), + share_table=st.just(False), + trigger_bounds_check=st.just(False), + mixed_B=st.just(False), + num_buckets=st.just(1), + backend_type=st.sampled_from([BackendType.DRAM]), + enable_optimizer_offloading=st.just(True), + prefetch_pipeline=st.just(True), + ) + @settings(verbosity=Verbosity.verbose, max_examples=MAX_EXAMPLES, deadline=None) + def test_direct_write_with_async_backend_fetch_synchronization( + self, + T: int, + D: int, + B: int, + log_E: int, + L: int, + weighted: bool, + cache_set_scale: float, + pooling_mode: PoolingMode, + weights_precision: SparseType, + output_dtype: SparseType, + share_table: bool, + trigger_bounds_check: bool, + mixed_B: bool, + num_buckets: int, + backend_type: BackendType, + enable_optimizer_offloading: bool, + prefetch_pipeline: bool, + ) -> None: + """ + Test Gap A fix: Verify that direct_write_embedding properly waits for + async backend fetch to complete before writing to scratch pad. + + This test ensures no race condition occurs between: + - Prefetch's async backend read to scratch pad (on ssd_eviction_stream) + - direct_write_embedding's writes to the same scratch pad (on main stream) + """ + + # Constants + lr = 0.5 + eps = 0.2 + ssd_shards = 2 + + ( + emb, + emb_ref, + Es, + _, + bucket_offsets, + bucket_sizes, + ) = self.generate_kvzch_tbes( + T, + D, + B, + log_E, + L, + weighted, + lr=lr, + eps=eps, + ssd_shards=ssd_shards, + cache_set_scale=cache_set_scale, + pooling_mode=pooling_mode, + weights_precision=weights_precision, + output_dtype=output_dtype, + share_table=share_table, + prefetch_pipeline=prefetch_pipeline, + backend_type=backend_type, + num_buckets=num_buckets, + enable_optimizer_offloading=enable_optimizer_offloading, + embedding_cache_mode=True, + ) + + # Generate inputs for iteration 1 + ( + indices_list_1, + per_sample_weights_list_1, + indices_1, + offsets_1, + per_sample_weights_1, + batch_size_per_feature_per_rank_1, + ) = self.generate_inputs_( + B, + L, + Es, + emb.feature_table_map, + weights_precision=weights_precision, + trigger_bounds_check=trigger_bounds_check, + mixed_B=mixed_B, + bucket_offsets=bucket_offsets, + bucket_sizes=bucket_sizes, + is_kv_tbes=True, + ) + + # Run iteration 1: prefetch + forward + emb.prefetch( + indices_1, + offsets_1, + batch_size_per_feature_per_rank=batch_size_per_feature_per_rank_1, + ) + torch.cuda.synchronize() + + self.execute_ssd_forward_( + emb, + emb_ref, + indices_list_1, + per_sample_weights_list_1, + indices_1, + offsets_1, + per_sample_weights_1, + B, + L, + weighted, + batch_size_per_feature_per_rank=batch_size_per_feature_per_rank_1, + ) + torch.cuda.synchronize() + + # Generate inputs for iteration 2 + ( + _, + _, + indices_2, + offsets_2, + _, + batch_size_per_feature_per_rank_2, + ) = self.generate_inputs_( + B, + L, + Es, + emb.feature_table_map, + weights_precision=weights_precision, + trigger_bounds_check=trigger_bounds_check, + mixed_B=mixed_B, + bucket_offsets=bucket_offsets, + bucket_sizes=bucket_sizes, + is_kv_tbes=True, + ) + + # Run iteration 2: prefetch (creates scratch pad for iter 2) + emb.prefetch( + indices_2, + offsets_2, + batch_size_per_feature_per_rank=batch_size_per_feature_per_rank_2, + ) + torch.cuda.synchronize() + + # Store scratch pad queue lengths before direct_write + eviction_data_before = len(emb.ssd_scratch_pad_eviction_data) + + # Now call direct_write_embedding - this should wait for ssd_event_get + # to ensure backend fetch to next scratch pad completes + linear_cache_indices = torch.ops.fbgemm.linearize_cache_indices( + emb.hash_size_cumsum, + indices_1, + offsets_1, + None, + -1, + ) + unique_indices, inverse_indices = torch.unique( + linear_cache_indices, return_inverse=True, sorted=True + ) + unique_weights = torch.randn( + unique_indices.numel(), + emb.cache_row_dim, + device=emb.current_device, + dtype=emb.weights_precision.as_dtype(), + ) + custom_weights = unique_weights[inverse_indices] + + emb.direct_write_embedding( + indices=indices_1, + offsets=offsets_1, + weights=custom_weights, + ) + torch.cuda.synchronize() + + # Assert: Verify that direct_write_embedding properly synchronizes + # The main goal of this test is to ensure that direct_write waits for + # the ssd_event_get (backend fetch completion) before writing to scratch pads. + # This prevents race conditions between async backend reads and direct writes. + + scratch_pads_after = len(emb.ssd_scratch_pads) + eviction_data_after = len(emb.ssd_scratch_pad_eviction_data) + + if prefetch_pipeline: + # Note: Queues may not be in sync before direct_write because the forward pass + # consumes eviction_data but scratch_pads are only popped by the backward hook. + # This is expected behavior in embedding_cache_mode. + + # Verify direct_write processed the queues correctly + # When prefetch_pipeline is enabled, direct_write calls _update_cache_counter_and_pointers + # which pops from both queues if there's eviction data + if eviction_data_before > 0: + # Eviction data should have been popped by 1 + self.assertEqual( + eviction_data_after, + eviction_data_before - 1, + f"direct_write should pop from eviction_data queue " + f"(before={eviction_data_before}, after={eviction_data_after})", + ) + + # After direct_write processes the queues, they should be in sync IF there was + # eviction data to process. + self.assertEqual( + scratch_pads_after, + eviction_data_after, + f"Scratch pads and eviction data queues should be in sync after direct_write " + f"(scratch_pads={scratch_pads_after}, eviction_data={eviction_data_after})", + ) + + # Verify the written weights are correct (no corruption from race condition) + lxu_cache_locations = torch.ops.fbgemm.lxu_cache_lookup( + linear_cache_indices, + emb.lxu_cache_state, + emb.total_hash_size, + ) + cache_location_mask = lxu_cache_locations >= 0 + cache_locations = lxu_cache_locations[cache_location_mask] + cache_weights = custom_weights[cache_location_mask] + + if cache_locations.numel() > 0: + actual_cache_weights = emb.lxu_cache_weights[cache_locations] + torch.testing.assert_close( + actual_cache_weights, + cache_weights, + rtol=1e-2 if weights_precision == SparseType.FP16 else 1e-4, + atol=1e-2 if weights_precision == SparseType.FP16 else 1e-4, + ) + + @given( + T=st.integers(min_value=1, max_value=3), + D=st.just(16), + B=st.just(8), + log_E=st.just(9), + L=st.just(4), + weighted=st.booleans(), + cache_set_scale=st.just(0.1), + pooling_mode=st.sampled_from([PoolingMode.SUM]), + weights_precision=st.sampled_from([SparseType.FP32, SparseType.FP16]), + output_dtype=st.sampled_from([SparseType.FP32, SparseType.FP16]), + share_table=st.just(False), + trigger_bounds_check=st.just(False), + mixed_B=st.just(False), + num_buckets=st.just(1), + backend_type=st.sampled_from([BackendType.DRAM]), + enable_optimizer_offloading=st.just(True), + prefetch_pipeline=st.just(True), + ) + @settings(verbosity=Verbosity.verbose, max_examples=MAX_EXAMPLES, deadline=None) + def test_direct_write_clears_location_update_queue( + self, + T: int, + D: int, + B: int, + log_E: int, + L: int, + weighted: bool, + cache_set_scale: float, + pooling_mode: PoolingMode, + weights_precision: SparseType, + output_dtype: SparseType, + share_table: bool, + trigger_bounds_check: bool, + mixed_B: bool, + num_buckets: int, + backend_type: BackendType, + enable_optimizer_offloading: bool, + prefetch_pipeline: bool, + ) -> None: + """ + Test Gap B fix: Verify that direct_write_embedding properly clears + the ssd_location_update_data queue to prevent stale mappings. + + This test ensures that when direct_write modifies a scratch pad, + any pending location update data referencing that scratch pad is + cleared to prevent cache pointer corruption. + """ + + # Constants + lr = 0.5 + eps = 0.2 + ssd_shards = 2 + + ( + emb, + emb_ref, + Es, + _, + bucket_offsets, + bucket_sizes, + ) = self.generate_kvzch_tbes( + T, + D, + B, + log_E, + L, + weighted, + lr=lr, + eps=eps, + ssd_shards=ssd_shards, + cache_set_scale=cache_set_scale, + pooling_mode=pooling_mode, + weights_precision=weights_precision, + output_dtype=output_dtype, + share_table=share_table, + prefetch_pipeline=prefetch_pipeline, + backend_type=backend_type, + num_buckets=num_buckets, + enable_optimizer_offloading=enable_optimizer_offloading, + embedding_cache_mode=True, + ) + + # Generate inputs + ( + indices_list, + per_sample_weights_list, + indices, + offsets, + per_sample_weights, + batch_size_per_feature_per_rank, + ) = self.generate_inputs_( + B, + L, + Es, + emb.feature_table_map, + weights_precision=weights_precision, + trigger_bounds_check=trigger_bounds_check, + mixed_B=mixed_B, + bucket_offsets=bucket_offsets, + bucket_sizes=bucket_sizes, + is_kv_tbes=True, + ) + + # Run prefetch + forward to populate queues + emb.prefetch( + indices, + offsets, + batch_size_per_feature_per_rank=batch_size_per_feature_per_rank, + ) + torch.cuda.synchronize() + + self.execute_ssd_forward_( + emb, + emb_ref, + indices_list, + per_sample_weights_list, + indices, + offsets, + per_sample_weights, + B, + L, + weighted, + batch_size_per_feature_per_rank=batch_size_per_feature_per_rank, + ) + torch.cuda.synchronize() + + # Run another prefetch to create location update data + ( + _, + _, + indices_2, + offsets_2, + _, + batch_size_per_feature_per_rank_2, + ) = self.generate_inputs_( + B, + L, + Es, + emb.feature_table_map, + weights_precision=weights_precision, + trigger_bounds_check=trigger_bounds_check, + mixed_B=mixed_B, + bucket_offsets=bucket_offsets, + bucket_sizes=bucket_sizes, + is_kv_tbes=True, + ) + + emb.prefetch( + indices_2, + offsets_2, + batch_size_per_feature_per_rank=batch_size_per_feature_per_rank_2, + ) + torch.cuda.synchronize() + + # Check location update queue length before direct_write + location_update_queue_before = ( + len(emb.ssd_location_update_data) if prefetch_pipeline else 0 + ) + + # Prepare weights for direct_write + linear_cache_indices = torch.ops.fbgemm.linearize_cache_indices( + emb.hash_size_cumsum, + indices, + offsets, + None, + -1, + ) + unique_indices, inverse_indices = torch.unique( + linear_cache_indices, return_inverse=True, sorted=True + ) + unique_weights = torch.randn( + unique_indices.numel(), + emb.cache_row_dim, + device=emb.current_device, + dtype=emb.weights_precision.as_dtype(), + ) + custom_weights = unique_weights[inverse_indices] + + # Call direct_write_embedding + emb.direct_write_embedding( + indices=indices, + offsets=offsets, + weights=custom_weights, + ) + torch.cuda.synchronize() + + # Assert: Verify location update queue was properly cleared + location_update_queue_after = ( + len(emb.ssd_location_update_data) if prefetch_pipeline else 0 + ) + + if prefetch_pipeline and location_update_queue_before > 0: + # After direct_write, one entry should be popped from location_update_data + self.assertEqual( + location_update_queue_after, + location_update_queue_before - 1, + "direct_write should clear stale location update data", + ) + + # Verify no data corruption occurred + lxu_cache_locations = torch.ops.fbgemm.lxu_cache_lookup( + linear_cache_indices, + emb.lxu_cache_state, + emb.total_hash_size, + ) + cache_location_mask = lxu_cache_locations >= 0 + cache_locations = lxu_cache_locations[cache_location_mask] + cache_weights = custom_weights[cache_location_mask] + + if cache_locations.numel() > 0: + actual_cache_weights = emb.lxu_cache_weights[cache_locations] + torch.testing.assert_close( + actual_cache_weights, + cache_weights, + rtol=1e-2 if weights_precision == SparseType.FP16 else 1e-4, + atol=1e-2 if weights_precision == SparseType.FP16 else 1e-4, + )