diff --git a/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py b/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py index a497cf9a5b..031d0eec56 100644 --- a/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py +++ b/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py @@ -179,6 +179,7 @@ def __init__( table_names: Optional[list[str]] = None, use_rowwise_bias_correction: bool = False, # For Adam use optimizer_state_dtypes: dict[str, SparseType] = {}, # noqa: B006 + enable_optimizer_offloading: bool = False, ) -> None: super(SSDTableBatchedEmbeddingBags, self).__init__() diff --git a/fbgemm_gpu/src/ssd_split_embeddings_cache/embedding_rocksdb_wrapper.h b/fbgemm_gpu/src/ssd_split_embeddings_cache/embedding_rocksdb_wrapper.h index 8cebdef1eb..82ce53445d 100644 --- a/fbgemm_gpu/src/ssd_split_embeddings_cache/embedding_rocksdb_wrapper.h +++ b/fbgemm_gpu/src/ssd_split_embeddings_cache/embedding_rocksdb_wrapper.h @@ -45,7 +45,9 @@ class EmbeddingRocksDBWrapper : public torch::jit::CustomClassHolder { std::optional table_dims = std::nullopt, std::optional hash_size_cumsum = std::nullopt, int64_t flushing_block_size = 2000000000 /*2GB*/, - bool disable_random_init = false) + bool disable_random_init = false, + std::optional enable_optimizer_offloading = std::nullopt, + std::optional optimizer_D = std::nullopt) : impl_( std::make_shared( path, @@ -77,7 +79,9 @@ class EmbeddingRocksDBWrapper : public torch::jit::CustomClassHolder { table_dims, hash_size_cumsum, flushing_block_size, - disable_random_init)) {} + disable_random_init, + enable_optimizer_offloading, + optimizer_D)) {} void set_cuda( at::Tensor indices, diff --git a/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp index 0b95285a8f..fab6ad9253 100644 --- a/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp +++ b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp @@ -809,7 +809,9 @@ static auto embedding_rocks_db_wrapper = std::optional, std::optional, int64_t, - bool>(), + bool, + std::optional, + std::optional>(), "", { torch::arg("path"), @@ -842,6 +844,8 @@ static auto embedding_rocks_db_wrapper = torch::arg("hash_size_cumsum") = std::nullopt, torch::arg("flushing_block_size") = 2000000000 /* 2GB */, torch::arg("disable_random_init") = false, + torch::arg("enable_optimizer_offloading") = std::nullopt, + torch::arg("optimizer_D") = std::nullopt, }) .def( "set_cuda", diff --git a/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_table_batched_embeddings.h b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_table_batched_embeddings.h index 47e1d29893..d95421dce3 100644 --- a/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_table_batched_embeddings.h +++ b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_table_batched_embeddings.h @@ -121,7 +121,9 @@ class EmbeddingRocksDB : public kv_db::EmbeddingKVDB { std::optional table_dims = std::nullopt, std::optional hash_size_cumsum = std::nullopt, int64_t flushing_block_size = 2000000000 /*2GB*/, - bool disable_random_init = false) + bool disable_random_init = false, + std::optional enable_optimizer_offloading = std::nullopt, + std::optional optimizer_D = std::nullopt) : kv_db::EmbeddingKVDB( num_shards, max_D, @@ -266,7 +268,9 @@ class EmbeddingRocksDB : public kv_db::EmbeddingKVDB { uniform_init_lower, uniform_init_upper, row_storage_bitwidth, - disable_random_init); + disable_random_init, + enable_optimizer_offloading, + optimizer_D); executor_ = std::make_unique(num_shards); ro_.verify_checksums = false; ro_.async_io = true; @@ -421,19 +425,33 @@ class EmbeddingRocksDB : public kv_db::EmbeddingKVDB { float uniform_init_lower, float uniform_init_upper, int64_t row_storage_bitwidth, - bool disable_random_init) { + bool disable_random_init, + std::optional enable_optimizer_offloading = std::nullopt, + std::optional optimizer_D = std::nullopt) { for (auto i = 0; i < num_shards; ++i) { auto* gen = at::check_generator( at::detail::getDefaultCPUGenerator()); { std::lock_guard lock(gen->mutex_); - initializers_.push_back( - std::make_unique( - gen->random64(), - max_D, - uniform_init_lower, - uniform_init_upper, - row_storage_bitwidth)); + auto initializer = std::make_unique( + gen->random64(), + max_D, + uniform_init_lower, + uniform_init_upper, + row_storage_bitwidth); + + // When Optimizer offloading is enabled, we want to initialize the last + // optimizer_D columns(optimizer values) to zero + if (enable_optimizer_offloading.has_value() && + enable_optimizer_offloading.value() && optimizer_D.has_value()) { + auto& tensor = initializer->row_storage_; + tensor + .index( + {"...", + at::indexing::Slice(max_D - optimizer_D.value(), max_D)}) + .zero_(); + } + initializers_.push_back(std::move(initializer)); } } disable_random_init_ = disable_random_init; @@ -1364,6 +1382,7 @@ class EmbeddingRocksDB : public kv_db::EmbeddingKVDB { std::vector db_paths_; bool disable_random_init_; + std::optional enable_optimizer_offloading = std::nullopt; }; // class EmbeddingRocksDB /// @ingroup embedding-ssd