diff --git a/lmdeploy/cli/serve.py b/lmdeploy/cli/serve.py index 3250e6edc4..8b8f01b292 100644 --- a/lmdeploy/cli/serve.py +++ b/lmdeploy/cli/serve.py @@ -165,7 +165,8 @@ def add_parser_api_server(): max_prefill_token_num_act = ArgumentHelper.max_prefill_token_num(pt_group) quant_policy = ArgumentHelper.quant_policy(pt_group) model_format = ArgumentHelper.model_format(pt_group) - ArgumentHelper.dp(pt_group) + dp_act = ArgumentHelper.dp(pt_group) + num_nodes_act = ArgumentHelper.num_nodes(pt_group) ArgumentHelper.ep(pt_group) ArgumentHelper.enable_microbatch(pt_group) ArgumentHelper.enable_eplb(pt_group) @@ -173,14 +174,14 @@ def add_parser_api_server(): ArgumentHelper.role(pt_group) ArgumentHelper.migration_backend(pt_group) # multi-node serving args - ArgumentHelper.node_rank(parser) - ArgumentHelper.num_nodes(parser) + node_rank_act = ArgumentHelper.node_rank(pt_group) # turbomind args tb_group = parser.add_argument_group('TurboMind engine arguments') # common engine args tb_group._group_actions.append(dtype_act) tb_group._group_actions.append(tp_act) + tb_group._group_actions.append(dp_act) tb_group._group_actions.append(session_len_act) tb_group._group_actions.append(max_batch_size_act) tb_group._group_actions.append(cache_max_entry_act) @@ -189,10 +190,13 @@ def add_parser_api_server(): tb_group._group_actions.append(max_prefill_token_num_act) tb_group._group_actions.append(quant_policy) tb_group._group_actions.append(model_format) + tb_group._group_actions.append(num_nodes_act) + tb_group._group_actions.append(node_rank_act) ArgumentHelper.rope_scaling_factor(tb_group) ArgumentHelper.num_tokens_per_iter(tb_group) ArgumentHelper.max_prefill_iters(tb_group) ArgumentHelper.communicator(tb_group) + ArgumentHelper.ngpus_per_node(tb_group) # vlm args vision_group = parser.add_argument_group('Vision model arguments') @@ -342,6 +346,10 @@ def api_server(args): from lmdeploy.messages import TurbomindEngineConfig backend_config = TurbomindEngineConfig(dtype=args.dtype, tp=args.tp, + dp=args.dp, + nnodes=args.nnodes, + ngpus_per_node=args.ngpus_per_node, + node_rank=args.node_rank, max_batch_size=max_batch_size, session_len=args.session_len, model_format=args.model_format, diff --git a/lmdeploy/cli/utils.py b/lmdeploy/cli/utils.py index 5c6ca0b478..8dc43918fd 100644 --- a/lmdeploy/cli/utils.py +++ b/lmdeploy/cli/utils.py @@ -192,6 +192,12 @@ def num_nodes(parser): return parser.add_argument('--nnodes', type=int, default=1, help='The total node nums') + @staticmethod + def ngpus_per_node(parser): + """Add argument ngpus_per_node to parser.""" + + return parser.add_argument('--ngpus-per-node', type=int, default=None, help='The total gpu nums per node') + @staticmethod def session_id(parser): """Add argument session_id to parser.""" diff --git a/lmdeploy/messages.py b/lmdeploy/messages.py index f8efcff8a2..b009c4072f 100644 --- a/lmdeploy/messages.py +++ b/lmdeploy/messages.py @@ -229,12 +229,17 @@ class TurbomindEngineConfig: model_format: Optional[str] = None tp: int = 1 dp: int = 1 + pp: int = 1 device_num: int = None attn_tp_size: int = None attn_dp_size: int = None mlp_tp_size: int = None mlp_dp_size: int = None outer_dp_size: int = None + nnodes: int = 1 + node_rank: int = 0 + ngpus_per_node: Optional[int] = None + devices: List[int] = None session_len: Optional[int] = None max_batch_size: int = None cache_max_entry_count: float = 0.8 diff --git a/lmdeploy/turbomind/turbomind.py b/lmdeploy/turbomind/turbomind.py index 0c5632bc94..7144ba7988 100644 --- a/lmdeploy/turbomind/turbomind.py +++ b/lmdeploy/turbomind/turbomind.py @@ -5,6 +5,7 @@ import copy import json import math +import os import os.path as osp import sys from collections import defaultdict @@ -84,14 +85,23 @@ def complete_parallel_config(cfg: TurbomindEngineConfig): def update_parallel_config(cfg: TurbomindEngineConfig): + if cfg.nnodes > 1: + assert cfg.ngpus_per_node is not None or cfg.devices is not None + cfg.devices = cfg.devices or list(range(cfg.ngpus_per_node)) + cfg.ngpus_per_node = cfg.ngpus_per_node or len(cfg.devices) + cfg.device_num = cfg.device_num or len(cfg.devices) * cfg.nnodes + if not complete_parallel_config(cfg): - total = cfg.dp * cfg.tp + total = cfg.dp * cfg.tp * cfg.pp if not cfg.device_num: count = torch.cuda.device_count() if total < count: count = total cfg.device_num = count + assert cfg.device_num % cfg.pp == 0 assert total % cfg.device_num == 0 + if cfg.dp > 1: + total = cfg.device_num // cfg.pp overlap = total // cfg.device_num attn_dp_size = overlap mlp_tp_size = overlap @@ -102,10 +112,19 @@ def update_parallel_config(cfg: TurbomindEngineConfig): cfg.mlp_dp_size = 1 cfg.mlp_tp_size = mlp_tp_size * inner_tp_size assert cfg.attn_dp_size * cfg.attn_tp_size == cfg.mlp_dp_size * cfg.mlp_tp_size - assert cfg.attn_dp_size * cfg.attn_tp_size * cfg.outer_dp_size == cfg.device_num + assert cfg.attn_dp_size * cfg.attn_tp_size * cfg.outer_dp_size * cfg.pp == cfg.device_num + assert cfg.outer_dp_size > 0 and cfg.attn_tp_size > 0 cfg.devices = cfg.devices or list(range(cfg.device_num)) + # update devices + if cfg.nnodes == 1: + cfg.devices = cfg.devices if cfg.devices else list(range(cfg.device_num)) + cfg.ngpus_per_node = cfg.ngpus_per_node or len(cfg.devices) + # for simplicity, each node has dp + assert cfg.outer_dp_size * cfg.attn_dp_size % cfg.nnodes == 0 + + class TurboMind: """LMDeploy's inference engine. @@ -141,8 +160,15 @@ def __init__(self, f' greater than 0, but got {_engine_config.max_batch_size}' update_parallel_config(_engine_config) - - self.gpu_count = _engine_config.device_num + if _engine_config.nnodes > 1 and _engine_config.node_rank == 0: + from torch.distributed import TCPStore + master_addr = os.environ.get('LMDEPLOY_DP_MASTER_ADDR') + master_port = os.environ.get('LMDEPLOY_DP_MASTER_PORT') + assert master_addr is not None and master_port is not None, \ + 'LMDEPLOY_DP_MASTER_ADDR and LMDEPLOY_DP_MASTER_PORT should be set when using multi-node' + self.store = TCPStore(host_name=master_addr, port=int(master_port), is_master=True) + + self.gpu_count = len(_engine_config.devices) self.devices = _engine_config.devices self.tokenizer = tokenizer @@ -196,10 +222,8 @@ def _create_engine(self): def _create_weight(self, model_comm): """Allocate weight buffer, load params if from_workspace.""" - # TODO: support mpi - self.node_id = 0 - self.node_num = 1 - torch.cuda.synchronize() + engine_cfg = self.config_dict['engine_config'] + self.node_id = engine_cfg['node_rank'] # create weight def _create_weight_func(device_id): @@ -394,6 +418,8 @@ def close(self): del self._export_iter if self.model_comm is not None: self.model_comm = None + if hasattr(self, 'store'): + del self.store def create_instance(self, cuda_stream_id=0): """Create a turbomind instance. @@ -500,11 +526,6 @@ def __init__(self, tm_model: TurboMind, config: TurbomindModelConfig, cuda_strea self.tm_model = tm_model self.cuda_stream_id = cuda_stream_id - self.node_id = tm_model.node_id - self.gpu_count = tm_model.gpu_count - - self.session_len = tm_model.session_len - # create model instances lazy_init = self.tm_model.config_dict['engine_config'].get('empty_init', False) self._model_inst = None if lazy_init else self._create_model_instance(0) diff --git a/src/turbomind/comm/CMakeLists.txt b/src/turbomind/comm/CMakeLists.txt index 6e5c772c46..73490146df 100644 --- a/src/turbomind/comm/CMakeLists.txt +++ b/src/turbomind/comm/CMakeLists.txt @@ -20,6 +20,14 @@ if (BUILD_MULTI_GPU) target_link_libraries(device_comm INTERFACE nccl_comm) endif () + add_subdirectory(gloo) + target_link_libraries(host_comm INTERFACE gloo_comm) + + add_library(serialize STATIC serialize.cc) + target_link_libraries(serialize PRIVATE core) + set_property(TARGET serialize PROPERTY POSITION_INDEPENDENT_CODE ON) + target_link_libraries(host_comm INTERFACE serialize) + if (BUILD_TEST) add_executable(test_comm test_comm.cu) target_link_libraries(test_comm PRIVATE device_comm host_comm core pthread nvtx_utils) diff --git a/src/turbomind/comm/device_comm.h b/src/turbomind/comm/device_comm.h index d68ebdc4da..7618169d04 100644 --- a/src/turbomind/comm/device_comm.h +++ b/src/turbomind/comm/device_comm.h @@ -106,6 +106,16 @@ class DeviceCommImpl { { throw std::runtime_error("not implemented"); } + + virtual void Send(const void* sendbuff, size_t count, DataType type, int dst, int group, cudaStream_t stream) + { + throw std::runtime_error("not implemented"); + } + + virtual void Recv(void* recvbuff, size_t count, DataType type, int src, int group, cudaStream_t stream) + { + throw std::runtime_error("not implemented"); + } }; class DeviceComm { diff --git a/src/turbomind/comm/gloo/CMakeLists.txt b/src/turbomind/comm/gloo/CMakeLists.txt new file mode 100644 index 0000000000..cb3bf80278 --- /dev/null +++ b/src/turbomind/comm/gloo/CMakeLists.txt @@ -0,0 +1,29 @@ +# Copyright (c) OpenMMLab. All rights reserved. +cmake_minimum_required(VERSION 3.8) + +include(FetchContent) +FetchContent_Declare( + gloo + GIT_REPOSITORY https://github.com/pytorch/gloo.git + GIT_TAG c7b7b022c124d9643957d9bd55f57ac59fce8fa2 # pytorch-v2.8.0-rc4 +) + +# some settings of gloo, +set(GLOO_INSTALL OFF CACHE BOOL "" FORCE) +set(GLOO_STATIC_OR_SHARED STATIC CACHE STRING "" FORCE) +set(USE_NCCL OFF) +set(BUILD_TEST OFF) +FetchContent_MakeAvailable(gloo) + +# gloo build doesn't add include directories as a target property... +target_include_directories(gloo PUBLIC + $ + $ # config.h generated at cmake config time +) + +add_library(gloo_comm STATIC + gloo_comm.cc + tcp_store.cc +) +set_property(TARGET gloo_comm PROPERTY POSITION_INDEPENDENT_CODE ON) +target_link_libraries(gloo_comm PUBLIC gloo logger) diff --git a/src/turbomind/comm/gloo/gloo_comm.cc b/src/turbomind/comm/gloo/gloo_comm.cc new file mode 100644 index 0000000000..2391170d02 --- /dev/null +++ b/src/turbomind/comm/gloo/gloo_comm.cc @@ -0,0 +1,333 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "src/turbomind/comm/gloo/tcp_store.h" +#include "src/turbomind/comm/host_comm.h" +#include "src/turbomind/utils/logger.h" + +namespace turbomind::comm { + +const char* GLOO_SOCKET_IFNAME_ENV = "GLOO_SOCKET_IFNAME"; +const char STORE_INFO_DELIM = ','; + +std::shared_ptr<::gloo::transport::Device> createGlooDevice() +{ + ::gloo::transport::tcp::attr attr; + if (auto ifname = std::getenv(GLOO_SOCKET_IFNAME_ENV); ifname) { + attr.iface = ifname; + } + else { + attr.hostname = ::gloo::getHostname(); + } + return ::gloo::transport::tcp::CreateDevice(attr); +} + +class Store: public ::gloo::rendezvous::PrefixStore { +public: + explicit Store(const std::string& host, int port, const std::string& prefix): + host_(host), port_(port), ::gloo::rendezvous::PrefixStore(prefix, nullptr) + { + store_ = std::make_shared(host_, port_); + }; + + ~Store() = default; + + std::shared_ptr New(const std::string& prefix) + { + std::string new_prefix = prefix + "/" + prefix_; + return std::make_shared(host_, port_, new_prefix); + } + +public: + std::string host_; + int port_; + + using ::gloo::rendezvous::PrefixStore::store_; + using ::gloo::rendezvous::PrefixStore::prefix_; +}; + +class GlobalStoreFactory { +public: + static GlobalStoreFactory& Instance() + { + static GlobalStoreFactory instance; + return instance; + } + + std::string New() + { + std::lock_guard lock(mutex_); + + std::string host = std::getenv("LMDEPLOY_DP_MASTER_ADDR"); + int port = std::stoi(std::getenv("LMDEPLOY_DP_MASTER_PORT")); + + std::stringstream ss; + ss << host << STORE_INFO_DELIM << port << STORE_INFO_DELIM << prefix_++; + return ss.str(); + } + + std::shared_ptr Load(const std::string& info) + { + std::stringstream ss(info); + std::vector keys; + std::string local; + while (getline(ss, local, STORE_INFO_DELIM)) { + keys.push_back(std::move(local)); + } + FT_CHECK(keys.size() == 3); + + std::string host = keys[0]; + int port = stoi(keys[1]); + std::string prefix = keys[2]; + + return std::make_shared(host, port, prefix); + } + +private: + GlobalStoreFactory() {} + + std::mutex mutex_; + int prefix_{0}; +}; + +typedef void (*ReduceFunc)(void*, const void*, const void*, size_t); + +struct GlooCommImpl: public HostCommImpl { + + struct SplitInfo { + int color; + int rank; + + bool operator<(const SplitInfo& other) const + { + return (color < other.color) || (color == other.color && rank < other.rank); + } + + bool operator==(const SplitInfo& other) const + { + return (color == other.color) && (rank == other.rank); + } + }; + + GlooCommImpl(std::shared_ptr store, int n_ranks, int rank): + store_{std::move(store)}, rank_{rank}, n_ranks_{n_ranks} + { + // TM_LOG_INFO("[GlooCommImpl] rank=%d, n_ranks=%d, prefix=%s", rank_, n_ranks_, store_->prefix_.c_str()); + device_ = createGlooDevice(); + context_ = std::make_shared<::gloo::rendezvous::Context>(rank_, n_ranks_); + context_->connectFullMesh(store_, device_); + } + + ~GlooCommImpl() {} + + int rank() const override + { + return rank_; + } + + int n_ranks() const override + { + return n_ranks_; + } + + bool is_same_process() const override + { + return false; + } + + std::shared_ptr Split(int color, int key) override + { + // don't know why key was set to 0 + auto vec = comm::AllGather(this, SplitInfo{color, rank_}); + auto last = std::stable_partition(vec.begin(), vec.end(), [&](auto x) { // + return x.color == color; + }); + vec.erase(last, vec.end()); + std::stable_sort(vec.begin(), vec.end(), [](auto& a, auto& b) { // + return a < b; + }); + + auto new_prefix = std::to_string(color) + ":" + std::to_string(n_split_++); + auto new_store = store_->New(new_prefix); + int new_n_ranks = vec.size(); + int new_rank = std::find(vec.begin(), vec.end(), SplitInfo{color, rank_}) - vec.begin(); + return std::make_shared(new_store, new_n_ranks, new_rank); + } + + void Sync() override + { + if (n_ranks_ == 1) { + return; + } + ::gloo::BarrierOptions opts(context_); + opts.setTimeout(std::chrono::milliseconds(1000 * 60 * 30)); + ::gloo::barrier(opts); + } + + void Broadcast(void* data, int count, DataType dtype, int root, copy_fn copy) override + { + if (n_ranks_ == 1) { + return; + } + ::gloo::BroadcastOptions opts(context_); + opts.setRoot(root); + opts.setTimeout(std::chrono::milliseconds(1000 * 60 * 30)); + opts.setOutput((char*)data, count); + ::gloo::broadcast(opts); + } + + void AllGather(void* data, int count, DataType dtype, copy_fn copy) override + { + if (n_ranks_ == 1) { + return; + } + ::gloo::AllgatherOptions opts(context_); + opts.setTimeout(std::chrono::milliseconds(1000 * 60 * 30)); + opts.setOutput((char*)data, count * n_ranks_); + ::gloo::allgather(opts); + } + + static ReduceFunc getReduceFunc(DataType dtype, RedOp red_op) + { + + auto dispatch_op = [&](auto t) -> ReduceFunc { + using T = decltype(t); + switch (red_op) { + case RedOp::kSum: + return ::gloo::sum; + case RedOp::kMax: + return ::gloo::max; + case RedOp::kMin: + return ::gloo::min; + default: + return {}; + } + }; + + auto dispatch = [&]() -> ReduceFunc { + switch (dtype) { + case kInt32: + return dispatch_op(int32_t{}); + case kInt64: + return dispatch_op(int64_t{}); + case kUint32: + return dispatch_op(uint32_t{}); + case kUint64: + return dispatch_op(uint64_t{}); + default: + return {}; + } + }; + + if (auto fn = dispatch()) { + return fn; + } + else { + throw std::runtime_error("not implemented"); + return {}; + } + } + + void AllReduce(void* data, int count, DataType dtype, RedOp red_op) override + { + if (n_ranks_ == 1) { + return; + } + ::gloo::AllreduceOptions opts(context_); + opts.setTimeout(std::chrono::milliseconds(1000 * 60 * 30)); + opts.setReduceFunction(getReduceFunc(dtype, red_op)); + switch (dtype) { + case kInt32: + opts.setOutput((int32_t*)data, count); + break; + case kInt64: + opts.setOutput((int64_t*)data, count); + break; + case kUint32: + opts.setOutput((uint32_t*)data, count); + break; + case kUint64: + opts.setOutput((uint64_t*)data, count); + break; + default: + throw std::runtime_error("not implemented"); + } + ::gloo::allreduce(opts); + } + + void Send(void* data, int count, DataType dtype, int dst) override + { + auto buf = context_->createUnboundBuffer(const_cast(data), byte_size(dtype) * count); + buf->send(dst, 0); + buf->waitSend(std::chrono::milliseconds(1000 * 60 * 30)); + } + + void Recv(void* data, int count, DataType dtype, int src, copy_fn copy) override + { + auto buf = context_->createUnboundBuffer(data, byte_size(dtype) * count); + buf->recv(src, 0); + buf->waitRecv(std::chrono::milliseconds(1000 * 60 * 30)); + } + + int n_split_{}; + std::shared_ptr<::gloo::transport::Device> device_; + std::shared_ptr<::gloo::rendezvous::Context> context_; + std::shared_ptr store_; + int rank_; + int n_ranks_; + uint32_t tag_{}; +}; + +class GlooGroupId: public HostGroupId { + + void Initialize() override + { + info_ = GlobalStoreFactory::Instance().New(); + // TM_LOG_ERROR("[GlooGroupId][Initialize] info=%s", info_.c_str()); + } + + void Export(std::ostream& os) override + { + os << info_; + } + + void Import(std::istream& is) override + { + std::stringstream ss; + ss << is.rdbuf(); + info_ = ss.str(); + } + + HostComm CreateCommunicator(int n_ranks, int rank) override + { + FT_CHECK(info_ != ""); + auto impl = std::make_shared(GlobalStoreFactory::Instance().Load(info_), n_ranks, rank); + return std::static_pointer_cast(impl); + } + +private: + std::string info_; // ip,port,prefix + std::shared_ptr<::gloo::rendezvous::Store> store_; +}; + +std::unique_ptr CreateGlooGroupId() +{ + return std::make_unique(); +} + +} // namespace turbomind::comm diff --git a/src/turbomind/comm/gloo/tcp_store.cc b/src/turbomind/comm/gloo/tcp_store.cc new file mode 100644 index 0000000000..54706aba0f --- /dev/null +++ b/src/turbomind/comm/gloo/tcp_store.cc @@ -0,0 +1,220 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#include +#include +#include +#include + +#include +#include + +#include "src/turbomind/comm/gloo/tcp_store.h" +#include "src/turbomind/utils/logger.h" + +namespace turbomind::comm { + +namespace { + +// copy from pytorch https://github.com/pytorch/pytorch/blob/v2.8.0-rc4/torch/csrc/distributed/c10d/TCPStoreBackend.hpp + +static const uint32_t validationMagicNumber = 0x3C85F7CE; + +enum class CheckResponseType : uint8_t +{ + READY, + NOT_READY +}; + +enum class QueryType : uint8_t +{ + VALIDATE, + SET, + COMPARE_SET, + GET, + ADD, + CHECK, + WAIT, + GETNUMKEYS, + DELETE_KEY, + APPEND, + MULTI_GET, + MULTI_SET, + CANCEL_WAIT, + PING, + QUEUE_PUSH, + QUEUE_POP, + QUEUE_LEN, +}; + +} // namespace + +struct Buffer { + std::vector buffer; + + template>> + void append(T val) + { + char* ptr = (char*)&val; + buffer.insert(buffer.end(), ptr, ptr + sizeof(T)); + } + + void append(const std::vector& vec) + { + append((uint64_t)vec.size()); + buffer.insert(buffer.end(), vec.begin(), vec.end()); + } + + void append(const std::string& str) + { + append((uint64_t)str.size()); + buffer.insert(buffer.end(), str.begin(), str.end()); + } + + const char* data() const + { + return buffer.data(); + } + + size_t count() const + { + return buffer.size(); + } +}; + +void validate(std::shared_ptr<::gloo::transport::tcp::Socket>& socket) +{ + Buffer buffer; + buffer.append(QueryType::VALIDATE); + buffer.append(validationMagicNumber); + socket->write(buffer.data(), buffer.count()); +} + +void ping(std::shared_ptr<::gloo::transport::tcp::Socket>& socket) +{ + Buffer buffer; + buffer.append(QueryType::PING); + uint32_t nonce = getpid(); + uint32_t returnedNonce = -1; + buffer.append(nonce); + socket->write(buffer.data(), buffer.count()); + int r = socket->read(&returnedNonce, sizeof(returnedNonce)); + if (nonce != returnedNonce) { + std::stringstream ss; + ss << "Ping failed, nonce=" << nonce << ", returnedNonce=" << returnedNonce << ", socket read=" << r; + throw std::runtime_error(ss.str()); + } +} + +TCPStore::TCPStore(const std::string& host, int port) +{ + auto retry = 0; + do { + try { + ::addrinfo hints{}, *res{}; + hints.ai_flags = AI_V4MAPPED | AI_ALL | AI_NUMERICSERV; + hints.ai_family = AF_UNSPEC; + hints.ai_socktype = SOCK_STREAM; + + int status = getaddrinfo(host.c_str(), std::to_string(port).c_str(), &hints, &res); + + std::shared_ptr holder(res, [](addrinfo* p) { + if (p != nullptr) { + freeaddrinfo(p); + } + }); + + if (status != 0) { + throw std::runtime_error("getaddrinfo failed: " + std::string(gai_strerror(status))); + } + + for (::addrinfo* addr = res; addr != nullptr; addr = addr->ai_next) { + int fd = ::socket(addr->ai_family, addr->ai_socktype, addr->ai_protocol); + if (fd == -1) { + continue; + } + auto socket = std::make_shared<::gloo::transport::tcp::Socket>(fd); + socket->connect(addr->ai_addr, addr->ai_addrlen); + socket->noDelay(true); + socket->recvTimeout(std::chrono::milliseconds(5000)); + socket->sendTimeout(std::chrono::milliseconds(5000)); + validate(socket); // validate the connection + ping(socket); // check send/recv + socket_ = std::move(socket); + break; + } + + if (socket_ == nullptr) { + throw std::runtime_error("unable to connect to " + host + ":" + std::to_string(port)); + } + } + catch (const std::exception& e) { + TM_LOG_WARNING("[TCPStore] Failed to connect to store after %d retries: %s", retry, e.what()); + std::this_thread::sleep_for(std::chrono::seconds(1)); + retry += 1; + } + } while (socket_ == nullptr); +} + +void TCPStore::set(const std::string& key, const std::vector& data) +{ + std::lock_guard lock(mutex_); + Buffer buffer; + buffer.append(QueryType::SET); + buffer.append(key); + buffer.append(data); + socket_->write(buffer.data(), buffer.count()); +} + +std::vector TCPStore::get(const std::string& key) +{ + wait({key}); + std::lock_guard lock(mutex_); + Buffer buffer; + buffer.append(QueryType::GET); + buffer.append(key); + socket_->write(buffer.data(), buffer.count()); + + uint64_t vec_size; + socket_->read(&vec_size, sizeof(vec_size)); + std::vector value(vec_size); + socket_->read(value.data(), value.size()); + return value; +} + +bool TCPStore::check(const std::vector& keys) +{ + std::lock_guard lock(mutex_); + Buffer buffer; + buffer.append(QueryType::CHECK); + buffer.append((uint64_t)keys.size()); + for (const auto& key : keys) { + buffer.append(key); + } + socket_->write(buffer.data(), buffer.count()); + + CheckResponseType response; + socket_->read(&response, sizeof(response)); + return response == CheckResponseType::READY; +} + +void TCPStore::wait(const std::vector& keys, const std::chrono::milliseconds& timeout) +{ + const auto start = std::chrono::steady_clock::now(); + while (!check(keys)) { + const auto elapsed = std::chrono::duration_cast(std::chrono::steady_clock::now() - start); + if (elapsed > timeout) { + std::stringstream ss; + ss << "Wait timeout for key(s): ["; + for (const auto& key : keys) { + ss << key << " "; + } + ss << "]"; + throw std::runtime_error("Wait timeout for key(s): " + ss.str()); + } + std::this_thread::sleep_for(std::chrono::milliseconds(1000)); + } +} + +TCPStore::~TCPStore() = default; + +} // namespace turbomind::comm diff --git a/src/turbomind/comm/gloo/tcp_store.h b/src/turbomind/comm/gloo/tcp_store.h new file mode 100644 index 0000000000..35dd1c05bf --- /dev/null +++ b/src/turbomind/comm/gloo/tcp_store.h @@ -0,0 +1,37 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#pragma once + +#include +#include + +#include +#include + +namespace turbomind::comm { + +class TCPStore: public gloo::rendezvous::Store { +public: + explicit TCPStore(const std::string& host, int port); + + ~TCPStore(); + + void set(const std::string& key, const std::vector& data) override; + + std::vector get(const std::string& key) override; + + bool check(const std::vector& keys); + + void wait(const std::vector& keys) override + { + wait(keys, std::chrono::seconds(30)); + } + + void wait(const std::vector& keys, const std::chrono::milliseconds& timeout) override; + +private: + std::shared_ptr<::gloo::transport::tcp::Socket> socket_; + std::mutex mutex_; +}; + +} // namespace turbomind::comm diff --git a/src/turbomind/comm/host_comm.cc b/src/turbomind/comm/host_comm.cc index 0d3cf367e2..37c45732df 100644 --- a/src/turbomind/comm/host_comm.cc +++ b/src/turbomind/comm/host_comm.cc @@ -8,8 +8,16 @@ HostCommImpl::~HostCommImpl() = default; std::unique_ptr CreateThreadGroupId(); +std::unique_ptr CreateGlooGroupId(); + std::unique_ptr CreateHostGroupId(const std::string& backend) { +#ifdef BUILD_MULTI_GPU + if (backend == "gloo") { + return CreateGlooGroupId(); + } +#endif + return CreateThreadGroupId(); } diff --git a/src/turbomind/comm/host_comm.h b/src/turbomind/comm/host_comm.h index b036142264..c6b6a37f5a 100644 --- a/src/turbomind/comm/host_comm.h +++ b/src/turbomind/comm/host_comm.h @@ -3,12 +3,15 @@ #pragma once #include +#include #include #include #include #include +#include "src/turbomind/comm/serialize.h" #include "src/turbomind/core/data_type.h" +#include "src/turbomind/utils/logger.h" namespace turbomind::comm { @@ -42,6 +45,10 @@ class HostCommImpl { virtual void AllGather(void* data, int count, DataType dtype, copy_fn copy) = 0; virtual void AllReduce(void* data, int count, DataType dtype, RedOp red_op) = 0; + + virtual void Send(void* data, int count, DataType dtype, int dst) = 0; + + virtual void Recv(void* data, int count, DataType dtype, int src, copy_fn copy) = 0; }; class HostComm { @@ -88,7 +95,23 @@ void Broadcast(HostCommImpl* comm, T* data, int n, int root) comm->Broadcast(data, n, kNull, root, detail::copy_fn); } else { - throw std::runtime_error("not implemented"); + try { + // buf may have different size on different ranks + std::vector buf; + serialize(data, n, buf); + size_t size = buf.size(); + Broadcast(comm, &size, 1, root); + buf.resize(size); + comm->Broadcast(buf.data(), buf.size(), data_type_v, root, detail::copy_fn); + if (comm->rank() != root) { + // some field in data may be not shared by all rank + deserialize(data, n, buf); + } + } + catch (const std::invalid_argument& e) { + TM_LOG_ERROR("Broadcast failed: %s", e.what()); + throw; + } } } } @@ -105,8 +128,31 @@ void AllGather(HostCommImpl* comm, T* data, int n) comm->AllGather(data, n, kNull, detail::copy_fn); } else { - /// serialize data - throw std::runtime_error("not implemented"); + try { + // buf may have different size on different ranks + std::vector rbuf; + for (int i = 0; i < n; ++i) { + std::vector ibuf; + serialize(data + n * comm->rank() + i, 1, ibuf); + rbuf.insert(rbuf.end(), ibuf.begin(), ibuf.end()); + } + int size = rbuf.size(); + comm->AllReduce(&size, 1, data_type_v, RedOp::kMax); + std::vector buf(size * comm->n_ranks()); + std::memcpy(buf.data() + comm->rank() * size, rbuf.data(), rbuf.size()); + comm->AllGather(buf.data(), size, data_type_v, detail::copy_fn); + for (int i = 0; i < comm->n_ranks(); ++i) { + if (i != comm->rank()) { + // some field in data may be not shared by all rank + deserialize( + data + n * i, n, std::vector(buf.begin() + i * size, buf.begin() + (i + 1) * size)); + } + } + } + catch (const std::invalid_argument& e) { + TM_LOG_ERROR("AllGather failed: %s", e.what()); + throw; + } } } } @@ -114,9 +160,64 @@ void AllGather(HostCommImpl* comm, T* data, int n) template void AllReduce(HostCommImpl* comm, T* data, int n, RedOp red_op) { + static_assert(std::is_trivially_copyable_v, "AllReduce only supports trivially copyable types"); comm->AllReduce(data, n, data_type_v, red_op); } +template +void Send(HostCommImpl* comm, T* data, int n, int dst) +{ + if constexpr (std::is_trivially_copyable_v) { + comm->Send(data, sizeof(T) * n, data_type_v, dst); + } + else { + if (comm->is_same_process()) { + comm->Send(data, n, kNull, dst); + } + else { + try { + std::vector buf; + for (int i = 0; i < n; ++i) { + std::vector ibuf; + serialize(data + i, 1, ibuf); + buf.insert(buf.end(), ibuf.begin(), ibuf.end()); + } + uint64_t size = buf.size(); + comm->Send(&size, 1, data_type_v, dst); + comm->Send(buf.data(), buf.size(), data_type_v, dst); + } + catch (const std::invalid_argument& e) { + TM_CHECK(0) << "Send failed: " << e.what(); + } + } + } +} + +template +void Recv(HostCommImpl* comm, T* data, int n, int src) +{ + if constexpr (std::is_trivially_copyable_v) { + comm->Recv(data, sizeof(T) * n, data_type_v, src, detail::copy_fn); + } + else { + if (comm->is_same_process()) { + comm->Recv(data, n, kNull, src, detail::copy_fn); + } + else { + try { + uint64_t size; + comm->Recv(&size, 1, data_type_v, src, detail::copy_fn); + std::vector buf(size); + comm->Recv(buf.data(), size, data_type_v, src, detail::copy_fn); + deserialize(data, n, buf); + } + catch (const std::invalid_argument& e) { + TM_CHECK(0) << "Recv failed: " << e.what(); + } + } + } +} + ////////////////////////////////////////////////////////////////////////////////// // Typed value interface template @@ -142,6 +243,18 @@ T AllReduce(HostCommImpl* comm, const T& value, RedOp red_op) return tmp; } +template +void Send(HostCommImpl* comm, T& value, int dst) +{ + Send(comm, &value, 1, dst); +} + +template +void Recv(HostCommImpl* comm, T& value, int src) +{ + Recv(comm, &value, 1, src); +} + class HostGroupId { public: virtual ~HostGroupId() = default; diff --git a/src/turbomind/comm/nccl/nccl.cu b/src/turbomind/comm/nccl/nccl.cu index 804dfaaa46..78e5febbda 100644 --- a/src/turbomind/comm/nccl/nccl.cu +++ b/src/turbomind/comm/nccl/nccl.cu @@ -302,6 +302,50 @@ public: } } + int CreateOrGetP2PGroupIndex(int src, int dst, int group) + { + int low_rank = src < dst ? src : dst; + int high_rank = src < dst ? dst : src; + std::string key = std::to_string(group) + ":" + std::to_string(low_rank) + ":" + std::to_string(high_rank); + + if (p2p_group_index_map_.count(key) == 0) { + ncclUniqueId uid{}; + static_assert(std::is_trivially_copyable_v); + if (src == rank(group)) { + NCCLCHECK(ncclGetUniqueId(&uid)); + ::turbomind::comm::Send(h_comm_, uid, dst); + } + else { + ::turbomind::comm::Recv(h_comm_, uid, src); + } + + int new_rank = low_rank == rank(group) ? 0 : 1; + ncclComm_t comm{}; + NCCLCHECK(ncclCommInitRank(&comm, 2, uid, new_rank)); + groups_.push_back(comm); + p2p_group_index_map_[key] = groups_.size() - 1; + } + return p2p_group_index_map_[key]; + } + + void Send(const void* sendbuff, size_t count, DataType type, int dst, int group, cudaStream_t stream) override + { + int peer = rank(group) < dst ? 1 : 0; + ncclComm_t comm = groups_.at(CreateOrGetP2PGroupIndex(rank(group), dst, group)); + NCCLCHECK(ncclGroupStart()); + NCCLCHECK(ncclSend(sendbuff, count, to_nccl_dtype(type), peer, comm, stream)); + NCCLCHECK(ncclGroupEnd()); + } + + void Recv(void* recvbuff, size_t count, DataType type, int src, int group, cudaStream_t stream) override + { + int peer = rank(group) < src ? 1 : 0; + ncclComm_t comm = groups_.at(CreateOrGetP2PGroupIndex(src, rank(group), group)); + NCCLCHECK(ncclGroupStart()); + NCCLCHECK(ncclRecv(recvbuff, count, to_nccl_dtype(type), peer, comm, stream)); + NCCLCHECK(ncclGroupEnd()); + } + private: HostComm h_comm_; @@ -312,6 +356,8 @@ private: std::unordered_map handles_; std::unordered_map buffers_; + + std::unordered_map p2p_group_index_map_; }; DeviceComm CreateNcclCommunicator(int n_ranks, int rank, HostComm h_comm) diff --git a/src/turbomind/comm/serialize.cc b/src/turbomind/comm/serialize.cc new file mode 100644 index 0000000000..8441f4be45 --- /dev/null +++ b/src/turbomind/comm/serialize.cc @@ -0,0 +1,195 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#include +#include + +#include "src/turbomind/comm/serialize.h" +#include "src/turbomind/engine/request.h" + +namespace turbomind::comm { + +std::vector streambuf_to_vector(std::streambuf* sb) +{ + auto start = sb->pubseekoff(0, std::ios::beg, std::ios::in); + auto end = sb->pubseekoff(0, std::ios::end, std::ios::in); + auto size = end - start; + + std::vector buffer(size); + sb->pubseekpos(start); + sb->sgetn(buffer.data(), size); + return buffer; +} + +void serialize(std::ostream& os, const std::string& s) +{ + int size = s.length(); + serialize(os, size); + os << s; +} + +void deserialize(std::istream& is, std::string& s) +{ + int size; + deserialize(is, size); + s.resize(size); + is.read(s.data(), size); +} + +void serialize(std::ostream& os, const GenerationConfig& gen) +{ + serialize(os, gen.max_new_tokens); + serialize(os, gen.min_new_tokens); + serialize(os, gen.eos_ids); + serialize(os, gen.stop_ids[0]); + serialize(os, gen.stop_ids[1]); + serialize(os, gen.bad_ids[0]); + serialize(os, gen.bad_ids[1]); + serialize(os, gen.top_k); + serialize(os, gen.top_p); + serialize(os, gen.min_p); + serialize(os, gen.temperature); + serialize(os, gen.repetition_penalty); + serialize(os, gen.random_seed); + serialize(os, gen.output_logprobs); + serialize(os, gen.output_last_hidden_state); + serialize(os, gen.output_logits); +} + +void deserialize(std::istream& is, GenerationConfig& gen) +{ + deserialize(is, gen.max_new_tokens); + deserialize(is, gen.min_new_tokens); + deserialize(is, gen.eos_ids); + deserialize(is, gen.stop_ids[0]); + deserialize(is, gen.stop_ids[1]); + deserialize(is, gen.bad_ids[0]); + deserialize(is, gen.bad_ids[1]); + deserialize(is, gen.top_k); + deserialize(is, gen.top_p); + deserialize(is, gen.min_p); + deserialize(is, gen.temperature); + deserialize(is, gen.repetition_penalty); + deserialize(is, gen.random_seed); + deserialize(is, gen.output_logprobs); + deserialize(is, gen.output_last_hidden_state); + deserialize(is, gen.output_logits); +} + +void serialize(std::ostream& os, const SessionParam& sess) +{ + serialize(os, sess.id); + serialize(os, sess.step); + serialize(os, sess.start_flag); + serialize(os, sess.end_flag); + serialize(os, sess.kill_flag); +} + +void deserialize(std::istream& is, SessionParam& sess) +{ + deserialize(is, sess.id); + deserialize(is, sess.step); + deserialize(is, sess.start_flag); + deserialize(is, sess.end_flag); + deserialize(is, sess.kill_flag); +} + +void serialize(std::ostream& os, const Layout& layout) +{ + serialize(os, layout.shape()); + serialize(os, layout.stride()); +} + +void deserialize(std::istream& is, Layout& layout) +{ + std::vector shape; + std::vector stride; + deserialize(is, shape); + deserialize(is, stride); + layout = Layout(std::move(shape), std::move(stride)); +} + +void serialize(std::ostream& os, const Buffer& buffer) +{ + FT_CHECK(buffer.device() == turbomind::core::Device(kCPU)); + serialize(os, buffer.size()); + serialize(os, buffer.dtype()); + os.write((char*)buffer.raw_data(), buffer.byte_size()); +} + +void deserialize(std::istream& is, Buffer& buffer) +{ + ssize_t size; + DataType dtype; + deserialize(is, size); + deserialize(is, dtype); + buffer = Buffer(size, dtype, turbomind::core::Device(kCPU)); + is.read((char*)buffer.raw_data(), buffer.byte_size()); +} + +void serialize(std::ostream& os, const Tensor& tensor) +{ + FT_CHECK(tensor.is_contiguous()); + serialize(os, tensor.layout()); + serialize(os, tensor.buffer()); +} + +void deserialize(std::istream& is, Tensor& tensor) +{ + Layout layout; + Buffer buffer; + deserialize(is, layout); + deserialize(is, buffer); + tensor = Tensor(std::move(buffer), std::move(layout)); +} + +void serialize(std::ostream& os, const TensorMap& map) +{ + int size = map.size(); + serialize(os, size); + for (const auto& [key, tensor] : map) { + serialize(os, key); + serialize(os, tensor); + } +} + +void deserialize(std::istream& is, TensorMap& map) +{ + int size; + deserialize(is, size); + for (int i = 0; i < size; ++i) { + std::string key; + deserialize(is, key); + Tensor tensor; + deserialize(is, tensor); + map.emplace(key, tensor); + } +} + +void serialize(std::ostream& os, const Request& req) +{ + serialize(os, req.id); + serialize(os, req.unique_id); + serialize(os, req.session); + serialize(os, req.gen_cfg); + serialize(os, req.stream_output); + serialize(os, req.inputs); + serialize(os, req.outputs); + serialize(os, req.ec); +} + +void deserialize(std::istream& is, Request& req) +{ + deserialize(is, req.id); + deserialize(is, req.unique_id); + deserialize(is, req.session); + deserialize(is, req.gen_cfg); + deserialize(is, req.stream_output); + deserialize(is, req.inputs); + deserialize(is, req.outputs); + deserialize(is, req.ec); + + req.output_ids = req.outputs.at("output_ids"); + req.sequence_length = req.outputs.at("sequence_length"); +} + +} // namespace turbomind::comm diff --git a/src/turbomind/comm/serialize.h b/src/turbomind/comm/serialize.h new file mode 100644 index 0000000000..dded661e34 --- /dev/null +++ b/src/turbomind/comm/serialize.h @@ -0,0 +1,92 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "src/turbomind/core/tensor.h" +#include "src/turbomind/engine/request.h" + +namespace turbomind::comm { + +std::vector streambuf_to_vector(std::streambuf* sb); + +template +inline void serialize(const T*, int n, std::vector&) +{ + throw std::invalid_argument("not implemented"); +} + +template +inline void deserialize(T*, int n, const std::vector&) +{ + throw std::invalid_argument("not implemented"); +} + +template>> +inline void serialize(std::ostream& os, const T& v) +{ + os.write((char*)&v, sizeof(v)); +} + +template>> +inline void deserialize(std::istream& is, T& v) +{ + is.read((char*)&v, sizeof(v)); +} + +void serialize(std::ostream& os, const std::string& s); + +void deserialize(std::istream& is, std::string& s); + +template>> +inline void serialize(std::ostream& os, const std::vector& vec) +{ + int size = vec.size(); + os.write((char*)&size, sizeof(int)); + os.write((char*)vec.data(), sizeof(T) * size); +} + +template>> +inline void deserialize(std::istream& is, std::vector& vec) +{ + int size; + is.read((char*)&size, sizeof(int)); + vec.resize(size); + is.read((char*)vec.data(), sizeof(T) * size); +} + +void serialize(std::ostream& os, const GenerationConfig& gen); + +void deserialize(std::istream& is, GenerationConfig& gen); + +void serialize(std::ostream& os, const SessionParam& sess); + +void deserialize(std::istream& is, SessionParam& sess); + +void serialize(std::ostream& os, const Layout& layout); + +void deserialize(std::istream& is, Layout& layout); + +void serialize(std::ostream& os, const Buffer& buffer); + +void deserialize(std::istream& is, Buffer& buffer); + +void serialize(std::ostream& os, const Tensor& tensor); + +void deserialize(std::istream& is, Tensor& tensor); + +void serialize(std::ostream& os, const TensorMap& map); + +void deserialize(std::istream& is, TensorMap& map); + +void serialize(std::ostream& os, const Request& req); + +void deserialize(std::istream& is, Request& req); + +} // namespace turbomind::comm diff --git a/src/turbomind/comm/thread_comm.cc b/src/turbomind/comm/thread_comm.cc index 017d83abb0..a39b536844 100644 --- a/src/turbomind/comm/thread_comm.cc +++ b/src/turbomind/comm/thread_comm.cc @@ -280,6 +280,40 @@ struct ThreadCommImpl: public HostCommImpl { } } } + + void Send(void* data, int count, DataType dtype, int dst) override + { + if (l2g_.at(dst) == rank_) { // src == dst + return; + } + + // transform dst to global rank + dst = l2g_.at(dst); + + auto& c = channel(rank_, dst); + void* expected{}; + while (!c.compare_exchange_weak(expected, data, std::memory_order_release)) { + expected = {}; + } + while (c.load(std::memory_order_relaxed)) {} + } + + void Recv(void* data, int count, DataType dtype, int src, copy_fn copy) override + { + TM_CHECK(copy); + if (l2g_.at(src) == rank_) { // src == dst + return; + } + + // transform src to global rank + src = l2g_.at(src); + + auto& c = channel(src, rank_); + void* incoming{}; + while (!(incoming = c.load(std::memory_order_acquire))) {} + copy(incoming, count, data, 0); + c.store(nullptr, std::memory_order_relaxed); + } }; class ThreadGroupId: public HostGroupId { diff --git a/src/turbomind/core/buffer.cc b/src/turbomind/core/buffer.cc index 6971e63482..39d37d13a7 100644 --- a/src/turbomind/core/buffer.cc +++ b/src/turbomind/core/buffer.cc @@ -22,6 +22,10 @@ Buffer Buffer::view(DataType dtype) const Buffer Buffer::slice(ssize_t base, ssize_t size) const { + if (size_ == 0) { + TM_CHECK(this->base_ == 0 && size == 0); + return *this; + } TM_CHECK_LE(base + size, size_); auto b = *this; b.base_ += base; diff --git a/src/turbomind/core/layout.cc b/src/turbomind/core/layout.cc index 995f2a1fbf..261845ef84 100644 --- a/src/turbomind/core/layout.cc +++ b/src/turbomind/core/layout.cc @@ -29,7 +29,7 @@ Layout::Layout(vector shape, vector stride): shape_{std::move( ssize_t Layout::cosize() const noexcept { - if (rank() == 0) { + if (rank() == 0 || size() == 0) { return 0; } ssize_t value{1}; diff --git a/src/turbomind/core/layout.h b/src/turbomind/core/layout.h index f7c18d8d41..d682f998b4 100644 --- a/src/turbomind/core/layout.h +++ b/src/turbomind/core/layout.h @@ -64,6 +64,9 @@ class Layout { bool is_contiguous() const noexcept { + if (size() == 0) { + return true; + } if (stride_.back() != 1) { return false; } diff --git a/src/turbomind/engine/gateway.cc b/src/turbomind/engine/gateway.cc index 3dd8c4b4cb..ff7846bff7 100644 --- a/src/turbomind/engine/gateway.cc +++ b/src/turbomind/engine/gateway.cc @@ -7,9 +7,13 @@ namespace turbomind { -Gateway::Gateway(int groups, int group_size, std::function()> ctx_factory): +Gateway::Gateway(int groups, + int group_size, + std::vector node_dp_ranks, + std::function()> ctx_factory): size_{groups * group_size}, group_size_{group_size}, + node_dp_ranks_{std::move(node_dp_ranks)}, queues_(size_), flags_(groups), ctx_factory_{ctx_factory}, diff --git a/src/turbomind/engine/gateway.h b/src/turbomind/engine/gateway.h index 8350822046..12a8b6b6dd 100644 --- a/src/turbomind/engine/gateway.h +++ b/src/turbomind/engine/gateway.h @@ -60,7 +60,10 @@ class SeqId2Rank { class Gateway { public: - Gateway(int groups, int group_size, std::function()> ctx_factory); + Gateway(int groups, + int group_size, + std::vector node_dp_ranks, + std::function()> ctx_factory); void shutdown(); @@ -73,7 +76,8 @@ class Gateway { rank = seqid2rank_.find(r->session.id); } else { - rank = next_.fetch_add(1, std::memory_order_relaxed) % size_; + rank = next_.fetch_add(1, std::memory_order_relaxed) % node_dp_ranks_.size(); + rank = node_dp_ranks_[rank]; } if (rank >= 0) { @@ -188,6 +192,7 @@ class Gateway { std::vector> queues_; std::vector>> flags_; + std::vector node_dp_ranks_; std::function()> ctx_factory_; diff --git a/src/turbomind/engine/request_queue.h b/src/turbomind/engine/request_queue.h index 590578bf8a..4d6ee641b9 100644 --- a/src/turbomind/engine/request_queue.h +++ b/src/turbomind/engine/request_queue.h @@ -78,10 +78,10 @@ class RequestQueue { || flag_->load(std::memory_order_relaxed) == expected_ // || closed_; }); - if (closed_) { - abort = true; - return false; - } + } + if (closed_) { + abort = true; + return false; } bool is_first = false; diff --git a/src/turbomind/kernels/activation_kernels.cu b/src/turbomind/kernels/activation_kernels.cu index 77373a090c..0c9e36e356 100644 --- a/src/turbomind/kernels/activation_kernels.cu +++ b/src/turbomind/kernels/activation_kernels.cu @@ -221,6 +221,10 @@ void invokeGenericActivation_v2( template class Activation> void invokeGenericActivation_v3(Ref inter_, const Tensor& gate, cudaStream_t stream) { + if (inter_.get().size() == 0) { + return; + } + auto& inter = inter_.get(); TM_CHECK_EQ(inter.ndim(), 2); TM_CHECK_EQ(gate.ndim(), 2); diff --git a/src/turbomind/kernels/norm/rms_norm.cu b/src/turbomind/kernels/norm/rms_norm.cu index ee826c4105..4b6aaba20f 100644 --- a/src/turbomind/kernels/norm/rms_norm.cu +++ b/src/turbomind/kernels/norm/rms_norm.cu @@ -86,15 +86,15 @@ __global__ void RMSNorm(T* dst, void invokeRMSNorm(Tensor& out, const Tensor& x, const Tensor& w, float eps, cudaStream_t st) { + if (x.size() == 0) { + return; + } + TM_CHECK(x.ndim() == 2); TM_CHECK(out.shape() == x.shape()); TM_CHECK(out.dtype() == x.dtype()); TM_CHECK(w.dtype() == x.dtype() && w.shape(-1) == x.shape(-1)); - if (x.size() == 0) { - return; - } - auto invoke = [&](auto t) { using T = decltype(t); diff --git a/src/turbomind/models/llama/LlamaBatch.cc b/src/turbomind/models/llama/LlamaBatch.cc index ccdb46012e..8e1b300aff 100644 --- a/src/turbomind/models/llama/LlamaBatch.cc +++ b/src/turbomind/models/llama/LlamaBatch.cc @@ -44,6 +44,7 @@ #include "src/turbomind/models/llama/llama_kernels.h" #include "src/turbomind/models/llama/llama_utils.h" +#include "src/turbomind/comm/serialize.h" #include "src/turbomind/utils/anomaly_handler.h" #include "src/turbomind/utils/constant.h" #include "src/turbomind/utils/cuda_utils.h" @@ -541,7 +542,7 @@ void LlamaBatch::Initialize(GenerationState& g) } CopyState(cpys); // Swap the buffers - std::swap(state_, back_); + SwapState(state_, back_); ClearState(*back_); ClearState(*incoming_); @@ -603,9 +604,9 @@ void LlamaBatch::Initialize(GenerationState& g) } // Real-time context length that will change during generation - Copy_(state_->h_context_length, batch_size, context_length_buf_); - Copy_(state_->h_finished, batch_size, finished_buf_); - Copy_(state_->h_rope_theta, batch_size, rope_theta_); + Copy_(state_->h_context_length, batch_size, state_->context_length_buf); + Copy_(state_->h_finished, batch_size, state_->finished_buf); + Copy_(state_->h_rope_theta, batch_size, state_->rope_theta); bool skip_init_sampling = std::equal(g.unique_ids.begin(), // g.unique_ids.end() - g.partial, @@ -621,8 +622,9 @@ void LlamaBatch::Initialize(GenerationState& g) // TM_LOG_ERROR("[Initialize] batch size: %d, active size: %d", state_->size, state_->active_size); if (!skip_init_sampling) { - g.max_init_ctx_len = max_context_len; - g.step = max_context_len; + g.max_init_ctx_len = max_context_len; + g.step = max_context_len; + state_->pp_init_sampling = true; } } @@ -683,6 +685,38 @@ void LlamaBatch::CopyState(const std::vectorseq_len_limit[di] = s->seq_len_limit[si]; d->sequences[di] = s->sequences[si]; d->requests[di] = s->requests[si]; + d->errors[di] = s->errors[si]; + } +} + +void LlamaBatch::SwapState(BatchState*& a, BatchState*& b) +{ + std::swap(a, b); + + if (param_.pp_size > 1) { + ClearState(*b); + std::vector> cpys; + FT_CHECK(b->size == 0 && b->active_size == 0); + for (int i = 0; i < a->size; ++i) { + cpys.emplace_back(a, b, b->size, b->size); + b->size++; + } + b->active_size = a->active_size; + CopyState(cpys); + std::swap(a, b); + } + else { + // shared buffers between state_ and back_ + a->init_context_length = b->init_context_length; + a->context_length_buf = b->context_length_buf; + a->sequence_lengths = b->sequence_lengths; + a->rope_theta = b->rope_theta; + a->h_seq_limit_len = b->h_seq_limit_len; + a->seq_limit_len = b->seq_limit_len; + a->finished_buf = b->finished_buf; + a->token_ids_buf = b->token_ids_buf; + a->input_ids_buf = b->input_ids_buf; + a->h_input_length_buf = b->h_input_length_buf; } } @@ -699,16 +733,8 @@ void LlamaBatch::AllocateBuffer(ssize_t batch_size, ssize_t session_len, int cac const ssize_t max_batch_block_count = batch_size * ((session_len + cache_block_seq_len - 1) / cache_block_seq_len) + 1; - input_ids_buf_ = {max_forward_token_num_, kDEVICE}; - decoder_output_buf_ = {{batchxbeam, hidden_units}, data_type_, kDEVICE}; - input_length_buf_ = {batchxbeam, kDEVICE}; - context_length_buf_ = {batchxbeam, kDEVICE}; - init_context_length_ = {batchxbeam, kDEVICE}; - - sequence_lengths_ = {batchxbeam, kDEVICE}; - cu_block_counts_ = {batch_size + 1, kDEVICE}; block_ptrs_ = {max_batch_block_count, kDEVICE}; @@ -716,15 +742,8 @@ void LlamaBatch::AllocateBuffer(ssize_t batch_size, ssize_t session_len, int cac sampled_indexes_ = {batchxbeam * kMaxLogProb, kDEVICE}; sampled_nums_ = {batchxbeam, kDEVICE}; - token_ids_buf_ = {ssize_t(session_len * 2 * batchxbeam), kDEVICE}; - sampling_logits_ = {{(ssize_t)max_batch_size_, (ssize_t)model_->vocab_size_padded_}, kDEVICE}; - finished_buf_ = {(int)batchxbeam, kDEVICE}; - seq_limit_len_ = {batch_size, kDEVICE}; - - rope_theta_ = {batch_size, kDEVICE}; - h_random_seed_ = {batch_size, kCPUpinned}; Clear(h_random_seed_); @@ -745,9 +764,8 @@ void LlamaBatch::AllocateBuffer(ssize_t batch_size, ssize_t session_len, int cac Clear(s.curand_state.buffer()); } - h_input_length_buf_ = {batch_size, kCPUpinned}; - h_cu_block_counts_ = {batch_size + 1, kCPUpinned}; - h_block_ptrs_ = {(ssize_t)max_batch_block_count, kCPUpinned}; + h_cu_block_counts_ = {batch_size + 1, kCPUpinned}; + h_block_ptrs_ = {(ssize_t)max_batch_block_count, kCPUpinned}; for (auto& s : states_) { s.h_prompt_length = {batch_size, kCPUpinned}; @@ -756,8 +774,19 @@ void LlamaBatch::AllocateBuffer(ssize_t batch_size, ssize_t session_len, int cac s.h_rope_theta = {batch_size, kCPUpinned}; } - h_seq_limit_len_ = {batch_size, kCPUpinned}; - std::fill_n(h_seq_limit_len_.data(), batch_size, 0); + for (int i = 0; i < states_.size() - 2; ++i) { + auto& s = state_[i]; + s.context_length_buf = {batchxbeam, kDEVICE}; + s.init_context_length = {batchxbeam, kDEVICE}; + s.sequence_lengths = {batchxbeam, kDEVICE}; + s.rope_theta = {batch_size, kDEVICE}; + s.token_ids_buf = {ssize_t(session_len * 2 * batchxbeam), kDEVICE}; + s.finished_buf = {(int)batchxbeam, kDEVICE}; + s.h_seq_limit_len = {batch_size, kCPUpinned}; + s.seq_limit_len = {batch_size, kDEVICE}; + s.input_ids_buf = {max_forward_token_num_, kDEVICE}; + s.h_input_length_buf = {batch_size, kCPUpinned}; + } h_output_ids_ = {batch_size * session_len_, kCPUpinned}; @@ -776,12 +805,18 @@ void LlamaBatch::AllocSymmBuffers() symm_hidden_states_buf_ = {{max_forward_token_num_ * param_.attn_dp_size, hidden_units}, data_type_, symm_alloc_}; symm_logits_buf_ = {{max_batch_size_, vocab_size_padded}, data_type_, symm_alloc_}; + + if (param_.pp_size > 1) { + symm_residual_buf_ = {{max_forward_token_num_ * param_.attn_dp_size, hidden_units}, data_type_, symm_alloc_}; + } } void LlamaBatch::FreeSymmBuffers() { symm_hidden_states_buf_ = {}; symm_logits_buf_ = {}; + + symm_residual_buf_ = {}; } LlamaBatch::~LlamaBatch() @@ -843,10 +878,15 @@ LlamaBatch::LlamaBatch(DataType data_type, const auto get_free_size = [&] { // size_t free{}, total{}; check_cuda_error(cudaMemGetInfo(&free, &total)); - return AllReduce(model_->comm_->h_tp_group, free, comm::RedOp::kMin); + free = AllReduce(model_->comm_->h_tp_group, free, comm::RedOp::kMin); + if (param_.pp_size > 1) { + free = AllReduce(model_->comm_->h_pp_group, free, comm::RedOp::kMin); + } + return free; }; - sequence_manager_.reset(new SequenceManager{model_->layer_num_, + const size_t layer_num = model_->layer_num_ / param_.pp_size + (model_->layer_num_ % param_.pp_size != 0); + sequence_manager_.reset(new SequenceManager{layer_num, block_config, param.cache_max_block_count, param.cache_chunk_size, @@ -868,6 +908,8 @@ LlamaBatch::LlamaBatch(DataType data_type, FT_CHECK(max_context_token_num_ >= session_len_); FT_CHECK(max_forward_token_num_ >= max_batch_size_); + const int state_size = param_.pp_size + 2; + states_.resize(state_size); for (auto& s : states_) { s.requests.resize(max_batch_size_); s.sequences.resize(max_batch_size_); @@ -876,8 +918,13 @@ LlamaBatch::LlamaBatch(DataType data_type, } state_ = &states_[0]; - back_ = &states_[1]; - incoming_ = &states_[2]; + back_ = &states_[state_size - 2]; + incoming_ = &states_[state_size - 1]; + + gs_.resize(param_.pp_size); + for (int i = 0; i < param_.pp_size; ++i) { + slots_.push_back({&states_[i], &gs_[i]}); + } symm_alloc_ = core::SimpleAllocator::Create([this](ssize_t size) { return SymmAlloc(size, true); }, [this](void* p, ssize_t size) { return SymmFree(p, size, true); }, @@ -901,35 +948,41 @@ void LlamaBatch::InitializeSampling(const GenerationState& g) return; } - // Context length at initialization, will stay constant until re-initialziation - Copy(context_length_buf_, batch_size, init_context_length_); + if (param_.pp_size == 0 || state_->pp_init_sampling) { + state_->pp_init_sampling = false; // updated by Initialize(g) - Copy(context_length_buf_, batch_size, sequence_lengths_); - // `sequence_lengths_` will be increased by dynamic decode - // note that in decoder and in output "sequence length" has different semantic - // - in decoder it means length of sequence that has kv cache already computed - // - in output it means length of all tokens (the last generated token does not have k/v cache computed yet) - invokePlusScalar(sequence_lengths_.data(), -1, batch_size, stream_); - sync_check_cuda_error(); + // Context length at initialization, will stay constant until re-initialziation + Copy(state_->context_length_buf, batch_size, state_->init_context_length); - Clear(token_ids_buf_.slice(0, batch_size * session_len_)); - invokeTranspose2D(token_ids_buf_.data(), state_->output_ids.data(), batch_size, session_len_, stream_); - sync_check_cuda_error(); + Copy(state_->context_length_buf, batch_size, state_->sequence_lengths); + // `sequence_lengths_` will be increased by dynamic decode + // note that in decoder and in output "sequence length" has different semantic + // - in decoder it means length of sequence that has kv cache already computed + // - in output it means length of all tokens (the last generated token does not have k/v cache computed yet) + invokePlusScalar(state_->sequence_lengths.data(), -1, batch_size, stream_); + sync_check_cuda_error(); - // token_ids_buf_[s, b] - // ABCDe ABCDe e - // ABCDEFGHIJk ABCDEFGHIJk - // ABCDEFGHi -> ABCDEFGHi i - // ABCDEFGh ABCDEFGh h - // ABCd ABCd d - invokePadLastTokenIds(token_ids_buf_.data(), init_context_length_.data(), g.max_init_ctx_len, batch_size, stream_); - sync_check_cuda_error(); + Clear(state_->token_ids_buf.slice(0, batch_size * session_len_)); + invokeTranspose2D(state_->token_ids_buf.data(), state_->output_ids.data(), batch_size, session_len_, stream_); + sync_check_cuda_error(); - // seq_limit_len_, will be compared to `step` instead of `sequence_length`, so padding len should be accounted for - for (int i = 0; i < batch_size; ++i) { - h_seq_limit_len_[i] = state_->seq_len_limit[i] + (g.max_init_ctx_len - state_->h_context_length[i]); + // token_ids_buf_[s, b] + // ABCDe ABCDe e + // ABCDEFGHIJk ABCDEFGHIJk + // ABCDEFGHi -> ABCDEFGHi i + // ABCDEFGh ABCDEFGh h + // ABCd ABCd d + invokePadLastTokenIds( + state_->token_ids_buf.data(), state_->init_context_length.data(), g.max_init_ctx_len, batch_size, stream_); + sync_check_cuda_error(); + + // seq_limit_len_, will be compared to `step` instead of `sequence_length`, so padding len should be accounted + // for + for (int i = 0; i < batch_size; ++i) { + state_->h_seq_limit_len[i] = state_->seq_len_limit[i] + (g.max_init_ctx_len - state_->h_context_length[i]); + } + Copy(state_->h_seq_limit_len, batch_size, state_->seq_limit_len); } - Copy(h_seq_limit_len_, batch_size, seq_limit_len_); std::vector rs; rs.reserve(batch_size); @@ -949,7 +1002,7 @@ void LlamaBatch::ComputeAndOutputLogits(const Tensor& hidden_states, int first, if (state_->requests[i]->gen_cfg.output_logits == GenerationConfig::kAll) { const auto& s = *state_->sequences[i]; // Skip when the seq is filling missed cache only - if (s.cache_len + h_input_length_buf_[i] > s.tokens.size()) { + if (s.cache_len + state_->h_input_length_buf[i] > s.tokens.size()) { return true; } } @@ -994,7 +1047,7 @@ void LlamaBatch::OutputLogits(const Tensor& logits, int first, int last, Generat for (int i = first; i < last; ++i) { - const int input_len = h_input_length_buf_[i]; // input lenght for this iter + const int input_len = state_->h_input_length_buf[i]; // input lenght for this iter if (state_->requests[i]->gen_cfg.output_logits == out_type) { @@ -1051,12 +1104,16 @@ void LlamaBatch::OutputLogits(const Tensor& logits, int first, int last, Generat void LlamaBatch::OutputLastHiddenState(const Tensor& hidden_states, int first, int last) { + if (tp_rank_ != 0) { + return; + } + const auto& src_buf = hidden_states.buffer(); const auto data_type = src_buf.dtype(); int base = 0; for (int i = first; i < last; ++i) { - const int input_len = h_input_length_buf_[i]; // input lenght for this iter + const int input_len = state_->h_input_length_buf[i]; // input lenght for this iter if (auto out_type = state_->requests[i]->gen_cfg.output_last_hidden_state) { @@ -1103,8 +1160,8 @@ void LlamaBatch::Finish(GenerationState& g, std::vector& signals) // [s,b] -> [b,s] and skip padding in [context_len, max_context_len) invokeGatherOutput(state_->output_ids.data(), - token_ids_buf_.data(), - init_context_length_.data(), + state_->token_ids_buf.data(), + state_->init_context_length.data(), g.max_init_ctx_len, g.step, session_len_, @@ -1113,9 +1170,11 @@ void LlamaBatch::Finish(GenerationState& g, std::vector& signals) sync_check_cuda_error(); } - Copy(token_ids_buf_.slice((g.step - 1) * (batch_size - g.partial), -1), batch_size - g.partial, h_output_ids_); - Copy(finished_buf_, batch_size, state_->h_finished); - Copy(sequence_lengths_, batch_size, state_->h_context_length); + Copy(state_->token_ids_buf.slice((g.step - 1) * (batch_size - g.partial), -1), + batch_size - g.partial, + h_output_ids_); + Copy(state_->finished_buf, batch_size, state_->h_finished); + Copy(state_->sequence_lengths, batch_size, state_->h_context_length); bool output_logprobs = false; for (int i = 0; i < batch_size - g.partial; ++i) { @@ -1313,6 +1372,161 @@ struct RequestData { } // namespace +#ifdef BUILD_MULTI_GPU +namespace comm { + +void serialize(std::ostream& os, const RequestData& req) +{ + // std::vector> infer; + serialize(os, (int)req.infer.size()); + for (const auto& r : req.infer) { + serialize(os, *r); + } + // std::vector> kill; + serialize(os, (int)req.kill.size()); + for (const auto& r : req.kill) { + serialize(os, *r); + } + + serialize(os, req.cancel); // std::vector cancel; + serialize(os, req.abort); // bool abort; +} + +template<> +void serialize(const std::shared_ptr* req, int n, std::vector& vec) +{ + std::stringstream ss; + for (int i = 0; i < n; ++i) { + const auto& r = req[i]; + if (r != nullptr) { + serialize(ss, *r); + } + } + vec = streambuf_to_vector(ss.rdbuf()); +} + +void deserialize(std::istream& is, RequestData& req) +{ + auto process = [](std::istream& is, std::vector>& vec) { + int size; + deserialize(is, size); + vec.resize(size); + for (auto& r : vec) { + r = std::make_shared(); + deserialize(is, *r); + } + }; + process(is, req.infer); + process(is, req.kill); + deserialize(is, req.cancel); + deserialize(is, req.abort); +} + +template<> +void deserialize(std::shared_ptr* req, int n, const std::vector& vec) +{ + std::stringstream ss; + ss.write(vec.data(), vec.size()); + for (int i = 0; i < n; ++i) { + auto& r = req[i]; + if (r == nullptr) { + r = std::make_shared(); + } + deserialize(ss, *r); + } +} + +template +void serialize(std::ostream& os, const Buffer_& buf) +{ + serialize(os, buf.size()); + serialize(os, buf.device().type); + os.write((char*)buf.raw_data(), sizeof(T) * buf.size()); +} + +template +void deserialize(std::istream& is, Buffer_& buf) +{ + ssize_t size; + DeviceType dev_type; + deserialize(is, size); + deserialize(is, dev_type); + buf = {size, kCPU}; + is.read((char*)buf.raw_data(), sizeof(T) * size); +} + +void serialize(std::ostream& os, const IntermediateData& inter) +{ + serialize(os, inter.abort); + serialize(os, inter.dc_batch_size); + serialize(os, inter.pf_batch_size); + serialize(os, inter.local_token_nums); + serialize(os, inter.global_token_num); + + if (inter.dc_batch_size + inter.pf_batch_size) { + int sz = inter.blocks.size(); + serialize(os, sz); + for (const auto& b : inter.blocks) { + serialize(os, b); + } + serialize(os, inter.h_cu_block_counts); + serialize(os, inter.h_input_length_buf); + serialize(os, inter.h_context_length); + serialize(os, inter.h_rope_theta); + serialize(os, inter.h_finished); + } +} + +void deserialize(std::istream& is, IntermediateData& inter) +{ + deserialize(is, inter.abort); + deserialize(is, inter.dc_batch_size); + deserialize(is, inter.pf_batch_size); + deserialize(is, inter.local_token_nums); + deserialize(is, inter.global_token_num); + + if (inter.dc_batch_size + inter.pf_batch_size) { + int sz; + deserialize(is, sz); + inter.blocks.resize(sz); + for (auto& b : inter.blocks) { + deserialize(is, b); + } + deserialize(is, inter.h_cu_block_counts); + deserialize(is, inter.h_input_length_buf); + deserialize(is, inter.h_context_length); + deserialize(is, inter.h_rope_theta); + deserialize(is, inter.h_finished); + } +} + +template<> +void serialize(const IntermediateData* inter, int n, std::vector& vec) +{ + + std::stringstream ss; + for (int i = 0; i < n; ++i) { + const auto& r = inter[i]; + serialize(ss, r); + } + vec = streambuf_to_vector(ss.rdbuf()); +} + +template<> +void deserialize(IntermediateData* inter, int n, const std::vector& vec) +{ + std::stringstream ss; + ss.write(vec.data(), vec.size()); + for (int i = 0; i < n; ++i) { + auto& r = inter[i]; + deserialize(ss, r); + } +} + +} // namespace comm + +#endif // BUILD_MULTI_GPU + void LlamaBatch::InternalThreadEntry() { // TM_LOG_INFO("[InternalThreadEntry] %d", (int)rank_); @@ -1323,20 +1537,32 @@ void LlamaBatch::InternalThreadEntry() // Initialize `AnomalyHandler` AnomalyHandler::instance().Init(tp_rank_, model_->vocab_size_padded_, 0, max_batch_size_, stream_); - GenerationState g{}; + GenerationState* g = &gs_[0]; while (1) { + if (param_.pp_size > 1 && param_.pp_rank == 0) { + std::tie(state_, g) = slots_.front(); + } - std::shared_ptr req; + auto req = std::make_shared(); - if (tp_rank_ == 0) { + if (tp_rank_ == 0 && param_.pp_rank == 0) { req = std::make_shared(); { NvtxScope _("pop"); - const int free_slot_count = max_batch_size_ - state_->size + g.finished_count; + const int free_slot_count = max_batch_size_ - state_->size + g->finished_count; const bool is_empty = (free_slot_count == max_batch_size_); - // Block if batch is empty AND no silbings are ready - gateway_->pop(req->infer, req->kill, free_slot_count, is_empty, req->abort, dp_rank_); + // Block if batch is empty AND no silbings are ready AND comm in same node + const bool blocking = is_empty && comm_.h_comm->is_same_process() && param_.pp_size == 1; + int wait = 0; + do { + gateway_->pop(req->infer, req->kill, free_slot_count, blocking, req->abort, dp_rank_); + if (!comm_.h_comm->is_same_process() && param_.pp_size == 1) { + bool empty_pop = req->infer.size() == 0 && req->kill.size() == 0 && req->abort == false; + wait = is_empty && empty_pop; + wait = AllReduce(comm_.h_comm, wait, comm::RedOp::kSum) == comm_.h_comm->n_ranks(); + } + } while (wait); } // Mark reqs to the same session_id as invalid (which are dangerous to the engine) DisableInvalidRequests(req->infer, req->kill); @@ -1349,12 +1575,20 @@ void LlamaBatch::InternalThreadEntry() // 2. Broadcast `ec` from rank-0 // shared_state_->barrier->wait(); // comm_.h_comm->Sync(comm_.h_comm_tp_group); + if (comm_.h_tp_group->n_ranks() > 1 && param_.pp_rank == 0) { + Broadcast(comm_.h_tp_group, req, 0); + } - Broadcast(comm_.h_tp_group, req, 0); + if (!comm_.h_comm->is_same_process() && param_.pp_rank == 0) { + req->abort = AllReduce(comm_.h_comm, (int)req->abort, comm::RedOp::kSum) > 0; + } - if (req->abort) { - TM_LOG_INFO("[InternalThreadEntry] stop requested."); - break; + if (req->abort || pp_abort_) { + if (param_.pp_size == 1 || (batch_que_.empty() && pp_abort_)) { + TM_LOG_ERROR("[InternalThreadEntry] stop requested."); + break; + } + pp_abort_ = true; } std::vector signals; @@ -1370,21 +1604,23 @@ void LlamaBatch::InternalThreadEntry() ProcessCancelRequests(req->cancel, signals); - if (tp_rank_ == 0) { + if (tp_rank_ == 0 && param_.pp_rank == 0) { gateway_->notify(std::move(signals)); } - Initialize(g); + Initialize(*g); const int n_active = AllReduce(comm_.h_dp_group, state_->active_size, comm::RedOp::kSum); - if (n_active) { + if (n_active || param_.pp_size > 1) { // - Forward(g); + if (!Forward(g)) { + continue; + } - Finish(g, signals); + Finish(*g, signals); - if (g.finished_count) { + if (g->finished_count) { // Finished requests and corresponding output tensors will be released when notified // wait for all ranks to ensure no rank (except for output thread) will access related // resources @@ -1415,17 +1651,17 @@ void LlamaBatch::Start() }); } -bool LlamaBatch::Forward(GenerationState& g) +bool LlamaBatch::Forward(GenerationState*& g) { NvtxScope _("Forward"); FT_CHECK(max_context_token_num_ >= max_batch_size_); - const int active_size = state_->active_size; + int active_size = state_->active_size; constexpr int kLogInterval = 10; - if (tp_rank_ == 0 && (g.step - 1) % kLogInterval == 0) { - TM_LOG_INFO("------------------------- step = %d -------------------------", g.step - 1); + if (tp_rank_ == 0 && (g->step - 1) % kLogInterval == 0 && param_.pp_rank == 0) { + TM_LOG_INFO("------------------------- step = %d -------------------------", g->step - 1); } int pf_offset = -1; @@ -1435,8 +1671,8 @@ bool LlamaBatch::Forward(GenerationState& g) const auto& seq = *state_->sequences[i]; // const int missing = state_->h_context_length[i] - seq.cache_len; FT_CHECK(seq.input_length >= 1); - h_input_length_buf_[i] = seq.input_length; - input_d_ptrs[i] = state_->output_ids.data() + i * session_len_ + seq.cache_len; + state_->h_input_length_buf[i] = seq.input_length; + input_d_ptrs[i] = state_->output_ids.data() + i * session_len_ + seq.cache_len; if (seq.input_length > 1 && pf_offset < 0) { pf_offset = i; } @@ -1445,12 +1681,6 @@ bool LlamaBatch::Forward(GenerationState& g) pf_offset = active_size; } - // These buffers are only accessed when there are prefill workloads - if (pf_offset != active_size) { - Copy(state_->h_context_length, active_size, context_length_buf_); - Copy(h_input_length_buf_, active_size, input_length_buf_); - } - // Find mini-batch offsets: input length > 1 ? prefill() : decode() // Constraints on mini-batches // sum(Q) <= `max_forward_token_num` && sum(K) <= `max_context_token_num` @@ -1459,8 +1689,8 @@ bool LlamaBatch::Forward(GenerationState& g) int sum_q = pf_offset; int sum_k = 0; // only for prefill for (int i = pf_offset; i < active_size; ++i) { - FT_CHECK(h_input_length_buf_[i] <= max_forward_token_num_); - const int q = sum_q + h_input_length_buf_[i]; + FT_CHECK(state_->h_input_length_buf[i] <= max_forward_token_num_); + const int q = sum_q + state_->h_input_length_buf[i]; const int k = sum_k + state_->h_context_length[i]; if (q <= max_forward_token_num_ && k <= max_context_token_num_) { sum_q = q; @@ -1468,7 +1698,7 @@ bool LlamaBatch::Forward(GenerationState& g) } else { offsets.push_back(i); - sum_q = h_input_length_buf_[i]; + sum_q = state_->h_input_length_buf[i]; sum_k = state_->h_context_length[i]; } } @@ -1482,39 +1712,45 @@ bool LlamaBatch::Forward(GenerationState& g) offsets.push_back(offsets.back()); } - // forward on mini-batches - for (int p = 0; p < (int)offsets.size() - 1; ++p) { - const int first = offsets[p]; - const int last = offsets[p + 1]; - const int mini_batch_size = last - first; - int* input_ids = input_ids_buf_.data(); + // prepare inputs + Buffer_ input_ids_buf; + Tensor hidden_states; + Tensor residual; + int batch_size{}; + IntermediateData inter{}; // pipeline parallel intermediate data + + if (param_.pp_rank == 0) { + TM_CHECK(n_batches == 2) << "pipeline parallel only support n_batches=1"; // TODO:This modification relies on + // the removal of mini-batch. + const int first = offsets[0]; + const int last = offsets[1]; + TM_CHECK(last - first == state_->active_size); + int* input_ids = state_->input_ids_buf.data(); BatchedCopy batched_copy; int sum_k = 0; for (int i = first; i < last; ++i) { - input_ids = batched_copy.Add(input_d_ptrs[i], h_input_length_buf_[i], input_ids); - if (h_input_length_buf_[i] > 1) { + input_ids = batched_copy.Add(input_d_ptrs[i], state_->h_input_length_buf[i], input_ids); + if (state_->h_input_length_buf[i] > 1) { sum_k += state_->h_context_length[i]; } } - int sum_q = input_ids - input_ids_buf_.data(); - + int sum_q = input_ids - state_->input_ids_buf.data(); batched_copy.Submit(stream_); - - const int dc_batch_size = p ? 0 : pf_offset; - const int pf_batch_size = mini_batch_size - dc_batch_size; + state_->dc_batch_size = pf_offset; + state_->pf_batch_size = state_->active_size - state_->dc_batch_size; if (tp_rank_ == 0) { - if (pf_batch_size) { - const auto max_q = - *std::max_element(h_input_length_buf_.data() + first, h_input_length_buf_.data() + last); + if (state_->pf_batch_size) { + const auto max_q = *std::max_element(state_->h_input_length_buf.data() + first, + state_->h_input_length_buf.data() + last); const auto max_k = *std::max_element(state_->h_context_length.data() + first, state_->h_context_length.data() + last); TM_LOG_INFO("[Forward] [%d, %d), dc=%d, pf=%d, sum_q=%d, sum_k=%d, max_q=%d, max_k=%d", first, last, - dc_batch_size, - pf_batch_size, + state_->dc_batch_size, + state_->pf_batch_size, sum_q, sum_k, max_q, @@ -1523,31 +1759,91 @@ bool LlamaBatch::Forward(GenerationState& g) } // Synchronize batch token num with sync DP ranks - auto local_token_nums = AllGather(comm_.h_dp_group, sum_q); - auto global_token_num = std::accumulate(local_token_nums.begin(), local_token_nums.end(), 0); + state_->local_token_nums = AllGather(comm_.h_dp_group, sum_q); + state_->global_token_num = std::accumulate(state_->local_token_nums.begin(), state_->local_token_nums.end(), 0); + + input_ids_buf = state_->input_ids_buf.slice(0, sum_q); + } + else { + RecvIntermediateData(inter); + PostProcessIntermediateData(inter); + } - auto hidden_states = symm_hidden_states_buf_.slice(0, global_token_num); + batch_size = state_->pf_batch_size + state_->dc_batch_size; - model_->Forward(input_ids_buf_.slice(0, sum_q), // temp - hidden_states, // temp - decoder_output_buf_.slice(first, mini_batch_size), + // forward logits + if (batch_size) { + state_->hidden_states = symm_hidden_states_buf_.slice(0, state_->global_token_num).borrow(); + state_->residual = + (param_.pp_size > 1) ? symm_residual_buf_.slice(0, state_->global_token_num).borrow() : Tensor{}; + + model_->Forward(input_ids_buf, // temp + state_->hidden_states, // temp + state_->residual, // used by pipeline parallel + decoder_output_buf_.slice(0, batch_size), block_ptrs_, - cu_block_counts_.slice(first, mini_batch_size + 1), - h_input_length_buf_.slice(first, mini_batch_size), - state_->h_context_length.slice(first, mini_batch_size), - rope_theta_.slice(first, mini_batch_size), - finished_buf_.slice(first, mini_batch_size), - Buffer(local_token_nums.data(), local_token_nums.size(), kCPU), + cu_block_counts_.slice(0, batch_size + 1), + state_->h_input_length_buf.slice(0, batch_size), + state_->h_context_length.slice(0, batch_size), + state_->rope_theta.slice(0, batch_size), + state_->finished_buf.slice(0, batch_size), + Buffer(state_->local_token_nums.data(), state_->local_token_nums.size(), kCPU), lora_mask_buf_, - dc_batch_size, - pf_batch_size, - state_->sequences.data() + first); + state_->dc_batch_size, + state_->pf_batch_size, + state_->sequences.data()); + } + + if (param_.pp_size > 1) { + // for pipeline parallel + // - pp_rank 0 ~ pp_size - 1 should send intermediate data to next pp_rank + // - pp_rank 0 should receive intermediate data from last pp_rank and output logits/hidden states + if (param_.pp_rank == 0) { + IntermediateData last_inter{}; + BatchState* last_state{}; + GenerationState* last_g{}; + + // receive + if (batch_que_.size() == param_.pp_size - 1 || (batch_que_.size() > 0 && batch_size == 0)) { + RecvIntermediateData(last_inter); // logits + std::tie(last_state, last_g) = batch_que_.front(); + batch_que_.pop(); + } + + if (batch_size > 0 || (batch_size == 0 && batch_que_.empty())) { + // send, maybe dummy + PreProcessIntermediateData(inter); + SendIntermediateData(inter); + } + + if (batch_size > 0) { + batch_que_.push({state_, g}); + slots_.pop_front(); + } + + if (!last_state) { + return false; + } + + state_ = last_state; + g = last_g; + slots_.push_front({state_, g}); + } + else { + if (param_.pp_rank != param_.pp_size - 1 || batch_size > 0) { + SendIntermediateData(inter); + } + return false; + } + } - ComputeAndOutputLogits(hidden_states, first, last); - OutputLastHiddenState(hidden_states, first, last); + // forward logits & dynamic decode + if (const auto bsz = state_->pf_batch_size + state_->dc_batch_size; bsz > 0) { + ComputeAndOutputLogits(state_->hidden_states, 0, bsz); + OutputLastHiddenState(state_->hidden_states, 0, bsz); } - if (const auto bsz = active_size - g.partial; bsz > 0) { + if (const auto bsz = state_->active_size - g->partial; bsz > 0) { auto logits = model_->postDecodeEmbedding(decoder_output_buf_.slice(0, bsz), symm_logits_buf_.buffer()); @@ -1555,10 +1851,10 @@ bool LlamaBatch::Forward(GenerationState& g) OutputLogits(logits, 0, bsz, GenerationConfig::kGeneration); - TM_CHECK_GE(g.step, 0); + TM_CHECK_GE(g->step, 0); - if (!g.skip_init_sampling) { - InitializeSampling(g); + if (!g->skip_init_sampling || param_.pp_size > 1) { + InitializeSampling(*g); } bool output_logprobs = [&] { @@ -1574,27 +1870,27 @@ bool LlamaBatch::Forward(GenerationState& g) invokeCastFloat2D(logits, sampling_logits, stream_); sync_check_cuda_error(); - // stop-words & bad-words require the matched tokens to be contiguous, so item size > 1 is not supported - model_->dynamicDecode(token_ids_buf_, - finished_buf_, - sequence_lengths_, + // stop-words & bad-words require the matched tokens to be contiguous, so item size > 1 is not supported. + model_->dynamicDecode(state_->token_ids_buf, + state_->finished_buf, + state_->sequence_lengths, state_->curand_state, sampling_logits, // <- batch size indicator - seq_limit_len_, - init_context_length_, + state_->seq_limit_len, + state_->init_context_length, state_->h_context_length, state_->h_prompt_length, output_logprobs ? sampled_logprobs_ : Buffer{}, // <- indicator sampled_indexes_, sampled_nums_, - g.step, - g.max_init_ctx_len); + g->step, + g->max_init_ctx_len); } - std::fill(h_input_length_buf_.data(), h_input_length_buf_.data() + active_size, 0); + std::fill(state_->h_input_length_buf.data(), state_->h_input_length_buf.data() + state_->active_size, 0); // `SequenceManager` needs real-time value of cache length - for (int i = 0; i < active_size; ++i) { + for (int i = 0; i < state_->active_size; ++i) { FT_CHECK((bool)state_->requests[i]); FT_CHECK(state_->sequences[i]); state_->sequences[i]->cache_len += state_->sequences[i]->input_length; @@ -1612,19 +1908,19 @@ bool LlamaBatch::Forward(GenerationState& g) AnomalyHandler::instance().Reset(); if (debug_ && tp_rank_ == 0) { - std::vector curr(active_size); - core::Copy(token_ids_buf_.data() + g.step * active_size, active_size, curr.data()); + std::vector curr(state_->active_size); + core::Copy(state_->token_ids_buf.data() + g->step * state_->active_size, state_->active_size, curr.data()); cudaStreamSynchronize(stream_); std::stringstream scurr; for (int k = 0; k < curr.size(); ++k) { scurr << std::setw(10) << curr[k]; } - TM_LOG_INFO("[Forward] step = %d, [%s]", g.step - 1, scurr.str().c_str()); + TM_LOG_INFO("[Forward] step = %d, [%s]", g->step - 1, scurr.str().c_str()); } //////////////////////////////////////////////// /// ! increase the counters - g.step += 1; + g->step += 1; return true; } @@ -1700,6 +1996,7 @@ void LlamaBatch::Warmup() for (auto& x : input_ids) { x = d(g); } + Copy(input_ids, input_ids_buf); check_cuda_error(cudaStreamSynchronize(stream_)); @@ -1719,15 +2016,16 @@ void LlamaBatch::Warmup() const auto bsz = 1; // A single sequence containing `token_num` prefill tokens - model_->Forward(input_ids_buf.slice(0, token_num), + model_->Forward(state_->input_ids_buf.slice(0, token_num), symm_hidden_states_buf_.slice(0, token_num * param_.attn_dp_size), + {}, // residual decoder_output_buf_.slice(0, bsz), block_ptrs_, cu_block_counts_.slice(0, bsz + 1), Buffer{&input_length, 1, kCPU}, Buffer{&input_length, 1, kCPU}, - rope_theta_.slice(0, bsz), - finished_buf_.slice(0, bsz), + state_->rope_theta.slice(0, bsz), + state_->finished_buf.slice(0, bsz), Buffer{local_token_nums.data(), (int)local_token_nums.size(), kCPU}, Buffer{}, 0, @@ -1796,9 +2094,126 @@ void LlamaBatch::DestroyCommunicators() // Destroy device communicator comm_.d_comm = {}; + if (param_.pp_size > 1) { + comm_.d_pp_comm = {}; + } cudaStreamSynchronize(stream_); comm_.h_comm->Sync(); } +void LlamaBatch::PreProcessIntermediateData(IntermediateData& inter) +{ + if (param_.pp_rank == 0) { + for (int i = 0; i < state_->active_size; ++i) { + const auto& seq = *state_->sequences[i]; + inter.blocks.push_back(seq.blocks); + } + + inter.abort = pp_abort_; + inter.h_cu_block_counts = h_cu_block_counts_; + inter.h_input_length_buf = state_->h_input_length_buf; + inter.h_context_length = state_->h_context_length; + inter.h_rope_theta = state_->h_rope_theta; + inter.h_finished = state_->h_finished; + inter.dc_batch_size = state_->dc_batch_size; + inter.pf_batch_size = state_->pf_batch_size; + inter.local_token_nums = state_->local_token_nums; + inter.global_token_num = state_->global_token_num; + } +} + +void LlamaBatch::PostProcessIntermediateData(IntermediateData& inter) +{ + // state should always copy + if (param_.pp_rank > 0) { + state_->pf_batch_size = inter.pf_batch_size; + state_->dc_batch_size = inter.dc_batch_size; + state_->local_token_nums = inter.local_token_nums; + state_->global_token_num = inter.global_token_num; + pp_abort_ = inter.abort; + } + + // early exist as there is no data to process + const int batch_size = inter.pf_batch_size + inter.dc_batch_size; + if (batch_size == 0 || param_.pp_rank == 0) { + return; + } + + // cpu + std::copy_n(inter.h_input_length_buf.data(), batch_size, state_->h_input_length_buf.data()); + std::copy_n(inter.h_context_length.data(), batch_size, state_->h_context_length.data()); + + // device + Copy(inter.h_cu_block_counts, batch_size + 1, cu_block_counts_); + Copy(inter.h_rope_theta, batch_size, state_->rope_theta); + Copy(inter.h_finished, batch_size, state_->finished_buf); + + h_cu_block_counts_[0] = 0; + auto block_ptrs = h_block_ptrs_.data(); + for (int i = 0; i < batch_size; ++i) { + const auto& seq = *state_->sequences[i]; + const auto& blocks = inter.blocks[i]; + + // cumulative num of blocks + h_cu_block_counts_[i + 1] = h_cu_block_counts_[i] + blocks.size(); + + block_ptrs = std::transform(blocks.cbegin(), blocks.cend(), block_ptrs, [&](int block_id) { + return reinterpret_cast(sequence_manager_->GetBlockPtr(block_id)); + }); + } + Copy(h_cu_block_counts_, batch_size + 1, cu_block_counts_); + Copy(h_block_ptrs_, h_cu_block_counts_[batch_size], block_ptrs_); +} + +void LlamaBatch::SendIntermediateData(IntermediateData& inter) +{ + const int dst = (param_.pp_rank + 1) % param_.pp_size; + const int batch_size = inter.dc_batch_size + inter.pf_batch_size; + + Send(comm_.h_pp_group, inter, dst); + if (batch_size == 0) { + // no device data to send + return; + } + + if (param_.pp_rank < param_.pp_size - 1) { + // for [0, pp_rank - 1), send hidden & residual + Tensor hidden = symm_hidden_states_buf_.slice(0, inter.global_token_num); + Tensor residual = symm_residual_buf_.slice(0, inter.global_token_num); + comm_.d_pp_comm->Send(hidden.raw_data(), hidden.size(), hidden.dtype(), dst, 0, stream_); + comm_.d_pp_comm->Send(residual.raw_data(), residual.size(), residual.dtype(), dst, 0, stream_); + } + else { + // for pp_rank - 1, send logits + Tensor logits = decoder_output_buf_.slice(0, batch_size); + comm_.d_pp_comm->Send(logits.raw_data(), logits.size(), logits.dtype(), dst, 0, stream_); + } +} + +void LlamaBatch::RecvIntermediateData(IntermediateData& inter) +{ + const int src = (param_.pp_rank - 1 + param_.pp_size) % param_.pp_size; + Recv(comm_.h_pp_group, inter, src); + + const int batch_size = inter.dc_batch_size + inter.pf_batch_size; + if (batch_size == 0) { + // no device data to receive + return; + } + + if (param_.pp_rank > 0) { + // for [1, pp_rank - 1], recv hidden & residual + Tensor hidden = symm_hidden_states_buf_.slice(0, inter.global_token_num); + Tensor residual = symm_residual_buf_.slice(0, inter.global_token_num); + comm_.d_pp_comm->Recv(hidden.raw_data(), hidden.size(), hidden.dtype(), src, 0, stream_); + comm_.d_pp_comm->Recv(residual.raw_data(), residual.size(), residual.dtype(), src, 0, stream_); + } + else { + // for pp_rank 0, recv logits + Tensor logits = decoder_output_buf_.slice(0, batch_size); + comm_.d_pp_comm->Recv(logits.raw_data(), logits.size(), logits.dtype(), src, 0, stream_); + } +} + } // namespace turbomind diff --git a/src/turbomind/models/llama/LlamaBatch.h b/src/turbomind/models/llama/LlamaBatch.h index 110bb519ab..7390b47688 100644 --- a/src/turbomind/models/llama/LlamaBatch.h +++ b/src/turbomind/models/llama/LlamaBatch.h @@ -3,6 +3,9 @@ #pragma once #include +#include +#include +#include #include "src/turbomind/core/core.h" @@ -37,6 +40,34 @@ struct BatchState { std::vector errors; + Buffer_ input_ids_buf; + + // lengths + Buffer_ context_length_buf; // history length + input_length + Buffer_ sequence_lengths; // current sequence length, updated by sampling + Buffer_ init_context_length; + Buffer_ h_input_length_buf; + + bool copy_init{true}; // extra flag for pipeline parallel to control whether to copy init_context_length + bool pp_init_sampling{true}; // flag to control whether to copy required buffers when initializing sampling + + // rope theta + Buffer_ rope_theta; + + // used by dynamic decoder + Buffer_ token_ids_buf; // all token IDs in [S, B], indexed using `step` + Buffer_ finished_buf; + Buffer_ h_seq_limit_len; + Buffer_ seq_limit_len; + + // value set when model forward, no need to copy + int dc_batch_size; + int pf_batch_size; + std::vector local_token_nums; + int global_token_num; + Tensor hidden_states; + Tensor residual; + // |<-- existing -->|<-- swap-in -->| // |<----------- active ----------->|<-- inactive -->| int active_size; @@ -62,6 +93,27 @@ struct GenerationState { int finished_count; }; +// struct for pipeline parallel +struct IntermediateData { + bool abort{false}; + + // cpu + std::vector blocks{}; + Buffer_ h_cu_block_counts; + Buffer_ h_input_length_buf; + Buffer_ h_context_length; + Buffer_ h_rope_theta; + Buffer_ h_finished; + int dc_batch_size{}; + int pf_batch_size{}; + + std::vector local_token_nums{}; + int global_token_num{}; + + // gpu + // hidden, residual, logits +}; + class LlamaBatch { public: void AllocateBuffer(ssize_t batch_size, ssize_t session_len, int cache_block_seq_len); @@ -88,7 +140,7 @@ class LlamaBatch { void InitializeSampling(const GenerationState& g); - bool Forward(GenerationState& g); + bool Forward(GenerationState*& g); void Finish(GenerationState& g, std::vector& signals); @@ -135,6 +187,8 @@ class LlamaBatch { void CopyState(const std::vector>& desc); + void SwapState(BatchState*& a, BatchState*& b); + template void IndexedCopyImpl(const int* src_idx, const int* dst_idx, int count, const std::tuple&... cpys) { @@ -182,6 +236,14 @@ class LlamaBatch { void DestroyCommunicators(); + void SendIntermediateData(IntermediateData& inter); + + void RecvIntermediateData(IntermediateData& inter); + + void PreProcessIntermediateData(IntermediateData& inter); + + void PostProcessIntermediateData(IntermediateData& inter); + private: const EngineParam param_; @@ -221,20 +283,12 @@ class LlamaBatch { // context decoding temp buffers Tensor symm_hidden_states_buf_; Tensor symm_logits_buf_; + Tensor symm_residual_buf_; Tensor decoder_output_buf_; Tensor_ sampling_logits_; - Buffer_ input_ids_buf_; - - // lengths - Buffer_ input_length_buf_; // input + cache missed length - Buffer_ context_length_buf_; // history length + input_length - Buffer_ init_context_length_; - - Buffer_ sequence_lengths_; // current sequence length - Buffer_ init_ctx_lens_; Buffer_ lora_mask_buf_; // lora Buffer_ sampled_logprobs_; @@ -244,17 +298,8 @@ class LlamaBatch { Buffer_ h_sampled_indexes_; Buffer_ h_sampled_nums_; - Buffer_ rope_theta_; - - // used by dynamic decoder - Buffer_ token_ids_buf_; // all token IDs in [S, B], indexed using `step` - Buffer_ finished_buf_; - Buffer_ seq_limit_len_; - // pinned buffers Buffer_ h_output_ids_; - Buffer_ h_input_length_buf_; - Buffer_ h_seq_limit_len_; Buffer_ h_cu_block_counts_; Buffer_ h_block_ptrs_; @@ -265,12 +310,18 @@ class LlamaBatch { Tensor_ h_curand_state_; // [n, sizeof(curandState_t)] Tensor_ d_curand_state_; - std::array states_{}; + std::vector states_; BatchState* state_{}; BatchState* back_{}; BatchState* incoming_{}; + // pipeline parallel + std::deque> slots_; + std::queue> batch_que_; + std::vector gs_; + bool pp_abort_{false}; + // hard limits for persistent buffers static constexpr int kMaxStopBadWordsLen = 32; static constexpr int kMaxEndIdsSize = 32; diff --git a/src/turbomind/models/llama/LlamaLinear.cu b/src/turbomind/models/llama/LlamaLinear.cu index a9caebc002..557cb52566 100644 --- a/src/turbomind/models/llama/LlamaLinear.cu +++ b/src/turbomind/models/llama/LlamaLinear.cu @@ -400,7 +400,9 @@ Tensor LlamaLinear::forward(const Tensor& input, // out = Tensor({in.shape(0), output_dim}, input.dtype(), input.device()); } - impl_->forward(out, in, dense, type); + if (out.size() > 0 && in.size() > 0) { + impl_->forward(out, in, dense, type); + } auto shape = input.shape(); shape.back() = out.shape(-1); diff --git a/src/turbomind/models/llama/LlamaV2.cc b/src/turbomind/models/llama/LlamaV2.cc index 7279decfd6..c361f2b523 100644 --- a/src/turbomind/models/llama/LlamaV2.cc +++ b/src/turbomind/models/llama/LlamaV2.cc @@ -60,6 +60,7 @@ LlamaV2::LlamaV2(DataType dtype, int max_batch_size, std::shared_ptr weights): dtype_{dtype}, + engine_param_(engine), param_(model), attn_param_(attn), lora_param_(lora), @@ -156,6 +157,7 @@ void LlamaV2::updateEmbedding(char* decoder_input, void LlamaV2::Forward(Buffer_ input_ids, Tensor hidden_states_out, + Tensor residual, Tensor decoder_out, Buffer kv_block_ptrs, Buffer cu_block_nums, @@ -171,18 +173,18 @@ void LlamaV2::Forward(Buffer_ input_ids, { TM_LOG_DEBUG(__PRETTY_FUNCTION__); - Tensor input_embeds; - const int token_num = input_ids.size(); if (token_num) { const auto& embedding_table = weights_->pre_decoder_embedding.weight; TM_CHECK_EQ(embedding_table.shape(1) * tp_size_, hidden_units_); - input_embeds = Tensor{{token_num, (int)hidden_units_}, dtype_, kDEVICE}; + if (!residual) { + residual = Tensor{{token_num, (int)hidden_units_}, dtype_, kDEVICE}; + } if (tp_size_ == 1) { - invokeEmbeddingLookup(input_embeds, input_ids, embedding_table, stream_); + invokeEmbeddingLookup(residual, input_ids, embedding_table, stream_); sync_check_cuda_error(); } else if (use_allgather_2d_) { @@ -206,7 +208,7 @@ void LlamaV2::Forward(Buffer_ input_ids, stream_); sync_check_cuda_error(); - Copy(temp.buffer(), input_embeds.buffer()); + Copy(temp.buffer(), residual.buffer()); } else { const auto local_hidden_units = embedding_table.shape(1); @@ -221,7 +223,7 @@ void LlamaV2::Forward(Buffer_ input_ids, local.raw_data(), temp.raw_data(), local.size(), dtype_, comm_->d_tp_group, stream_); sync_check_cuda_error(); - invokeInPlaceTranspose102((uint16_t*)input_embeds.raw_data(), + invokeInPlaceTranspose102((uint16_t*)residual.raw_data(), (uint16_t*)temp.raw_data(), tp_size_, token_num, @@ -231,13 +233,13 @@ void LlamaV2::Forward(Buffer_ input_ids, sync_check_cuda_error(); } - TM_DEBUG_TENSOR(input_embeds, "embeddings", 1); + TM_DEBUG_TENSOR(residual, "embeddings", 1); } bool have_embeddings = false; if (token_num) { // Copy input embeddings from corresponding sequences - updateEmbedding((char*)input_embeds.raw_data(), + updateEmbedding((char*)residual.raw_data(), h_input_length.size(), h_input_length.data(), sequences, @@ -247,7 +249,7 @@ void LlamaV2::Forward(Buffer_ input_ids, sync_check_cuda_error(); } - TensorMap args{{"decoder_input", input_embeds}, + TensorMap args{{"decoder_input", residual}, {"decoder_output", hidden_states_out.view({-1, (int)hidden_units_}).borrow()}, {"last_token_hidden_units", decoder_out}, {"output_norm_weight", weights_->output_norm_weight}, diff --git a/src/turbomind/models/llama/LlamaV2.h b/src/turbomind/models/llama/LlamaV2.h index e799070b3a..35bc3785cf 100644 --- a/src/turbomind/models/llama/LlamaV2.h +++ b/src/turbomind/models/llama/LlamaV2.h @@ -62,6 +62,7 @@ class LlamaV2 { void Forward(Buffer_ input_ids, Tensor hidden_states_out, + Tensor residual, Tensor decoder_out, Buffer kv_block_ptrs, Buffer cu_block_nums, @@ -97,6 +98,7 @@ class LlamaV2 { const DataType dtype_; + const EngineParam engine_param_; const ModelParam param_; const AttentionParam attn_param_; const LoraParam lora_param_; diff --git a/src/turbomind/models/llama/LlamaWeight.cc b/src/turbomind/models/llama/LlamaWeight.cc index 8d51f71bd0..fcb8aa04d6 100644 --- a/src/turbomind/models/llama/LlamaWeight.cc +++ b/src/turbomind/models/llama/LlamaWeight.cc @@ -70,6 +70,11 @@ LlamaWeight::LlamaWeight(DataType data_type, decoder_layer_weights.reserve(num_layer_); for (int i = 0; i < num_layer_; ++i) { + if (engine_param.start_layer > i || i > engine_param.end_layer) { + // [start_layer, end_layer], using closed interval for norm + decoder_layer_weights.emplace_back(nullptr); + continue; + } decoder_layer_weights.emplace_back( new LlamaDecoderLayerWeight(data_type, i, model, engine_param, lora_param, moe_param)); register_module("layers", *decoder_layer_weights.back(), i); @@ -112,7 +117,9 @@ void LlamaWeight::prepare(const cudaDeviceProp& prop) auto stream = core::Context::stream().handle(); for (auto& layer : decoder_layer_weights) { - layer->prepare(prop, stream); + if (layer != nullptr) { + layer->prepare(prop, stream); + } } // Block until processing is done diff --git a/src/turbomind/models/llama/context.h b/src/turbomind/models/llama/context.h index 33b7be29ac..4e9a053380 100644 --- a/src/turbomind/models/llama/context.h +++ b/src/turbomind/models/llama/context.h @@ -19,9 +19,11 @@ struct Communicators { comm::HostComm h_comm; comm::HostComm h_tp_group; comm::HostComm h_dp_group; + comm::HostComm h_pp_group; comm::DeviceComm d_comm; int d_tp_group; + comm::DeviceComm d_pp_comm; }; // Execution context for the model diff --git a/src/turbomind/models/llama/llama_params.h b/src/turbomind/models/llama/llama_params.h index 4ba07779c9..856bca0f49 100644 --- a/src/turbomind/models/llama/llama_params.h +++ b/src/turbomind/models/llama/llama_params.h @@ -96,6 +96,17 @@ struct EngineParam { int attn_tp_rank; int mlp_tp_size; int mlp_tp_rank; + int pp_size; + int pp_rank; + + // decoder layer range for pp [start, end) + int start_layer; + int end_layer; + + // multi-node + int nnodes; + int node_rank; + int ngpus_per_node; std::vector devices; }; diff --git a/src/turbomind/models/llama/unified_attention_layer.cc b/src/turbomind/models/llama/unified_attention_layer.cc index 692a68997b..fe1aa14230 100644 --- a/src/turbomind/models/llama/unified_attention_layer.cc +++ b/src/turbomind/models/llama/unified_attention_layer.cc @@ -72,6 +72,7 @@ UnifiedAttentionLayer::UnifiedAttentionLayer(const ModelParam& model, local_kv_head_num_(model.kv_head_num / tp_size), param_(attn), model_param_(model), + engine_param_(engine), lora_param_(lora), context_(ctx), stream_(ctx.stream), @@ -179,9 +180,6 @@ void UnifiedAttentionLayer::Forward(ForwardParam p) const auto& weights = *p.weights; - // [L, 2, H, s, D] - const size_t layer_offset = layer_id * 2 * local_kv_head_num_ * param_.cache_block_seq_len * size_per_head_; - Tensor qkv; if (weights.qkv.output_dim) { @@ -258,7 +256,7 @@ Tensor UnifiedAttentionLayer::core_attention(Tensor& qkv, const ForwardParam& p, // Decoding use only params.block_iter_params = BlockIteratorParams{(char**)kv_block_ptrs_.data(), // cu_block_nums_.data() + offset, - layer_id, + layer_id - engine_param_.start_layer, (int)param_.cache_block_seq_len}; // Prefilling use only diff --git a/src/turbomind/models/llama/unified_attention_layer.h b/src/turbomind/models/llama/unified_attention_layer.h index a498b3b881..5467502c98 100644 --- a/src/turbomind/models/llama/unified_attention_layer.h +++ b/src/turbomind/models/llama/unified_attention_layer.h @@ -86,6 +86,7 @@ class UnifiedAttentionLayer { const int local_head_num_; const int local_kv_head_num_; + const EngineParam engine_param_; const AttentionParam param_; const ModelParam model_param_; const LoraParam lora_param_; diff --git a/src/turbomind/models/llama/unified_decoder.cc b/src/turbomind/models/llama/unified_decoder.cc index c875c7852f..0bc5cd2100 100644 --- a/src/turbomind/models/llama/unified_decoder.cc +++ b/src/turbomind/models/llama/unified_decoder.cc @@ -25,6 +25,7 @@ UnifiedDecoder::UnifiedDecoder(const ModelParam& model, const Context& ctx): layer_num_(model.layer_num), hidden_units_(model.hidden_units), + param_(engine), attn_tp_size_(engine.attn_tp_size), attn_dp_size_(engine.attn_dp_size), attn_dp_rank_(engine.attn_dp_rank), @@ -59,7 +60,7 @@ void UnifiedDecoder::AllreduceResidualRMSnorm(Tensor& hidden_states, if (0) {} else if (group0 || group1) { d_comm_->AllreduceResidualBiasRMSnormEx(hidden_states.raw_data(), - residual.raw_data(), + residual.data_or((void*)nullptr), bias.data_or((void*)nullptr), weight.raw_data(), rmsnorm_eps_, @@ -73,7 +74,7 @@ void UnifiedDecoder::AllreduceResidualRMSnorm(Tensor& hidden_states, } else if (d_comm_) { d_comm_->AllreduceResidualBiasRMSnorm(hidden_states.raw_data(), - residual.raw_data(), + residual.data_or((void*)nullptr), bias.data_or((void*)nullptr), weight.raw_data(), rmsnorm_eps_, @@ -86,7 +87,7 @@ void UnifiedDecoder::AllreduceResidualRMSnorm(Tensor& hidden_states, } else { invokeResidualBiasRMSNorm(hidden_states.raw_data(), - residual.raw_data(), + residual.data_or((void*)nullptr), weight.raw_data(), bias.data_or((void*)nullptr), dtype, @@ -132,7 +133,7 @@ void UnifiedDecoder::Forward(TensorMap& args, const std::vector& we Tensor local_hidden_states = global_hidden_states; const auto global_token_num = global_hidden_states.shape(0); - const auto local_token_num = local_residual.shape(0); + const auto local_token_num = local_residual.size() ? local_residual.shape(0) : 0; if (attn_dp_size_ > 1) { // Offset hidden states buffer for mixed DP TM_CHECK_EQ(local_token_nums.size(), attn_dp_size_); @@ -148,12 +149,14 @@ void UnifiedDecoder::Forward(TensorMap& args, const std::vector& we TM_DEBUG_TENSOR(local_residual, "res", 1); TM_DEBUG_TENSOR(weights.at(0)->self_attn_norm, "norm_weight", 2); - invokeRMSNorm(local_hidden_states, local_residual, weights.at(0)->self_attn_norm, rmsnorm_eps_, stream_); - sync_check_cuda_error(); + if (param_.pp_rank == 0) { + invokeRMSNorm(local_hidden_states, local_residual, weights.at(0)->self_attn_norm, rmsnorm_eps_, stream_); + sync_check_cuda_error(); - TM_DEBUG_TENSOR(local_hidden_states, Concat("norm0", 0), 2); + TM_DEBUG_TENSOR(local_hidden_states, Concat("norm0", 0), 2); + } - for (int layer = 0; layer < layer_num_; ++layer) { + for (int layer = param_.start_layer; layer < param_.end_layer; ++layer) { /// TODO: do not skip the layers when they are heterogeneous if (isTuning() && layer >= tune_layer_num_) { diff --git a/src/turbomind/models/llama/unified_decoder.h b/src/turbomind/models/llama/unified_decoder.h index dd03293744..0864d43c9b 100644 --- a/src/turbomind/models/llama/unified_decoder.h +++ b/src/turbomind/models/llama/unified_decoder.h @@ -25,6 +25,8 @@ class UnifiedDecoder { void Forward(TensorMap& args, const std::vector& weights); private: + const EngineParam param_; + const size_t layer_num_; const size_t hidden_units_; diff --git a/src/turbomind/triton_backend/llama/LlamaTritonModel.cc b/src/turbomind/triton_backend/llama/LlamaTritonModel.cc index c53ee2c4db..245c4dfbc2 100644 --- a/src/turbomind/triton_backend/llama/LlamaTritonModel.cc +++ b/src/turbomind/triton_backend/llama/LlamaTritonModel.cc @@ -324,9 +324,17 @@ LlamaTritonModel::LlamaTritonModel(DataType dtype, engine_param_.attn_tp_rank = 0; engine_param_.mlp_tp_size = engine_reader["mlp_tp_size"].as(); engine_param_.mlp_tp_rank = 0; + engine_param_.pp_size = engine_reader["pp"].as(); + engine_param_.pp_rank = 0; engine_param_.devices = engine_reader["devices"].as>(); + // multi-node information + engine_param_.nnodes = engine_reader["nnodes"].as(); + engine_param_.node_rank = engine_reader["node_rank"].as(); + engine_param_.ngpus_per_node = engine_reader["ngpus_per_node"].as(); + FT_CHECK(engine_param_.devices.size() == engine_param_.ngpus_per_node); + { auto tp = engine_param_.attn_tp_size; engine_param_.max_forward_token_num = ((size_t)max_forward_token_num + tp - 1) / tp * tp; @@ -361,8 +369,6 @@ LlamaTritonModel::LlamaTritonModel(DataType dtype, handleMissingParams(); - gateway_ = std::make_shared(engine_param_.outer_dp_size, engine_param_.attn_dp_size, ffi_ctx_factory); - weights_.resize(engine_param_.devices.size()); engines_.resize(engine_param_.devices.size()); @@ -400,21 +406,55 @@ LlamaTritonModel::LlamaTritonModel(DataType dtype, // NOTE: This runs on Python main thread group_ids_.resize(engine_param_.outer_dp_size); for (size_t i = 0; i < group_ids_.size(); ++i) { - group_ids_[i] = comm::CreateHostGroupId(""); + // TODO: fine-grained comm control + const std::string group_backend = (comm_size_ <= engine_param_.ngpus_per_node) ? "" : "gloo"; + + group_ids_[i] = comm::CreateHostGroupId(group_backend); group_ids_[i]->Initialize(); } - const int device_num = engine_param_.outer_dp_size * comm_size_; + const int device_num_per_pp = engine_param_.outer_dp_size * comm_size_; + const int device_num = engine_param_.outer_dp_size * comm_size_ * engine_param_.pp_size; engine_params_.resize(device_num, engine_param_); for (int i = 0; i < device_num; ++i) { auto& e = engine_params_[i]; - e.outer_dp_rank = i / comm_size_; + e.outer_dp_rank = i % device_num_per_pp / comm_size_; e.attn_tp_rank = i % comm_size_ % e.attn_tp_size; e.attn_dp_rank = i % comm_size_ / e.attn_tp_size; e.mlp_tp_rank = i % comm_size_; + e.pp_rank = i / device_num_per_pp; + } + + std::vector decoder_layers = {0, (int)model_param_.layer_num}; + if (engine_param_.pp_size > 1) { + decoder_layers.resize(engine_param_.pp_size + 1); + for (int i = 1; i <= engine_param_.pp_size; ++i) { + int layer_num_i = model_param_.layer_num / engine_param_.pp_size; + if (i <= model_param_.layer_num % engine_param_.pp_size) { + layer_num_i++; + } + decoder_layers[i] = decoder_layers[i - 1] + layer_num_i; + } + } + for (auto& e : engine_params_) { + e.start_layer = decoder_layers[e.pp_rank]; + e.end_layer = decoder_layers[e.pp_rank + 1]; } + std::vector node_dp_ranks; + for (int local_rank = 0, offset = engine_param_.ngpus_per_node * engine_param_.node_rank; + local_rank < engine_param_.ngpus_per_node; + ++local_rank) { + auto& e = engine_params_[offset + local_rank]; + if (e.attn_tp_rank == 0 && e.pp_rank == 0) { + node_dp_ranks.push_back(e.outer_dp_rank * e.attn_dp_size + e.attn_dp_rank); + } + } + + gateway_ = std::make_shared( + engine_param_.outer_dp_size, engine_param_.attn_dp_size, std::move(node_dp_ranks), ffi_ctx_factory); + TM_LOG_INFO("%s", toString().c_str()); } @@ -429,7 +469,7 @@ std::unique_ptr LlamaTritonModel::createModelInstance(int device_i void LlamaTritonModel::createSharedWeights(int device_id, int rank) { CudaDeviceGuard dev_guard(engine_param_.devices[device_id]); - weights_[rank] = + weights_[rank % engine_param_.devices.size()] = std::make_shared(dtype_, model_param_, engine_params_.at(rank), lora_param_, moe_param_); // model inited with model_dir // if (model_dir_ != "") { @@ -439,7 +479,7 @@ void LlamaTritonModel::createSharedWeights(int device_id, int rank) TensorMap LlamaTritonModel::getParams(int device_id, int rank) { - return TM_CHECK_NOTNULL(weights_[rank])->get_parameters(); + return TM_CHECK_NOTNULL(weights_[rank % engine_param_.devices.size()])->get_parameters(); } void LlamaTritonModel::processWeights(int device_id, int rank) @@ -458,22 +498,31 @@ Communicators LlamaTritonModel::createCommSplits(int rank) { Communicators comm{}; - const int outer_rank = rank / comm_size_; - const int inner_rank = rank % comm_size_; + const int device_num_per_pp = engine_param_.outer_dp_size * comm_size_; - comm.h_comm = group_ids_[outer_rank]->CreateCommunicator(comm_size_, inner_rank); + const int outer_rank = rank % device_num_per_pp / comm_size_; + const int inner_rank = (rank / device_num_per_pp) * comm_size_ + rank % comm_size_; - comm.h_tp_group = comm.h_comm->Split(inner_rank / engine_param_.attn_tp_size, 0); - comm.h_dp_group = comm.h_comm->Split(inner_rank % engine_param_.attn_tp_size, 0); + auto h_comm_all_pp = group_ids_[outer_rank]->CreateCommunicator(comm_size_ * engine_param_.pp_size, inner_rank); + + comm.h_comm = h_comm_all_pp->Split(inner_rank / comm_size_, 0); + + comm.h_tp_group = comm.h_comm->Split(comm.h_comm->rank() / engine_param_.attn_tp_size, 0); + comm.h_dp_group = comm.h_comm->Split(comm.h_comm->rank() % engine_param_.attn_tp_size, 0); + comm.h_pp_group = h_comm_all_pp->Split(inner_rank % comm_size_, 0); if (comm_size_ > 1) { - comm.d_comm = CreateDeviceCommunicator(communicator_, comm_size_, inner_rank, comm.h_comm); + comm.d_comm = CreateDeviceCommunicator(communicator_, comm_size_, comm.h_comm->rank(), comm.h_comm); // comm.d_tp_group = 0; if (engine_param_.attn_tp_size != comm_size_) { - comm.d_tp_group = comm.d_comm->Split(inner_rank / engine_param_.attn_tp_size, 0, 0); + comm.d_tp_group = comm.d_comm->Split(comm.h_comm->rank() / engine_param_.attn_tp_size, 0, 0); } } + if (engine_param_.pp_size > 1) { + comm.d_pp_comm = + CreateDeviceCommunicator(communicator_, engine_param_.pp_size, inner_rank / comm_size_, comm.h_pp_group); + } return comm; } @@ -510,7 +559,7 @@ void LlamaTritonModel::createEngine(int device_id, int rank) try { const int dp_rank = engine_param.outer_dp_rank * engine_param.attn_dp_size + engine_param.attn_dp_rank; engines_[device_id] = std::make_unique(dtype_, - engine_param_, // + engine_param, // std::move(model), std::move(ctx), gateway_, @@ -530,7 +579,10 @@ void LlamaTritonModel::createEngine(int device_id, int rank) auto& engine = *engines_[device_id]; try { - engine.Warmup(); + if (engine_param_.pp_size == 1) { + // TODO: support pp + engine.Warmup(); + } } catch (const std::exception& e) { TM_LOG_ERROR("[Engine][Warmup] %s", e.what());