Skip to content

Commit 99bc254

Browse files
emlinfacebook-github-bot
authored andcommitted
support prefetch pipeline (#5032)
Summary: X-link: facebookresearch/FBGEMM#2045 Fix the direct_write for prefetch pipeline Reviewed By: kausv, steven1327 Differential Revision: D85021220
1 parent 51210b8 commit 99bc254

File tree

2 files changed

+472
-7
lines changed

2 files changed

+472
-7
lines changed

fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1984,12 +1984,13 @@ def _prefetch( # noqa C901
19841984
# Store info for evicting the previous iteration's
19851985
# scratch pad after the corresponding backward pass is
19861986
# done
1987-
self.ssd_location_update_data.append(
1988-
(
1989-
sp_curr_prev_map_gpu,
1990-
inserted_rows,
1987+
if self.training:
1988+
self.ssd_location_update_data.append(
1989+
(
1990+
sp_curr_prev_map_gpu,
1991+
inserted_rows,
1992+
)
19911993
)
1992-
)
19931994

19941995
# Ensure the previous iterations eviction is complete
19951996
current_stream.wait_event(self.ssd_event_sp_evict)
@@ -2173,7 +2174,7 @@ def _prefetch( # noqa C901
21732174

21742175
# Store scratch pad info for post backward eviction only for training
21752176
# for eval job, no backward pass, so no need to store this info
2176-
if self.training and not self._embedding_cache_mode:
2177+
if self.training:
21772178
self.ssd_scratch_pad_eviction_data.append(
21782179
(
21792180
inserted_rows,
@@ -4548,6 +4549,12 @@ def direct_write_embedding(
45484549
if len(self.ssd_scratch_pad_eviction_data) > 0:
45494550
self.ssd_scratch_pad_eviction_data.pop(0)
45504551
if len(self.ssd_scratch_pad_eviction_data) > 0:
4552+
# Wait for any pending backend reads to the next scratch pad
4553+
# to complete before we write to it. Otherwise, stale backend data
4554+
# will overwrite our direct_write updates.
4555+
# The ssd_event_get marks completion of backend fetch operations.
4556+
current_stream.wait_event(self.ssd_event_get)
4557+
45514558
# if scratch pad exists, write to next batch scratch pad
45524559
sp = self.ssd_scratch_pad_eviction_data[0][0]
45534560
sp_idx = self.ssd_scratch_pad_eviction_data[0][1].to(

0 commit comments

Comments
 (0)