Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,9 @@ class EmbeddingRocksDB : public kv_db::EmbeddingKVDB {
std::optional<at::Tensor> table_dims = std::nullopt,
std::optional<at::Tensor> hash_size_cumsum = std::nullopt,
int64_t flushing_block_size = 2000000000 /*2GB*/,
bool disable_random_init = false)
bool disable_random_init = false,
bool enable_optimizer_offloading = false,
int64_t optimizer_D = 0)
: kv_db::EmbeddingKVDB(
num_shards,
max_D,
Expand Down Expand Up @@ -266,7 +268,9 @@ class EmbeddingRocksDB : public kv_db::EmbeddingKVDB {
uniform_init_lower,
uniform_init_upper,
row_storage_bitwidth,
disable_random_init);
disable_random_init,
enable_optimizer_offloading,
optimizer_D);
executor_ = std::make_unique<folly::CPUThreadPoolExecutor>(num_shards);
ro_.verify_checksums = false;
ro_.async_io = true;
Expand Down Expand Up @@ -421,19 +425,29 @@ class EmbeddingRocksDB : public kv_db::EmbeddingKVDB {
float uniform_init_lower,
float uniform_init_upper,
int64_t row_storage_bitwidth,
bool disable_random_init) {
bool disable_random_init,
bool enable_optimizer_offloading = false,
int64_t optimizer_D = 0) {
for (auto i = 0; i < num_shards; ++i) {
auto* gen = at::check_generator<at::CPUGeneratorImpl>(
at::detail::getDefaultCPUGenerator());
{
std::lock_guard<std::mutex> lock(gen->mutex_);
initializers_.push_back(
std::make_unique<Initializer>(
gen->random64(),
max_D,
uniform_init_lower,
uniform_init_upper,
row_storage_bitwidth));
auto initializer = std::make_unique<Initializer>(
gen->random64(),
max_D,
uniform_init_lower,
uniform_init_upper,
row_storage_bitwidth);

// When Optimizer offloading is enabled, we want to initialize the last
// optimizer_D columns(optimizer values) to zero
if (enable_optimizer_offloading) {
auto& tensor = initializer->row_storage_;
tensor.index({"...", at::indexing::Slice(max_D - optimizer_D, max_D)})
.zero_();
}
initializers_.push_back(std::move(initializer));
}
}
disable_random_init_ = disable_random_init;
Expand Down
Loading