Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ def __init__(
table_names: Optional[list[str]] = None,
use_rowwise_bias_correction: bool = False, # For Adam use
optimizer_state_dtypes: dict[str, SparseType] = {}, # noqa: B006
enable_optimizer_offloading: bool = False,
) -> None:
super(SSDTableBatchedEmbeddingBags, self).__init__()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,9 @@ class EmbeddingRocksDBWrapper : public torch::jit::CustomClassHolder {
std::optional<at::Tensor> table_dims = std::nullopt,
std::optional<at::Tensor> hash_size_cumsum = std::nullopt,
int64_t flushing_block_size = 2000000000 /*2GB*/,
bool disable_random_init = false)
bool disable_random_init = false,
std::optional<bool> enable_optimizer_offloading = std::nullopt,
std::optional<int64_t> optimizer_D = std::nullopt)
: impl_(
std::make_shared<ssd::EmbeddingRocksDB>(
path,
Expand Down Expand Up @@ -77,7 +79,9 @@ class EmbeddingRocksDBWrapper : public torch::jit::CustomClassHolder {
table_dims,
hash_size_cumsum,
flushing_block_size,
disable_random_init)) {}
disable_random_init,
enable_optimizer_offloading,
optimizer_D)) {}

void set_cuda(
at::Tensor indices,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -809,7 +809,9 @@ static auto embedding_rocks_db_wrapper =
std::optional<at::Tensor>,
std::optional<at::Tensor>,
int64_t,
bool>(),
bool,
std::optional<bool>,
std::optional<int64_t>>(),
"",
{
torch::arg("path"),
Expand Down Expand Up @@ -842,6 +844,8 @@ static auto embedding_rocks_db_wrapper =
torch::arg("hash_size_cumsum") = std::nullopt,
torch::arg("flushing_block_size") = 2000000000 /* 2GB */,
torch::arg("disable_random_init") = false,
torch::arg("enable_optimizer_offloading") = std::nullopt,
torch::arg("optimizer_D") = std::nullopt,
})
.def(
"set_cuda",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,9 @@ class EmbeddingRocksDB : public kv_db::EmbeddingKVDB {
std::optional<at::Tensor> table_dims = std::nullopt,
std::optional<at::Tensor> hash_size_cumsum = std::nullopt,
int64_t flushing_block_size = 2000000000 /*2GB*/,
bool disable_random_init = false)
bool disable_random_init = false,
std::optional<bool> enable_optimizer_offloading = std::nullopt,
std::optional<int64_t> optimizer_D = std::nullopt)
: kv_db::EmbeddingKVDB(
num_shards,
max_D,
Expand Down Expand Up @@ -266,7 +268,9 @@ class EmbeddingRocksDB : public kv_db::EmbeddingKVDB {
uniform_init_lower,
uniform_init_upper,
row_storage_bitwidth,
disable_random_init);
disable_random_init,
enable_optimizer_offloading,
optimizer_D);
executor_ = std::make_unique<folly::CPUThreadPoolExecutor>(num_shards);
ro_.verify_checksums = false;
ro_.async_io = true;
Expand Down Expand Up @@ -421,19 +425,33 @@ class EmbeddingRocksDB : public kv_db::EmbeddingKVDB {
float uniform_init_lower,
float uniform_init_upper,
int64_t row_storage_bitwidth,
bool disable_random_init) {
bool disable_random_init,
std::optional<bool> enable_optimizer_offloading = std::nullopt,
std::optional<int64_t> optimizer_D = std::nullopt) {
for (auto i = 0; i < num_shards; ++i) {
auto* gen = at::check_generator<at::CPUGeneratorImpl>(
at::detail::getDefaultCPUGenerator());
{
std::lock_guard<std::mutex> lock(gen->mutex_);
initializers_.push_back(
std::make_unique<Initializer>(
gen->random64(),
max_D,
uniform_init_lower,
uniform_init_upper,
row_storage_bitwidth));
auto initializer = std::make_unique<Initializer>(
gen->random64(),
max_D,
uniform_init_lower,
uniform_init_upper,
row_storage_bitwidth);

// When Optimizer offloading is enabled, we want to initialize the last
// optimizer_D columns(optimizer values) to zero
if (enable_optimizer_offloading.has_value() &&
enable_optimizer_offloading.value() && optimizer_D.has_value()) {
auto& tensor = initializer->row_storage_;
tensor
.index(
{"...",
at::indexing::Slice(max_D - optimizer_D.value(), max_D)})
.zero_();
}
initializers_.push_back(std::move(initializer));
}
}
disable_random_init_ = disable_random_init;
Expand Down Expand Up @@ -1364,6 +1382,7 @@ class EmbeddingRocksDB : public kv_db::EmbeddingKVDB {
std::vector<std::string> db_paths_;

bool disable_random_init_;
std::optional<bool> enable_optimizer_offloading = std::nullopt;
}; // class EmbeddingRocksDB

/// @ingroup embedding-ssd
Expand Down
Loading