diff --git a/backends/llamacpp/src/main.rs b/backends/llamacpp/src/main.rs index 9ee61ce6e2e..c193c8689b2 100644 --- a/backends/llamacpp/src/main.rs +++ b/backends/llamacpp/src/main.rs @@ -119,7 +119,7 @@ struct Args { #[clap(default_value = "3000", long, short, env)] port: u16, - #[clap(default_value = "9000", long, short, env)] + #[clap(default_value = "9000", long, env)] prometheus_port: u16, /// Enable JSON output format. diff --git a/backends/trtllm/csrc/backend.cpp b/backends/trtllm/csrc/backend.cpp index 2151466be6e..de975ad1d6e 100644 --- a/backends/trtllm/csrc/backend.cpp +++ b/backends/trtllm/csrc/backend.cpp @@ -26,7 +26,7 @@ namespace huggingface::tgi::backends::trtllm { } - tle::ExecutorConfig backend_workspace_t::executor_config() const { + tle::ExecutorConfig backend_workspace_t::executor_config(const std::vector& encoded_vocab, std::string_view tokenizer_str) const { // Retrieve the compute capabilities to enable some options at runtime const auto compute_capabilities = hardware::cuda::compute_capabilities_t(); @@ -40,32 +40,50 @@ namespace huggingface::tgi::backends::trtllm { executor_config.setKvCacheConfig(tle::KvCacheConfig(true)); executor_config.setEnableChunkedContext(compute_capabilities.is_at_least_ampere()); executor_config.setSchedulerConfig(tle::SchedulerConfig(tle::CapacitySchedulerPolicy::kMAX_UTILIZATION)); + executor_config.setGuidedDecodingConfig(tle::GuidedDecodingConfig( + tle::GuidedDecodingConfig::GuidedDecodingBackend::kXGRAMMAR, + encoded_vocab, + std::string(tokenizer_str), + generation_config().eos_token_ids + )); return executor_config; } - backend_t::backend_t(std::filesystem::path &engines_folder, std::filesystem::path &executor_worker_path) - : workspace(engines_folder, executor_worker_path), executor_(executor_factory_initializer(workspace)) {} - - size_t backend_t::num_tokens_ready() const noexcept { - return executor_.getNumResponsesReady(); - } + backend_t::backend_t(std::filesystem::path &engines_folder, std::filesystem::path &executor_worker_path, const std::vector &encoded_vocab, std::string_view tokenizer_str) + : workspace(engines_folder, executor_worker_path), + executor_(executor_factory_initializer(workspace, encoded_vocab, tokenizer_str)) {} std::expected backend_t::submit(std::span token_ids, const generation_params_t g_params, const sampling_params_t s_params) noexcept { SPDLOG_DEBUG("Submit {:d} tokens for scheduling ({}, {})", token_ids.size(), g_params, s_params); - return executor_.enqueueRequest(tle::Request{ + tle::Request req { {token_ids.begin(), token_ids.end()}, // Making actual copy of the tokens static_cast(g_params.max_new_tokens), true, (tle::SamplingConfig) s_params, - tle::OutputConfig{ /* returnLogProbs= */ true}, + tle::OutputConfig{ + /* returnLogProbs= */ true, + false, + false, + false, + false, + /* returnPerfMetrics=*/ true, + }, std::nullopt, std::nullopt, std::nullopt, std::nullopt, workspace.generation_config().stop_words - }); + }; + + if (g_params.guide_type.has_value()) { + req.setGuidedDecodingParams(tle::GuidedDecodingParams( + g_params.guide_type.value(), + g_params.guide + )); + } + return executor_.enqueueRequest(req); } std::vector backend_t::pull_tokens() noexcept { diff --git a/backends/trtllm/csrc/backend.hpp b/backends/trtllm/csrc/backend.hpp index 40b44a842b3..184bf26f9be 100644 --- a/backends/trtllm/csrc/backend.hpp +++ b/backends/trtllm/csrc/backend.hpp @@ -25,6 +25,8 @@ namespace huggingface::tgi::backends::trtllm { */ struct generation_params_t { uint32_t max_new_tokens; + std::optional guide_type; + std::string guide; }; /** @@ -66,17 +68,31 @@ namespace huggingface::tgi::backends::trtllm { float_t top_p; float_t temperature; std::list> stop_words; + std::vector eos_token_ids; constexpr explicit generation_config_t(const json &config) : - top_p(config.value("top_p", 1.0f)), temperature(config.value("temperature", 1.0f)), stop_words(0) { - if (config.contains("/eos_token_id"_json_pointer) && config["/eos_token_id"_json_pointer].is_array()) { + top_p(config.value("top_p", 1.0f)), temperature(config.value("temperature", 1.0f)), stop_words(0), eos_token_ids{} { + if (!config.contains("/eos_token_id"_json_pointer)) { + return; + } + if (config["/eos_token_id"_json_pointer].is_array()) { + SPDLOG_DEBUG("generation config eos_token_id is array"); const auto &eos_token_id = config["/eos_token_id"_json_pointer]; std::for_each(eos_token_id.begin(), eos_token_id.end(), [this](const auto token_id) { - stop_words.emplace_back(1, token_id.template get()); + const auto token = token_id.template get(); + stop_words.emplace_back(1, token); + eos_token_ids.emplace_back(token); }); + } - SPDLOG_DEBUG("Detected {:d} predefined stop_words from generation_config.json", stop_words.size()); + if (config["/eos_token_id"_json_pointer].is_number()) { + SPDLOG_DEBUG("generation config eos_token_id is number"); + const auto token = config["/eos_token_id"_json_pointer].get(); + stop_words.emplace_back(1, token); + eos_token_ids.emplace_back(token); } + + SPDLOG_DEBUG("Detected {:d} predefined stop_words from generation_config.json", stop_words.size()); } }; @@ -134,7 +150,7 @@ namespace huggingface::tgi::backends::trtllm { * to initialize `tensorrt_llm::executor::Executor` * @return `tensorrt_llm::executor::ExecutorConfig` instance */ - [[nodiscard]] tle::ExecutorConfig executor_config() const; + [[nodiscard]] tle::ExecutorConfig executor_config(const std::vector& encoded_vocab, std::string_view tokenizer_str) const; }; /** @@ -158,10 +174,10 @@ namespace huggingface::tgi::backends::trtllm { tle::Executor executor_; public: - backend_t(std::filesystem::path &engines_folder, std::filesystem::path &executor_worker_path); + backend_t(std::filesystem::path &engines_folder, std::filesystem::path &executor_worker_path, const std::vector &encoded_vocab, std::string_view tokenizer_str); - backend_t(std::filesystem::path &&engines_folder, std::filesystem::path &&executor_worker_path) - : backend_t(engines_folder, executor_worker_path) {}; + backend_t(std::filesystem::path &&engines_folder, std::filesystem::path &&executor_worker_path, const std::vector &encoded_vocab, std::string_view tokenizer_str) + : backend_t(engines_folder, executor_worker_path, encoded_vocab, tokenizer_str) {}; /** * Submit a new request to the executor @@ -175,13 +191,6 @@ namespace huggingface::tgi::backends::trtllm { submit(std::span token_ids, generation_params_t generation_params, sampling_params_t sampling_params) noexcept; - /** - * Query the number of tokens available across all in-flight generations - * @return - */ - [[nodiscard("Pulling out the number of tokens")]] - size_t num_tokens_ready() const noexcept; - /** * Pull out newly generated tokens from the executor * @return @@ -199,9 +208,9 @@ namespace huggingface::tgi::backends::trtllm { /** * Create a TensorRT-LLM executor from a workspace */ - const auto executor_factory_initializer = [](const backend_workspace_t &workspace) -> tle::Executor { + const auto executor_factory_initializer = [](const backend_workspace_t &workspace, const std::vector &encoded_vocab, std::string_view tokenizer_str) -> tle::Executor { return {workspace.engines_folder(), tensorrt_llm::executor::ModelType::kDECODER_ONLY, - workspace.executor_config()}; + workspace.executor_config(encoded_vocab, tokenizer_str)}; }; } diff --git a/backends/trtllm/csrc/ffi.hpp b/backends/trtllm/csrc/ffi.hpp index 840614bbcfe..3a16630c017 100644 --- a/backends/trtllm/csrc/ffi.hpp +++ b/backends/trtllm/csrc/ffi.hpp @@ -1,7 +1,10 @@ #ifndef TGI_BACKEND_TRTLLM_FFI #define TGI_BACKEND_TRTLLM_FFI +#include +#include #include +#include #include #include @@ -17,7 +20,7 @@ namespace rust::behavior { template static void trycatch(Try &&func, Fail &&fail) noexcept try { func(); - } catch (tensorrt_llm::common::TllmException &e) { + } catch (const std::exception &e) { fail(e.what()); } } @@ -42,22 +45,46 @@ namespace huggingface::tgi::backends::trtllm { return finish_reason_t::kEND_ID; case tle::FinishReason::kLENGTH: return finish_reason_t::kLENGTH; + case tle::FinishReason::kTIMED_OUT: + return finish_reason_t::kTIMED_OUT; + case tle::FinishReason::kCANCELLED: + return finish_reason_t::kCANCELLED; default: std::unreachable(); } } - static auto as_generation_step = [](const tle::Response &r) { + static auto as_generation_step = [](const tle::Response &r, const std::chrono::time_point created) { const auto reqId = r.getRequestId(); if (!r.hasError()) [[likely]] { const auto result = r.getResult(); - const auto logits = result.logProbs.value()[0]; + std::optional token_id = std::nullopt; + if (!result.outputTokenIds.empty() && !result.outputTokenIds[0].empty()) { + token_id = static_cast(result.outputTokenIds[0][0]); + } + + std::optional log_prob = std::nullopt; + if (result.logProbs && !result.logProbs->empty() && !result.logProbs.value()[0].empty()) { + log_prob = result.logProbs.value()[0].back(); + } + + std::optional first_scheduled_time_ns = std::nullopt; + if (result.requestPerfMetrics) { + const auto &t = result.requestPerfMetrics->timingMetrics; + const auto ns = std::chrono::duration_cast(t.firstScheduledTime - created).count(); + first_scheduled_time_ns = static_cast(ns); + } + return generation_step_t{ reqId, - static_cast(result.outputTokenIds[0][0]), - logits.back(), + token_id.value_or(0), + log_prob.value_or(0.0), + first_scheduled_time_ns.value_or(0), result.isFinal, as_finish_reason_t(result.finishReasons[0]), + token_id.has_value(), + log_prob.has_value(), + first_scheduled_time_ns.has_value(), false, std::string() }; @@ -66,8 +93,12 @@ namespace huggingface::tgi::backends::trtllm { reqId, 0, 0.0, + 0, true, finish_reason_t::kNOT_FINISHED, + false, + false, + false, true, std::move(r.getErrorMsg()) }; @@ -77,13 +108,18 @@ namespace huggingface::tgi::backends::trtllm { class tensorrt_llm_backend_t { private: - backend_t inner_; + mutable backend_t inner_; + + // m_created_time is a reference point to convert time from c++ time_point + // to rust Instant. + std::chrono::time_point m_created_time; - public: - tensorrt_llm_backend_t(std::filesystem::path &&engine_folder, std::filesystem::path &&executor_worker_path) - : inner_(engine_folder, executor_worker_path) {} - size_t num_tokens_ready() const noexcept { return inner_.num_tokens_ready(); } + public: + tensorrt_llm_backend_t(std::filesystem::path &&engine_folder, std::filesystem::path &&executor_worker_path, const std::chrono::time_point& created_time, const std::vector& encoded_vocab, std::string_view tokenizer_str) + : inner_(engine_folder, executor_worker_path, encoded_vocab, tokenizer_str), + m_created_time {created_time} + {} request_id_t submit( rust::Slice tokens, @@ -93,16 +129,31 @@ namespace huggingface::tgi::backends::trtllm { float_t temperature, float_t repetition_penalty, float_t frequency_penalty, - uint64_t seed - ) { + uint64_t seed, + grammar_type_t grammar_type, + rust::Str grammar_value + ) const { // This is enabled only if using add_compile_definitions(SPDLOG_ACTIVE_LEVEL=SPDLOG_LEVEL_TRACE) SPDLOG_TRACE(FMT_STRING("[FFI] Submitting {:d} prompt tokens to the executor")); // Submit the request to the executor and get back a potential request_id used to track request status const auto signed_tokens = std::vector(tokens.begin(), tokens.end()); + + std::optional guide_type = std::nullopt; + switch (grammar_type) { + case grammar_type_t::kJSON: + guide_type = tle::GuidedDecodingParams::GuideType::kJSON_SCHEMA; + break; + case grammar_type_t::kREGEX: + guide_type = tle::GuidedDecodingParams::GuideType::kREGEX; + break; + default: + break; + } + const auto maybe_request_id = inner_.submit( signed_tokens, - {max_new_tokens}, + {max_new_tokens, guide_type, std::string(grammar_value)}, {top_k, top_p, repetition_penalty, frequency_penalty, temperature, seed} ); @@ -115,28 +166,26 @@ namespace huggingface::tgi::backends::trtllm { } } - std::unique_ptr> pull_tokens() noexcept { - if (num_tokens_ready() > 0) [[likely]] { - const auto responses = inner_.pull_tokens(); + std::unique_ptr> pull_tokens() const noexcept { + const auto responses = inner_.pull_tokens(); - SPDLOG_TRACE("[FFI] Successfully pulled out {:d} responses from executor", responses.size()); + SPDLOG_TRACE("[FFI] Successfully pulled out {:d} responses from executor", responses.size()); - // Transform tle::Response to generation_step_t + auto f = [this](const tle::Response &r){ + return as_generation_step(r, m_created_time); + }; + auto steps = std::make_unique>(); + // Transform tle::Response to generation_step_t #ifdef __cpp_lib_ranges_to_container - auto steps = responses | std::views::transform(as_generation_step) | std::ranges::to(); + *steps = responses | std::views::transform(f) | std::ranges::to(); #else - auto steps = std::vector(); - steps.reserve(responses.size()); - std::transform(responses.begin(), responses.end(), std::back_inserter(steps), as_generation_step); + steps->reserve(responses.size()); + std::transform(responses.begin(), responses.end(), std::back_inserter(steps), f); #endif - return std::make_unique>(steps); - - } else { - return std::make_unique>(); - } + return steps; } - void cancel(request_id_t request_id) noexcept { + void cancel(request_id_t request_id) const noexcept { SPDLOG_DEBUG("[FFI] cancelling request {:d}", request_id); inner_.cancel(request_id); } @@ -178,13 +227,25 @@ namespace huggingface::tgi::backends::trtllm { } std::unique_ptr - create_backend_from_engine_folder(const rust::Str engines_folder, const rust::Str executor_worker_path) { + create_backend_from_engine_folder(const rust::Str engines_folder, const rust::Str executor_worker_path, const rust::Str tokenizer_str, const rust::Vec encoded_vocab) { + const auto created_time = std::chrono::steady_clock::now(); std::call_once(backend_initialized_flag, initialize_tensorrt_llm_backend); + + std::vector encoded_vocab_std{}; + encoded_vocab_std.reserve(encoded_vocab.size()); + + for (const auto& v : encoded_vocab) { + encoded_vocab_std.push_back(std::string(v)); + } + return std::make_unique( std::filesystem::path(std::string_view(engines_folder.begin(), engines_folder.end()), std::filesystem::path::format::auto_format), std::filesystem::path(std::string_view(executor_worker_path.begin(), executor_worker_path.end()), - std::filesystem::path::format::auto_format) + std::filesystem::path::format::auto_format), + created_time, + encoded_vocab_std, + std::string_view(tokenizer_str) ); } } diff --git a/backends/trtllm/src/errors.rs b/backends/trtllm/src/errors.rs index 812fd6e30d8..3e6bd7430b7 100644 --- a/backends/trtllm/src/errors.rs +++ b/backends/trtllm/src/errors.rs @@ -19,4 +19,8 @@ pub enum TensorRtLlmBackendError { WebServer(#[from] server::WebServerError), #[error("Tokio runtime failed to start: {0}")] Tokio(#[from] std::io::Error), + #[error("config.json doesn't exist in engine folder {0}")] + ConfigNotFound(PathBuf), + #[error("generation_config.json doesn't exist in engine folder {0}")] + GenerationConfigNotFound(PathBuf), } diff --git a/backends/trtllm/src/lib.rs b/backends/trtllm/src/lib.rs index 085072561f1..306f9f486b2 100644 --- a/backends/trtllm/src/lib.rs +++ b/backends/trtllm/src/lib.rs @@ -24,6 +24,14 @@ mod ffi { /// The request finished because the maximum number of tokens was reached. #[cxx_name = "kLENGTH"] MaxLength = 3u8, + + #[cxx_name = "kTIMED_OUT"] + /// The request finished because it got timed out (via the mAllotedTime parameter) + TimedOut = 4u8, + + #[cxx_name = "kCANCELLED"] + /// The request was cancelled by calling cancelRequest. + Cancelled = 5u8, } /// Struct used as shared type between rust and C++ to represent the result @@ -34,8 +42,14 @@ mod ffi { request_id: u64, token_id: u32, log_prob: f32, + + /// The time of first schedule since the creation of the backend + first_scheduled_time_ns: i64, is_final: bool, finish_reason: FinishReason, + token_id_valid: bool, + log_prob_valid: bool, + first_scheduled_time_ns_valid: bool, has_error: bool, error_msg: String, } @@ -64,12 +78,12 @@ mod ffi { fn create_backend_from_engine_folder( engine_folder: &str, executor_worker: &str, + tokenizer_str: &str, + encoded_vocab: Vec, ) -> Result>; - fn num_tokens_ready(self: &TensorRtLlmBackendImpl) -> usize; - fn submit( - self: Pin<&mut TensorRtLlmBackendImpl>, + self: &TensorRtLlmBackendImpl, tokens: &[u32], max_new_tokens: u32, top_k: u32, @@ -78,13 +92,28 @@ mod ffi { repetition_penalty: f32, frequency_penalty: f32, seed: u64, + grammar_type: GrammarType, + grammar_value: &str, ) -> Result; fn pull_tokens( - self: Pin<&mut TensorRtLlmBackendImpl>, + self: &TensorRtLlmBackendImpl, ) -> Result>>; - fn cancel(self: Pin<&mut TensorRtLlmBackendImpl>, request_id: u64); + fn cancel(self: &TensorRtLlmBackendImpl, request_id: u64); + } + + #[cxx_name = "grammar_type_t"] + #[derive(Debug, Clone, Copy)] + pub enum GrammarType { + #[cxx_name = "kNONE"] + None = 0u8, + + #[cxx_name = "kJSON"] + Json = 1u8, + + #[cxx_name = "kREGEX"] + Regex = 2u8, } } diff --git a/backends/trtllm/src/looper.rs b/backends/trtllm/src/looper.rs index 5fed954fff7..32de99f37a1 100644 --- a/backends/trtllm/src/looper.rs +++ b/backends/trtllm/src/looper.rs @@ -1,14 +1,15 @@ use async_trait::async_trait; use cxx::UniquePtr; use hashbrown::HashMap; -use std::hint; use std::ops::Deref; -use std::path::Path; +use std::path::{Path, PathBuf}; +use std::sync::Arc; use tokenizers::Tokenizer; +use tokio::sync::mpsc::error::TryRecvError; use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender}; use tokio::sync::TryAcquireError; use tokio::task::spawn_blocking; -use tokio::time::Instant; +use tokio::time::{Duration, Instant}; use tokio_stream::wrappers::UnboundedReceiverStream; use tracing::{debug, error, warn}; @@ -17,12 +18,13 @@ use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStr use text_generation_router::validation::ValidationError::{ EmptyInput, Grammar, TopNTokensDisabled, UnsupportedModality, }; -use text_generation_router::validation::{Chunk, ValidGenerateRequest}; +use text_generation_router::validation::{Chunk, ValidGenerateRequest, ValidGrammar}; use text_generation_router::Token; use crate::errors::TensorRtLlmBackendError; use crate::ffi::{ - create_backend_from_engine_folder, FinishReason, GenerationStep, TensorRtLlmBackendImpl, + create_backend_from_engine_folder, FinishReason, GenerationStep, GrammarType, + TensorRtLlmBackendImpl, }; use crate::utils::first_line; @@ -35,6 +37,9 @@ struct GenerationContext { tokens: Vec, start: Option, queued: Instant, + + /// output_buffer stores the output for detecting stop sequences + output_buffer: Option, } #[derive(Debug, Copy, Clone)] @@ -49,132 +54,234 @@ impl<'step> TryFrom<&'step GenerationStep> for DecodedToken { type Error = InferError; fn try_from(step: &'step GenerationStep) -> Result { - if !step.has_error { - Ok(Self { - id: step.token_id, - log_prob: step.log_prob, - is_final: step.is_final, - finish_reason: step.finish_reason, - }) - } else { - Err(GenerationError(step.error_msg.clone())) + if step.has_error { + return Err(GenerationError(step.error_msg.clone())); + } + + if !step.token_id_valid { + return Err(GenerationError( + "GenerationStep contains no token_id".to_string(), + )); + } + + if !step.log_prob_valid { + return Err(GenerationError( + "GenerationStep contains no log_prob".to_string(), + )); } + + Ok(Self { + id: step.token_id, + log_prob: step.log_prob, + is_final: step.is_final, + finish_reason: step.finish_reason, + }) } } -fn executor_status_looper( +struct InFlightRequest { + request_id: u64, + ctx: GenerationContext, +} + +/// request_looper reads from the backlog, sends the request to backend, +/// and then transfer the request context to the response_looper via in_flights. +fn request_looper( + backend: Arc>, + mut backlog: UnboundedReceiver, + in_flights: UnboundedSender, +) { + loop { + let Some(ctx) = backlog.blocking_recv() else { + break; + }; + // Submit all the request to the executor and move the context to the in-flight tracker + let request = &ctx.request; + let generation_params = &request.parameters; + let stopping_params = &request.stopping_parameters; + let input_ids = request.input_ids.as_deref(); + let top_k = if generation_params.do_sample { + generation_params.top_k + } else { + 1 + }; + + let (grammar_type, grammar_value): (GrammarType, &str) = + if let Some(grammar) = &generation_params.grammar { + match grammar { + ValidGrammar::Json(v) => (GrammarType::Json, v), + ValidGrammar::Regex(v) => (GrammarType::Regex, v), + } + } else { + (GrammarType::None, "") + }; + + // Submit to the TensorRT-LLM executor for scheduling + match backend.submit( + &input_ids.unwrap(), // This is checked beforehand in validate() + stopping_params.max_new_tokens, + top_k, + generation_params.top_p, + generation_params.temperature, + generation_params.repetition_penalty, + generation_params.frequency_penalty, + generation_params.seed, + grammar_type, + grammar_value, + ) { + Ok(request_id) => { + // Insert the context linked to the generated request id in the tracker + debug!("[in-flight] Added {}", request_id); + if let Err(err) = in_flights.send(InFlightRequest { request_id, ctx }) { + error!("[in-flight] Send failed {}", err); + return; + } + } + Err(e) => { + // Return to the caller + let what = e.to_string(); + error!(error = what.as_str(), "Failed to schedule request"); + + let err = Err(InferError::Overloaded(TryAcquireError::NoPermits)); + if let Err(_) = ctx.streamer.send(err) { + error!("Failed to send back error to the client"); + } + } + }; + } +} + +/// response_looper awaits requests from in_flights if there are no active ones +/// or awaits for tokens from backend. The tokens are processed and sent back. +fn response_looper( max_inflight_requests: usize, tokenizer: Tokenizer, - mut backend: UniquePtr, - mut backlog: UnboundedReceiver, + created_time: Instant, + backend: Arc>, + mut in_flight_recv: UnboundedReceiver, ) { - // Track the tuple (request_id, stream) for each request + // // Track the tuple (request_id, stream) for each request let mut in_flights = HashMap::::with_capacity(max_inflight_requests * 2); - - 'scheduler: loop { - // Is there any request pending to be scheduled? - let awaiting_requests = backlog.len(); - for _ in 0..awaiting_requests { - // Retrieve all the requests - if let Some(ctx) = backlog.blocking_recv() { - // Submit all the request to the executor and move the context to the in-flight tracker - let request = &ctx.request; - let generation_params = &request.parameters; - let stopping_params = &request.stopping_parameters; - let input_ids = request.input_ids.as_deref(); - - // Submit to the TensorRT-LLM executor for scheduling - match backend.pin_mut().submit( - &input_ids.unwrap(), // This is checked beforehand in validate() - stopping_params.max_new_tokens, - generation_params.top_k, - generation_params.top_p, - generation_params.temperature, - generation_params.repetition_penalty, - generation_params.frequency_penalty, - generation_params.seed, - ) { - Ok(request_id) => { - // Insert the context linked to the generated request id in the tracker - debug!("[in-flight] Added {}", request_id); - in_flights.insert(request_id, ctx); - } - Err(e) => { - // Return to the caller - let what = e.to_string(); - error!(error = what.as_str(), "Failed to schedule request"); - - let err = Err(InferError::Overloaded(TryAcquireError::NoPermits)); - if let Err(_) = ctx.streamer.send(err) { - error!("Failed to send back error to the client"); - } - } - }; - } else { - break 'scheduler; - } + loop { + if in_flights.is_empty() { + // If there are no active requests, block on Rust channel instead of C++ side. + let Some(req) = in_flight_recv.blocking_recv() else { + return; + }; + in_flights.insert(req.request_id, req.ctx); } + match backend.pull_tokens() { + Ok(responses) => { + // Fetch all pending requests, in case we are receiving tokens from them. + loop { + match in_flight_recv.try_recv() { + Ok(req) => in_flights.insert(req.request_id, req.ctx), + Err(err) => match err { + TryRecvError::Empty => break, + TryRecvError::Disconnected => return, + }, + }; + } + + // Iterate through all the decoded token + for step in responses.deref() { + if let Some(ctx) = in_flights.get_mut(&step.request_id) { + // Update the starting timestamp if not set + if ctx.start.is_none() { + if step.first_scheduled_time_ns_valid { + if step.first_scheduled_time_ns >= 0 { + ctx.start = created_time.checked_add(Duration::from_nanos( + step.first_scheduled_time_ns as u64, + )); + } else { + ctx.start = created_time.checked_sub(Duration::from_nanos( + -step.first_scheduled_time_ns as u64, + )); + } + } - if backend.num_tokens_ready() > 0 { - let mut backend = backend.pin_mut(); - match backend.as_mut().pull_tokens() { - Ok(responses) => { - // Iterate through all the decoded token - for step in responses.deref() { - if let Some(ctx) = in_flights.get_mut(&step.request_id) { - // Update the starting timestamp if not set - // This value might not be the actual real starting time of the request - // on the executor side - Need to expose more info from the executor to - // retrieve this value - // TODO : Expose actual real starting time for a request on FFI layer if ctx.start.is_none() { ctx.start = Some(Instant::now()); } + } - // Try to map the generation step to a DecodedToken - let response = match DecodedToken::try_from(step) { - Ok(decoded_token) => { - post_process_decoded_token(&tokenizer, ctx, decoded_token) - } - Err(err) => Err(err), - }; - - // Attempt to send back the response to the client - if let Err(_) = ctx.streamer.send(response) { - // Client has dropped, remove from tracked requests - debug!( - "Client dropped - removing request {} from tracked requests", - step.request_id - ); - backend.as_mut().cancel(step.request_id); - let _ = in_flights.remove(&step.request_id); + // Try to map the generation step to a DecodedToken + let response = match DecodedToken::try_from(step) { + Ok(decoded_token) => { + post_process_decoded_token(&tokenizer, ctx, decoded_token) + } + Err(err) => Err(err), + }; + + // Attempt to send back the response to the client + if let Err(_) = ctx.streamer.send(response) { + // Client has dropped, remove from tracked requests + debug!( + "Client dropped - removing request {} from tracked requests", + step.request_id + ); + backend.cancel(step.request_id); + let _ = in_flights.remove(&step.request_id); + } + } else { + match step.finish_reason { + FinishReason::Cancelled => { + // The client has canceled the request, so this should not generate a + // warning. + debug!("Cancelled request {}", step.request_id); + } + _ => { + warn!("Untracked request {}", step.request_id); } - } else { - warn!("Untracked request {}", step.request_id,); } } } - Err(ref err) => { - error!("Failed to get responses from the executor: {}.", err.what()); - break 'scheduler; - } + } + Err(ref err) => { + error!("Failed to get responses from the executor: {}.", err.what()); + break; } } - - // Hint the CPU we are spin-locking - hint::spin_loop(); } } fn post_process_decoded_token( tokenizer: &Tokenizer, ctx: &mut GenerationContext, - decoded_token: DecodedToken, + mut decoded_token: DecodedToken, ) -> InferResult { match tokenizer.decode(&[decoded_token.id], false) { Ok(text) => { let is_special = tokenizer.get_added_vocabulary().is_special_token(&text); + + if let Some(buf) = ctx.output_buffer.as_mut() { + if buf.len() + text.len() > buf.capacity() { + let mut start = buf.len() + text.len() - buf.capacity(); + while start <= buf.len() && !buf.is_char_boundary(start) { + start += 1; + } + buf.drain(..start); + } + buf.push_str(&text); + + for stop_seq in &ctx.request.stopping_parameters.stop_sequences { + let start = if 1 + buf.len() > text.len() + stop_seq.len() { + let mut start = 1 + buf.len() - text.len() - stop_seq.len(); + while start > 0 && !buf.is_char_boundary(start) { + start -= 1; + } + start + } else { + 0 + }; + if buf[start..].contains(stop_seq) { + decoded_token.is_final = true; + decoded_token.finish_reason = FinishReason::StopWords; + } + } + } + let token = Token { id: decoded_token.id, text, @@ -231,6 +338,26 @@ fn ensure_paths_exist, PP: AsRef>( return Err(err); } + let mut config_path = PathBuf::from(engine_folder); + config_path.push("config.json"); + + if !config_path.exists() { + let err = TensorRtLlmBackendError::ConfigNotFound(engine_folder.to_path_buf()); + + error!("Path validation failed: {}", err,); + return Err(err); + } + + let mut generation_config_path = PathBuf::from(engine_folder); + generation_config_path.push("generation_config.json"); + + if !generation_config_path.exists() { + let err = TensorRtLlmBackendError::GenerationConfigNotFound(engine_folder.to_path_buf()); + + error!("Path validation failed: {}", err,); + return Err(err); + } + // Ensure executor worker binary exists if !executor_worker_path.exists() { let err = TensorRtLlmBackendError::ExecutorWorkerNotFound(engine_folder.to_path_buf()); @@ -255,6 +382,7 @@ fn ensure_paths_exist, PP: AsRef>( } unsafe impl Send for TensorRtLlmBackendImpl {} +unsafe impl Sync for TensorRtLlmBackendImpl {} pub struct TensorRtLlmBackendV2(UnboundedSender); @@ -271,13 +399,47 @@ impl TensorRtLlmBackendV2 { // Allocate the IPC layer to communicate with the backend let (executor_sender, executor_receiver) = unbounded_channel(); + let (in_flight_sender, in_flight_receiver) = unbounded_channel(); + + // This is a reference point to convert time from c++ time_point + // to rust Instant. + let created_time = Instant::now(); + + let encoded_vocab = { + let vocab = tokenizer.get_vocab(true); + let mut tokens: Vec = vocab.keys().map(|x| x.clone()).collect(); + tokens.sort_by(|a, b| vocab.get(a).cmp(&vocab.get(b))); + tokens + }; + + let tokenizer_str = tokenizer + .to_string(false) + .map_err(|e| TensorRtLlmBackendError::Tokenizer(e.to_string()))?; + // Create the FFI backend - let backend = create_backend_from_engine_folder(&engine_folder, &executor_worker_path) - .map_err(|e| TensorRtLlmBackendError::Runtime(first_line(e.what(), "Unknown error")))?; + let backend = create_backend_from_engine_folder( + &engine_folder, + &executor_worker_path, + &tokenizer_str, + encoded_vocab, + ) + .map_err(|e| TensorRtLlmBackendError::Runtime(first_line(e.what(), "Unknown error")))?; - // Executor looper is responsible for scheduling and pulling requests state at regular interval + let backend = Arc::new(backend); + let backend_response = backend.clone(); + + // Request looper is responsible for scheduling requests + spawn_blocking(move || request_looper(backend, executor_receiver, in_flight_sender)); + + // Response looper is responsible for awaiting tokens and send them back spawn_blocking(move || { - executor_status_looper(max_inflight_requests, tokenizer, backend, executor_receiver) + response_looper( + max_inflight_requests, + tokenizer, + created_time, + backend_response, + in_flight_receiver, + ) }); Ok(TensorRtLlmBackendV2(executor_sender)) @@ -292,11 +454,6 @@ impl TensorRtLlmBackendV2 { return Err(ValidationError(TopNTokensDisabled)); } - // TODO: Is it really needed? How can it be validated before? - if request.parameters.grammar.is_some() { - return Err(ValidationError(Grammar)); - } - match request.inputs.len() { 0 => Err(ValidationError(EmptyInput)), 2.. => Err(GenerationError( @@ -323,12 +480,20 @@ impl Backend for TensorRtLlmBackendV2 { // Send the context to the executor for scheduling let queued = Instant::now(); + let output_buffer = request + .stopping_parameters + .stop_sequences + .iter() + .map(|x| x.len()) + .max() + .map(|m| String::with_capacity(m + 32)); // TODO: is this number enough? match self.0.send(GenerationContext { request, streamer, tokens: Vec::with_capacity(256), start: None, queued, + output_buffer, }) { Ok(_) => Ok(UnboundedReceiverStream::new(receiver)), Err(_) => Err(GenerationError( diff --git a/backends/trtllm/src/main.rs b/backends/trtllm/src/main.rs index 543f8e6e352..0e28ad02f24 100644 --- a/backends/trtllm/src/main.rs +++ b/backends/trtllm/src/main.rs @@ -37,7 +37,7 @@ struct Args { hostname: String, #[clap(default_value = "3000", long, short, env)] port: u16, - #[clap(default_value = "9000", long, short, env)] + #[clap(default_value = "9000", long, env)] prometheus_port: u16, #[clap(long, env, required = true)] tokenizer_name: String, @@ -67,6 +67,8 @@ struct Args { usage_stats: UsageStatsLevel, #[clap(default_value = "2000000", long, env)] payload_limit: usize, + #[clap(long, env, default_value_t = false)] + disable_grammar_support: bool, } async fn get_tokenizer(tokenizer_name: &str, revision: Option<&str>) -> Option { @@ -244,6 +246,7 @@ async fn main() -> Result<(), TensorRtLlmBackendError> { executor_worker, usage_stats, payload_limit, + disable_grammar_support, } = args; // Launch Tokio runtime @@ -321,7 +324,7 @@ async fn main() -> Result<(), TensorRtLlmBackendError> { false, None, None, - true, + disable_grammar_support, max_client_batch_size, usage_stats, payload_limit, diff --git a/backends/v2/src/main.rs b/backends/v2/src/main.rs index 60b5d52bbe2..a0f3558c0da 100644 --- a/backends/v2/src/main.rs +++ b/backends/v2/src/main.rs @@ -36,7 +36,7 @@ struct Args { hostname: String, #[clap(default_value = "3000", long, short, env)] port: u16, - #[clap(default_value = "9000", long, short, env)] + #[clap(default_value = "9000", long, env)] prometheus_port: u16, #[clap(default_value = "/tmp/text-generation-server-0", long, env)] master_shard_uds_path: String, diff --git a/backends/v3/src/main.rs b/backends/v3/src/main.rs index 44e63853e04..75a2069124e 100644 --- a/backends/v3/src/main.rs +++ b/backends/v3/src/main.rs @@ -36,7 +36,7 @@ struct Args { hostname: String, #[clap(default_value = "3000", long, short, env)] port: u16, - #[clap(default_value = "9000", long, short, env)] + #[clap(default_value = "9000", long, env)] prometheus_port: u16, #[clap(default_value = "/tmp/text-generation-server-0", long, env)] master_shard_uds_path: String, diff --git a/docs/source/reference/launcher.md b/docs/source/reference/launcher.md index 5b7321b73a3..f49cbac5a2a 100644 --- a/docs/source/reference/launcher.md +++ b/docs/source/reference/launcher.md @@ -254,7 +254,7 @@ Options: ``` ## PROMETHEUS_PORT ```shell - -p, --prometheus-port + --prometheus-port The Prometheus port to listen on [env: PROMETHEUS_PORT=] diff --git a/launcher/src/main.rs b/launcher/src/main.rs index c727623ce47..f339cbb47e0 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -774,7 +774,7 @@ struct Args { port: u16, /// The Prometheus port to listen on. - #[clap(default_value = "9000", long, short, env)] + #[clap(default_value = "9000", long, env)] prometheus_port: u16, /// The name of the socket for gRPC communication between the webserver