diff --git a/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py b/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py index a497cf9a5b..32fb3991f7 100644 --- a/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py +++ b/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py @@ -2089,7 +2089,7 @@ def _prefetch( # noqa C901 torch.tensor( [weights.shape[0]], device="cpu", dtype=torch.long ), - weights.cpu().view(torch.float32).view(-1, 2), + weights.cpu(), ) # Generate row addresses (pointing to either L1 or the current diff --git a/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_cache.h b/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_cache.h index 9738b846cc..98f3a44e35 100644 --- a/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_cache.h +++ b/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_cache.h @@ -770,7 +770,6 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB { CHECK_EQ(indices.size(0), engege_rates.size(0)); auto indices_data_ptr = indices.data_ptr(); auto engage_rate_ptr = engege_rates.data_ptr(); - int64_t stride = 2; { auto before_write_lock_ts = facebook::WallClockUtil::NowInUsecFast(); @@ -785,8 +784,7 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB { index_iter++) { const auto& id_index = *index_iter; auto id = int64_t(indices_data_ptr[id_index]); - float engege_rate = - float(engage_rate_ptr[id_index * stride + 0]); + float engege_rate = float(engage_rate_ptr[id_index]); // use mempool weight_type* block = nullptr; auto before_lookup_cache_ts =