Skip to content
17 changes: 15 additions & 2 deletions lightllm/common/basemodel/cuda_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import copy
import bisect
from typing import Optional
from tqdm import tqdm
from lightllm.utils.log_utils import init_logger
from lightllm.utils.envs_utils import get_env_start_args
from lightllm.distributed import dist_group_manager, lightllm_capture_graph, CustomProcessGroup
Expand Down Expand Up @@ -191,7 +192,12 @@ def warmup(self, model):
model: TpPartBaseModel = model

# decode cuda graph init
for batch_size in self.cuda_graph_batch_sizes[::-1]:
progress_bar = tqdm(self.cuda_graph_batch_sizes[::-1], desc="Capturing CUDA graphs")
for batch_size in progress_bar:
# Get available memory info
avail_mem, total_mem = torch.cuda.mem_get_info()
avail_mem_gb = avail_mem / (1024 ** 3)
progress_bar.set_description(f"Capturing CUDA graphs - Batch: {batch_size}, AvailMem: {avail_mem_gb:.2f}GB")
seq_len = 2
total_token_num = batch_size * seq_len
max_len_in_batch = self.graph_max_len_in_batch
Expand Down Expand Up @@ -246,7 +252,14 @@ def warmup_overlap(self, model):

model: TpPartBaseModel = model

for batch_size in self.cuda_graph_batch_sizes[::-1]:
progress_bar = tqdm(self.cuda_graph_batch_sizes[::-1], desc="Capturing overlap CUDA graphs")
for batch_size in progress_bar:
# Get available memory info
avail_mem, total_mem = torch.cuda.mem_get_info()
avail_mem_gb = avail_mem / (1024 ** 3)
progress_bar.set_description(
f"Capturing overlap CUDA graphs - Batch: {batch_size}, AvailMem: {avail_mem_gb:.2f}GB"
)
decode_batches = []
for micro_batch_index in [0, 1]:
# dummy decoding, capture the cudagraph
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@
from .norm_weight import NormWeight, GEMMANormWeight, TpNormWeight
from .fused_moe_weight_tp import create_tp_moe_wegiht_obj
from .fused_moe_weight_ep import FusedMoeWeightEP
from .parameter_weight import ParameterWeight, TpParameterWeight
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import torch
from typing import Dict
from .base_weight import BaseWeightTpl
from lightllm.utils.dist_utils import get_current_device_id


class ParameterWeight(BaseWeightTpl):
def __init__(self, weight_name: str, data_type: torch.dtype, bias_name: str = None):
super().__init__()
self.weight_name = weight_name
self.bias_name = bias_name
self.data_type_ = data_type
self.weight = None
self.bias = None

def load_hf_weights(self, weights: Dict[str, torch.Tensor]) -> None:
if self.weight_name in weights:
self.weight = weights[self.weight_name].to(self.data_type_).cuda(get_current_device_id())
if self.bias_name in weights:
self.bias = weights[self.bias_name].to(self.data_type_).cuda(get_current_device_id())

def verify_load(self):
load_ok = True
# Verify weight. The weight must be not None.
load_ok = load_ok and self.weight is not None
# Verify bias. If bias_name is set, it must be not None.
if self.bias_name is not None:
load_ok = load_ok and self.bias is not None
return load_ok


class TpParameterWeight(ParameterWeight):
def __init__(self, weight_name: str, data_type: torch.dtype, split_n_embed: int, bias_name: str = None):
super().__init__(weight_name, data_type, bias_name)
self.split_n_embed = split_n_embed

def load_hf_weights(self, weights: Dict[str, torch.Tensor]) -> None:
start = self.split_n_embed * self.tp_rank_
end = self.split_n_embed * (self.tp_rank_ + 1)

if self.weight_name in weights:
self.weight = weights[self.weight_name][start:end].to(self.data_type_).cuda(get_current_device_id())
if self.bias_name in weights:
self.bias = weights[self.bias_name][start:end].to(self.data_type_).cuda(get_current_device_id())
170 changes: 96 additions & 74 deletions lightllm/common/kv_cache_mem_manager/mem_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,9 @@
logger = init_logger(__name__)


