diff --git a/qa/L0_grpc_state_cleanup/shutdown_stress_test.py b/qa/L0_grpc_state_cleanup/shutdown_stress_test.py new file mode 100755 index 0000000000..3cf6b2cb3c --- /dev/null +++ b/qa/L0_grpc_state_cleanup/shutdown_stress_test.py @@ -0,0 +1,208 @@ +#!/usr/bin/env python3 + +# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import sys + +sys.path.append("../common") + +import os +import queue +import signal +import subprocess +import threading +import time +import unittest +from functools import partial + +import numpy as np +import test_util as tu +import tritonclient.grpc as grpcclient +from tritonclient.utils import InferenceServerException + + +class UserData: + def __init__(self): + self._response_queue = queue.Queue() + + +def callback(user_data, result, error): + if error: + user_data._response_queue.put(error) + else: + user_data._response_queue.put(result) + + +class ShutdownStressTest(tu.TestResultCollector): + """ + Stress test for gRPC shutdown race condition (GitHub issue #6899). + + This test verifies that handler threads don't block indefinitely during + server shutdown when alarm events are scheduled on a shutting-down + completion queue. The fix ensures that: + 1. Alarms are not scheduled after NotifyCQShutdown() is called + 2. Active alarms are cancelled during shutdown + 3. Handler threads use deadline-based polling to detect shutdown + """ + + def setUp(self): + self.model_name_ = "custom_zero_1_float32" + self.shutdown_timeout_ = 10 # seconds + + def _continuous_inference(self, duration_seconds, results): + """ + Run continuous gRPC inference requests for the specified duration. + Track success/failure counts in results dict. + """ + results["success"] = 0 + results["timeout"] = 0 + results["unavailable"] = 0 + results["other_errors"] = 0 + + start_time = time.time() + while time.time() - start_time < duration_seconds: + try: + with grpcclient.InferenceServerClient( + url="localhost:8001", verbose=False + ) as triton_client: + inputs = [] + inputs.append(grpcclient.InferInput("INPUT0", [1, 1], "FP32")) + input_data = np.array([[1.0]], dtype=np.float32) + inputs[0].set_data_from_numpy(input_data) + + outputs = [] + outputs.append(grpcclient.InferRequestedOutput("OUTPUT0")) + + # Use a short timeout to fail fast + response = triton_client.infer( + model_name=self.model_name_, + inputs=inputs, + outputs=outputs, + client_timeout=2.0, + ) + results["success"] += 1 + + except InferenceServerException as ex: + if "Deadline Exceeded" in str(ex): + results["timeout"] += 1 + elif "UNAVAILABLE" in str(ex) or "unavailable" in str(ex): + results["unavailable"] += 1 + else: + results["other_errors"] += 1 + except Exception as ex: + results["other_errors"] += 1 + + # Small delay between requests + time.sleep(0.01) + + def _shutdown_server(self, server_pid, delay_seconds): + """ + Wait for the specified delay, then send SIGINT to shutdown the server. + """ + time.sleep(delay_seconds) + print(f"Sending shutdown signal to server PID {server_pid}...") + os.kill(int(server_pid), signal.SIGINT) + + def test_shutdown_during_active_requests(self): + """ + Test that server shuts down cleanly while gRPC requests are active. + + This is a regression test for issue #6899 where handler threads would + block indefinitely waiting for completion queue events that never arrived. + """ + # Start continuous inference in background thread + inference_results = {} + inference_thread = threading.Thread( + target=self._continuous_inference, args=(5.0, inference_results) + ) + inference_thread.start() + + # Wait for some requests to be in flight + time.sleep(2.0) + + # Get server PID from environment + server_pid = os.environ.get("SERVER_PID") + if not server_pid: + self.assertTrue(False, "SERVER_PID environment variable not set") + + # Shutdown server while requests are active + shutdown_thread = threading.Thread( + target=self._shutdown_server, args=(server_pid, 0) + ) + shutdown_start = time.time() + shutdown_thread.start() + + # Wait for inference thread to complete + inference_thread.join(timeout=self.shutdown_timeout_) + shutdown_duration = time.time() - shutdown_start + + # Wait for shutdown thread + shutdown_thread.join(timeout=self.shutdown_timeout_) + + # Verify shutdown completed in reasonable time (not blocked indefinitely) + self.assertTrue( + shutdown_duration < self.shutdown_timeout_, + f"Server shutdown took {shutdown_duration:.2f}s, " + f"expected < {self.shutdown_timeout_}s. " + "This suggests handler threads may be blocked.", + ) + + # Verify we had some successful requests before shutdown + total_requests = sum(inference_results.values()) + print(f"\nInference results: {inference_results}") + print(f"Total requests: {total_requests}") + print(f"Shutdown duration: {shutdown_duration:.2f}s") + + self.assertTrue( + inference_results.get("success", 0) > 0, + "Expected at least some successful requests before shutdown", + ) + + # After shutdown, unavailable errors are expected, but timeouts should be minimal + # (timeouts would indicate threads blocked waiting for events) + timeout_ratio = inference_results.get("timeout", 0) / max(total_requests, 1) + self.assertTrue( + timeout_ratio < 0.5, + f"High timeout ratio ({timeout_ratio:.2%}) suggests handler threads " + "may have blocked during shutdown", + ) + + def test_repeated_shutdown_cycles(self): + """ + Test multiple server start/shutdown cycles with concurrent requests. + + This stresses the shutdown path to catch intermittent race conditions. + Note: This test would need to be run from a shell script that can + restart the server between cycles. + """ + # This is a placeholder - full implementation would require + # shell script orchestration to restart server between cycles + print("Note: Full repeated shutdown test requires shell script orchestration") + + +if __name__ == "__main__": + unittest.main() diff --git a/qa/L0_grpc_state_cleanup/test.sh b/qa/L0_grpc_state_cleanup/test.sh index 9def779b72..3c2ef3c4b4 100755 --- a/qa/L0_grpc_state_cleanup/test.sh +++ b/qa/L0_grpc_state_cleanup/test.sh @@ -1,5 +1,5 @@ #!/bin/bash -# Copyright 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright 2023-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions @@ -225,6 +225,40 @@ fi set -e +# Test for gRPC shutdown race condition (issue #6899) +TEST_NAME=test_shutdown_during_active_requests +SHUTDOWN_STRESS_TEST=shutdown_stress_test.py + +SERVER_ARGS="--model-repository=`pwd`/models --log-verbose=2" +SERVER_LOG="./inference_server.$TEST_NAME.log" +run_server +if [ "$SERVER_PID" == "0" ]; then + echo -e "\n***\n*** Failed to start $SERVER\n***" + cat $SERVER_LOG + exit 1 +fi + +echo "Test: $TEST_NAME" >>$CLIENT_LOG + +set +e +SERVER_PID=$SERVER_PID python $SHUTDOWN_STRESS_TEST ShutdownStressTest.$TEST_NAME >>$CLIENT_LOG 2>&1 +if [ $? -ne 0 ]; then + echo -e "\n***\n*** Test $TEST_NAME Failed\n***" >>$CLIENT_LOG + echo -e "\n***\n*** Test $TEST_NAME Failed\n***" + RET=1 +fi + +wait $SERVER_PID + +check_state_release $SERVER_LOG +if [ $? -ne 0 ]; then + cat $SERVER_LOG + echo -e "\n***\n*** State Verification Failed for $TEST_NAME\n***" + RET=1 +fi + +set -e + if [ $RET -eq 0 ]; then echo -e "\n***\n*** Test Passed\n***" else diff --git a/src/grpc/grpc_handler.h b/src/grpc/grpc_handler.h index 4f1bcdfac0..2e677b6990 100644 --- a/src/grpc/grpc_handler.h +++ b/src/grpc/grpc_handler.h @@ -1,4 +1,4 @@ -// Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright 2023-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions @@ -33,6 +33,8 @@ class HandlerBase { virtual ~HandlerBase() = default; virtual void Start() = 0; virtual void Stop() = 0; + // Notify handler that completion queue is shutting down + virtual void NotifyCQShutdown() = 0; }; class ICallData { diff --git a/src/grpc/grpc_server.cc b/src/grpc/grpc_server.cc index 3a28963c80..673e0276ce 100644 --- a/src/grpc/grpc_server.cc +++ b/src/grpc/grpc_server.cc @@ -274,6 +274,9 @@ class CommonHandler : public HandlerBase { // Stop handling requests. void Stop() override; + // Notify that CQ is shutting down (no-op for CommonHandler) + void NotifyCQShutdown() override {} + private: void SetUpAllRequests(); @@ -2603,6 +2606,17 @@ Server::Stop() graceful_shutdown_thread_.join(); } + // Notify all handlers that completion queues are shutting down. + // This must be done BEFORE calling Shutdown() on the CQs to prevent + // race conditions where alarms are set on shutting-down queues. + common_handler_->NotifyCQShutdown(); + for (auto& model_infer_handler : model_infer_handlers_) { + model_infer_handler->NotifyCQShutdown(); + } + for (auto& model_stream_infer_handler : model_stream_infer_handlers_) { + model_stream_infer_handler->NotifyCQShutdown(); + } + // Shutdown completion queues common_cq_->Shutdown(); model_infer_cq_->Shutdown(); diff --git a/src/grpc/infer_handler.cc b/src/grpc/infer_handler.cc index 2166f0c60e..639bfb6062 100644 --- a/src/grpc/infer_handler.cc +++ b/src/grpc/infer_handler.cc @@ -657,7 +657,7 @@ InferRequestComplete( void ModelInferHandler::StartNewRequest() { - auto context = std::make_shared(cq_); + auto context = CreateContext(); context->SetCompressionLevel(compression_level_); State* state = StateNew(tritonserver_.get(), context); @@ -1108,8 +1108,13 @@ ModelInferHandler::InferResponseComplete( } // Send state back to the queue so that state can be released - // in the next cycle. - state->context_->PutTaskBackToQueue(state); + // in the next cycle. If CQ is shutting down, don't enqueue. + if (!state->context_->PutTaskBackToQueue(state)) { + // CQ is shutting down, cleanup without enqueuing + LOG_VERBOSE(1) + << "InferResponseComplete: not requeueing state due to shutdown, " + << state->unique_id_; + } delete response_release_payload; } return; diff --git a/src/grpc/infer_handler.h b/src/grpc/infer_handler.h index e9bcef9eaf..7cfc37d893 100644 --- a/src/grpc/infer_handler.h +++ b/src/grpc/infer_handler.h @@ -29,11 +29,14 @@ #include #include +#include #include +#include #include #include #include #include +#include #include "../tracer.h" #include "grpc_handler.h" @@ -709,16 +712,35 @@ class InferHandlerState { // transaction (e.g. a stream). struct Context { explicit Context( - ::grpc::ServerCompletionQueue* cq, const uint64_t unique_id = 0) - : cq_(cq), unique_id_(unique_id), ongoing_requests_(0), - step_(Steps::START), finish_ok_(true), ongoing_write_(false), - received_notification_(false) + ::grpc::ServerCompletionQueue* cq, std::atomic* cq_shutting_down, + std::function register_alarm_fn, + std::function unregister_alarm_fn, + const uint64_t unique_id = 0) + : cq_(cq), cq_shutting_down_(cq_shutting_down), + register_alarm_fn_(register_alarm_fn), + unregister_alarm_fn_(unregister_alarm_fn), unique_id_(unique_id), + ongoing_requests_(0), step_(Steps::START), finish_ok_(true), + ongoing_write_(false), received_notification_(false) { ctx_.reset(new ::grpc::ServerContext()); responder_.reset(new ServerResponderType(ctx_.get())); gRPCErrorTracker_ = std::make_unique(); } + void RegisterAlarm(::grpc::Alarm* alarm) + { + if (register_alarm_fn_) { + register_alarm_fn_(alarm); + } + } + + void UnregisterAlarm(::grpc::Alarm* alarm) + { + if (unregister_alarm_fn_) { + unregister_alarm_fn_(alarm); + } + } + void SetCompressionLevel(grpc_compression_level compression_level) { ctx_->set_compression_level(compression_level); @@ -976,16 +998,26 @@ class InferHandlerState { } // Adds the state object to the completion queue so - // that it can be processed later - void PutTaskBackToQueue(InferHandlerStateType* state) + // that it can be processed later. Returns false if the + // completion queue is shutting down and the alarm was not set. + bool PutTaskBackToQueue(InferHandlerStateType* state) { + // Check if CQ is shutting down before scheduling alarm + if (state->context_->cq_shutting_down_ && + state->context_->cq_shutting_down_->load(std::memory_order_acquire)) { + LOG_VERBOSE(1) << "PutTaskBackToQueue suppressed for " + << state->unique_id_ << " due to CQ shutdown"; + return false; + } + std::lock_guard lock(mu_); - // FIXME: Is there a better way to put task on the - // completion queue rather than using alarm object? + // Register alarm before setting it + state->context_->RegisterAlarm(&state->alarm_); // The alarm object will add a new task to the back of the - // completion queue when it expires or when it’s cancelled. + // completion queue when it expires or when it's cancelled. state->alarm_.Set( cq_, gpr_now(gpr_clock_type::GPR_CLOCK_REALTIME), state); + return true; } // Check the state at the front of the queue and write it if @@ -1054,6 +1086,11 @@ class InferHandlerState { // The grpc completion queue associated with the RPC. ::grpc::ServerCompletionQueue* cq_; + // Shutdown guard - pointer to handler's shutdown flag + std::atomic* cq_shutting_down_; + std::function register_alarm_fn_; + std::function unregister_alarm_fn_; + // Unique ID for the context. Used only for debugging so will // always be 0 in non-debug builds. const uint64_t unique_id_; @@ -1349,6 +1386,15 @@ class InferHandler : public HandlerBase { return state; } + // Create a new context with shutdown guard + std::shared_ptr CreateContext() + { + return std::make_shared( + cq_, &cq_shutting_down_, + [this](::grpc::Alarm* a) { this->RegisterAlarm(a); }, + [this](::grpc::Alarm* a) { this->UnregisterAlarm(a); }); + } + void StateRelease(State* state) { LOG_VERBOSE(2) << "StateRelease, " << state->unique_id_ << " Step " @@ -1418,6 +1464,55 @@ class InferHandler : public HandlerBase { virtual bool Process(State* state, bool rpc_ok, bool is_notification) = 0; bool ExecutePrecondition(InferHandler::State* state); + // Notify handler that the completion queue is shutting down. + // This prevents new alarms from being set and cancels active ones. + void NotifyCQShutdown() override + { + bool expected = false; + if (cq_shutting_down_.compare_exchange_strong( + expected, true, std::memory_order_acq_rel)) { + LOG_VERBOSE(1) << "NotifyCQShutdown called for " << Name(); + CancelActiveAlarms(); + } + } + + // Check if the completion queue is shutting down + bool IsCQShuttingDown() const + { + return cq_shutting_down_.load(std::memory_order_acquire); + } + + // Register an alarm so it can be cancelled during shutdown + void RegisterAlarm(::grpc::Alarm* alarm) + { + std::lock_guard lock(alarms_mu_); + active_alarms_.insert(alarm); + } + + // Unregister an alarm when it fires or is cancelled + void UnregisterAlarm(::grpc::Alarm* alarm) + { + std::lock_guard lock(alarms_mu_); + active_alarms_.erase(alarm); + } + + // Cancel all active alarms + void CancelActiveAlarms() + { + std::unordered_set<::grpc::Alarm*> alarms_copy; + { + std::lock_guard lock(alarms_mu_); + alarms_copy.swap(active_alarms_); + } + LOG_VERBOSE(1) << "Cancelling " << alarms_copy.size() + << " active alarms for " << Name(); + for (auto* alarm : alarms_copy) { + if (alarm != nullptr) { + alarm->Cancel(); + } + } + } + TRITONSERVER_Error* ForwardHeadersAsParameters( TRITONSERVER_InferenceRequest* irequest, InferHandler::State* state); @@ -1444,6 +1539,11 @@ class InferHandler : public HandlerBase { std::shared_mutex* conn_mtx_; std::atomic* conn_cnt_; bool* accepting_new_conn_; + + // Shutdown guard for completion queue + std::atomic cq_shutting_down_{false}; + std::mutex alarms_mu_; + std::unordered_set<::grpc::Alarm*> active_alarms_; }; template < @@ -1500,8 +1600,37 @@ InferHandler< void* tag; bool ok; - while (cq_->Next(&tag, &ok)) { + // Use deadline-based polling to allow checking shutdown flag + while (true) { + // Poll with 100ms timeout to allow shutdown detection + gpr_timespec deadline = gpr_time_add( + gpr_now(GPR_CLOCK_MONOTONIC), + gpr_time_from_millis(100, GPR_TIMESPAN)); + + auto status = cq_->AsyncNext(&tag, &ok, deadline); + + if (status == ::grpc::CompletionQueue::SHUTDOWN) { + LOG_VERBOSE(1) << "Completion queue shut down for " << Name(); + break; + } + + if (status == ::grpc::CompletionQueue::TIMEOUT) { + // Check if we should exit due to shutdown + if (IsCQShuttingDown()) { + LOG_VERBOSE(1) << "Handler exiting due to CQ shutdown for " << Name(); + break; + } + continue; + } + + // GOT_EVENT - process the event State* state = static_cast(tag); + + // Unregister alarm if this event is from an alarm + if (state->step_ != Steps::WAITING_NOTIFICATION) { + state->context_->UnregisterAlarm(&state->alarm_); + } + bool is_notification = false; if (state->step_ == Steps::WAITING_NOTIFICATION) { State* state_wrapper = state; diff --git a/src/grpc/stream_infer_handler.cc b/src/grpc/stream_infer_handler.cc index 3789e5f9e8..21fc7f9edf 100644 --- a/src/grpc/stream_infer_handler.cc +++ b/src/grpc/stream_infer_handler.cc @@ -110,7 +110,7 @@ StreamOutputBufferAttributes( void ModelStreamInferHandler::StartNewRequest() { - auto context = std::make_shared(cq_, NEXT_UNIQUE_ID); + auto context = CreateContext(); context->SetCompressionLevel(compression_level_); State* state = StateNew(tritonserver_.get(), context); @@ -695,7 +695,11 @@ ModelStreamInferHandler::StreamInferResponseComplete( // that state object can be released. if (is_complete) { state->step_ = Steps::CANCELLED; - state->context_->PutTaskBackToQueue(state); + if (!state->context_->PutTaskBackToQueue(state)) { + LOG_VERBOSE(1) << "StreamInferResponseComplete: not requeueing state " + "due to shutdown, " + << state->unique_id_; + } delete response_release_payload; } @@ -823,7 +827,11 @@ ModelStreamInferHandler::StreamInferResponseComplete( // that state object can be released. if (is_complete) { state->step_ = Steps::CANCELLED; - state->context_->PutTaskBackToQueue(state); + if (!state->context_->PutTaskBackToQueue(state)) { + LOG_VERBOSE(1) << "StreamInferResponseComplete: not requeueing state " + "due to shutdown, " + << state->unique_id_; + } delete response_release_payload; } @@ -854,7 +862,11 @@ ModelStreamInferHandler::StreamInferResponseComplete( // The response queue is empty and complete final flag is received, so // mark the state as 'WRITEREADY' so it can be cleaned up later. state->step_ = Steps::WRITEREADY; - state->context_->PutTaskBackToQueue(state); + if (!state->context_->PutTaskBackToQueue(state)) { + LOG_VERBOSE(1) << "StreamInferResponseComplete: not requeueing state " + "due to shutdown, " + << state->unique_id_; + } } state->complete_ = is_complete; }