diff --git a/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_table_batched_embeddings.h b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_table_batched_embeddings.h index 47e1d29893..c440eec69b 100644 --- a/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_table_batched_embeddings.h +++ b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_table_batched_embeddings.h @@ -121,7 +121,9 @@ class EmbeddingRocksDB : public kv_db::EmbeddingKVDB { std::optional table_dims = std::nullopt, std::optional hash_size_cumsum = std::nullopt, int64_t flushing_block_size = 2000000000 /*2GB*/, - bool disable_random_init = false) + bool disable_random_init = false, + bool enable_optimizer_offloading = false, + int64_t optimizer_D = 0) : kv_db::EmbeddingKVDB( num_shards, max_D, @@ -266,7 +268,9 @@ class EmbeddingRocksDB : public kv_db::EmbeddingKVDB { uniform_init_lower, uniform_init_upper, row_storage_bitwidth, - disable_random_init); + disable_random_init, + enable_optimizer_offloading, + optimizer_D); executor_ = std::make_unique(num_shards); ro_.verify_checksums = false; ro_.async_io = true; @@ -421,19 +425,29 @@ class EmbeddingRocksDB : public kv_db::EmbeddingKVDB { float uniform_init_lower, float uniform_init_upper, int64_t row_storage_bitwidth, - bool disable_random_init) { + bool disable_random_init, + bool enable_optimizer_offloading = false, + int64_t optimizer_D = 0) { for (auto i = 0; i < num_shards; ++i) { auto* gen = at::check_generator( at::detail::getDefaultCPUGenerator()); { std::lock_guard lock(gen->mutex_); - initializers_.push_back( - std::make_unique( - gen->random64(), - max_D, - uniform_init_lower, - uniform_init_upper, - row_storage_bitwidth)); + auto initializer = std::make_unique( + gen->random64(), + max_D, + uniform_init_lower, + uniform_init_upper, + row_storage_bitwidth); + + // When Optimizer offloading is enabled, we want to initialize the last + // optimizer_D columns(optimizer values) to zero + if (enable_optimizer_offloading) { + auto& tensor = initializer->row_storage_; + tensor.index({"...", at::indexing::Slice(max_D - optimizer_D, max_D)}) + .zero_(); + } + initializers_.push_back(std::move(initializer)); } } disable_random_init_ = disable_random_init;