Skip to content

Commit f3cd0ac

Browse files
emlinfacebook-github-bot
authored andcommitted
support prefetch pipeline
Differential Revision: D85021220
1 parent fcd6ab4 commit f3cd0ac

File tree

2 files changed

+450
-0
lines changed

2 files changed

+450
-0
lines changed

fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4547,7 +4547,21 @@ def direct_write_embedding(
45474547
):
45484548
if len(self.ssd_scratch_pad_eviction_data) > 0:
45494549
self.ssd_scratch_pad_eviction_data.pop(0)
4550+
if self.prefetch_pipeline:
4551+
if len(self.ssd_scratch_pads) > 0:
4552+
self.ssd_scratch_pads.pop(0) # Keep in sync
4553+
4554+
# Clear location update data since there is no backward flow for embedding cache.
4555+
if len(self.ssd_location_update_data) > 0:
4556+
self.ssd_location_update_data.pop(0)
4557+
45504558
if len(self.ssd_scratch_pad_eviction_data) > 0:
4559+
# Wait for any pending backend reads to the next scratch pad
4560+
# to complete before we write to it. Otherwise, stale backend data
4561+
# will overwrite our direct_write updates.
4562+
# The ssd_event_get marks completion of backend fetch operations.
4563+
current_stream.wait_event(self.ssd_event_get)
4564+
45514565
# if scratch pad exists, write to next batch scratch pad
45524566
sp = self.ssd_scratch_pad_eviction_data[0][0]
45534567
sp_idx = self.ssd_scratch_pad_eviction_data[0][1].to(

0 commit comments

Comments
 (0)