Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/llama-memory-recurrent.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -861,9 +861,12 @@ void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std::
bool llama_memory_recurrent::state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id) {
if (dest_seq_id != -1) {
// single sequence

seq_rm(dest_seq_id, -1, -1);

if (cell_count == 0) {
return true;
}

llama_batch_allocr balloc(hparams.n_pos_per_embd());

llama_ubatch ubatch = balloc.ubatch_reserve(cell_count, 1);
Expand Down
86 changes: 52 additions & 34 deletions tools/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3676,6 +3676,20 @@ struct server_context {
alora_disabled_id = enabled_loras[0];
}

bool do_checkpoint = params_base.n_ctx_checkpoints > 0;

// make a checkpoint of the parts of the memory that cannot be rolled back.
// checkpoints are created only if:
// - the model uses SWA and we are not using `swa_full`
// - the model architecture is marked as recurrent or hybrid
//
// TODO: try to make this conditional on the context or the memory module, instead of the model type
do_checkpoint = do_checkpoint && (
llama_model_is_recurrent(model) ||
llama_model_is_hybrid(model) ||
(llama_model_n_swa(model) > 0 && !params_base.swa_full)
);

// add prompt tokens for processing in the current batch
while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) {
// get next token to process
Expand All @@ -3700,6 +3714,11 @@ struct server_context {

slot.n_prompt_tokens_processed++;
slot.n_past++;

// process the last few tokens of the prompt separately in order to allow for a checkpoint to be created.
if (do_checkpoint && slot.n_prompt_tokens - slot.n_past == 64) {
break;
}
}

// SLT_INF(slot, "new cache_tokens: %s\n", slot.cache_tokens.str().c_str());
Expand Down Expand Up @@ -3730,6 +3749,39 @@ struct server_context {
slot.i_batch = batch.n_tokens - 1;

SLT_INF(slot, "prompt done, n_past = %d, n_tokens = %d\n", slot.n_past, batch.n_tokens);

const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id);
const auto pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx), slot.id);

// no need for empty or small checkpoints
do_checkpoint = do_checkpoint && (pos_min >= 0 && pos_max >= 64);

// no need to create checkpoints that are too close together
do_checkpoint = do_checkpoint && (slot.ctx_checkpoints.empty() || pos_max > slot.ctx_checkpoints.back().pos_max + 64);

if (do_checkpoint) {
while (slot.ctx_checkpoints.size() >= (size_t) params_base.n_ctx_checkpoints) {
// make room for the new checkpoint, if needed
const auto & cur = slot.ctx_checkpoints.front();
SLT_WRN(slot, "erasing old context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n",
cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024);

slot.ctx_checkpoints.erase(slot.ctx_checkpoints.begin());
}

const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);

auto & cur = slot.ctx_checkpoints.emplace_back(ctx_checkpoint{
/*.pos_min = */ pos_min,
/*.pos_max = */ pos_max,
/*.data = */ std::vector<uint8_t>(checkpoint_size),
});

llama_state_seq_get_data_ext(ctx, cur.data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);

SLT_WRN(slot, "saved context checkpoint %d of %d (pos_min = %d, pos_max = %d, size = %.3f MiB)\n",
(int) slot.ctx_checkpoints.size(), params_base.n_ctx_checkpoints, cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024);
}
}
}

Expand Down Expand Up @@ -3853,40 +3905,6 @@ struct server_context {

// prompt evaluated for next-token prediction
slot.state = SLOT_STATE_GENERATING;

// make a checkpoint of the parts of the memory that cannot be rolled back.
// checkpoints are created only if:
// - the model uses SWA and we are not using `swa_full`
// - the model architecture is marked as recurrent or hybrid
//
// TODO: try to make this conditional on the context or the memory module, instead of the model type
const bool do_checkpoint =
(llama_model_is_recurrent(model) || llama_model_is_hybrid(model)) ||
(llama_model_n_swa(model) > 0 && !params_base.swa_full);

if (do_checkpoint && params_base.n_ctx_checkpoints > 0) {
while (slot.ctx_checkpoints.size() >= (size_t) params_base.n_ctx_checkpoints) {
// make room for the new checkpoint, if needed
const auto & cur = slot.ctx_checkpoints.front();
SLT_WRN(slot, "erasing old context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n",
cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024);

slot.ctx_checkpoints.erase(slot.ctx_checkpoints.begin());
}

const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);

auto & cur = slot.ctx_checkpoints.emplace_back(ctx_checkpoint{
/*.pos_min = */ llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id),
/*.pos_max = */ llama_memory_seq_pos_max(llama_get_memory(ctx), slot.id),
/*.data = */ std::vector<uint8_t>(checkpoint_size),
});

llama_state_seq_get_data_ext(ctx, cur.data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);

SLT_WRN(slot, "saved context checkpoint %d of %d (pos_min = %d, pos_max = %d, size = %.3f MiB)\n",
(int) slot.ctx_checkpoints.size(), params_base.n_ctx_checkpoints, cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024);
}
} else if (slot.state != SLOT_STATE_GENERATING) {
continue; // continue loop of slots
}
Expand Down
Loading