2525// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2626#pragma once
2727
28+ #ifdef TRITON_ENABLE_CUDA_CTX_SHARING
29+ #include < cuda.h>
30+ #endif // TRITON_ENABLE_CUDA_CTX_SHARING
31+ #include < sstream>
32+
2833#include " triton/backend/backend_model.h"
2934
3035namespace triton { namespace backend { namespace tensorrt {
@@ -34,6 +39,14 @@ class TensorRTModel : public BackendModel {
3439 TensorRTModel (TRITONBACKEND_Model* triton_model);
3540 virtual ~TensorRTModel () = default ;
3641
42+ template <typename T>
43+ TRITONSERVER_Error* GetParameter (std::string const & name, T& value)
44+ {
45+ assert (false );
46+ auto dummy = T ();
47+ return dummy;
48+ }
49+
3750 TRITONSERVER_Error* SetTensorRTModelConfig ();
3851
3952 TRITONSERVER_Error* ParseModelConfig ();
@@ -53,6 +66,65 @@ class TensorRTModel : public BackendModel {
5366 bool EagerBatching () { return eager_batching_; }
5467 bool BusyWaitEvents () { return busy_wait_events_; }
5568
69+ template <>
70+ TRITONSERVER_Error* GetParameter<std::string>(
71+ std::string const & name, std::string& str_value);
72+
73+ void * StringToPointer (std::string& str)
74+ {
75+ std::stringstream ss;
76+ ss << str;
77+
78+ void * ctx_ptr;
79+ ss >> ctx_ptr;
80+ return ctx_ptr;
81+ }
82+
83+ // ! Following functions are related to custom Cuda context (Cuda in Graphics)
84+ // ! sharing for gaming use case. Creating a shared contexts reduces context
85+ // ! switching overhead and leads to better performance of model execution
86+ // ! along side Graphics workload.
87+
88+ bool IsCudaContextSharingEnabled ()
89+ {
90+ #ifdef TRITON_ENABLE_CUDA_CTX_SHARING
91+ return cuda_ctx != nullptr ;
92+ #else
93+ return false ;
94+ #endif // TRITON_ENABLE_CUDA_CTX_SHARING
95+ }
96+
97+ inline TRITONSERVER_Error* PushCudaContext ()
98+ {
99+ #ifdef TRITON_ENABLE_CUDA_CTX_SHARING
100+ if (CUDA_SUCCESS != cuCtxPushCurrent (cuda_ctx)) {
101+ return TRITONSERVER_ErrorNew (
102+ TRITONSERVER_ERROR_INTERNAL,
103+ (std::string (" unable to push Cuda context for " ) + Name ()).c_str ());
104+ }
105+ #endif // TRITON_ENABLE_CUDA_CTX_SHARING
106+ return nullptr ;
107+ }
108+
109+ inline TRITONSERVER_Error* PopCudaContext ()
110+ {
111+ #ifdef TRITON_ENABLE_CUDA_CTX_SHARING
112+ CUcontext oldCtx{};
113+ if (CUDA_SUCCESS != cuCtxPopCurrent (&oldCtx)) {
114+ return TRITONSERVER_ErrorNew (
115+ TRITONSERVER_ERROR_INTERNAL,
116+ (std::string (" unable to pop Cuda context for " ) + Name ()).c_str ());
117+ }
118+ if (oldCtx != cuda_ctx) {
119+ return TRITONSERVER_ErrorNew (
120+ TRITONSERVER_ERROR_INTERNAL,
121+ (std::string (" popping the wrong Cuda context for " ) + Name ())
122+ .c_str ());
123+ }
124+ #endif // TRITON_ENABLE_CUDA_CTX_SHARING
125+ return nullptr ;
126+ }
127+
56128 protected:
57129 common::TritonJson::Value graph_specs_;
58130 Priority priority_;
@@ -61,6 +133,30 @@ class TensorRTModel : public BackendModel {
61133 bool separate_output_stream_;
62134 bool eager_batching_;
63135 bool busy_wait_events_;
136+ #ifdef TRITON_ENABLE_CUDA_CTX_SHARING
137+ CUcontext cuda_ctx = nullptr ;
138+ #endif // TRITON_ENABLE_CUDA_CTX_SHARING
139+ };
140+
141+ struct ScopedRuntimeCudaContext {
142+ ScopedRuntimeCudaContext (TensorRTModel* model_state)
143+ : model_state_(model_state)
144+ {
145+ #ifdef TRITON_ENABLE_CUDA_CTX_SHARING
146+ if (model_state_->IsCudaContextSharingEnabled ()) {
147+ THROW_IF_BACKEND_MODEL_ERROR (model_state_->PushCudaContext ());
148+ }
149+ #endif // TRITON_ENABLE_CUDA_CTX_SHARING
150+ }
151+ ~ScopedRuntimeCudaContext ()
152+ {
153+ #ifdef TRITON_ENABLE_CUDA_CTX_SHARING
154+ if (model_state_->IsCudaContextSharingEnabled ()) {
155+ THROW_IF_BACKEND_MODEL_ERROR (model_state_->PopCudaContext ());
156+ }
157+ #endif // TRITON_ENABLE_CUDA_CTX_SHARING
158+ }
159+ TensorRTModel* model_state_;
64160};
65161
66162}}} // namespace triton::backend::tensorrt
0 commit comments