Skip to content

Commit 88ea2c4

Browse files
joyang-nvtongyuantongyuhchings
authored
[TRTLLM-7349][feat] Adding new orchestrator type -- ray (#7520)
Signed-off-by: Erin Ho <[email protected]> Co-authored-by: Yuan Tong <[email protected]> Co-authored-by: Erin Ho <[email protected]>
1 parent 9d098e3 commit 88ea2c4

File tree

91 files changed

+5541
-606
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

91 files changed

+5541
-606
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ tensorrt_llm/deep_ep_cpp_tllm.pyi
4646
tensorrt_llm/deep_gemm/
4747
tensorrt_llm/deep_gemm_cpp_tllm.*.so
4848
tensorrt_llm/deep_gemm_cpp_tllm.pyi
49+
tensorrt_llm/pg_utils_bindings.*.so
4950
*docs/cpp_docs*
5051
*docs/source/_cpp_gen*
5152
docs/source/**/*.rst

cpp/include/tensorrt_llm/batch_manager/cacheTransceiver.h

Lines changed: 142 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,17 @@
2323
#include "tensorrt_llm/executor/cacheCommunicator.h"
2424
#include "tensorrt_llm/executor/dataTransceiverState.h"
2525
#include "tensorrt_llm/runtime/utils/mpiUtils.h"
26+
#include "tensorrt_llm/runtime/utils/pgUtils.h"
2627
#include <future>
27-
#include <map>
2828
#include <memory>
29+
#include <mutex>
30+
#include <optional>
31+
#include <pybind11/pybind11.h>
32+
#include <torch/csrc/jit/python/pybind_utils.h>
33+
#include <torch/custom_class.h>
34+
#include <torch/python.h>
35+
#include <type_traits>
36+
#include <vector>
2937

3038
using SizeType32 = tensorrt_llm::runtime::SizeType32;
3139

@@ -43,6 +51,134 @@ class BaseKVCacheManager;
4351
class CacheSender;
4452
class CacheReceiver;
4553

54+
class CacheTransceiverComm
55+
{
56+
public:
57+
// Construct from a non-owning raw pointer, won't take ownership of the pointer
58+
explicit CacheTransceiverComm(mpi::MpiComm const* mpiComm)
59+
: mMpiComm(std::shared_ptr<mpi::MpiComm const>(nullptr), mpiComm)
60+
{
61+
}
62+
63+
// Construct from a shared_ptr with shared ownership
64+
explicit CacheTransceiverComm(std::shared_ptr<mpi::MpiComm const> mpiComm)
65+
: mMpiComm(std::move(mpiComm))
66+
{
67+
}
68+
69+
// Construct from a ProcessGroup communicator
70+
explicit CacheTransceiverComm(c10::intrusive_ptr<c10d::ProcessGroup> pgComm)
71+
: mPgComm(std::move(pgComm))
72+
{
73+
}
74+
75+
~CacheTransceiverComm() = default;
76+
77+
bool isMpi() const noexcept
78+
{
79+
return mMpiComm != nullptr;
80+
}
81+
82+
int getRank() const
83+
{
84+
if (isMpi())
85+
{
86+
return mMpiComm->getRank();
87+
}
88+
return mPgComm->getRank();
89+
}
90+
91+
int getSize() const
92+
{
93+
if (isMpi())
94+
{
95+
return mMpiComm->getSize();
96+
}
97+
return mPgComm->getSize();
98+
}
99+
100+
void allgather(void const* sendbuf, void* recvbuf, int count, mpi::MpiType dtype) const
101+
{
102+
if (isMpi())
103+
{
104+
mMpiComm->allgather(sendbuf, recvbuf, count, dtype);
105+
return;
106+
}
107+
TLLM_THROW("Input arguments only supported in mpi");
108+
}
109+
110+
template <typename Input, typename Output>
111+
bool allgather(Input input, Output output, c10d::AllgatherOptions options = c10d::AllgatherOptions()) const
112+
{
113+
if (isMpi())
114+
{
115+
TLLM_THROW("Input arguments only supported in pg");
116+
}
117+
tensorrt_llm::pg_utils::PgHelper pgh{mPgComm};
118+
119+
PGCHECK_THROW(pgh.allgather(input, output, options));
120+
return true;
121+
}
122+
123+
template <typename Input, typename Output>
124+
bool allgatherv(Input input, Output output, std::vector<int> const& sizes,
125+
c10d::AllgatherOptions options = c10d::AllgatherOptions()) const
126+
{
127+
if (isMpi())
128+
{
129+
TLLM_THROW("Input arguments only supported in pg");
130+
}
131+
tensorrt_llm::pg_utils::PgHelper pgh{mPgComm};
132+
PGCHECK_THROW(pgh.allgatherv(input, output, sizes, options));
133+
return true;
134+
}
135+
136+
bool allgatherv(void const* sendbuf, int sendcount, mpi::MpiType sendtype, void* recvbuf,
137+
std::vector<int> const& recvcounts, std::vector<int> const& displs, mpi::MpiType recvtype) const
138+
{
139+
if (isMpi())
140+
{
141+
mMpiComm->allgatherv(sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, recvtype);
142+
return true;
143+
}
144+
TLLM_THROW("Input arguments only supported in mpi");
145+
}
146+
147+
CacheTransceiverComm split(int color, int key)
148+
{
149+
if (isMpi())
150+
{
151+
auto subgroup = mMpiComm->split(color, key);
152+
return CacheTransceiverComm(std::make_shared<mpi::MpiComm const>(std::move(subgroup)));
153+
}
154+
bool const initialized = Py_IsInitialized();
155+
TLLM_CHECK_WITH_INFO(initialized, "Trying to use ProcessGroup communicator but Python is not initialized");
156+
try
157+
{
158+
c10::intrusive_ptr<c10d::ProcessGroup> pgSub;
159+
{
160+
pybind11::gil_scoped_acquire gil;
161+
auto const m = pybind11::module::import("tensorrt_llm._torch.distributed.pg_utils");
162+
// Properly box the existing intrusive_ptr ProcessGroup into an IValue
163+
// and convert to a Python object without constructing a new instance.
164+
auto const py_pg = torch::jit::toPyObject(c10::IValue(mPgComm));
165+
166+
auto const py_sub_pg = m.attr("split")(color, key, py_pg);
167+
pgSub = torch::jit::toCustomClass<c10d::ProcessGroup>(py_sub_pg);
168+
}
169+
return CacheTransceiverComm(pgSub);
170+
}
171+
catch (...)
172+
{
173+
TLLM_THROW("Failed to split process group");
174+
}
175+
}
176+
177+
private:
178+
std::shared_ptr<mpi::MpiComm const> mMpiComm;
179+
c10::intrusive_ptr<c10d::ProcessGroup> mPgComm;
180+
};
181+
46182
class CacheTransceiverFactory
47183
{
48184
public:
@@ -124,9 +260,11 @@ class CacheTransceiver : public BaseCacheTransceiver
124260
std::unique_ptr<CacheReceiver> mCacheReceiver;
125261
std::vector<std::pair<LlmRequest*, std::future<void>>> mSenderFutures;
126262
std::vector<std::pair<LlmRequest*, std::future<void>>> mRequesterFutures;
127-
mpi::MpiComm const *mMpiGroupComm{nullptr}, *mMpiWorldComm{nullptr};
128-
std::shared_ptr<mpi::MpiComm> mMpiGroupTensorParaComm, mMpiGroupPipeParaComm, mMpiGroupDataComm,
129-
mMpiGroupTPInDPComm;
263+
mpi::MpiComm const* mMpiWorldComm{nullptr};
264+
265+
std::shared_ptr<CacheTransceiverComm> mGroupComm;
266+
std::shared_ptr<CacheTransceiverComm> mGroupTensorParaComm, mGroupPipeParaComm, mGroupDataComm, mGroupTPInDPComm;
267+
130268
executor::kv_cache::CommState const* mCommState;
131269
std::unique_ptr<executor::kv_cache::CacheState> mCacheState;
132270
std::unique_ptr<executor::kv_cache::ConnectionManager> mManager;
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
/*
2+
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
#pragma once
18+
19+
#include "c10/util/intrusive_ptr.h"
20+
#include <Python.h>
21+
22+
namespace tensorrt_llm::common
23+
{
24+
25+
// Adapted from pybind11's example implementation:
26+
// https://github.com/pybind/pybind11/blob/master/include/pybind11/conduit/pybind11_conduit_v1.h
27+
// Copyright (c) 2024 The pybind Community.
28+
29+
inline void* get_raw_pointer_ephemeral(
30+
PyObject* py_obj, std::type_info const* cpp_type_info, std::string const& pybind11_abi)
31+
{
32+
PyObject* cpp_type_info_capsule = PyCapsule_New(
33+
const_cast<void*>(static_cast<void const*>(cpp_type_info)), typeid(std::type_info).name(), nullptr);
34+
if (cpp_type_info_capsule == nullptr)
35+
{
36+
return nullptr;
37+
}
38+
PyObject* cpp_conduit = PyObject_CallMethod(
39+
py_obj, "_pybind11_conduit_v1_", "yOy", pybind11_abi.c_str(), cpp_type_info_capsule, "raw_pointer_ephemeral");
40+
Py_DECREF(cpp_type_info_capsule);
41+
if (cpp_conduit == nullptr)
42+
{
43+
return nullptr;
44+
}
45+
void* raw_ptr = PyCapsule_GetPointer(cpp_conduit, cpp_type_info->name());
46+
Py_DECREF(cpp_conduit);
47+
if (PyErr_Occurred())
48+
{
49+
return nullptr;
50+
}
51+
return raw_ptr;
52+
}
53+
54+
template <typename T, typename E>
55+
T* get_type_pointer_ephemeral(PyObject* py_obj, std::string pybind11_abi)
56+
{
57+
void* raw_ptr = get_raw_pointer_ephemeral(py_obj, &typeid(T), pybind11_abi);
58+
if (raw_ptr == nullptr)
59+
{
60+
throw E();
61+
}
62+
return static_cast<T*>(raw_ptr);
63+
}
64+
65+
template <typename T, typename E>
66+
c10::intrusive_ptr<T> get_intrusive_ptr(PyObject* py_obj, std::string pybind11_abi)
67+
{
68+
auto* const p = get_type_pointer_ephemeral<T, E>(py_obj, pybind11_abi);
69+
return c10::intrusive_ptr<T>::reclaim_copy(p);
70+
}
71+
72+
} // namespace tensorrt_llm::common

cpp/include/tensorrt_llm/runtime/utils/mpiUtils.h

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
#include <cstdlib>
3636
#include <memory>
3737
#include <mutex>
38+
#include <optional>
3839
#include <thread>
3940

4041
#if ENABLE_MULTI_DEVICE
@@ -425,7 +426,29 @@ class MpiComm
425426
return !(rhs == *this);
426427
}
427428

429+
bool couldUseMPI() const
430+
{
431+
if (!mDisableMPI.has_value())
432+
{
433+
char* val = std::getenv("TLLM_DISABLE_MPI");
434+
if (val != NULL && std::string(val) == "1")
435+
{
436+
mDisableMPI = true;
437+
}
438+
else
439+
{
440+
mDisableMPI = false;
441+
}
442+
}
443+
if (mDisableMPI.value())
444+
{
445+
throw std::runtime_error("MPI is disabled, DON\'T USE MPI");
446+
}
447+
return true;
448+
}
449+
428450
private:
451+
mutable std::optional<bool> mDisableMPI;
429452
//! \brief Corresponds to `world()` by default, but can be overridden per process.
430453
static MpiComm& mutableSession();
431454

0 commit comments

Comments
 (0)