Skip to content

Commit 2b060ca

Browse files
authored
Revert "Enable Cuda in Graphics Implementation for TensorRT backend (#100)" (#105)
This reverts commit ab13c10.
1 parent ab13c10 commit 2b060ca

File tree

6 files changed

+25
-189
lines changed

6 files changed

+25
-189
lines changed

CMakeLists.txt

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,6 @@ set(TRITON_MIN_CXX_STANDARD 17 CACHE STRING "The minimum C++ standard which feat
3737
option(TRITON_ENABLE_GPU "Enable GPU support in backend." ON)
3838
option(TRITON_ENABLE_STATS "Include statistics collections in backend." ON)
3939
option(TRITON_ENABLE_NVTX "Include nvtx markers collection in backend." OFF)
40-
option(TRITON_ENABLE_CUDA_CTX_SHARING "Enable Cuda context sharing support in backend." OFF)
41-
4240
set(TRITON_TENSORRT_LIB_PATHS "" CACHE PATH "Paths to TensorRT libraries. Multiple paths may be specified by separating them with a semicolon.")
4341
set(TRITON_TENSORRT_INCLUDE_PATHS "" CACHE PATH "Paths to TensorRT includes. Multiple paths may be specified by separating them with a semicolon.")
4442

@@ -234,17 +232,6 @@ target_link_libraries(
234232
CUDA::cudart
235233
)
236234

237-
if(${TRITON_ENABLE_CUDA_CTX_SHARING})
238-
target_compile_definitions(
239-
triton-tensorrt-backend
240-
PRIVATE TRITON_ENABLE_CUDA_CTX_SHARING
241-
)
242-
target_link_libraries(
243-
triton-tensorrt-backend
244-
PRIVATE
245-
CUDA::cuda_driver
246-
)
247-
endif()
248235

249236
#
250237
# Install

src/instance_state.cc

Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -257,9 +257,7 @@ ModelInstanceState::ModelInstanceState(
257257

258258
ModelInstanceState::~ModelInstanceState()
259259
{
260-
if (!model_state_->IsCudaContextSharingEnabled()) {
261-
cudaSetDevice(DeviceId());
262-
}
260+
cudaSetDevice(DeviceId());
263261
for (auto& io_binding_infos : io_binding_infos_) {
264262
for (auto& io_binding_info : io_binding_infos) {
265263
if (!io_binding_info.IsDynamicShapeOutput() &&
@@ -426,9 +424,7 @@ ModelInstanceState::Run(
426424
payload_.reset(new Payload(next_set_, requests, request_count));
427425
SET_TIMESTAMP(payload_->compute_start_ns_);
428426

429-
if (!model_state_->IsCudaContextSharingEnabled()) {
430-
cudaSetDevice(DeviceId());
431-
}
427+
cudaSetDevice(DeviceId());
432428
#ifdef TRITON_ENABLE_STATS
433429
{
434430
SET_TIMESTAMP(payload_->compute_start_ns_);
@@ -1555,16 +1551,13 @@ ModelInstanceState::EvaluateTensorRTContext(
15551551
TRITONSERVER_Error*
15561552
ModelInstanceState::InitStreamsAndEvents()
15571553
{
1558-
if (!model_state_->IsCudaContextSharingEnabled()) {
1559-
// Set the device before preparing the context.
1560-
auto cuerr = cudaSetDevice(DeviceId());
1561-
if (cuerr != cudaSuccess) {
1562-
return TRITONSERVER_ErrorNew(
1563-
TRITONSERVER_ERROR_INTERNAL,
1564-
(std::string("unable to set device for ") + Name() + ": " +
1565-
cudaGetErrorString(cuerr))
1566-
.c_str());
1567-
}
1554+
// Set the device before preparing the context.
1555+
auto cuerr = cudaSetDevice(DeviceId());
1556+
if (cuerr != cudaSuccess) {
1557+
return TRITONSERVER_ErrorNew(
1558+
TRITONSERVER_ERROR_INTERNAL, (std::string("unable to set device for ") +
1559+
Name() + ": " + cudaGetErrorString(cuerr))
1560+
.c_str());
15681561
}
15691562

15701563
// Create CUDA streams associated with the instance

src/model_state.cc

Lines changed: 16 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -175,9 +175,7 @@ ModelState::ModelState(TRITONBACKEND_Model* triton_model)
175175
ModelState::~ModelState()
176176
{
177177
for (auto& device_engine : device_engines_) {
178-
if (!IsCudaContextSharingEnabled()) {
179-
cudaSetDevice(device_engine.first.first);
180-
}
178+
cudaSetDevice(device_engine.first.first);
181179
auto& runtime = device_engine.second.first;
182180
auto& engine = device_engine.second.second;
183181
// Need to reset explicitly to ensure proper destruction order
@@ -211,16 +209,15 @@ ModelState::CreateEngine(
211209
// We share the engine (for models that don't have dynamic shapes) and
212210
// runtime across instances that have access to the same GPU/NVDLA.
213211
if (eit->second.second == nullptr) {
214-
if (!IsCudaContextSharingEnabled()) {
215-
auto cuerr = cudaSetDevice(gpu_device);
216-
if (cuerr != cudaSuccess) {
217-
return TRITONSERVER_ErrorNew(
218-
TRITONSERVER_ERROR_INTERNAL,
219-
(std::string("unable to set device for ") + Name() + ": " +
220-
cudaGetErrorString(cuerr))
221-
.c_str());
222-
}
212+
auto cuerr = cudaSetDevice(gpu_device);
213+
if (cuerr != cudaSuccess) {
214+
return TRITONSERVER_ErrorNew(
215+
TRITONSERVER_ERROR_INTERNAL,
216+
(std::string("unable to set device for ") + Name() + ": " +
217+
cudaGetErrorString(cuerr))
218+
.c_str());
223219
}
220+
224221
const bool new_runtime = (eit->second.first == nullptr);
225222
RETURN_IF_ERROR(LoadPlan(
226223
model_path, dla_core_id, &eit->second.first, &eit->second.second,
@@ -324,18 +321,6 @@ ModelState::AutoCompleteConfig()
324321
" to auto-complete config for " + Name())
325322
.c_str()));
326323

327-
#ifdef TRITON_ENABLE_CUDA_CTX_SHARING
328-
// Return failure if Cuda context sharing is enabled and
329-
// if it is a multi GPU setup
330-
if (IsCudaContextSharingEnabled() && device_id != 0) {
331-
return TRITONSERVER_ErrorNew(
332-
TRITONSERVER_ERROR_INTERNAL,
333-
(std::string(
334-
"Cuda context sharing is not supported on multi-GPU system."))
335-
.c_str());
336-
}
337-
#endif // TRITON_ENABLE_CUDA_CTX_SHARING
338-
339324
cuerr = cudaSetDevice(device_id);
340325
if (cuerr != cudaSuccess) {
341326
return TRITONSERVER_ErrorNew(
@@ -388,15 +373,13 @@ ModelState::AutoCompleteConfig()
388373

389374
RETURN_IF_ERROR(AutoCompleteConfigHelper(model_path));
390375

391-
if (!IsCudaContextSharingEnabled()) {
392-
cuerr = cudaSetDevice(current_device);
393-
if (cuerr != cudaSuccess) {
394-
return TRITONSERVER_ErrorNew(
395-
TRITONSERVER_ERROR_INTERNAL,
396-
(std::string("unable to revert CUDA device to GPU ") +
397-
std::to_string(current_device) + " : " + cudaGetErrorString(cuerr))
398-
.c_str());
399-
}
376+
cuerr = cudaSetDevice(current_device);
377+
if (cuerr != cudaSuccess) {
378+
return TRITONSERVER_ErrorNew(
379+
TRITONSERVER_ERROR_INTERNAL,
380+
(std::string("unable to revert CUDA device to GPU ") +
381+
std::to_string(current_device) + " : " + cudaGetErrorString(cuerr))
382+
.c_str());
400383
}
401384

402385
if (TRITONSERVER_LogIsEnabled(TRITONSERVER_LOG_VERBOSE)) {

src/tensorrt.cc

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -318,7 +318,6 @@ TRITONBACKEND_ModelInstanceInitialize(TRITONBACKEND_ModelInstance* instance)
318318
DeviceMemoryTracker::TrackThreadMemoryUsage(lusage.get());
319319
}
320320

321-
ScopedRuntimeCudaContext cuda_scope(model_state);
322321

323322
// With each instance we create a ModelInstanceState object and
324323
// associate it with the TRITONBACKEND_ModelInstance.
@@ -354,11 +353,6 @@ TRITONBACKEND_ModelInstanceFinalize(TRITONBACKEND_ModelInstance* instance)
354353
LOG_MESSAGE(
355354
TRITONSERVER_LOG_INFO,
356355
"TRITONBACKEND_ModelInstanceFinalize: delete instance state");
357-
if (!instance_state) {
358-
return nullptr;
359-
}
360-
361-
ScopedRuntimeCudaContext cuda_scope(instance_state->StateForModel());
362356

363357
delete instance_state;
364358

@@ -383,8 +377,6 @@ TRITONBACKEND_ModelInstanceExecute(
383377
instance, reinterpret_cast<void**>(&instance_state)));
384378
ModelState* model_state = instance_state->StateForModel();
385379

386-
ScopedRuntimeCudaContext cuda_scope(model_state);
387-
388380
// For TensorRT backend, the executing instance may not closely tie to
389381
// TRITONBACKEND_ModelInstance, the instance will be assigned based on
390382
// execution policy.

src/tensorrt_model.cc

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -90,14 +90,6 @@ TensorRTModel::ParseModelConfig()
9090
}
9191
}
9292

93-
#ifdef TRITON_ENABLE_CUDA_CTX_SHARING
94-
std::string ptr_str = "";
95-
RETURN_IF_ERROR(GetParameter("CUDA_CONTEXT_PTR", ptr_str));
96-
cuda_ctx = static_cast<CUcontext>(StringToPointer(ptr_str));
97-
// cuda_ctx = static_cast<CUcontext>(reinterpret_cast<void*>(ptr_str));
98-
LOG_MESSAGE(TRITONSERVER_LOG_VERBOSE, "Cuda Context pointer is set");
99-
#endif // TRITON_ENABLE_CUDA_CTX_SHARING
100-
10193
return nullptr; // Success
10294
}
10395

@@ -128,19 +120,4 @@ TensorRTModel::GetCudaStreamPriority()
128120
return cuda_stream_priority;
129121
}
130122

131-
template <>
132-
TRITONSERVER_Error*
133-
TensorRTModel::GetParameter<std::string>(
134-
std::string const& name, std::string& str_value)
135-
{
136-
triton::common::TritonJson::Value parameters;
137-
RETURN_IF_ERROR(model_config_.MemberAsObject("parameters", &parameters));
138-
139-
triton::common::TritonJson::Value value;
140-
RETURN_IF_ERROR(parameters.MemberAsObject(name.c_str(), &value));
141-
142-
value.MemberAsString("string_value", &str_value);
143-
return nullptr;
144-
}
145-
146123
}}} // namespace triton::backend::tensorrt

src/tensorrt_model.h

Lines changed: 0 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,6 @@
2525
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2626
#pragma once
2727

28-
#ifdef TRITON_ENABLE_CUDA_CTX_SHARING
29-
#include <cuda.h>
30-
#endif // TRITON_ENABLE_CUDA_CTX_SHARING
31-
#include <sstream>
32-
3328
#include "triton/backend/backend_model.h"
3429

3530
namespace triton { namespace backend { namespace tensorrt {
@@ -39,14 +34,6 @@ class TensorRTModel : public BackendModel {
3934
TensorRTModel(TRITONBACKEND_Model* triton_model);
4035
virtual ~TensorRTModel() = default;
4136

42-
template <typename T>
43-
TRITONSERVER_Error* GetParameter(std::string const& name, T& value)
44-
{
45-
assert(false);
46-
auto dummy = T();
47-
return dummy;
48-
}
49-
5037
TRITONSERVER_Error* SetTensorRTModelConfig();
5138

5239
TRITONSERVER_Error* ParseModelConfig();
@@ -66,65 +53,6 @@ class TensorRTModel : public BackendModel {
6653
bool EagerBatching() { return eager_batching_; }
6754
bool BusyWaitEvents() { return busy_wait_events_; }
6855

69-
template <>
70-
TRITONSERVER_Error* GetParameter<std::string>(
71-
std::string const& name, std::string& str_value);
72-
73-
void* StringToPointer(std::string& str)
74-
{
75-
std::stringstream ss;
76-
ss << str;
77-
78-
void* ctx_ptr;
79-
ss >> ctx_ptr;
80-
return ctx_ptr;
81-
}
82-
83-
//! Following functions are related to custom Cuda context (Cuda in Graphics)
84-
//! sharing for gaming use case. Creating a shared contexts reduces context
85-
//! switching overhead and leads to better performance of model execution
86-
//! along side Graphics workload.
87-
88-
bool IsCudaContextSharingEnabled()
89-
{
90-
#ifdef TRITON_ENABLE_CUDA_CTX_SHARING
91-
return cuda_ctx != nullptr;
92-
#else
93-
return false;
94-
#endif // TRITON_ENABLE_CUDA_CTX_SHARING
95-
}
96-
97-
inline TRITONSERVER_Error* PushCudaContext()
98-
{
99-
#ifdef TRITON_ENABLE_CUDA_CTX_SHARING
100-
if (CUDA_SUCCESS != cuCtxPushCurrent(cuda_ctx)) {
101-
return TRITONSERVER_ErrorNew(
102-
TRITONSERVER_ERROR_INTERNAL,
103-
(std::string("unable to push Cuda context for ") + Name()).c_str());
104-
}
105-
#endif // TRITON_ENABLE_CUDA_CTX_SHARING
106-
return nullptr;
107-
}
108-
109-
inline TRITONSERVER_Error* PopCudaContext()
110-
{
111-
#ifdef TRITON_ENABLE_CUDA_CTX_SHARING
112-
CUcontext oldCtx{};
113-
if (CUDA_SUCCESS != cuCtxPopCurrent(&oldCtx)) {
114-
return TRITONSERVER_ErrorNew(
115-
TRITONSERVER_ERROR_INTERNAL,
116-
(std::string("unable to pop Cuda context for ") + Name()).c_str());
117-
}
118-
if (oldCtx != cuda_ctx) {
119-
return TRITONSERVER_ErrorNew(
120-
TRITONSERVER_ERROR_INTERNAL,
121-
(std::string("popping the wrong Cuda context for ") + Name())
122-
.c_str());
123-
}
124-
#endif // TRITON_ENABLE_CUDA_CTX_SHARING
125-
return nullptr;
126-
}
127-
12856
protected:
12957
common::TritonJson::Value graph_specs_;
13058
Priority priority_;
@@ -133,30 +61,6 @@ class TensorRTModel : public BackendModel {
13361
bool separate_output_stream_;
13462
bool eager_batching_;
13563
bool busy_wait_events_;
136-
#ifdef TRITON_ENABLE_CUDA_CTX_SHARING
137-
CUcontext cuda_ctx = nullptr;
138-
#endif // TRITON_ENABLE_CUDA_CTX_SHARING
139-
};
140-
141-
struct ScopedRuntimeCudaContext {
142-
ScopedRuntimeCudaContext(TensorRTModel* model_state)
143-
: model_state_(model_state)
144-
{
145-
#ifdef TRITON_ENABLE_CUDA_CTX_SHARING
146-
if (model_state_->IsCudaContextSharingEnabled()) {
147-
THROW_IF_BACKEND_MODEL_ERROR(model_state_->PushCudaContext());
148-
}
149-
#endif // TRITON_ENABLE_CUDA_CTX_SHARING
150-
}
151-
~ScopedRuntimeCudaContext()
152-
{
153-
#ifdef TRITON_ENABLE_CUDA_CTX_SHARING
154-
if (model_state_->IsCudaContextSharingEnabled()) {
155-
THROW_IF_BACKEND_MODEL_ERROR(model_state_->PopCudaContext());
156-
}
157-
#endif // TRITON_ENABLE_CUDA_CTX_SHARING
158-
}
159-
TensorRTModel* model_state_;
16064
};
16165

16266
}}} // namespace triton::backend::tensorrt

0 commit comments

Comments
 (0)