class MemoryManager:
def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9):
class BaseAllocator:
def __init__(self, size, mem_manager_name=None):
self.size = size
self.head_num = head_num
self.head_dim = head_dim
self.layer_num = layer_num
self.always_copy = always_copy
self.dtype = dtype
# profile the max total token num if the size is None
self.profile_size(mem_fraction)

self.mem_state = torch.arange(
0, self.size, dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True
Expand All @@ -48,22 +41,102 @@ def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False
self.can_use_mem_size = self.size

# 用共享内存进行共享,router 模块读取进行精确的调度估计, nccl port 作为一个单机中单实列的标记。防止冲突。
from lightllm.utils.envs_utils import get_unique_server_name

if mem_manager_name is None:
mem_manager_name = get_unique_server_name()
rank_in_node = get_current_rank_in_node()
self.shared_can_use_token_num = SharedInt(
f"{get_unique_server_name()}_mem_manger_can_use_token_num_{rank_in_node}"
)
self.shared_can_use_token_num = SharedInt(f"{mem_manager_name}_mem_manger_can_use_token_num_{rank_in_node}")

self.shared_can_use_token_num.set_value(self.can_use_mem_size)
self.HOLD_TOKEN_MEMINDEX = self.size

def alloc(self, need_size) -> torch.Tensor:
if need_size > self.mark_end - self.mark_start:
logger.error(f"warn no enough cache need_size {need_size} left_size {self.can_use_mem_size}")
assert False, "error alloc state"

start = self.mark_start
end = self.mark_start + need_size
self.mark_start += need_size

self.can_use_mem_size -= need_size
self.shared_can_use_token_num.set_value(self.can_use_mem_size)

# 利用缓冲区返回,避免异步情况下的内存竞争
if self._return_start + need_size > self._mem_state_return.shape[0]:
self._return_start = 0
ans = self._mem_state_return[self._return_start : self._return_start + need_size]
ans.copy_(self.mem_state[start:end])
self._return_start += need_size
return ans

def free(self, free_index: Union[torch.Tensor, List[int]]):
"""_summary_

Args:
free_index (torch.Tensor): _description_
"""
end = self.mark_start
start = self.mark_start - len(free_index)
assert start >= 0, f"error free state start: {self.mark_start} free len {len(free_index)}"

if isinstance(free_index, list):
self.mem_state.numpy()[start:end] = free_index
else:
# 从 gpu 到 cpu 的拷贝操作是流内阻塞操作
self.mem_state[start:end] = free_index

self.mark_start -= len(free_index)

self.can_use_mem_size += len(free_index)
self.shared_can_use_token_num.set_value(self.can_use_mem_size)

if self.can_use_mem_size == len(self.mem_state):
logger.debug(f"freed all gpu mem size {self.can_use_mem_size}")
return

def free_all(self):
self.can_use_mem_size = len(self.mem_state)
self.shared_can_use_token_num.set_value(self.can_use_mem_size)
self.mem_state.numpy()[:] = list(range(0, len(self.mem_state)))
self.mark_start = 0
self.mark_end = len(self.mem_state)

def resize_mem(self, new_size):
"""
just for test code
"""
self.size = new_size
self.mem_state = torch.arange(
0, self.size, dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True
)
self.mark_start = 0
self.mark_end = self.size
self.can_use_mem_size = self.size
self.shared_can_use_token_num.set_value(self.can_use_mem_size)
return


class MemoryManager(BaseAllocator):
def __init__(
self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9, mem_manager_name=None
):
self.size = size
self.head_num = head_num
self.head_dim = head_dim
self.layer_num = layer_num
self.always_copy = always_copy
self.dtype = dtype
# profile the max total token num if the size is None
self.profile_size(mem_fraction)
super().__init__(self.size, mem_manager_name)

self._init_buffers(
self.size,
dtype,
head_num,
head_dim,
layer_num,
)
self.HOLD_TOKEN_MEMINDEX = self.size

def get_cell_size(self):
return 2 * self.head_num * self.head_dim * self.layer_num * torch._utils._element_size(self.dtype)
Expand Down Expand Up @@ -326,59 +399,13 @@ def _write_kv_move_data_p2p(self, token_indexes: torch.Tensor, buffer_tensor: to
def _free_buffers(self):
self.kv_buffer = None

def alloc(self, need_size) -> torch.Tensor:
if need_size > self.mark_end - self.mark_start:
logger.error(f"warn no enough cache need_size {need_size} left_size {self.can_use_mem_size}")
assert False, "error alloc state"

start = self.mark_start
end = self.mark_start + need_size
self.mark_start += need_size

self.can_use_mem_size -= need_size
self.shared_can_use_token_num.set_value(self.can_use_mem_size)

# 利用缓冲区返回,避免异步情况下的内存竞争
if self._return_start + need_size > self._mem_state_return.shape[0]:
self._return_start = 0
ans = self._mem_state_return[self._return_start : self._return_start + need_size]
ans.copy_(self.mem_state[start:end])
self._return_start += need_size
return ans

def free(self, free_index: Union[torch.Tensor, List[int]]):
"""_summary_

Args:
free_index (torch.Tensor): _description_
"""

end = self.mark_start
start = self.mark_start - len(free_index)
assert start >= 0, f"error free state start: {self.mark_start} free len {len(free_index)}"

if isinstance(free_index, list):
self.mem_state.numpy()[start:end] = free_index
else:
# 从 gpu 到 cpu 的拷贝操作是流内阻塞操作
self.mem_state[start:end] = free_index

self.mark_start -= len(free_index)

self.can_use_mem_size += len(free_index)
self.shared_can_use_token_num.set_value(self.can_use_mem_size)

if self.can_use_mem_size == len(self.mem_state):
logger.debug(f"freed all gpu mem size {self.can_use_mem_size}")
return
def get_index_kv_buffer(self, index):
return {"kv_buffer": self.kv_buffer[:, index]}

def free_all(self):
self.can_use_mem_size = len(self.mem_state)
self.shared_can_use_token_num.set_value(self.can_use_mem_size)
self.mem_state.numpy()[:] = list(range(0, len(self.mem_state)))
self.mark_start = 0
self.mark_end = len(self.mem_state)
def load_index_kv_buffer(self, index, load_tensor_dict):
self.kv_buffer[:, index].copy_(load_tensor_dict["kv_buffer"])

# 重写resize_mem方法,添加_free_buffers和_init_buffers调用
def resize_mem(self, new_size):
"""
just for test code
Expand All @@ -389,14 +416,9 @@ def resize_mem(self, new_size):
head_dim = self.head_dim
layer_num = self.layer_num

self.size = new_size
self.mem_state = torch.arange(
0, self.size, dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True
)
self.mark_start = 0
self.mark_end = self.size
self.can_use_mem_size = self.size
self.shared_can_use_token_num.set_value(self.can_use_mem_size)
# 调用父类的resize_mem
super().resize_mem(new_size)

self._free_buffers()
self._init_buffers(size, dtype, head_num, head_dim, layer_num)
return
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
{
"4": {
"BK": 128,
"BV": 64,
"num_stages": 4,
"num_warps": 4
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
{
"4": {
"BK": 128,
"BV": 64,
"num_stages": 2,
"num_warps": 4
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
{
"4": {
"BK": 64,
"BV": 128,
"num_stages": 3,
"num_warps": 4
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
{
"4": {
"BV": 32,
"num_stages": 4,
"num_warps": 4
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
{
"1": {
"num_warps": 4
},
"100": {
"num_warps": 8
},
"1024": {
"num_warps": 4
},
"128": {
"num_warps": 1
},
"16": {
"num_warps": 4
},
"164096": {
"num_warps": 1
},
"2048": {
"num_warps": 2
},
"256": {
"num_warps": 1
},
"32": {
"num_warps": 8
},
"4096": {
"num_warps": 2
},
"64": {
"num_warps": 8
},
"8": {
"num_warps": 8
},
"8448": {
"num_warps": 1
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
{
"4": {
"BK": 64,
"num_stages": 3,
"num_warps": 4
}
}
Loading
Loading