diff --git a/README.md b/README.md index 4fcd7b6..33c9fa7 100644 --- a/README.md +++ b/README.md @@ -99,3 +99,21 @@ but the listed CMake argument can be used to override. * triton-inference-server/backend: -DTRITON_BACKEND_REPO_TAG=[tag] * triton-inference-server/core: -DTRITON_CORE_REPO_TAG=[tag] * triton-inference-server/common: -DTRITON_COMMON_REPO_TAG=[tag] + +## Parameters + +Triton exposes some flags to control the execution mode of the TensorRT models through +the Parameters section of the model's `config.pbtxt` file. + +### execution_context_allocation_strategy + +Different memory allocation behaviors for IExecutionContext. IExecutionContext requires a block of device memory for internal activation tensors during inference. The user can let the execution context manage the memory in various ways. Current options are "STATIC" (default) and "ON_PROFILE_CHANGE". + +``` +parameters: { + key: "execution_context_allocation_strategy" + value: { + string_value: "STATIC" + } +} +``` diff --git a/src/instance_state.cc b/src/instance_state.cc index 56208a1..8819911 100644 --- a/src/instance_state.cc +++ b/src/instance_state.cc @@ -1,4 +1,4 @@ -// Copyright 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions @@ -1693,19 +1693,6 @@ ModelInstanceState::InitIOIndexMap() TRITONSERVER_Error* ModelInstanceState::InitOptimizationProfiles() { - // TRT sets the optimization profile index to be 0 implicitly with - // the first context creation. As currently triton supports one - // context per engine, in order to set the specified profile_index, - // another context is created and the previous context is destroyed. - std::shared_ptr default_trt_context( - engine_->createExecutionContext()); - if (default_trt_context == nullptr) { - return TRITONSERVER_ErrorNew( - TRITONSERVER_ERROR_INTERNAL, - (std::string("unable to create TensorRT context: ") + - model_state_->GetTensorRTLogger().LastErrorMsg()) - .c_str()); - } std::vector> profile_name_index; // No optimization profile is set for this TensorRT plan if (ProfileNames().empty()) { @@ -1736,17 +1723,19 @@ ModelInstanceState::InitOptimizationProfiles() .c_str()); continue; } - if (profile_index == 0) { - res.first->second.context_ = std::move(default_trt_context); - } else { - res.first->second.context_.reset(engine_->createExecutionContext()); - if (res.first->second.context_ == nullptr) { - return TRITONSERVER_ErrorNew( - TRITONSERVER_ERROR_INTERNAL, - (std::string("unable to create TensorRT context: ") + - model_state_->GetTensorRTLogger().LastErrorMsg()) - .c_str()); - } + + // Create a new execution context for the profile + res.first->second.context_.reset( + engine_->createExecutionContext(model_state_->AllocationStrategy())); + if (res.first->second.context_ == nullptr) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INTERNAL, + (std::string("unable to create TensorRT context: ") + + model_state_->GetTensorRTLogger().LastErrorMsg()) + .c_str()); + } + + if (profile_index != 0) { if (!res.first->second.context_->setOptimizationProfileAsync( profile_index, stream_)) { return TRITONSERVER_ErrorNew( diff --git a/src/model_state.cc b/src/model_state.cc index 6127989..bc72d67 100644 --- a/src/model_state.cc +++ b/src/model_state.cc @@ -1,4 +1,4 @@ -// Copyright 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions @@ -142,7 +142,8 @@ ModelState::Create(TRITONBACKEND_Model* triton_model, ModelState** state) } ModelState::ModelState(TRITONBACKEND_Model* triton_model) - : TensorRTModel(triton_model), engine_sharing_(true) + : TensorRTModel(triton_model), engine_sharing_(true), + alloc_strategy_(nvinfer1::ExecutionContextAllocationStrategy::kSTATIC) { // Obtain backend configuration TRITONBACKEND_Backend* backend; @@ -288,6 +289,43 @@ ModelState::ValidateModelConfig() TRITONSERVER_Error* ModelState::ParseParameters() { + triton::common::TritonJson::Value params; + bool status = ModelConfig().Find("parameters", ¶ms); + if (status) { + // If 'execution_context_allocation_strategy' is not present in + // 'parameters', will use the default strategy "STATIC". + std::string alloc_strategy; + TRITONSERVER_Error* err = GetParameterValue( + params, "execution_context_allocation_strategy", &alloc_strategy); + if (err != nullptr) { + if (TRITONSERVER_ErrorCode(err) != TRITONSERVER_ERROR_NOT_FOUND) { + return err; + } else { + TRITONSERVER_ErrorDelete(err); + } + } else { + // 'execution_context_allocation_strategy' is present in model config + // parameters. + if (alloc_strategy == "STATIC") { + alloc_strategy_ = nvinfer1::ExecutionContextAllocationStrategy::kSTATIC; + } else if (alloc_strategy == "ON_PROFILE_CHANGE") { + alloc_strategy_ = + nvinfer1::ExecutionContextAllocationStrategy::kON_PROFILE_CHANGE; + } else { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + ("Invalid value for 'execution_context_allocation_strategy': '" + + alloc_strategy + "' for model instance '" + Name() + + "'. Supported values are 'STATIC' and 'ON_PROFILE_CHANGE'.") + .c_str()); + } + LOG_MESSAGE( + TRITONSERVER_LOG_INFO, + ("'execution_context_allocation_strategy' set to '" + alloc_strategy + + "' for model instance '" + Name() + "'") + .c_str()); + } + } return nullptr; // success } diff --git a/src/model_state.h b/src/model_state.h index b132806..42274a3 100644 --- a/src/model_state.h +++ b/src/model_state.h @@ -1,4 +1,4 @@ -// Copyright 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions @@ -88,6 +88,11 @@ class ModelState : public TensorRTModel { TensorRTLogger& GetTensorRTLogger() { return tensorrt_logger_; } + nvinfer1::ExecutionContextAllocationStrategy AllocationStrategy() const + { + return alloc_strategy_; + } + private: ModelState(TRITONBACKEND_Model* triton_model); @@ -140,6 +145,8 @@ class ModelState : public TensorRTModel { // Whether the backend should support version-compatible TensorRT models. static inline bool is_version_compatible_{false}; + + nvinfer1::ExecutionContextAllocationStrategy alloc_strategy_; };