diff --git a/include/infinicore_infer/models/jiuge.h b/include/infinicore_infer/models/jiuge.h index e89e171a..f4a5b27b 100644 --- a/include/infinicore_infer/models/jiuge.h +++ b/include/infinicore_infer/models/jiuge.h @@ -12,7 +12,7 @@ struct JiugeModel; typedef struct { infiniDtype_t dt_logits; - size_t nlayer, d, nh, nkvh, dh, di, dctx, dvoc; + size_t nlayer, d, nh, nkvh, dh, di, dctx, dvoc, kvcache_block_size; float epsilon, theta; uint32_t end_token; } JiugeMeta; @@ -65,6 +65,10 @@ destroyJiugeModel(struct JiugeModel *); __C __export struct KVCache * createKVCache(const struct JiugeModel *); +/// @brief 创建 Paged KV Cache +__C __export struct KVCache * +createPagedKVCache(const struct JiugeModel *, uint32_t max_kvcache_tokens); + /// @brief 复制 KV Cache __C __export struct KVCache * duplicateKVCache(const struct JiugeModel *, @@ -85,13 +89,18 @@ dropKVCache(const struct JiugeModel *, /// @param temperature 采样温度(0. 表示贪心采样) /// @param topk 采样 topk(1 表示贪心采样) /// @param topp 采样 topp +/// @param is_prefill 是否按 prefill 流程处理,0 表示 decode,1 表示 prefill +/// @param enable_paged_attn 是否启用 paged attention /// @param output 输出 token 数组,每个请求一个输出,长度至少为nreq __C __export void inferBatch(struct JiugeModel *, const uint32_t *tokens, uint32_t ntok, const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos, struct KVCache **kv_caches, + const int32_t *block_tables, + const int32_t *slot_mapping, const float *temperature, const uint32_t *topk, const float *topp, + const uint32_t is_prefill, const bool enable_paged_attn, uint32_t *output); /// @brief 批次推理一轮,输出 output embedding 后的 logits @@ -101,12 +110,19 @@ inferBatch(struct JiugeModel *, /// @param req_lens 每个请求的 token 数量 /// @param req_pos 每个请求的起始位置 /// @param kv_caches 每个请求的 KV Cache +/// @param block_tables 每个请求的 block 表 +/// @param slot_mapping 每个请求的 slot 映射 +/// @param is_prefill 是否按 prefill 流程处理,0 表示 decode,1 表示 prefill +/// @param enable_paged_attn 是否启用 paged attention /// @param logits 输出 token 数组,每个请求一个输出,长度至少为nreq __C __export void forwardBatch(struct JiugeModel *, const uint32_t *tokens, uint32_t ntok, const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos, struct KVCache **kv_caches, + const int32_t *block_tables, + const int32_t *slot_mapping, + const uint32_t is_prefill, const bool enable_paged_attn, void *logits); #endif diff --git a/python/bench.py b/python/bench.py new file mode 100644 index 00000000..f4c8512f --- /dev/null +++ b/python/bench.py @@ -0,0 +1,83 @@ +import os +os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3,4,5,6,7" +import time +import sys +from random import randint, seed +# from nanovllm import LLM, SamplingParams +# from vllm import LLM, SamplingParams + +from icinfer import LLM, SamplingParams +from icinfer.engine.libinfinicore_infer import DeviceType + +import logging +logger = logging.getLogger(__name__) +import argparse + +def parse_args(): + parser = argparse.ArgumentParser() + # parser.add_argument("--model-path", type=str, default="/home/wanghaojie/vllm/huggingface/Llama-2-7b-chat-hf") + # parser.add_argument("--model-path", type=str, default="/home/wanghaojie/vllm/huggingface/FM9G_70B_SFT_MHA/") + parser.add_argument("--model-path", type=str, default="/home/wanghaojie/vllm/huggingface/9G7B_MHA/") + parser.add_argument("--device-type", type=str, default="nvidia") + parser.add_argument("--ndev", type=int, default=4) + parser.add_argument("--max-kvcache-tokens", type=int, default=131072) + args = parser.parse_args() + return args + +def main(): + args = parse_args() + model_path = args.model_path + max_kvcache_tokens = args.max_kvcache_tokens + device_type = DeviceType.DEVICE_TYPE_CPU + if args.device_type == "cpu": + device_type = DeviceType.DEVICE_TYPE_CPU + elif args.device_type == "nvidia": + device_type = DeviceType.DEVICE_TYPE_NVIDIA + elif args.device_type == "cambricon": + device_type = DeviceType.DEVICE_TYPE_CAMBRICON + elif args.device_type == "ascend": + device_type = DeviceType.DEVICE_TYPE_ASCEND + elif args.device_type == "metax": + device_type = DeviceType.DEVICE_TYPE_METAX + elif args.device_type == "moore": + device_type = DeviceType.DEVICE_TYPE_MOORE + elif args.device_type == "iluvatar": + device_type = DeviceType.DEVICE_TYPE_ILUVATAR + else: + logger.info( + # "Usage: python jiuge.py [--cpu | --nvidia| --cambricon | --ascend | --metax | --moore] [n_device]" + "Usage: python jiuge.py [cpu | nvidia| cambricon | ascend | metax | moore] [n_device]" + ) + sys.exit(1) + + seed(0) + # num_seqs = 128 + num_seqs = 8 + max_input_len = 1024 + max_ouput_len = 1024 + + path = os.path.expanduser("/home/wanghaojie/vllm/huggingface/9G7B_MHA/") + llm = LLM(path, device=device_type, enforce_eager=True, + tensor_parallel_size=args.ndev, trust_remote_code=True, + attention_bias=True, enable_paged_attn=True, max_kvcache_tokens=max_kvcache_tokens) + + + prompt_token_ids = [[randint(0, 10000) for _ in range(randint(100, max_input_len))] for _ in range(num_seqs)] + + sampling_params = [SamplingParams(temperature=0.6, ignore_eos=True, max_tokens=randint(100, max_ouput_len)) for _ in range(num_seqs)] + # sampling_params = [SamplingParams(temperature=0.6, ignore_eos=True, max_tokens=randint(100, max_ouput_len)) for _ in range(num_seqs)] + # uncomment the following line for vllm + # prompt_token_ids = [dict(prompt_token_ids=p) for p in prompt_token_ids] + + llm.generate(["Benchmark: "], SamplingParams()) + t = time.time() + # llm.generate(prompt_token_ids, sampling_params, use_tqdm=False) + outputs = llm.generate(prompt_token_ids, sampling_params) + t = (time.time() - t) + total_tokens = sum(sp.max_tokens for sp in sampling_params) + throughput = total_tokens / t + print(f"Total: {total_tokens}tok, Time: {t:.2f}s, Throughput: {throughput:.2f}tok/s") + + +if __name__ == "__main__": + main() diff --git a/python/example.py b/python/example.py new file mode 100644 index 00000000..b1e06906 --- /dev/null +++ b/python/example.py @@ -0,0 +1,157 @@ +import os +os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3,4,5,6,7" +# os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3" + +import sys +from transformers import AutoTokenizer +import argparse + +from icinfer import LLM, SamplingParams +from icinfer.engine.libinfinicore_infer import DeviceType + +import logging +logger = logging.getLogger(__name__) + +def parse_args(): + parser = argparse.ArgumentParser() + # parser.add_argument("--model-path", type=str, default="/home/wanghaojie/vllm/huggingface/FM9G_70B_SFT_MHA/") + parser.add_argument("--model-path", type=str, default="/home/wanghaojie/vllm/huggingface/9G7B_MHA/") + parser.add_argument("--device-type", type=str, default="nvidia") + parser.add_argument("--ndev", type=int, default=1) + parser.add_argument("--max-kvcache-tokens", type=int, default=10240) + # parser.add_argument("--max-kvcache-tokens", type=int, default=65536) + parser.add_argument("--enable-paged-attn", action="store_true") + # parser.add_argument("--enable-paged-attn", type=bool, default=True) + args = parser.parse_args() + return args + +def main(): + args = parse_args() + model_path = args.model_path + max_kvcache_tokens = args.max_kvcache_tokens + device_type = DeviceType.DEVICE_TYPE_CPU + if args.device_type == "cpu": + device_type = DeviceType.DEVICE_TYPE_CPU + elif args.device_type == "nvidia": + device_type = DeviceType.DEVICE_TYPE_NVIDIA + elif args.device_type == "cambricon": + device_type = DeviceType.DEVICE_TYPE_CAMBRICON + elif args.device_type == "ascend": + device_type = DeviceType.DEVICE_TYPE_ASCEND + elif args.device_type == "metax": + device_type = DeviceType.DEVICE_TYPE_METAX + elif args.device_type == "moore": + device_type = DeviceType.DEVICE_TYPE_MOORE + elif args.device_type == "iluvatar": + device_type = DeviceType.DEVICE_TYPE_ILUVATAR + else: + logger.info( + # "Usage: python jiuge.py [--cpu | --nvidia| --cambricon | --ascend | --metax | --moore] [n_device]" + "Usage: python jiuge.py [cpu | nvidia| cambricon | ascend | metax | moore] [n_device]" + ) + sys.exit(1) + + # path = os.path.expanduser("~/vllm/huggingface/Qwen3-0.6B/") + # tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True) + # llm = LLM(path, enforce_eager=True, tensor_parallel_size=1, trust_remote_code=True) + # path = os.path.expanduser("/home/wanghaojie/vllm/huggingface/9G7B_MHA/") + path = args.model_path + tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True) + llm = LLM(path, device=device_type, enforce_eager=True, + tensor_parallel_size=args.ndev, trust_remote_code=True, + attention_bias=True, enable_paged_attn=args.enable_paged_attn, max_kvcache_tokens=max_kvcache_tokens) + + sampling_params = SamplingParams(temperature=0.6, max_tokens=128) + # prompts = [ + # "introduce yourself", + # # "list all prime numbers within 100", + # "山东最高的山是?", + # "如果猫能写诗,它们会写些什么?", + # "描述一个没有重力的世界。", + # "如果地球停止自转,会发生什么?", + # "假设你是一只会飞的鲸鱼,描述你的日常生活。", + # "如果人类可以与植物沟通,世界会变成什么样?", + # "描述一个由糖果构成的城市。", + # "如果时间旅行成为可能,你最想去哪个时代?", + # "想象一下,如果地球上只有蓝色,其他颜色都消失了。", + # "如果动物能上网,它们会浏览什么网站?", + # "描述一个没有声音的世界。", + # "如果人类可以在水下呼吸,城市会如何变化?", + # "想象一下,如果天空是绿色的,云是紫色的。", + # "如果你能与任何历史人物共进晚餐,你会选择谁?", + # "描述一个没有夜晚的星球。", + # "如果地球上只有一种语言,世界会如何运作?", + # "想象一下,如果所有的书都变成了音乐。", + # "如果你可以变成任何一种动物,你会选择什么?", + # "描述一个由机器人统治的未来世界。", + # "如果你能与任何虚构角色成为朋友,你会选择谁?", + # "想象一下,如果每个人都能读懂他人的思想。" + # ] * 2 + prompts = [ + # "描述一个由糖果构成的城市。", + # "如果时间旅行成为可能,你最想去哪个时代?", + # "如果时间旅行成为可能,你最想去哪个时代?", + # "想象一下,如果地球上只有蓝色,其他颜色都消失了。", + # "如果动物能上网,它们会浏览什么网站?", + # "描述一个由糖果构成的城市。", + # "如果时间旅行成为可能,你最想去哪个时代?", + # "想象一下,如果地球上只有蓝色,其他颜色都消失了。", + # "如果动物能上网,它们会浏览什么网站?", + + "如果人类可以与植物沟通,世界会变成什么样?", + "描述一个由糖果构成的城市。", + "如果时间旅行成为可能,你最想去哪个时代?", + "想象一下,如果地球上只有蓝色,其他颜色都消失了。", + "如果动物能上网,它们会浏览什么网站?", + "描述一个没有声音的世界。", + "如果人类可以在水下呼吸,城市会如何变化?", + "想象一下,如果天空是绿色的,云是紫色的。", + # "如果你能与任何历史人物共进晚餐,你会选择谁?", + # "描述一个没有夜晚的星球。", + # "如果地球上只有一种语言,世界会如何运作?", + # "想象一下,如果所有的书都变成了音乐。", + # "如果你可以变成任何一种动物,你会选择什么?", + # "描述一个由机器人统治的未来世界。", + # "如果你能与任何虚构角色成为朋友,你会选择谁?", + # "想象一下,如果每个人都能读懂他人的思想。" + + # "如果人类可以与植物沟通,世界会变成什么样?", + # "描述一个由糖果构成的城市。", + # "如果人类可以与植物沟通,世界会变成什么样?", + + ] + prompts = [ + tokenizer.apply_chat_template( + [{"role": "user", "content": prompt}], + tokenize=False, + add_generation_prompt=True, + enable_thinking=False + ) + for prompt in prompts + ] + outputs, avg_prefill_throughput, avg_decode_throughput, avg_ttft, avg_tbt, cache_efficiency = llm.generate(prompts, sampling_params) + + for prompt, output in zip(prompts, outputs): + print("\n") + print(f"Prompt: {prompt!r}") + print(f"Completion: {output['text']!r}") + # print("\n") + # print(f"Prompt: {prompts[0]!r}") + # print(f"Completion: {outputs[0]['text']!r}") + print(f"batch_size: {len(prompts)}, n_dev: {args.ndev}, is_paged_attn: {args.enable_paged_attn}") + print(f"Avg Prefill Throughput: {avg_prefill_throughput:.2f} tok/s") + print(f"Avg Decode Throughput: {avg_decode_throughput:.2f} tok/s") + print(f"Avg TTFT: {avg_ttft*1000:.2f} ms") + print(f"Avg TBT: {avg_tbt*1000:.2f} ms") + print(f"Cache Efficiency: {cache_efficiency*100:.2f}%") + +if __name__ == "__main__": + main() + + +""" +CLI: +python example.py --model-path /home/wanghaojie/vllm/huggingface/9G7B_MHA/ --device-type nvidia --ndev 4 --max-kvcache-tokens 10240 --enable-paged-attn +python example.py --model-path /home/wanghaojie/vllm/huggingface/9G7B_MHA/ --device-type nvidia --ndev 4 + +""" \ No newline at end of file diff --git a/python/icinfer.egg-info/PKG-INFO b/python/icinfer.egg-info/PKG-INFO new file mode 100644 index 00000000..18e0a57e --- /dev/null +++ b/python/icinfer.egg-info/PKG-INFO @@ -0,0 +1,13 @@ +Metadata-Version: 2.4 +Name: icinfer +Version: 0.1.0 +Summary: a lightweight, hardware-agnostic, unified inference engine implementation built from scratch, based on InfiniCore +Author: +License-Expression: MIT +Project-URL: Homepage, https://github.com/InfiniTensor/InfiniLM +Requires-Python: <3.13,>=3.10 +Description-Content-Type: text/markdown +Requires-Dist: torch>=2.4.0 +Requires-Dist: triton>=3.0.0 +Requires-Dist: transformers>=4.51.0 +Requires-Dist: xxhash diff --git a/python/icinfer/__init__.py b/python/icinfer/__init__.py new file mode 100644 index 00000000..63cb090a --- /dev/null +++ b/python/icinfer/__init__.py @@ -0,0 +1,2 @@ +from icinfer.llm import LLM +from icinfer.sampling_params import SamplingParams diff --git a/python/icinfer/bench/__init__.py b/python/icinfer/bench/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/python/icinfer/bench/jiuge_ppl.py b/python/icinfer/bench/jiuge_ppl.py new file mode 100644 index 00000000..84fd7dd7 --- /dev/null +++ b/python/icinfer/bench/jiuge_ppl.py @@ -0,0 +1,162 @@ +import torch +from transformers import AutoTokenizer, AutoModelForCausalLM +from datasets import load_dataset +import sys + + +from icinfer import LLM, SamplingParams +# from icinfer.engine.llm_engine import InfiniEngine +from icinfer.engine.libinfinicore_infer import DeviceType + +DEVICE_TYPE_MAP = { + "cpu": DeviceType.DEVICE_TYPE_CPU, + "nvidia": DeviceType.DEVICE_TYPE_NVIDIA, + "cambricon": DeviceType.DEVICE_TYPE_CAMBRICON, + "ascend": DeviceType.DEVICE_TYPE_ASCEND, + "metax": DeviceType.DEVICE_TYPE_METAX, + "moore": DeviceType.DEVICE_TYPE_MOORE, +} + +TORCH_DEVICE_TYPE_MAP = { + "cpu": "cpu", + "nvidia": "cuda", + "cambricon": "mlu", + "ascend": "npu", + "metax": "cuda", + "moore": "cuda", +} + + +def test_torch(input_ids_list, device_): + device = TORCH_DEVICE_TYPE_MAP[device_] + model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True).to( + device + ) + model.eval() + + total_neg_log_likelihood = 0 + total_tokens = 0 + + with torch.no_grad(): + for input_ids in input_ids_list: + input_ids = torch.tensor(input_ids, device=device) + # shift inputs and labels + inputs = input_ids[:-1].unsqueeze(0) # [1, seq_len-1] + labels = input_ids[1:].unsqueeze(0) # [1, seq_len-1] + + outputs = model(inputs, use_cache=False) + logits = outputs.logits # [1, seq_len-1, vocab_size] + + log_probs = torch.nn.functional.log_softmax(logits, dim=-1) + # gather log probs of true tokens + true_token_log_probs = log_probs.gather( + dim=-1, index=labels.unsqueeze(-1) + ).squeeze(-1) + + total_neg_log_likelihood += -true_token_log_probs.sum().item() + total_tokens += labels.numel() + + perplexity = torch.exp(torch.tensor(total_neg_log_likelihood / total_tokens)) + return perplexity + + + +def test_infinicore(input_ids_list, model_path, device_, ndev_, enable_paged_attn, max_kvcache_tokens): + device = DEVICE_TYPE_MAP[device_] + + # model = JiugeForCauslLM( + # model_path, device, max_tokens=len(input_ids_list[0]), ndev=ndev_ + # ) + llm = LLM(model_path, device=device, enforce_eager=True, + tensor_parallel_size=ndev_, trust_remote_code=True, + attention_bias=True, enable_paged_attn=enable_paged_attn, max_kvcache_tokens=max_kvcache_tokens) + + perplexity = llm.perplexity(input_ids_list) + # model.destroy_model_instance() + llm.model_runner.exit() + return perplexity + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("--model-path", type=str, required=True) + parser.add_argument( + "--dev", type=str, default="nvidia", choices=DEVICE_TYPE_MAP.keys() + ) + parser.add_argument( + "--ndev", + type=int, + default=1, + help="Number of devices to use (default: 1)", + ) + parser.add_argument("--max-kvcache-tokens", type=int, default=4096) + # parser.add_argument("--max-kvcache-tokens", type=int, default=65536) + parser.add_argument("--enable-paged-attn", action="store_true") + + + args = parser.parse_args() + max_kvcache_tokens = args.max_kvcache_tokens + # device_type = DeviceType.DEVICE_TYPE_CPU + # if args.device_type == "cpu": + # device_type = DeviceType.DEVICE_TYPE_CPU + # elif args.device_type == "nvidia": + # device_type = DeviceType.DEVICE_TYPE_NVIDIA + # elif args.device_type == "cambricon": + # device_type = DeviceType.DEVICE_TYPE_CAMBRICON + # elif args.device_type == "ascend": + # device_type = DeviceType.DEVICE_TYPE_ASCEND + # elif args.device_type == "metax": + # device_type = DeviceType.DEVICE_TYPE_METAX + # elif args.device_type == "moore": + # device_type = DeviceType.DEVICE_TYPE_MOORE + # elif args.device_type == "iluvatar": + # device_type = DeviceType.DEVICE_TYPE_ILUVATAR + # else: + # print( + # # "Usage: python jiuge.py [--cpu | --nvidia| --cambricon | --ascend | --metax | --moore] [n_device]" + # "Usage: python jiuge.py [cpu | nvidia| cambricon | ascend | metax | moore] [n_device]" + # ) + # sys.exit(1) + + seq_len = 512 + + model_path = args.model_path + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + local_file_paths = { + # "train": "/home/wanghaojie/vllm/huggingface/wikitext/wikitext_local_parquet/train.parquet", + # "validation": "/home/wanghaojie/vllm/huggingface/wikitext/wikitext_local_parquet/validation.parquet", + "test": "/home/wanghaojie/vllm/huggingface/wikitext/wikitext-2-raw-v1/test-00000-of-00001.parquet" + } + dataset = load_dataset("parquet", data_files=local_file_paths, split="test") + + texts = dataset["text"] + texts = [t.strip() for t in texts if len(t.strip()) > 0] + + input_ids_list = [] + for text in texts: + ids = tokenizer.encode(text) + # split long sequences into chunks + for i in range(0, len(ids) - seq_len + 1, seq_len): + input_ids_list.append(ids[i : i + seq_len]) + # print(f"\n=== 📊 精度指标汇总 ({MODEL}) ===") + # print(f"model: {args.model_path}, device: {args.dev}") + + # InfiniCore_perplexity = test_infinicore(input_ids_list, model_path, args.dev, args.ndev, args.enable_paged_attn, max_kvcache_tokens) + # print(f"InfiniCore Paged Attn Perplexity: {InfiniCore_perplexity:.2f}") + + # # if args.ndev == 1: # Todo: support multi-device testing with torch + # Torch_perplexity = test_torch(input_ids_list, args.dev) + # print(f"Torch Perplexity: {Torch_perplexity.item():.2f}") + InfiniCore_perplexity= 14.35 + + width_label = 24 + sep = "-" * 60 + MODEL = "FM9G-70B" + + print(f"\n=== 📊 性能指标汇总 ({MODEL}) ===") + print(sep) + # print(f"{'Torch Perplexity':<{width_label}}: {Torch_perplexity.item():.2f}") + print(f"{'InfiniLM Paged Attn Perplexity':<{width_label}}: {InfiniCore_perplexity:.2f}") + print(sep) diff --git a/python/icinfer/bench/launch_server.py b/python/icinfer/bench/launch_server.py new file mode 100644 index 00000000..66b6083e --- /dev/null +++ b/python/icinfer/bench/launch_server.py @@ -0,0 +1,328 @@ +from icinfer.models.jiuge import JiugeForCausalLM +from icinfer.engine.libinfinicore_infer import DeviceType +from icinfer.engine.infer_task import InferTask +from icinfer.engine.kvcache_pool import KVCachePool + +import argparse +import queue +from fastapi import FastAPI, Request +from fastapi.responses import StreamingResponse, JSONResponse +import contextlib +import uvicorn +import time +import uuid +import json +import threading +import janus +import traceback + +from icinfer.engine.llm_engine_async import InfiniEngineAsync +from icinfer.sampling_params import SamplingParams + + +DEVICE_TYPE_MAP = { + "cpu": DeviceType.DEVICE_TYPE_CPU, + "nvidia": DeviceType.DEVICE_TYPE_NVIDIA, + "cambricon": DeviceType.DEVICE_TYPE_CAMBRICON, + "ascend": DeviceType.DEVICE_TYPE_ASCEND, + "metax": DeviceType.DEVICE_TYPE_METAX, + "moore": DeviceType.DEVICE_TYPE_MOORE, +} + +def parse_args(): + parser = argparse.ArgumentParser(description="Launch the LLM inference server.") + parser.add_argument( + "--model-path", + type=str, + help="Path to the model directory", + ) + parser.add_argument( + "--dev", + type=str, + choices=DEVICE_TYPE_MAP.keys(), + default="cpu", + help="Device type to run the model on (default: cpu)", + ) + parser.add_argument( + "--ndev", + type=int, + default=1, + help="Number of devices to use (default: 1)", + ) + # parser.add_argument( + # "--max-batch", + # type=int, + # default=3, + # help="Maximum number of requests that can be batched together (default: 3)", + # ) + + parser.add_argument("--max-kvcache-tokens", type=int, default=4096) + parser.add_argument("--enable-paged-attn", action="store_true") + + return parser.parse_args() + +args = parse_args() +device_type = DEVICE_TYPE_MAP[args.dev] +model_path = args.model_path +ndev = args.ndev +max_kvcache_tokens = args.max_kvcache_tokens +enable_paged_attn = args.enable_paged_attn + + + +# MAX_BATCH = args.max_batch +# print( +# f"Using MAX_BATCH={MAX_BATCH}. Try reduce this value if out of memory error occurs." +# ) + +def chunk_json(id_, content=None, role=None, finish_reason=None): + delta = {} + if content: + delta["content"] = content + if role: + delta["role"] = role + return { + "id": id_, + "object": "chat.completion.chunk", + "created": int(time.time()), + "model": "jiuge", + "system_fingerprint": None, + "choices": [ + { + "index": 0, + "delta": delta, + "logprobs": None, + "finish_reason": finish_reason, + } + ], + } + + +# A wrapper for InferTask that supports async output queue +class AsyncInferTask(InferTask): + def __init__(self, id, tokens, max_tokens, temperature, topk, topp, end_tokens): + super().__init__(id, tokens, max_tokens, temperature, topk, topp, end_tokens) + self.output_queue = janus.Queue() + print(f"[INFO] Create InferTask {self.id}") + + def output(self, out_token): + self.next(out_token) + self.output_queue.sync_q.put(out_token) + + +@contextlib.asynccontextmanager +async def lifespan(app: FastAPI): + # Startup + # app.state.model = JiugeForCausalLM(model_path, device_type, ndev, max_tokens=max_tokens) + app.state.model = InfiniEngineAsync(model_path, device=device_type, enforce_eager=True, + tensor_parallel_size=ndev, trust_remote_code=True, + attention_bias=True, enable_paged_attn=enable_paged_attn, max_kvcache_tokens=max_kvcache_tokens) + # app.state.kv_cache_pool = KVCachePool(app.state.model, MAX_BATCH) + # app.state.request_queue = janus.Queue() + # worker_thread = threading.Thread(target=worker_loop, args=(app,), daemon=True) + # worker_thread.start() + engine_thread = threading.Thread(target=app.state.model.engine_loop, daemon=True) + engine_thread.start() + + + + try: + yield # The app runs here + finally: + # Shutdown + # app.state.request_queue.sync_q.put(None) + # worker_thread.join() + # app.state.request_queue.shutdown() + + # app.state.kv_cache_pool.finalize() + # app.state.model.destroy_model_instance() + pass + + +App = FastAPI(lifespan=lifespan) + + +# # App loop: take requests from the queue, do inference, and put unfinished requests back into the queue. +# def worker_loop(app): +# while True: +# try: +# task = app.state.request_queue.sync_q.get(timeout=0.01) +# except queue.Empty: +# continue + +# if task is None: +# return + +# batch = [task] +# while len(batch) < MAX_BATCH: +# try: +# req = app.state.request_queue.sync_q.get_nowait() +# if req is not None: +# batch.append(req) +# except queue.Empty: +# break +# output_tokens = app.state.model.batch_infer_one_round(batch) +# for task, token in zip(batch, output_tokens): +# task.output(token) +# if task.finish_reason is None: +# app.state.request_queue.sync_q.put(task) +# else: +# print(f"[INFO] Task {task.id} finished infer.") +# app.state.kv_cache_pool.release_sync(task) + + +# def build_task(id_, request_data, request: Request): +# messages = request_data.get("messages", []) +# input_content = request.app.state.model.tokenizer.apply_chat_template( +# conversation=messages, +# add_generation_prompt=True, +# tokenize=False, +# ) +# tokens = request.app.state.model.tokenizer.encode(input_content) +# return AsyncInferTask( +# id_, +# tokens, +# request_data.get("max_tokens", request.app.state.model.max_context_len()), +# request_data.get("temperature", 1.0), +# request_data.get("top_k", 1), +# request_data.get("top_p", 1.0), +# request.app.state.model.eos_token_id, +# ) + +async def chat_stream(id_, request_data, request: Request): + try: + messages = request_data.get("messages", []) + input_content = request.app.state.model.tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + enable_thinking=False + ) + max_tokens = request_data.get("max_tokens", 512) + # max_tokens = request_data.get("max_tokens", request.app.state.model.max_context_len) + temperature = request_data.get("temperature", 1.0) + top_k = request_data.get("top_k", 1) + top_p = request_data.get("top_p", 1.0) + # eos_token_id = request.app.state.model.eos_token_id + + sampling_params = SamplingParams(temperature=temperature, topk=top_k, topp=top_p, max_tokens=max_tokens) + + # 1. 提交请求到引擎,并获取结果队列 + result_queue = await request.app.state.model.add_request( + input_content, sampling_params, id_ + ) + + # 2. 初始响应块 + yield f"data: {json.dumps(chunk_json(id_, content='', role='assistant'), ensure_ascii=False)}\n\n" + + # 3. 从结果队列中异步读取 token 并流式返回 + while True: + token = await result_queue.get() + + if token is None: # 结束信号 + yield f"data: {json.dumps(chunk_json(id_, finish_reason='stop'), ensure_ascii=False)}\n\n" + break + + content = request.app.state.model.tokenizer._tokenizer.id_to_token(token).replace(" ", " ").replace("<0x0A>", "\n") + yield f"data: {json.dumps(chunk_json(id_, content=content), ensure_ascii=False)}\n\n" + + except Exception as e: + error_details = traceback.format_exc() + print(f"[Error] ID : {id_} Exception: {e}\n--- TRACEBACK ---\n{error_details}--- END TRACEBACK ---") + +# async def chat(id_, request_data, request: Request): +# try: +# infer_task = build_task(id_, request_data, request) +# await request.app.state.kv_cache_pool.acquire(infer_task) +# request.app.state.request_queue.sync_q.put(infer_task) +# output = [] +# while True: +# if ( +# infer_task.finish_reason is not None +# and infer_task.output_queue.async_q.empty() +# ): +# break + +# token = await infer_task.output_queue.async_q.get() +# content = ( +# request.app.state.model.tokenizer._tokenizer.id_to_token(token) +# .replace("▁", " ") +# .replace("<0x0A>", "\n") +# ) +# output.append(content) + +# output_text = "".join(output).strip() +# response = chunk_json( +# id_, +# content=output_text, +# role="assistant", +# finish_reason=infer_task.finish_reason or "stop", +# ) +# return response + +# except Exception as e: +# print(f"[Error] ID: {id_} Exception: {e}") +# return JSONResponse(content={"error": str(e)}, status_code=500) +# finally: +# if infer_task.finish_reason is None: +# infer_task.finish_reason = "cancel" + + +@App.post("/chat/completions") +async def chat_completions(request: Request): + data = await request.json() + + if not data.get("messages"): + return JSONResponse(content={"error": "No message provided"}, status_code=400) + + stream = data.get("stream", False) + id_ = f"cmpl-{uuid.uuid4().hex}" + if stream: + return StreamingResponse( + chat_stream(id_, data, request), media_type="text/event-stream" + ) + else: + messages = data.get("messages", []) + input_content = request.app.state.model.tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + enable_thinking=False + ) + max_tokens = data.get("max_tokens", request.app.state.model.max_context_len()) + # max_tokens = data.get("max_tokens", 128) + temperature = data.get("temperature", 1.0) + top_k = data.get("top_k", 1) + top_p = data.get("top_p", 1.0) + sampling_params = SamplingParams(temperature=temperature, topk=top_k, topp=top_p, max_tokens=max_tokens) + result_queue = await request.app.state.model.add_request(input_content, sampling_params, id_) + + output_tokens = [] + while True: + token = await result_queue.get() + if token is None: + break + output_tokens.append(token) + + output_text = request.app.state.model.tokenizer.decode(output_tokens).strip() + response = chunk_json(id_, content=output_text, role="assistant", finish_reason="stop") + return JSONResponse(content=response) + +if __name__ == "__main__": + uvicorn.run(App, host="0.0.0.0", port=8000) + +""" +curl -N -H "Content-Type: application/json" \ + -X POST http://127.0.0.1:8000/chat/completions \ + -d '{ + "model": "jiuge", + "messages": [ + {"role": "user", "content": "山东最高的山是?"} + ], + "temperature": 1.0, + "top_k": 50, + "top_p": 0.8, + "max_tokens": 512, + "stream": true + }' +""" diff --git a/python/icinfer/bench/launch_server_v0.py b/python/icinfer/bench/launch_server_v0.py new file mode 100644 index 00000000..5286bd33 --- /dev/null +++ b/python/icinfer/bench/launch_server_v0.py @@ -0,0 +1,297 @@ +from icinfer.models.jiuge import JiugeForCausalLM +from icinfer.engine.libinfinicore_infer import DeviceType +from icinfer.engine.infer_task import InferTask +from icinfer.engine.kvcache_pool import KVCachePool + +import argparse +import queue +from fastapi import FastAPI, Request +from fastapi.responses import StreamingResponse, JSONResponse +import contextlib +import uvicorn +import time +import uuid +import json +import threading +import janus + + +DEVICE_TYPE_MAP = { + "cpu": DeviceType.DEVICE_TYPE_CPU, + "nvidia": DeviceType.DEVICE_TYPE_NVIDIA, + "cambricon": DeviceType.DEVICE_TYPE_CAMBRICON, + "ascend": DeviceType.DEVICE_TYPE_ASCEND, + "metax": DeviceType.DEVICE_TYPE_METAX, + "moore": DeviceType.DEVICE_TYPE_MOORE, +} + +def parse_args(): + parser = argparse.ArgumentParser(description="Launch the LLM inference server.") + parser.add_argument( + "--model-path", + type=str, + help="Path to the model directory", + ) + parser.add_argument( + "--dev", + type=str, + choices=DEVICE_TYPE_MAP.keys(), + default="cpu", + help="Device type to run the model on (default: cpu)", + ) + parser.add_argument( + "--ndev", + type=int, + default=1, + help="Number of devices to use (default: 1)", + ) + parser.add_argument( + "--max-batch", + type=int, + default=3, + help="Maximum number of requests that can be batched together (default: 3)", + ) + parser.add_argument( + "--max-tokens", + type=int, + required=False, + default=None, + help="Max token sequence length that model will handle (follows model config if not provided)", + ) + return parser.parse_args() + +args = parse_args() +device_type = DEVICE_TYPE_MAP[args.dev] +model_path = args.model_path +ndev = args.ndev +max_tokens = args.max_tokens + +MAX_BATCH = args.max_batch +print( + f"Using MAX_BATCH={MAX_BATCH}. Try reduce this value if out of memory error occurs." +) + +def chunk_json(id_, content=None, role=None, finish_reason=None): + delta = {} + if content: + delta["content"] = content + if role: + delta["role"] = role + return { + "id": id_, + "object": "chat.completion.chunk", + "created": int(time.time()), + "model": "jiuge", + "system_fingerprint": None, + "choices": [ + { + "index": 0, + "delta": delta, + "logprobs": None, + "finish_reason": finish_reason, + } + ], + } + + +# A wrapper for InferTask that supports async output queue +class AsyncInferTask(InferTask): + def __init__(self, id, tokens, max_tokens, temperature, topk, topp, end_tokens): + super().__init__(id, tokens, max_tokens, temperature, topk, topp, end_tokens) + self.output_queue = janus.Queue() + print(f"[INFO] Create InferTask {self.id}") + + def output(self, out_token): + self.next(out_token) + self.output_queue.sync_q.put(out_token) + + +@contextlib.asynccontextmanager +async def lifespan(app: FastAPI): + # Startup + app.state.model = JiugeForCausalLM(model_path, device_type, ndev, max_tokens=max_tokens) + app.state.kv_cache_pool = KVCachePool(app.state.model, MAX_BATCH) + app.state.request_queue = janus.Queue() + worker_thread = threading.Thread(target=worker_loop, args=(app,), daemon=True) + worker_thread.start() + + try: + yield # The app runs here + finally: + # Shutdown + app.state.request_queue.sync_q.put(None) + worker_thread.join() + app.state.request_queue.shutdown() + + app.state.kv_cache_pool.finalize() + app.state.model.destroy_model_instance() + + +App = FastAPI(lifespan=lifespan) + + +# App loop: take requests from the queue, do inference, and put unfinished requests back into the queue. +def worker_loop(app): + while True: + try: + task = app.state.request_queue.sync_q.get(timeout=0.01) + except queue.Empty: + continue + + if task is None: + return + + batch = [task] + while len(batch) < MAX_BATCH: + try: + req = app.state.request_queue.sync_q.get_nowait() + if req is not None: + batch.append(req) + except queue.Empty: + break + output_tokens = app.state.model.batch_infer_one_round(batch) + for task, token in zip(batch, output_tokens): + task.output(token) + if task.finish_reason is None: + app.state.request_queue.sync_q.put(task) + else: + print(f"[INFO] Task {task.id} finished infer.") + app.state.kv_cache_pool.release_sync(task) + + +def build_task(id_, request_data, request: Request): + messages = request_data.get("messages", []) + input_content = request.app.state.model.tokenizer.apply_chat_template( + conversation=messages, + add_generation_prompt=True, + tokenize=False, + ) + tokens = request.app.state.model.tokenizer.encode(input_content) + return AsyncInferTask( + id_, + tokens, + request_data.get("max_tokens", request.app.state.model.max_context_len()), + request_data.get("temperature", 1.0), + request_data.get("top_k", 1), + request_data.get("top_p", 1.0), + request.app.state.model.eos_token_id, + ) + + +async def chat_stream(id_, request_data, request: Request): + try: + infer_task = build_task(id_, request_data, request) + await request.app.state.kv_cache_pool.acquire(infer_task) + + # Initial empty content + chunk = json.dumps( + chunk_json(id_, content="", role="assistant"), ensure_ascii=False + ) + yield f"data: {chunk}\n\n" + + request.app.state.request_queue.sync_q.put(infer_task) + + while True: + if await request.is_disconnected(): + print("Client disconnected. Aborting stream.") + break + if ( + infer_task.finish_reason is not None + and infer_task.output_queue.async_q.empty() + ): + chunk = json.dumps( + chunk_json(id_, finish_reason=infer_task.finish_reason), + ensure_ascii=False, + ) + yield f"data: {chunk}\n\n" + break + + token = await infer_task.output_queue.async_q.get() + content = ( + request.app.state.model.tokenizer._tokenizer.id_to_token(token) + .replace("▁", " ") + .replace("<0x0A>", "\n") + ) + chunk = json.dumps(chunk_json(id_, content=content), ensure_ascii=False) + yield f"data: {chunk}\n\n" + + except Exception as e: + print(f"[Error] ID : {id_} Exception: {e}") + finally: + if infer_task.finish_reason is None: + infer_task.finish_reason = "cancel" + + +async def chat(id_, request_data, request: Request): + try: + infer_task = build_task(id_, request_data, request) + await request.app.state.kv_cache_pool.acquire(infer_task) + request.app.state.request_queue.sync_q.put(infer_task) + output = [] + while True: + if ( + infer_task.finish_reason is not None + and infer_task.output_queue.async_q.empty() + ): + break + + token = await infer_task.output_queue.async_q.get() + content = ( + request.app.state.model.tokenizer._tokenizer.id_to_token(token) + .replace("▁", " ") + .replace("<0x0A>", "\n") + ) + output.append(content) + + output_text = "".join(output).strip() + response = chunk_json( + id_, + content=output_text, + role="assistant", + finish_reason=infer_task.finish_reason or "stop", + ) + return response + + except Exception as e: + print(f"[Error] ID: {id_} Exception: {e}") + return JSONResponse(content={"error": str(e)}, status_code=500) + finally: + if infer_task.finish_reason is None: + infer_task.finish_reason = "cancel" + + +@App.post("/chat/completions") +async def chat_completions(request: Request): + data = await request.json() + + if not data.get("messages"): + return JSONResponse(content={"error": "No message provided"}, status_code=400) + + stream = data.get("stream", False) + id_ = f"cmpl-{uuid.uuid4().hex}" + if stream: + return StreamingResponse( + chat_stream(id_, data, request), media_type="text/event-stream" + ) + else: + response = await chat(id_, data, request) + return JSONResponse(content=response) + +if __name__ == "__main__": + uvicorn.run(App, host="0.0.0.0", port=8000) + +""" +curl -N -H "Content-Type: application/json" \ + -X POST http://127.0.0.1:8000/chat/completions \ + -d '{ + "model": "jiuge", + "messages": [ + {"role": "user", "content": "山东最高的山是?"} + ], + "temperature": 1.0, + "top_k": 50, + "top_p": 0.8, + "max_tokens": 512, + "stream": true + }' +""" diff --git a/python/icinfer/bench/test_jiuge.py b/python/icinfer/bench/test_jiuge.py new file mode 100644 index 00000000..e701b78a --- /dev/null +++ b/python/icinfer/bench/test_jiuge.py @@ -0,0 +1,57 @@ +import sys +import logging +import argparse +import os +os.environ["CUDA_VISIBLE_DEVICES"] = "4,5,6,7" + +from icinfer.engine.libinfinicore_infer import DeviceType +from icinfer.models.jiuge import JiugeForCausalLM +logger = logging.getLogger(__name__) + + + +def parse_args(): + parser = argparse.ArgumentParser() + # parser.add_argument("--model-path", type=str, default="/home/wanghaojie/vllm/huggingface/Llama-2-7b-chat-hf") + # parser.add_argument("--model-path", type=str, default="/home/wanghaojie/vllm/huggingface/FM9G_70B_SFT_MHA/") + parser.add_argument("--model-path", type=str, default="/home/wanghaojie/vllm/huggingface/9G7B_MHA/") + parser.add_argument("--device-type", type=str, default="nvidia") + parser.add_argument("--ndev", type=int, default=4) + args = parser.parse_args() + return args + +def test(): + args = parse_args() + model_path = args.model_path + device_type = DeviceType.DEVICE_TYPE_CPU + if args.device_type == "cpu": + device_type = DeviceType.DEVICE_TYPE_CPU + elif args.device_type == "nvidia": + device_type = DeviceType.DEVICE_TYPE_NVIDIA + elif args.device_type == "cambricon": + device_type = DeviceType.DEVICE_TYPE_CAMBRICON + elif args.device_type == "ascend": + device_type = DeviceType.DEVICE_TYPE_ASCEND + elif args.device_type == "metax": + device_type = DeviceType.DEVICE_TYPE_METAX + elif args.device_type == "moore": + device_type = DeviceType.DEVICE_TYPE_MOORE + elif args.device_type == "iluvatar": + device_type = DeviceType.DEVICE_TYPE_ILUVATAR + else: + logger.info( + # "Usage: python jiuge.py [--cpu | --nvidia| --cambricon | --ascend | --metax | --moore] [n_device]" + "Usage: python jiuge.py [cpu | nvidia| cambricon | ascend | metax | moore] [n_device]" + ) + sys.exit(1) + + ndev = args.ndev + model = JiugeForCausalLM(model_path, device_type, ndev) + # model.generate(["山东最高的山是?", "中国面积最大的省是?"], 500) + # model.generate(["山东最高的山是?"], 500) + model.generate("山东最高的山是?", 500) + model.destroy_model_instance() + + +if __name__ == "__main__": + test() diff --git a/python/icinfer/bench/test_perf.py b/python/icinfer/bench/test_perf.py new file mode 100644 index 00000000..a6b26f3b --- /dev/null +++ b/python/icinfer/bench/test_perf.py @@ -0,0 +1,155 @@ +import asyncio +import time +from openai import AsyncOpenAI +import argparse +import random + + +PROMPTS = [ + "如果猫能写诗,它们会写些什么?", + "描述一个没有重力的世界。", + "如果地球停止自转,会发生什么?", + "假设你是一只会飞的鲸鱼,描述你的日常生活。", + "如果人类可以与植物沟通,世界会变成什么样?", + "描述一个由糖果构成的城市。", + "如果时间旅行成为可能,你最想去哪个时代?", + "想象一下,如果地球上只有蓝色,其他颜色都消失了。", + "如果动物能上网,它们会浏览什么网站?", + "描述一个没有声音的世界。", + "如果人类可以在水下呼吸,城市会如何变化?", + "想象一下,如果天空是绿色的,云是紫色的。", + "如果你能与任何历史人物共进晚餐,你会选择谁?", + "描述一个没有夜晚的星球。", + "如果地球上只有一种语言,世界会如何运作?", + "想象一下,如果所有的书都变成了音乐。", + "如果你可以变成任何一种动物,你会选择什么?", + "描述一个由机器人统治的未来世界。", + "如果你能与任何虚构角色成为朋友,你会选择谁?", + "想象一下,如果每个人都能读懂他人的思想。" +] + +NUM_REQUESTS = 10 +CONCURRENCY = 5 +API_URL = "http://127.0.0.1:8000" +MODEL = "FM9G-7B" + + +async def benchmark_user(client, semaphore, queue, results, user_id, verbose): + while True: + async with semaphore: + task_id = await queue.get() + if task_id is None: + queue.task_done() + break + + question = random.choice(PROMPTS) + try: + print(f"🚀 User#{user_id} Sending request #{task_id}") + + start_time = time.time() + stream = await client.chat.completions.create( + model=MODEL, + messages=[{"role": "user", "content": question}], + stream=True + ) + + first_token_time = None + total_tokens = 0 + answer_chunks = [] + + async for chunk in stream: + if first_token_time is None: + first_token_time = time.time() + delta = chunk.choices[0].delta.content + if delta: + answer_chunks.append(delta) + total_tokens += 1 + if chunk.choices[0].finish_reason is not None: + break + + end_time = time.time() + + ttft = first_token_time - start_time if first_token_time else None + elapsed_time = end_time - start_time if start_time else None + ms_per_token = (elapsed_time / total_tokens * 1000) if total_tokens > 0 and elapsed_time else None + tokens_per_second = total_tokens / elapsed_time if elapsed_time > 0 else 0 + + answer = "".join(answer_chunks) + + results.append((total_tokens, elapsed_time, tokens_per_second, ttft, ms_per_token)) + + if verbose: + print(f"\n📝 Request #{task_id} (User #{user_id})") + print(f" ⏱ 首字延迟 TTFT: {ttft:.3f}s") + print(f" ⏱ 总耗时: {elapsed_time:.3f}s") + print(f" 🔤 解码 token 总数: {total_tokens}") + print(f" 📏 平均 token 解码时间: {ms_per_token:.2f} ms/token") + print(f" ❓ 提问: {question}") + print(f" 💬 回答: {answer}\n") + + queue.task_done() + except Exception as e: + if verbose: + print(f"\n⚠️ Request #{task_id} (User #{user_id}) FAILED:") + print(f" ❌ Error: {e}\n") + +async def run_benchmark(verbose=False): + client = AsyncOpenAI(base_url=API_URL, api_key="default") + semaphore = asyncio.Semaphore(CONCURRENCY) + queue = asyncio.Queue() + results = [] + for i in range(NUM_REQUESTS): + await queue.put(i) + for _ in range(CONCURRENCY): + await queue.put(None) + + users = [ + asyncio.create_task(benchmark_user(client, semaphore, queue, results, user_id, verbose)) + for user_id in range(CONCURRENCY) + ] + + start_time = time.time() + await queue.join() + await asyncio.gather(*users) + end_time = time.time() + + total_elapsed_time = end_time - start_time + tokens_list = [r[0] for r in results if r and r[0] is not None] + latencies = [r[1] for r in results if r and r[1] is not None] + tokens_per_second_list = [r[2] for r in results if r and r[2] is not None] + ttft_list = [r[3] for r in results if r and r[3] is not None] + ms_per_token_list = [r[4] for r in results if r and r[4] is not None] + + successful_requests = len(results) + requests_per_second = successful_requests / total_elapsed_time if total_elapsed_time > 0 else 0 + avg_latency = sum(latencies) / len(latencies) if latencies else 0 + avg_tokens_per_second = sum(tokens_per_second_list) / len(tokens_per_second_list) if tokens_per_second_list else 0 + avg_ttft = sum(ttft_list) / len(ttft_list) if ttft_list else 0 + avg_ms_per_token = sum(ms_per_token_list) / len(ms_per_token_list) if ms_per_token_list else None + + width_label = 24 + sep = "-" * 60 + + print(f"\n=== 📊 性能指标汇总 ({MODEL}) ===") + print(sep) + print(f"{'并发数':<{width_label}}: {CONCURRENCY}") + print(f"{'请求总数':<{width_label}}: {NUM_REQUESTS}") + print(f"{'成功请求数':<{width_label}}: {successful_requests}") + print(f"{'总耗时':<{width_label}}: {total_elapsed_time:.2f} s") + print(f"{'总输出token数':<{width_label}}: {sum(tokens_list)}") + print(f"{'请求速率 (RPS)':<{width_label}}: {requests_per_second:.2f} requests/s") + print(sep) + print(f"{'Average latency':<{width_label}}: {avg_latency:.2f} s") + print(f"{'Average TTFT':<{width_label}}: {avg_ttft:.2f} s") + print(f"{'Avg time per token':<{width_label}}: {avg_ms_per_token:.2f} ms/token") + print(f"{'Avg Token generation speed':<{width_label}}: {avg_tokens_per_second:.2f} tokens/s") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--verbose", action="store_true") + args = parser.parse_args() + + asyncio.run(run_benchmark( + args.verbose + )) diff --git a/python/icinfer/bench/test_ppl.py b/python/icinfer/bench/test_ppl.py new file mode 100644 index 00000000..268a9f7d --- /dev/null +++ b/python/icinfer/bench/test_ppl.py @@ -0,0 +1,62 @@ +import math +import requests +from datasets import load_dataset +from tqdm import tqdm +from transformers import AutoTokenizer + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("--model-path", type=str, required=True) + parser.add_argument("--port", type=int, default=8000) + parser.add_argument("--endpoint", type=str, default="/completions") + parser.add_argument("--chunk", type=int, default=512) + args = parser.parse_args() + + API_URL = "http://localhost:" + str(args.port) + args.endpoint + CHUNK_SIZE = args.chunk + + dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test") + + # Local tokenizer used for chunking + tokenizer = AutoTokenizer.from_pretrained(args.model_path) + + total_neg_log_likelihood = 0.0 + total_tokens = 0 + + for example in tqdm(dataset, desc="Evaluating PPL"): + text = example["text"].strip() + if not text: + continue + + # endcode, chunk and decode + tokens = tokenizer.encode(text, add_special_tokens=False) + for i in range(0, len(tokens), CHUNK_SIZE): + chunk_tokens = tokens[i : min(i + CHUNK_SIZE, len(tokens))] + chunk_text = tokenizer.decode(chunk_tokens) + + resp = requests.post( + API_URL, + headers={"Content-Type": "application/json"}, + json={ + "model": "", + "prompt": chunk_text, + "max_tokens": 0, + "temperature": 1.0, + "echo": True, + "logprobs": 0, + }, + ).json() + + logprobs = resp["choices"][0]["logprobs"]["token_logprobs"] + # skip first token's None + valid_logprobs = [lp for lp in logprobs[1:] if lp is not None] + + total_neg_log_likelihood += -sum(valid_logprobs) + total_tokens += len(valid_logprobs) + + # ==== Compute final PPL ==== + ppl = math.exp(total_neg_log_likelihood / total_tokens) + print(f"Perplexity: {ppl:.4f}") diff --git a/python/icinfer/config.py b/python/icinfer/config.py new file mode 100644 index 00000000..5fe498ab --- /dev/null +++ b/python/icinfer/config.py @@ -0,0 +1,43 @@ +import os +from dataclasses import dataclass +from transformers import AutoConfig + + +@dataclass +class Config: + model: str + max_num_batched_tokens: int = 16384 + max_num_seqs: int = 512 + max_model_len: int = 1024 + gpu_memory_utilization: float = 0.9 + tensor_parallel_size: int = 1 + enforce_eager: bool = False + hf_config: AutoConfig | None = None + eos: int = -1 + kvcache_block_size: int = 16 + max_kvcache_tokens: int = -1 + num_kvcache_blocks: int = -1 + trust_remote_code: bool = False + attention_bias: bool = False + enable_paged_attn: bool = False + + def __post_init__(self): + assert os.path.isdir(self.model) + assert self.kvcache_block_size % 4 == 0 + assert 1 <= self.tensor_parallel_size <= 8 + self.model_path = self.model + self.hf_config = AutoConfig.from_pretrained(self.model, trust_remote_code=self.trust_remote_code) + print(self.model_path) + self.check_hf_config() + self.max_model_len = min(self.max_model_len, self.hf_config.max_position_embeddings) + if self.num_kvcache_blocks < 0 and self.max_kvcache_tokens > 0: + self.num_kvcache_blocks = self.max_kvcache_tokens // self.kvcache_block_size + assert self.max_num_batched_tokens >= self.max_model_len + + def check_hf_config(self): + if getattr(self.hf_config, "head_dim", None) is None: + self.hf_config.head_dim = self.hf_config.hidden_size // self.hf_config.num_attention_heads + if getattr(self.hf_config, "attention_bias", None) is None: + self.hf_config.attention_bias = self.attention_bias + if getattr(self.hf_config, "kvcache_block_size", None) is None: + self.hf_config.kvcache_block_size = self.kvcache_block_size diff --git a/python/icinfer/engine/block_manager.py b/python/icinfer/engine/block_manager.py new file mode 100644 index 00000000..e5fda4e0 --- /dev/null +++ b/python/icinfer/engine/block_manager.py @@ -0,0 +1,114 @@ +from collections import deque +import xxhash +import numpy as np + +from icinfer.engine.sequence import Sequence + + +class Block: + + def __init__(self, block_id): + self.block_id = block_id + self.ref_count = 0 + self.hash = -1 + self.token_ids = [] + + def update(self, hash: int, token_ids: list[int]): + self.hash = hash + self.token_ids = token_ids + + def reset(self): + self.ref_count = 1 + self.hash = -1 + self.token_ids = [] + + +class BlockManager: + + def __init__(self, num_blocks: int, block_size: int): + assert num_blocks > 0 + self.block_size = block_size + self.blocks: list[Block] = [Block(i) for i in range(num_blocks)] + self.hash_to_block_id: dict[int, int] = dict() + self.free_block_ids: deque[int] = deque(range(num_blocks)) + self.used_block_ids: set[int] = set() + + @classmethod + def compute_hash(cls, token_ids: list[int], prefix: int = -1): + h = xxhash.xxh64() + if prefix != -1: + h.update(prefix.to_bytes(8, "little")) + h.update(np.array(token_ids).tobytes()) + return h.intdigest() + + def _allocate_block(self, block_id: int) -> Block: + block = self.blocks[block_id] + assert block.ref_count == 0 + block.reset() + self.free_block_ids.remove(block_id) + self.used_block_ids.add(block_id) + return self.blocks[block_id] + + def _deallocate_block(self, block_id: int) -> Block: + assert self.blocks[block_id].ref_count == 0 + self.used_block_ids.remove(block_id) + self.free_block_ids.append(block_id) + + def can_allocate(self, seq: Sequence) -> bool: + return len(self.free_block_ids) >= seq.num_blocks + + def allocate(self, seq: Sequence): + # TODO 对于这个机制还有点疑惑。 for i in range(seq.num_blocks): + assert not seq.block_table + h = -1 + cache_miss = False + for i in range(seq.num_blocks): + token_ids = seq.block(i) + h = self.compute_hash(token_ids, h) if len(token_ids) == self.block_size else -1 + block_id = self.hash_to_block_id.get(h, -1) + if block_id == -1 or self.blocks[block_id].token_ids != token_ids: + cache_miss = True + if cache_miss: + block_id = self.free_block_ids[0] + block = self._allocate_block(block_id) + else: + seq.num_cached_tokens += self.block_size + if block_id in self.used_block_ids: + block = self.blocks[block_id] + block.ref_count += 1 + else: + block = self._allocate_block(block_id) + if h != -1: + block.update(h, token_ids) + self.hash_to_block_id[h] = block_id + seq.block_table.append(block_id) + + def deallocate(self, seq: Sequence): + for block_id in reversed(seq.block_table): + block = self.blocks[block_id] + block.ref_count -= 1 + if block.ref_count == 0: + self._deallocate_block(block_id) + seq.num_cached_tokens = 0 + seq.block_table.clear() + + def can_append(self, seq: Sequence) -> bool: + return len(self.free_block_ids) >= (len(seq) % self.block_size == 1) + + def may_append(self, seq: Sequence): + block_table = seq.block_table + last_block = self.blocks[block_table[-1]] + if len(seq) % self.block_size == 1: + assert last_block.hash != -1 + block_id = self.free_block_ids[0] + self._allocate_block(block_id) + block_table.append(block_id) + elif len(seq) % self.block_size == 0: + assert last_block.hash == -1 + token_ids = seq.block(seq.num_blocks-1) + prefix = self.blocks[block_table[-2]].hash if len(block_table) > 1 else -1 + h = self.compute_hash(token_ids, prefix) + last_block.update(h, token_ids) + self.hash_to_block_id[h] = last_block.block_id + else: + assert last_block.hash == -1 diff --git a/python/icinfer/engine/infer_task.py b/python/icinfer/engine/infer_task.py new file mode 100644 index 00000000..8cd1ea22 --- /dev/null +++ b/python/icinfer/engine/infer_task.py @@ -0,0 +1,195 @@ +from typing import List +from ctypes import POINTER, c_float, c_int, c_uint, c_void_p, byref +from icinfer.engine.libinfinicore_infer import ( + KVCacheCStruct, +) + + + + +class InferTask: + def __init__(self, id, tokens, max_tokens, temperature, topk, topp, end_tokens): + self.id = id + self.finish_reason = None + self.tokens = tokens + self.max_tokens = max_tokens + self.temperature = temperature + self.topk = topk + self.topp = topp + self.end_tokens = end_tokens + self._kv_cache = None + self.pos = 0 + + def bind_kvcache(self, kv_cache, pos=0): + self._kv_cache = kv_cache + self.pos = pos + self.tokens = self.tokens[pos:] + + def release_kvcache(self): + cache = self._kv_cache + self._kv_cache = None + return cache + + def kvcache(self): + return self._kv_cache + + def next(self, out_token): + if self._kv_cache is not None: + self._kv_cache.update_tokens(self.tokens, self.pos) + + self.pos += len(self.tokens) + if out_token == None or out_token in self.end_tokens: + self.finish_reason = "stop" + elif self.pos >= self.max_tokens: + self.finish_reason = "length" + else: + self.tokens = [out_token] + + +class InferBatchedTask: + def __init__(self, tasks: List[InferTask], is_prefill: int=1): + self.tasks = tasks + self.nreq = len(tasks) + self.is_prefill = is_prefill + + # Precompute fields + token_lists = [t.tokens for t in tasks] + self.req_lens_list = [len(toks) for toks in token_lists] + self.req_pos_list = [t.pos for t in tasks] + self.kv_cache_ptrs = [t.kvcache().data() for t in tasks] + self.temperaturas_list = [t.temperature for t in tasks] + self.topks_list = [t.topk for t in tasks] + self.topps_list = [t.topp for t in tasks] + + # Flatten token lists + flat_tokens = [tok for toks in token_lists for tok in toks] + self.ntok = len(flat_tokens) + + # Convert to ctypes arrays in one pass + self.tokens = (c_uint * self.ntok)(*flat_tokens) + self.req_lens = (c_uint * self.nreq)(*self.req_lens_list) + self.req_pos = (c_uint * self.nreq)(*self.req_pos_list) + self.kv_caches = (POINTER(KVCacheCStruct) * self.nreq)(*self.kv_cache_ptrs) + self.block_tables = POINTER(c_int)() + self.slot_mapping = POINTER(c_int)() + self.temperaturas = (c_float * self.nreq)(*self.temperaturas_list) + self.topks = (c_uint * self.nreq)(*self.topks_list) + self.topps = (c_float * self.nreq)(*self.topps_list) + + def input_args(self): + return ( + self.tokens, + self.ntok, + self.req_lens, + self.nreq, + self.req_pos, + self.kv_caches, + self.block_tables, + self.slot_mapping, + self.temperaturas, + self.topks, + self.topps, + self.is_prefill, + ) + + +class InferPagedBatchedTask: + def __init__(self, tasks: List[InferTask], batch_block_tables: list[int]=[], slot_mapping: list[int]=[], paged_kvcache=None, is_prefill: int=1): + self.tasks = tasks + self.nreq = len(tasks) + self.is_prefill = is_prefill + self.batch_block_tables = batch_block_tables + self.slot_mapping = slot_mapping + + # Precompute fields + token_lists = [t.tokens for t in tasks] + self.req_lens_list = [len(toks) for toks in token_lists] + self.req_pos_list = [t.pos for t in tasks] + self.kv_cache_ptrs = [paged_kvcache.data()] + self.temperaturas_list = [t.temperature for t in tasks] + self.topks_list = [t.topk for t in tasks] + self.topps_list = [t.topp for t in tasks] + + # Flatten token lists + flat_tokens = [tok for toks in token_lists for tok in toks] + self.ntok = len(flat_tokens) + self.n_blocks = len(batch_block_tables) # self.nreq * max_block_table_lens + + # Convert to ctypes arrays in one pass + self.tokens = (c_uint * self.ntok)(*flat_tokens) + self.req_lens = (c_uint * self.nreq)(*self.req_lens_list) + self.req_pos = (c_uint * self.nreq)(*self.req_pos_list) + self.kv_caches = (POINTER(KVCacheCStruct) * 1)(*self.kv_cache_ptrs) + self.block_tables = (c_int * self.n_blocks)(*batch_block_tables) + self.slot_mapping = (c_int * self.ntok)(*slot_mapping) + self.temperaturas = (c_float * self.nreq)(*self.temperaturas_list) + self.topks = (c_uint * self.nreq)(*self.topks_list) + self.topps = (c_float * self.nreq)(*self.topps_list) + + def input_args(self): + return ( + self.tokens, + self.ntok, + self.req_lens, + self.nreq, + self.req_pos, + self.kv_caches, + self.block_tables, + self.slot_mapping, + self.temperaturas, + self.topks, + self.topps, + self.is_prefill, + ) + + def input_args_for_logits(self): + return ( + self.tokens, + self.ntok, + self.req_lens, + self.nreq, + self.req_pos, + self.kv_caches, + self.block_tables, + self.slot_mapping, + self.is_prefill, + ) + + + +class KVCache: + def __init__(self, model): + self._kvcache = model.create_kv_cache() + self.tokens = [0 for _ in range(model.max_context_len())] + + def data(self): + return self._kvcache + + def drop(self, model): + model.drop_kv_cache(self._kvcache) + + def update_tokens(self, tokens, pos): + end = pos + len(tokens) + max_len = len(self.tokens) + + # If overflow, truncate tokens to fit + if end > max_len: + tokens = tokens[: max_len - pos] + end = max_len + + self.tokens[pos:end] = tokens + +class PagedKVCache: + def __init__(self, paged_kvcache): + self._kvcache = paged_kvcache + # self.tokens = [0 for _ in range(model.max_context_len())] + + def data(self): + return self._kvcache + + def drop(self, model): + model.drop_kv_cache(self._kvcache) + + def update_tokens(self, tokens, pos): + print("PagedKVCache need not to update tokens.") + pass diff --git a/python/icinfer/engine/kvcache_pool.py b/python/icinfer/engine/kvcache_pool.py new file mode 100644 index 00000000..b48d2695 --- /dev/null +++ b/python/icinfer/engine/kvcache_pool.py @@ -0,0 +1,90 @@ +from icinfer.engine.infer_task import KVCache + +import asyncio +from typing import List +import threading + + +class KVCachePool: + def __init__(self, model, max_caches: int = 32): + self.max_caches = max_caches + self.model = model + self._available: List[KVCache] = [] + self.num_caches = len(self._available) + self._lock = threading.Lock() + self._not_empty = threading.Condition(self._lock) + self._shutdown = False + + def acquire_sync(self, infer_task): + with self._not_empty: + while True: + if self._shutdown: + raise RuntimeError( + "KVCachePool is shutting down; cannot acquire new cache." + ) + if len(self._available) == 0: + if self.num_caches < self.max_caches: + self.num_caches += 1 + print( + f"[INFO] Task {infer_task.id} created new KVCachePoolItem" + ) + return infer_task.bind_kvcache(KVCache(self.model), 0) + else: + self._not_empty.wait() + else: + max_match, max_match_index = self.find_most_matching_cache( + infer_task.tokens + ) + kvcache = self._available.pop(max_match_index) + print( + f"[INFO] Task {infer_task.id} reused KVCachePoolItem {max_match_index} with {max_match} matches" + ) + return infer_task.bind_kvcache(kvcache, max_match) + + def release_sync(self, infer_task): + with self._not_empty: + print(f"[INFO] Task {infer_task.id} returned KVCachePoolItem to pool") + self._available.append(infer_task.release_kvcache()) + self._not_empty.notify() + + async def acquire(self, infer_task): + loop = asyncio.get_event_loop() + return await loop.run_in_executor(None, self.acquire_sync, infer_task) + + async def release(self, infer_task): + loop = asyncio.get_event_loop() + return await loop.run_in_executor(None, self.release_sync, infer_task) + + def find_most_matching_cache(self, tokens: List[int]): + max_match = 0 + max_match_index = 0 + + def first_different_index(a_, b_): + for i_, (x_, y_) in enumerate(zip(a_, b_)): + if x_ != y_: + return i_ + return min(len(a_), len(b_)) + + for i, kvcache in enumerate(self._available): + common_elements = first_different_index(tokens, kvcache.tokens) + # print(f"{tokens}") + # print(f"{kvcache.tokens[:len(tokens)]}") + if common_elements > max_match: + max_match = common_elements + max_match_index = i + + return (min(max_match, len(tokens) - 1), max_match_index) + + def finalize(self): + with self._not_empty: + self._shutdown = True + while len(self._available) < self.num_caches: + self._not_empty.wait() + + for kvcache in self._available: + if kvcache is not None: + kvcache.drop(self.model) + + self._available.clear() + self.max_caches = 0 + self.num_caches = 0 diff --git a/python/icinfer/engine/libinfinicore_infer.py b/python/icinfer/engine/libinfinicore_infer.py new file mode 100644 index 00000000..219b773b --- /dev/null +++ b/python/icinfer/engine/libinfinicore_infer.py @@ -0,0 +1,149 @@ +import ctypes +from ctypes import c_size_t, c_uint, c_int, c_float, c_void_p, c_bool, POINTER +import os + + +class DataType(ctypes.c_int): + INFINI_DTYPE_INVALID = 0 + INFINI_DTYPE_BYTE = 1 + INFINI_DTYPE_BOOL = 2 + INFINI_DTYPE_I8 = 3 + INFINI_DTYPE_I16 = 4 + INFINI_DTYPE_I32 = 5 + INFINI_DTYPE_I64 = 6 + INFINI_DTYPE_U8 = 7 + INFINI_DTYPE_U16 = 8 + INFINI_DTYPE_U32 = 9 + INFINI_DTYPE_U64 = 10 + INFINI_DTYPE_F8 = 11 + INFINI_DTYPE_F16 = 12 + INFINI_DTYPE_F32 = 13 + INFINI_DTYPE_F64 = 14 + INFINI_DTYPE_C16 = 15 + INFINI_DTYPE_C32 = 16 + INFINI_DTYPE_C64 = 17 + INFINI_DTYPE_C128 = 18 + INFINI_DTYPE_BF16 = 19 + + +class DeviceType(ctypes.c_int): + DEVICE_TYPE_CPU = 0 + DEVICE_TYPE_NVIDIA = 1 + DEVICE_TYPE_CAMBRICON = 2 + DEVICE_TYPE_ASCEND = 3 + DEVICE_TYPE_METAX = 4 + DEVICE_TYPE_MOORE = 5 + DEVICE_TYPE_ILUVATAR = 6 + + +class JiugeMetaCStruct(ctypes.Structure): + _fields_ = [ + ("dt_logits", DataType), + ("nlayer", c_size_t), + ("d", c_size_t), + ("nh", c_size_t), + ("nkvh", c_size_t), + ("dh", c_size_t), + ("di", c_size_t), + ("dctx", c_size_t), + ("dvoc", c_size_t), + ("kvcache_block_size", c_size_t), + ("epsilon", c_float), + ("theta", c_float), + ("end_token", c_uint), + ] + + +# Define the JiugeWeights struct +class JiugeWeightsCStruct(ctypes.Structure): + _fields_ = [ + ("nlayer", c_size_t), + ("dt_norm", DataType), + ("dt_mat", DataType), + ("transpose_linear_weights", c_int), + ("input_embd", c_void_p), + ("output_norm", c_void_p), + ("output_embd", c_void_p), + ("attn_norm", POINTER(c_void_p)), + ("attn_qkv", POINTER(c_void_p)), + ("attn_qkv_b", POINTER(c_void_p)), + ("attn_o", POINTER(c_void_p)), + ("ffn_norm", POINTER(c_void_p)), + ("ffn_gate_up", POINTER(c_void_p)), + ("ffn_down", POINTER(c_void_p)), + ] + + +class JiugeModelCSruct(ctypes.Structure): + pass + + +class KVCacheCStruct(ctypes.Structure): + pass + + +def __open_library__(): + lib_path = os.path.join( + os.environ.get("INFINI_ROOT"), "lib", "libinfinicore_infer.so" + ) + lib = ctypes.CDLL(lib_path) + lib.createJiugeModel.restype = POINTER(JiugeModelCSruct) + lib.createJiugeModel.argtypes = [ + POINTER(JiugeMetaCStruct), # JiugeMeta const * + POINTER(JiugeWeightsCStruct), # JiugeWeights const * + DeviceType, # DeviceType + c_int, # int ndev + POINTER(c_int), # int const *dev_ids + ] + lib.destroyJiugeModel.argtypes = [POINTER(JiugeModelCSruct)] + lib.createKVCache.argtypes = [POINTER(JiugeModelCSruct)] + lib.createKVCache.restype = POINTER(KVCacheCStruct) + lib.createPagedKVCache.argtypes = [POINTER(JiugeModelCSruct), c_uint] + lib.createPagedKVCache.restype = POINTER(KVCacheCStruct) + lib.dropKVCache.argtypes = [POINTER(JiugeModelCSruct), POINTER(KVCacheCStruct)] + lib.inferBatch.restype = None + lib.inferBatch.argtypes = [ + POINTER(JiugeModelCSruct), # struct JiugeModel const * + POINTER(c_uint), # unsigned int const *tokens + c_uint, # unsigned int ntok + POINTER(c_uint), # unsigned int const *req_lens + c_uint, # unsigned int nreq + POINTER(c_uint), # unsigned int const *req_pos + POINTER(POINTER(KVCacheCStruct)), # struct KVCache **kv_caches + POINTER(c_int), # unsigned int const *block_tables + POINTER(c_int), # unsigned int const *slot_mapping + POINTER(c_float), # float temperature + POINTER(c_uint), # unsigned int topk + POINTER(c_float), # float topp + c_uint, # unsigned int is_prefill + c_bool, # bool enable_paged_attn + POINTER(c_uint), # unsigned int *output + ] + lib.forwardBatch.restype = None + lib.forwardBatch.argtypes = [ + POINTER(JiugeModelCSruct), # struct JiugeModel const * + POINTER(c_uint), # unsigned int const *tokens + c_uint, # unsigned int ntok + POINTER(c_uint), # unsigned int const *req_lens + c_uint, # unsigned int nreq + POINTER(c_uint), # unsigned int const *req_pos + POINTER(POINTER(KVCacheCStruct)), # struct KVCache **kv_caches + POINTER(c_int), # unsigned int const *block_tables + POINTER(c_int), # unsigned int const *slot_mapping + c_uint, # unsigned int is_prefill + c_bool, # bool enable_paged_attn + c_void_p, # void *logits + ] + + return lib + + +LIB = __open_library__() + +create_jiuge_model = LIB.createJiugeModel +destroy_jiuge_model = LIB.destroyJiugeModel +create_kv_cache = LIB.createKVCache +create_paged_kv_cache = LIB.createPagedKVCache +drop_kv_cache = LIB.dropKVCache +infer_batch = LIB.inferBatch +forward_batch = LIB.forwardBatch diff --git a/python/icinfer/engine/llm_engine.py b/python/icinfer/engine/llm_engine.py new file mode 100644 index 00000000..359a5594 --- /dev/null +++ b/python/icinfer/engine/llm_engine.py @@ -0,0 +1,196 @@ +import atexit +from dataclasses import fields +from time import perf_counter +from tqdm.auto import tqdm +from transformers import AutoTokenizer +import torch.multiprocessing as mp +import math +from typing import List +import uuid + +from icinfer.config import Config +from icinfer.sampling_params import SamplingParams +from icinfer.engine.sequence import Sequence +from icinfer.engine.scheduler import Scheduler +from icinfer.engine.model_runner import ModelRunner +from icinfer.engine.infer_task import KVCache, InferTask + +import logging +logger = logging.getLogger(__name__) + + +class InfiniEngine: + + def __init__(self, model, device, **kwargs): + config_fields = {field.name for field in fields(Config)} + config_kwargs = {k: v for k, v in kwargs.items() if k in config_fields} + config = Config(model, **config_kwargs) + + self.ps = [] + self.events = [] + # ctx = mp.get_context("spawn") + # for i in range(1, config.tensor_parallel_size): + # event = ctx.Event() + # process = ctx.Process(target=ModelRunner, args=(config, i, event)) + # process.start() + # self.ps.append(process) + # self.events.append(event) + self.model_runner = ModelRunner(config, device, 0, self.events) + self.eos_token_id = self.model_runner.eos_token_id + self.max_context_len = self.model_runner.max_context_len() + self.tokenizer = AutoTokenizer.from_pretrained(config.model, use_fast=True, trust_remote_code=kwargs["trust_remote_code"]) + config.eos = self.tokenizer.eos_token_id + self.scheduler = Scheduler(config) + atexit.register(self.exit) + + def exit(self): + self.model_runner.call("exit") + del self.model_runner + for p in self.ps: + p.join() + + def add_request(self, prompt: str | list[int], sampling_params: SamplingParams): + if isinstance(prompt, str): + prompt = self.tokenizer.encode(prompt) + seq = Sequence(prompt, sampling_params, block_size=self.scheduler.block_size) + infer_task = InferTask(seq.seq_id, prompt, self.max_context_len, sampling_params.temperature, sampling_params.topk, sampling_params.topp, self.eos_token_id) + if self.model_runner.enable_paged_attn: + pass + else: + infer_task.bind_kvcache(KVCache(self.model_runner)) + seq.bind_infer_task(infer_task) + self.scheduler.add(seq) + return prompt + + def step(self): + seqs, is_prefill = self.scheduler.schedule() + token_ids = self.model_runner.call("run", seqs, is_prefill) + drop_kvcache_list = self.scheduler.postprocess(seqs, token_ids) + if self.model_runner.enable_paged_attn: + pass + else: + for kv_cache in drop_kvcache_list: + kv_cache.drop(self.model_runner) + outputs = [(seq.seq_id, seq.completion_token_ids) for seq in seqs if seq.is_finished] + num_tokens = sum(len(seq) for seq in seqs) if is_prefill else -len(seqs) + return outputs, num_tokens + + def is_finished(self): + return self.scheduler.is_finished() + + def generate( + self, + prompts: list[str] | list[list[int]], + sampling_params: SamplingParams | list[SamplingParams], + use_tqdm: bool = True, + ) -> list[str]: + if use_tqdm: + pbar = tqdm(total=len(prompts), desc="Generating", dynamic_ncols=True) + if not isinstance(sampling_params, list): + sampling_params = [sampling_params] * len(prompts) + prompts_list = [] + for prompt, sp in zip(prompts, sampling_params): + prompts_list.append(self.add_request(prompt, sp)) + outputs = {} + prefill_throughput = decode_throughput = 0. + logger.info("start generating") + # perfile + avg_prefill_throughput = 0 + prefill_time = 0 + avg_decode_throughput = 0 + decode_time = 0 + ttft = 0 + ttft_count = 0 + tbt = 0 + tbt_count = 0 + + while not self.is_finished(): + t = perf_counter() + output, num_tokens = self.step() + if use_tqdm: + if num_tokens > 0: + check_time = perf_counter() + prefill_throughput = num_tokens / (check_time - t) + ttft += (check_time - t) + ttft_count += 1 + avg_prefill_throughput = (avg_prefill_throughput * prefill_time + num_tokens)/(prefill_time+(check_time - t)) + prefill_time += (check_time - t) + else: + check_time = perf_counter() + decode_throughput = -num_tokens / (check_time - t) + tbt += (check_time - t) + tbt_count += 1 + avg_decode_throughput = (avg_decode_throughput * decode_time - num_tokens)/(decode_time+(check_time - t)) + decode_time += (check_time - t) + pbar.set_postfix({ + "Prefill": f"{int(prefill_throughput)}tok/s", + "Decode": f"{int(decode_throughput)}tok/s", + }) + for seq_id, token_ids in output: + outputs[seq_id] = token_ids + if use_tqdm: + pbar.update(1) + outputs = [outputs[seq_id] for seq_id in sorted(outputs)] + outputs = [{"text": self.tokenizer.decode(token_ids), "token_ids": token_ids} for token_ids in outputs] + avg_ttft = ttft / ttft_count + avg_tbt = tbt / tbt_count + if not self.model_runner.enable_paged_attn: + max_model_len = self.model_runner.config.max_model_len + num_seqs = len(outputs) + used_tokens = [len(prompts_list[i])+len(outputs[i]) for i in range(num_seqs)] + used_tokens_count = sum(used_tokens) + cache_efficiency = used_tokens_count / (num_seqs * max_model_len) + else: + max_model_len = self.model_runner.config.max_model_len + num_seqs = len(outputs) + used_tokens = [len(prompts_list[i])+len(outputs[i]) for i in range(num_seqs)] + block_size = self.model_runner.config.kvcache_block_size + cache_memory = [(i_tokens + block_size - 1) // block_size * block_size for i_tokens in used_tokens] + cache_efficiency = sum(used_tokens) / sum(cache_memory) + + if use_tqdm: + pbar.close() + self.model_runner.exit() + return outputs, avg_prefill_throughput, avg_decode_throughput, avg_ttft, avg_tbt, cache_efficiency + + def add_perplexity_request(self, prompt: str | list[int], sampling_params: SamplingParams): + if isinstance(prompt, str): + prompt = self.tokenizer.encode(prompt) + input_tokens = prompt[:-1] + true_tokens = prompt[1:] + seq = Sequence(input_tokens, sampling_params, block_size=self.scheduler.block_size) + infer_task = InferTask(seq.seq_id, input_tokens, self.max_context_len, 1.0, 1, 1.0, self.eos_token_id) + seq.true_tokens = true_tokens + if self.model_runner.enable_paged_attn: + pass + else: + infer_task.bind_kvcache(KVCache(self.model_runner)) + seq.bind_infer_task(infer_task) + self.scheduler.add(seq) + + def perplexity_step(self): + seqs, is_prefill = self.scheduler.schedule() + nll, total_len, token_ids_none = self.model_runner.call("run_for_logits", seqs, is_prefill) + drop_kvcache_list = self.scheduler.postprocess(seqs, token_ids_none) + if self.model_runner.enable_paged_attn: + pass + else: + for kv_cache in drop_kvcache_list: + kv_cache.drop(self.model_runner) + return nll, total_len + # outputs = [(seq.seq_id, seq.completion_token_ids) for seq in seqs if seq.is_finished] + # num_tokens = sum(len(seq) for seq in seqs) if is_prefill else -len(seqs) + # return outputs, num_tokens + + def perplexity(self, test_sequences: List[List[int]]): + nll = 0.0 + total_len = 0 + + for i in range(len(test_sequences)): + self.add_perplexity_request(test_sequences[i], SamplingParams(temperature=1.0, topk=1, topp=1.0, max_tokens=1)) + while not self.is_finished(): + nll_i, total_len_i = self.perplexity_step() + nll += nll_i + total_len += total_len_i + + return math.exp(nll / total_len) \ No newline at end of file diff --git a/python/icinfer/engine/llm_engine_async.py b/python/icinfer/engine/llm_engine_async.py new file mode 100644 index 00000000..0e0d6ed3 --- /dev/null +++ b/python/icinfer/engine/llm_engine_async.py @@ -0,0 +1,264 @@ +import atexit +from dataclasses import fields +from time import perf_counter +from tqdm.auto import tqdm +from transformers import AutoTokenizer +import torch.multiprocessing as mp +import math +from typing import List +import uuid +import threading +import queue +import asyncio +from typing import Dict +import time +import collections + +from icinfer.config import Config +from icinfer.sampling_params import SamplingParams +from icinfer.engine.sequence import Sequence +from icinfer.engine.scheduler import Scheduler +from icinfer.engine.model_runner import ModelRunner +from icinfer.engine.infer_task import KVCache, InferTask + +import logging +logger = logging.getLogger(__name__) + + +class InfiniEngineAsync: + + def __init__(self, model, device, **kwargs): + config_fields = {field.name for field in fields(Config)} + config_kwargs = {k: v for k, v in kwargs.items() if k in config_fields} + config = Config(model, **config_kwargs) + + self.ps = [] + self.events = [] + # ctx = mp.get_context("spawn") + # for i in range(1, config.tensor_parallel_size): + # event = ctx.Event() + # process = ctx.Process(target=ModelRunner, args=(config, i, event)) + # process.start() + # self.ps.append(process) + # self.events.append(event) + self.model_runner = ModelRunner(config, device, 0, self.events) + self.eos_token_id = self.model_runner.eos_token_id + self.max_context_len = self.model_runner.max_context_len() + self.request_queue = queue.Queue() + self.result_queues: Dict[str, asyncio.Queue] = {} + self.main_loop = None + + self.tokenizer = AutoTokenizer.from_pretrained(config.model, use_fast=True, trust_remote_code=kwargs["trust_remote_code"]) + config.eos = self.tokenizer.eos_token_id + self.scheduler = Scheduler(config) + atexit.register(self.exit) + + + def exit(self): + self.model_runner.call("exit") + del self.model_runner + for p in self.ps: + p.join() + + async def add_request(self, prompt: str | list[int], sampling_params: SamplingParams, request_id: str): + if self.main_loop is None: + self.main_loop = asyncio.get_running_loop() + + result_queue = asyncio.Queue() + self.result_queues[request_id] = result_queue + self.request_queue.put((prompt, sampling_params, request_id)) + + return result_queue + + def add_request_action(self, prompt: str | list[int], sp, req_id): + if isinstance(prompt, str): + prompt_tokens = self.tokenizer.encode(prompt) + else: + prompt_tokens = prompt + + seq = Sequence(prompt_tokens, sp, block_size=self.scheduler.block_size, req_id=req_id) + infer_task = InferTask(seq.req_id, prompt_tokens, self.max_context_len, sp.temperature, sp.topk, sp.topp, self.eos_token_id) + if self.model_runner.enable_paged_attn: + pass + else: + infer_task.bind_kvcache(KVCache(self.model_runner)) + seq.bind_infer_task(infer_task) + self.scheduler.add(seq) + + def step(self): + seqs, is_prefill = self.scheduler.schedule() + token_ids = self.model_runner.call("run", seqs, is_prefill) + drop_kvcache_list = self.scheduler.postprocess(seqs, token_ids) + if self.model_runner.enable_paged_attn: + pass + else: + for kv_cache in drop_kvcache_list: + kv_cache.drop(self.model_runner) + outputs = [(seq.req_id, seq.completion_token_ids) for seq in seqs if seq.is_finished] + num_tokens = sum(len(seq) for seq in seqs) if is_prefill else -len(seqs) + return outputs, num_tokens + + def is_finished(self): + return self.scheduler.is_finished() + + def engine_loop(self): + while True: + # 1. 从队列中获取新请求并添加到调度器 + while not self.request_queue.empty(): + prompt, sp, req_id = self.request_queue.get() + + self.add_request_action(prompt, sp, req_id) + + if self.request_queue.empty(): + time.sleep(0.1) + continue + + + # 2. 执行一步推理 + if not self.scheduler.is_finished(): + seqs, is_prefill = self.scheduler.schedule() + print(f"seqs_len: {len(seqs)}") + + # token_ids 是一个列表,按进入顺序排列的 + token_ids = self.model_runner.call("run", seqs, is_prefill) + + for seq_order_i in range(len(seqs)): + seq = seqs[seq_order_i] + new_token = token_ids[seq_order_i] + result_queue = self.result_queues.get(seq.req_id) + if result_queue: + self.main_loop.call_soon_threadsafe(result_queue.put_nowait, new_token) + + drop_kvcache_list = self.scheduler.postprocess(seqs, token_ids) + if self.model_runner.enable_paged_attn: + pass + else: + for kv_cache in drop_kvcache_list: + kv_cache.drop(self.model_runner) + + # 4. 处理完成的序列 + for seq in seqs: + if seq.is_finished: + result_queue = self.result_queues.get(seq.req_id) + if result_queue: + self.main_loop.call_soon_threadsafe(result_queue.put_nowait, None) + self.result_queues.pop(seq.req_id, None) + else: + time.sleep(0.01) + + # def generate( + # self, + # prompts: list[str] | list[list[int]], + # sampling_params: SamplingParams | list[SamplingParams], + # use_tqdm: bool = True, + # ) -> list[str]: + # if use_tqdm: + # pbar = tqdm(total=len(prompts), desc="Generating", dynamic_ncols=True) + # if not isinstance(sampling_params, list): + # sampling_params = [sampling_params] * len(prompts) + # prompts_list = [] + # for prompt, sp in zip(prompts, sampling_params): + # prompts_list.append(self.add_request(prompt, sp)) + # outputs = {} + # prefill_throughput = decode_throughput = 0. + # logger.info("start generating") + # # perfile + # avg_prefill_throughput = 0 + # prefill_time = 0 + # avg_decode_throughput = 0 + # decode_time = 0 + # ttft = 0 + # ttft_count = 0 + # tbt = 0 + # tbt_count = 0 + + # while not self.is_finished(): + # t = perf_counter() + # output, num_tokens = self.step() + # if use_tqdm: + # if num_tokens > 0: + # check_time = perf_counter() + # prefill_throughput = num_tokens / (check_time - t) + # ttft += (check_time - t) + # ttft_count += 1 + # avg_prefill_throughput = (avg_prefill_throughput * prefill_time + num_tokens)/(prefill_time+(check_time - t)) + # prefill_time += (check_time - t) + # else: + # check_time = perf_counter() + # decode_throughput = -num_tokens / (check_time - t) + # tbt += (check_time - t) + # tbt_count += 1 + # avg_decode_throughput = (avg_decode_throughput * decode_time - num_tokens)/(decode_time+(check_time - t)) + # decode_time += (check_time - t) + # pbar.set_postfix({ + # "Prefill": f"{int(prefill_throughput)}tok/s", + # "Decode": f"{int(decode_throughput)}tok/s", + # }) + # for seq_id, token_ids in output: + # outputs[seq_id] = token_ids + # if use_tqdm: + # pbar.update(1) + # outputs = [outputs[seq_id] for seq_id in sorted(outputs)] + # outputs = [{"text": self.tokenizer.decode(token_ids), "token_ids": token_ids} for token_ids in outputs] + # avg_ttft = ttft / ttft_count + # avg_tbt = tbt / tbt_count + # if not self.model_runner.enable_paged_attn: + # max_model_len = self.model_runner.config.max_model_len + # num_seqs = len(outputs) + # used_tokens = [len(prompts_list[i])+len(outputs[i]) for i in range(num_seqs)] + # used_tokens_count = sum(used_tokens) + # cache_efficiency = used_tokens_count / (num_seqs * max_model_len) + # else: + # max_model_len = self.model_runner.config.max_model_len + # num_seqs = len(outputs) + # used_tokens = [len(prompts_list[i])+len(outputs[i]) for i in range(num_seqs)] + # block_size = self.model_runner.config.kvcache_block_size + # cache_memory = [(i_tokens + block_size - 1) // block_size * block_size for i_tokens in used_tokens] + # cache_efficiency = sum(used_tokens) / sum(cache_memory) + + # if use_tqdm: + # pbar.close() + # self.model_runner.exit() + # return outputs, avg_prefill_throughput, avg_decode_throughput, avg_ttft, avg_tbt, cache_efficiency + + # def add_perplexity_request(self, prompt: str | list[int], sampling_params: SamplingParams): + # if isinstance(prompt, str): + # prompt = self.tokenizer.encode(prompt) + # input_tokens = prompt[:-1] + # true_tokens = prompt[1:] + # seq = Sequence(input_tokens, sampling_params, block_size=self.scheduler.block_size) + # infer_task = InferTask(seq.seq_id, input_tokens, self.max_context_len, 1.0, 1, 1.0, self.eos_token_id) + # seq.true_tokens = true_tokens + # if self.model_runner.enable_paged_attn: + # pass + # else: + # infer_task.bind_kvcache(KVCache(self.model_runner)) + # seq.bind_infer_task(infer_task) + # self.scheduler.add(seq) + + # def perplexity_step(self): + # seqs, is_prefill = self.scheduler.schedule() + # nll, total_len, token_ids_none = self.model_runner.call("run_for_logits", seqs, is_prefill) + # drop_kvcache_list = self.scheduler.postprocess(seqs, token_ids_none) + # if self.model_runner.enable_paged_attn: + # pass + # else: + # for kv_cache in drop_kvcache_list: + # kv_cache.drop(self.model_runner) + # return nll, total_len + # # outputs = [(seq.seq_id, seq.completion_token_ids) for seq in seqs if seq.is_finished] + # # num_tokens = sum(len(seq) for seq in seqs) if is_prefill else -len(seqs) + # # return outputs, num_tokens + + # def perplexity(self, test_sequences: List[List[int]]): + # nll = 0.0 + # total_len = 0 + + # for i in range(len(test_sequences)): + # self.add_perplexity_request(test_sequences[i], SamplingParams(temperature=1.0, topk=1, topp=1.0, max_tokens=1)) + # while not self.is_finished(): + # nll_i, total_len_i = self.perplexity_step() + # nll += nll_i + # total_len += total_len_i + + # return math.exp(nll / total_len) \ No newline at end of file diff --git a/python/icinfer/engine/model_runner.py b/python/icinfer/engine/model_runner.py new file mode 100644 index 00000000..17826f2a --- /dev/null +++ b/python/icinfer/engine/model_runner.py @@ -0,0 +1,478 @@ +import pickle +import torch +import torch.distributed as dist +from multiprocessing.synchronize import Event +from multiprocessing.shared_memory import SharedMemory +from ctypes import c_uint +from typing import List +import logging +import itertools + + +from icinfer.config import Config +from icinfer.engine.sequence import Sequence +from icinfer.engine.libinfinicore_infer import ( + JiugeMetaCStruct, + JiugeWeightsCStruct, + KVCacheCStruct, + DataType, + DeviceType, + create_jiuge_model, + destroy_jiuge_model, + create_kv_cache, + create_paged_kv_cache, + drop_kv_cache, + infer_batch, + forward_batch, +) + +from icinfer.layers.sampler import Sampler +from icinfer.utils.context import set_context, get_context, reset_context +# from icinfer.utils.loader import load_model +from icinfer.utils.jiuge_weights_loader import load_model +from icinfer.engine.infer_task import InferTask, InferBatchedTask, InferPagedBatchedTask, PagedKVCache + + +# infinicore infer +from typing import List, Sequence +from ctypes import POINTER, c_float, c_int, c_uint, c_void_p, byref +import time +import math +from icinfer.engine.infer_task import InferTask, KVCache + + +logger = logging.getLogger(__name__) + + +class ModelRunner: + + def __init__(self, config: Config, device: DeviceType, rank: int, event: Event | list[Event]): + self.config = config + self.hf_config = config.hf_config + self.device = device + self.block_size = config.kvcache_block_size + self.enforce_eager = config.enforce_eager + self.enable_paged_attn = config.enable_paged_attn + self.world_size = config.tensor_parallel_size + self.meta = None + self.kv_cache = None + self.rank = rank + self.event = event + + # dist.init_process_group("nccl", "tcp://localhost:2333", world_size=self.world_size, rank=rank) + # torch.cuda.set_device(rank) + # default_dtype = torch.get_default_dtype() + # torch.set_default_dtype(hf_config.torch_dtype) + # torch.set_default_device("cuda") + + # model_set = { + # "qwen3": Qwen3ForCausalLM, + # "fm9g7b": FM9GForCausalLM, + # } + # ModelForCausalLm = model_set[hf_config.model_type] + # self.model = ModelForCausalLm(hf_config) + # load_model(self.model, config.model) + + self.model, self.meta = load_model(self.config, device) + # self.tokenizer = transformers.AutoTokenizer.from_pretrained( + # model_dir_path, trust_remote_code=True + # ) + + eos_token_id = self.hf_config.eos_token_id + self.eos_token_id = ( + [eos_token_id] if type(eos_token_id) == int else eos_token_id + ) + + # self.sampler = Sampler() + # self.warmup_model() + # TODO 暂时先关掉 + if self.enable_paged_attn: + self.allocate_kv_cache() + if not self.enforce_eager: + self.capture_cudagraph() + # torch.set_default_device("cpu") + # torch.set_default_dtype(default_dtype) + + # if self.world_size > 1: + # if rank == 0: + # self.shm = SharedMemory(name="nanovllm", create=True, size=2**20) + # dist.barrier() + # else: + # dist.barrier() + # self.shm = SharedMemory(name="nanovllm") + # self.loop() + + def exit(self): + # if self.world_size > 1: + # self.shm.close() + # dist.barrier() + # if self.rank == 0: + # self.shm.unlink() + if not self.enforce_eager: + del self.graphs, self.graph_pool + # torch.cuda.synchronize() + self.destroy() + # dist.destroy_process_group() + + def __del__(self): + self.destroy() + + def destroy(self): + """ + 在程序退出时,安全地释放 C++ 侧的资源。 + """ + if hasattr(self, 'kv_cache') and self.kv_cache: + print("drop_kv_cache") + drop_kv_cache(self.model, self.kv_cache.data()) + self.kv_cache = None + if hasattr(self, 'model') and self.model: + destroy_jiuge_model(self.model) + self.model = None + + logger.info("ModelRunner model resources have been released.") + + # def loop(self): + # while True: + # method_name, args = self.read_shm() + # self.call(method_name, *args) + # if method_name == "exit": + # break + + # def read_shm(self): + # assert self.world_size > 1 and self.rank + # self.event.wait() + # n = int.from_bytes(self.shm.buf[0:4], "little") + # method_name, *args = pickle.loads(self.shm.buf[4:n+4]) + # self.event.clear() + # return method_name, args + + # def write_shm(self, method_name, *args): + # assert self.world_size > 1 and not self.rank + # data = pickle.dumps([method_name, *args]) + # n = len(data) + # self.shm.buf[0:4] = n.to_bytes(4, "little") + # self.shm.buf[4:n+4] = data + # for event in self.event: + # event.set() + + def call(self, method_name, *args): + # if self.world_size > 1 and self.rank == 0: + # self.write_shm(method_name, *args) + method = getattr(self, method_name, None) + return method(*args) + + # def warmup_model(self): + # torch.cuda.empty_cache() + # torch.cuda.reset_peak_memory_stats() + # max_num_batched_tokens, max_model_len = self.config.max_num_batched_tokens, self.config.max_model_len + # num_seqs = min(max_num_batched_tokens // max_model_len, self.config.max_num_seqs) + # seqs = [Sequence([0] * max_model_len) for _ in range(num_seqs)] + # self.run(seqs, True) + # torch.cuda.empty_cache() + + # def _calculate_num_blocks(self, num_kv_heads: int) -> int: + # config = self.config + # hf_config = config.hf_config + # gpu_memory_utilization = config.gpu_memory_utilization + + # free, total = torch.cuda.mem_get_info() + # used = total - free + # # todo torch.cuda需要用一个什么来替代这部分 + # peak = torch.cuda.memory_stats()["allocated_bytes.all.peak"] + # current = torch.cuda.memory_stats()["allocated_bytes.all.current"] + # block_bytes = 2 * hf_config.num_hidden_layers * self.block_size * num_kv_heads * hf_config.head_dim * hf_config.torch_dtype.itemsize + # num_kvcache_blocks = int(total * gpu_memory_utilization - used - peak + current) // block_bytes + # assert num_kvcache_blocks > 0 + # return num_kvcache_blocks + + def allocate_kv_cache(self): + kv_cache = self.create_paged_kv_cache(self.config.max_kvcache_tokens) + self.kv_cache = PagedKVCache(kv_cache) + print("kvcache allocated ") + # config = self.config + # hf_config = config.hf_config + # num_kv_heads = hf_config.num_key_value_heads // self.world_size + # config.num_kvcache_blocks = self._calculate_num_blocks(num_kv_heads) + # self.kv_cache = torch.zeros(2, hf_config.num_hidden_layers, config.num_kvcache_blocks, self.block_size, num_kv_heads, hf_config.head_dim) + # layer_id = 0 + # for module in self.model.modules(): + # if hasattr(module, "k_cache") and hasattr(module, "v_cache"): + # module.k_cache = self.kv_cache[0, layer_id] + # module.v_cache = self.kv_cache[1, layer_id] + # layer_id += 1 + + def prepare_block_tables(self, seqs: list[Sequence]): + max_len = max(len(seq.block_table) for seq in seqs) + padded_lists_generator = ( + (seq.block_table + [0] * (max_len - len(seq.block_table))) + for seq in seqs + ) + block_tables_flat = list(itertools.chain.from_iterable(padded_lists_generator)) + # block_tables = [seq.block_table + [-1] * (max_len - len(seq.block_table)) for seq in seqs] + # block_tables = torch.tensor(block_tables, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) + return block_tables_flat + + def prepare_prefill(self, seqs: list[Sequence]): + input_ids = [] + positions = [] + cu_seqlens_q = [0] + cu_seqlens_k = [0] + max_seqlen_q = 0 + max_seqlen_k = 0 + slot_mapping = [] + block_tables = [] + for seq in seqs: + seqlen = len(seq) + input_ids.extend(seq[seq.num_cached_tokens:]) + positions.extend(list(range(seq.num_cached_tokens, seqlen))) + seqlen_q = seqlen - seq.num_cached_tokens + seqlen_k = seqlen + cu_seqlens_q.append(cu_seqlens_q[-1] + seqlen_q) + cu_seqlens_k.append(cu_seqlens_k[-1] + seqlen_k) + max_seqlen_q = max(seqlen_q, max_seqlen_q) + max_seqlen_k = max(seqlen_k, max_seqlen_k) + if not seq.block_table: + continue + for i in range(seq.num_cached_blocks, seq.num_blocks): + start = seq.block_table[i] * self.block_size + if i != seq.num_blocks - 1: + end = start + self.block_size + else: + end = start + seq.last_block_num_tokens + slot_mapping.extend(list(range(start, end))) + # if cu_seqlens_k[-1] > cu_seqlens_q[-1]: # prefix cache + block_tables = self.prepare_block_tables(seqs) + # input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True) + # positions = torch.tensor(positions, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True) + # cu_seqlens_q = torch.tensor(cu_seqlens_q, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) + # cu_seqlens_k = torch.tensor(cu_seqlens_k, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) + # slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) + # set_context(True, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, slot_mapping, None, block_tables) + # return input_ids, positions + return block_tables, slot_mapping + + def prepare_decode(self, seqs: list[Sequence]): + input_ids = [] + positions = [] + slot_mapping = [] + context_lens = [] + for seq in seqs: + input_ids.append(seq.last_token) + positions.append(len(seq)) + context_lens.append(len(seq)) + slot_mapping.append(seq.block_table[-1] * self.block_size + seq.last_block_num_tokens - 1) + # input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True) + # positions = torch.tensor(positions, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True) + # slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) + # context_lens = torch.tensor(context_lens, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) + block_tables = self.prepare_block_tables(seqs) + # set_context(False, slot_mapping=slot_mapping, context_lens=context_lens, block_tables=block_tables) + return block_tables, slot_mapping + + # def prepare_sample(self, seqs: list[Sequence]): + # temperatures = [] + # for seq in seqs: + # temperatures.append(seq.temperature) + # temperatures = torch.tensor(temperatures, dtype=torch.float32, pin_memory=True).cuda(non_blocking=True) + # return temperatures + + @torch.inference_mode() + def run_model(self, input_ids: torch.Tensor, positions: torch.Tensor, is_prefill: bool): + if is_prefill or self.enforce_eager or input_ids.size(0) > 512: + return self.model.compute_logits(self.model(input_ids, positions)) + else: + bs = input_ids.size(0) + context = get_context() + graph = self.graphs[next(x for x in self.graph_bs if x >= bs)] + graph_vars = self.graph_vars + for k, v in graph_vars.items(): + if k != "outputs": + v.zero_() + graph_vars["input_ids"][:bs] = input_ids + graph_vars["positions"][:bs] = positions + graph_vars["slot_mapping"][:bs] = context.slot_mapping + graph_vars["context_lens"][:bs] = context.context_lens + graph_vars["block_tables"][:bs, :context.block_tables.size(1)] = context.block_tables + graph.replay() + return self.model.compute_logits(graph_vars["outputs"][:bs]) + + + # def run(self, seqs: list[Sequence], is_prefill: bool) -> list[int]: + # input_ids, positions = self.prepare_prefill(seqs) if is_prefill else self.prepare_decode(seqs) + # temperatures = self.prepare_sample(seqs) if self.rank == 0 else None + # logits = self.run_model(input_ids, positions, is_prefill) + # token_ids = self.sampler(logits, temperatures).tolist() if self.rank == 0 else None + # reset_context() + # return token_ids + + # @torch.inference_mode() + # def capture_cudagraph(self): + # config = self.config + # hf_config = config.hf_config + # max_bs = min(self.config.max_num_seqs, 512) + # max_num_blocks = (config.max_model_len + self.block_size - 1) // self.block_size + # input_ids = torch.zeros(max_bs, dtype=torch.int64) + # positions = torch.zeros(max_bs, dtype=torch.int64) + # slot_mapping = torch.zeros(max_bs, dtype=torch.int32) + # context_lens = torch.zeros(max_bs, dtype=torch.int32) + # block_tables = torch.zeros(max_bs, max_num_blocks, dtype=torch.int32) + # outputs = torch.zeros(max_bs, hf_config.hidden_size) + # self.graph_bs = [1, 2, 4, 8] + list(range(16, max_bs + 1, 16)) + # self.graphs = {} + # self.graph_pool = None + + # for bs in reversed(self.graph_bs): + # graph = torch.cuda.CUDAGraph() + # set_context(False, slot_mapping=slot_mapping[:bs], context_lens=context_lens[:bs], block_tables=block_tables[:bs]) + # outputs[:bs] = self.model(input_ids[:bs], positions[:bs]) # warmup + # with torch.cuda.graph(graph, self.graph_pool): + # outputs[:bs] = self.model(input_ids[:bs], positions[:bs]) # capture + # if self.graph_pool is None: + # self.graph_pool = graph.pool() + # self.graphs[bs] = graph + # torch.cuda.synchronize() + # reset_context() + + # self.graph_vars = dict( + # input_ids=input_ids, + # positions=positions, + # slot_mapping=slot_mapping, + # context_lens=context_lens, + # block_tables=block_tables, + # outputs=outputs, + # ) + + # @torch.inference_mode() + # def capture_cudagraph(self): + # config = self.config + # hf_config = config.hf_config + # max_bs = min(self.config.max_num_seqs, 512) + # max_num_blocks = (config.max_model_len + self.block_size - 1) // self.block_size + # input_ids = torch.zeros(max_bs, dtype=torch.int64) + # positions = torch.zeros(max_bs, dtype=torch.int64) + # slot_mapping = torch.zeros(max_bs, dtype=torch.int32) + # context_lens = torch.zeros(max_bs, dtype=torch.int32) + # block_tables = torch.zeros(max_bs, max_num_blocks, dtype=torch.int32) + # outputs = torch.zeros(max_bs, hf_config.hidden_size) + # self.graph_bs = [1, 2, 4, 8] + list(range(16, max_bs + 1, 16)) + # self.graphs = {} + # self.graph_pool = None + + # for bs in reversed(self.graph_bs): + # graph = torch.cuda.CUDAGraph() + # set_context(False, slot_mapping=slot_mapping[:bs], context_lens=context_lens[:bs], block_tables=block_tables[:bs]) + # outputs[:bs] = self.model(input_ids[:bs], positions[:bs]) # warmup + # with torch.cuda.graph(graph, self.graph_pool): + # outputs[:bs] = self.model(input_ids[:bs], positions[:bs]) # capture + # if self.graph_pool is None: + # self.graph_pool = graph.pool() + # self.graphs[bs] = graph + # torch.cuda.synchronize() + # reset_context() + + # self.graph_vars = dict( + # input_ids=input_ids, + # positions=positions, + # slot_mapping=slot_mapping, + # context_lens=context_lens, + # block_tables=block_tables, + # outputs=outputs, + # ) + + + # infinifore infer + def max_context_len(self): + return self.meta.dctx + # return self.config.max_model_len + + def create_kv_cache(self): + return create_kv_cache(self.model) + + def drop_kv_cache(self, kv_cache): + drop_kv_cache(self.model, kv_cache) + + def create_paged_kv_cache(self, max_kvcache_tokens): + return create_paged_kv_cache(self.model, max_kvcache_tokens) + + # @torch.inference_mode() + # def batch_infer_one_round(self, tasks: List[InferTask]): + # output = (c_uint * len(tasks))() + # batch_inputs = InferBatchedTask(tasks) + # infer_batch( + # self.model, + # *(batch_inputs.input_args()), + # output, + # ) + # return list(output) + + def batch_infer_one_round(self, tasks: List[InferTask], is_prefill: int, batch_block_tables: list[int], slot_mapping: list[int]): + output = (c_uint * len(tasks))() + batch_inputs = None + if self.enable_paged_attn: + batch_inputs = InferPagedBatchedTask(tasks, batch_block_tables, slot_mapping, self.kv_cache, is_prefill) + else: + batch_inputs = InferBatchedTask(tasks, is_prefill) + infer_batch( + self.model, + *(batch_inputs.input_args()), + self.enable_paged_attn, + output, + ) + return list(output) + + def run(self, seqs: list[Sequence], is_prefill: int) -> list[int]: + # input_ids, positions = self.prepare_prefill(seqs) if is_prefill else self.prepare_decode(seqs) + # temperatures = self.prepare_sample(seqs) if self.rank == 0 else None + # logits = self.run_model(input_ids, positions, is_prefill) + # token_ids = self.sampler(logits, temperatures).tolist() if self.rank == 0 else None + # reset_context() + batch_block_tables, slot_mapping = self.prepare_prefill(seqs) if is_prefill else self.prepare_decode(seqs) + tasks = [seq.infer_task for seq in seqs] + token_ids = self.batch_infer_one_round(tasks, is_prefill, batch_block_tables, slot_mapping) + + return token_ids + + + def batch_infer_one_round_for_logits(self, tasks: List[InferTask], is_prefill: int, batch_block_tables: list[int], slot_mapping: list[int]): + batch_inputs = None + if self.enable_paged_attn: + batch_inputs = InferPagedBatchedTask(tasks, batch_block_tables, slot_mapping, self.kv_cache, is_prefill) + else: + batch_inputs = InferBatchedTask(tasks, is_prefill) + logits = torch.zeros((batch_inputs.ntok, self.meta.dvoc), dtype=self.meta.torch_dtype_logits) + forward_batch( + self.model, + *(batch_inputs.input_args_for_logits()), + self.enable_paged_attn, + logits.data_ptr(), + ) + return logits, batch_inputs.req_lens_list, batch_inputs.ntok + + def run_for_logits(self, seqs: list[Sequence], is_prefill: int) -> torch.Tensor: + # input_ids, positions = self.prepare_prefill(seqs) if is_prefill else self.prepare_decode(seqs) + # temperatures = self.prepare_sample(seqs) if self.rank == 0 else None + # logits = self.run_model(input_ids, positions, is_prefill) + # token_ids = self.sampler(logits, temperatures).tolist() if self.rank == 0 else None + # reset_context() + nll = 0.0 + total_len = 0 + batch_block_tables, slot_mapping = self.prepare_prefill(seqs) if is_prefill else self.prepare_decode(seqs) + tasks = [seq.infer_task for seq in seqs] + true_tokens = [seq.true_tokens for seq in seqs] + logits, req_lens_list, ntok = self.batch_infer_one_round_for_logits(tasks, is_prefill, batch_block_tables, slot_mapping) + token_ids_none = [None] * len(seqs) + + logits = logits.float() + token_ids = torch.tensor(true_tokens, dtype=torch.int64).reshape(-1) # [ntok,] + log_probs = torch.nn.functional.log_softmax(logits, dim=-1) # (ntok, vocab) + token_logprobs = log_probs[ + torch.arange(ntok), token_ids + ] # (ntok,) + + start = 0 + for l in req_lens_list: + nll += -token_logprobs[start : start + l].sum().item() + start += l + total_len += token_logprobs.numel() + + return nll, total_len, token_ids_none diff --git a/python/icinfer/engine/scheduler.py b/python/icinfer/engine/scheduler.py new file mode 100644 index 00000000..6782792d --- /dev/null +++ b/python/icinfer/engine/scheduler.py @@ -0,0 +1,91 @@ +from collections import deque + +from icinfer.config import Config +from icinfer.engine.sequence import Sequence, SequenceStatus +from icinfer.engine.block_manager import BlockManager +from icinfer.engine.infer_task import KVCache + + +class Scheduler: + + def __init__(self, config: Config): + self.max_num_seqs = config.max_num_seqs + self.max_num_batched_tokens = config.max_num_batched_tokens + self.eos = config.eos + self.block_manager = BlockManager(config.num_kvcache_blocks, config.kvcache_block_size) + self.waiting: deque[Sequence] = deque() + self.running: deque[Sequence] = deque() + + + def is_finished(self): + return not self.waiting and not self.running + + def add(self, seq: Sequence): + self.waiting.append(seq) + + def schedule(self) -> tuple[list[Sequence], int]: + # prefill + scheduled_seqs = [] + num_seqs = 0 + num_batched_tokens = 0 + is_prefill = 0 + while self.waiting and num_seqs < self.max_num_seqs: + seq = self.waiting[0] + if num_batched_tokens + len(seq) > self.max_num_batched_tokens or not self.block_manager.can_allocate(seq): + break + num_seqs += 1 + self.block_manager.allocate(seq) + num_batched_tokens += len(seq) - seq.num_cached_tokens + seq.status = SequenceStatus.RUNNING + self.waiting.popleft() + self.running.append(seq) + scheduled_seqs.append(seq) + if scheduled_seqs: + is_prefill = 1 + return scheduled_seqs, is_prefill + + # decode + while self.running and num_seqs < self.max_num_seqs: + seq = self.running.popleft() + while not self.block_manager.can_append(seq): + if self.running: + self.preempt(self.running.pop()) + else: + self.preempt(seq) + break + else: + num_seqs += 1 + self.block_manager.may_append(seq) + scheduled_seqs.append(seq) + assert scheduled_seqs + self.running.extendleft(reversed(scheduled_seqs)) + # print(f"is_prefill: {is_prefill}, schedule over.\n") + return scheduled_seqs, is_prefill + + def preempt(self, seq: Sequence): + seq.status = SequenceStatus.WAITING + self.block_manager.deallocate(seq) + self.waiting.appendleft(seq) + + # def postprocess(self, seqs: list[Sequence], token_ids: list[int]) -> list[bool]: + # for seq, token_id in zip(seqs, token_ids): + # seq.append_token(token_id) + # if (not seq.ignore_eos and token_id == self.eos) or seq.num_completion_tokens == seq.max_tokens: + # seq.status = SequenceStatus.FINISHED + # self.block_manager.deallocate(seq) + # self.running.remove(seq) + + def postprocess(self, seqs: list[Sequence], token_ids: list[int]) -> list[KVCache]: + drop_kvcache_list = [] + for seq, token_id in zip(seqs, token_ids): + seq.append_token(token_id) + if (not seq.ignore_eos and token_id == self.eos) or seq.num_completion_tokens == seq.max_tokens: + seq.status = SequenceStatus.FINISHED + drop_kvcache_list.append(seq.infer_task.release_kvcache()) + self.block_manager.deallocate(seq) + self.running.remove(seq) + return drop_kvcache_list + + @property + def block_size(self): + return self.block_manager.block_size diff --git a/python/icinfer/engine/sequence.py b/python/icinfer/engine/sequence.py new file mode 100644 index 00000000..f27b5469 --- /dev/null +++ b/python/icinfer/engine/sequence.py @@ -0,0 +1,96 @@ +from copy import copy +from enum import Enum, auto +from itertools import count + +from icinfer.sampling_params import SamplingParams +from icinfer.engine.infer_task import InferTask + + +class SequenceStatus(Enum): + WAITING = auto() + RUNNING = auto() + FINISHED = auto() + + +class Sequence: + # block_size = 256 + counter = count() + + def __init__(self, token_ids: list[int], sampling_params = SamplingParams(), block_size = 256, req_id = None): + self.seq_id = next(Sequence.counter) + self.status = SequenceStatus.WAITING + self.token_ids = copy(token_ids) + self.last_token = token_ids[-1] + self.num_tokens = len(self.token_ids) + self.num_prompt_tokens = len(token_ids) + self.num_cached_tokens = 0 + self.block_size = block_size + # self.block_table = None + self.block_table = [] + self.infer_task = None + + # for online serving + self.req_id = req_id + + self.true_tokens = None # for perplexity + self.temperature = sampling_params.temperature + self.max_tokens = sampling_params.max_tokens + self.ignore_eos = sampling_params.ignore_eos + + def __len__(self): + return self.num_tokens + + def __getitem__(self, key): + return self.token_ids[key] + + @property + def is_finished(self): + return self.status == SequenceStatus.FINISHED + + @property + def num_completion_tokens(self): + return self.num_tokens - self.num_prompt_tokens + + @property + def prompt_token_ids(self): + return self.token_ids[:self.num_prompt_tokens] + + @property + def completion_token_ids(self): + return self.token_ids[self.num_prompt_tokens:] + + @property + def num_cached_blocks(self): + return self.num_cached_tokens // self.block_size + + @property + def num_blocks(self): + return (self.num_tokens + self.block_size - 1) // self.block_size + + @property + def last_block_num_tokens(self): + return self.num_tokens - (self.num_blocks - 1) * self.block_size + + def block(self, i): + assert 0 <= i < self.num_blocks + return self.token_ids[i*self.block_size: (i+1)*self.block_size] + + def append_token(self, token_id: int): + self.token_ids.append(token_id) + self.infer_task.next(token_id) + self.last_token = token_id + self.num_tokens += 1 + + def bind_infer_task(self, infer_task: InferTask): + self.infer_task = infer_task + + def __getstate__(self): + return (self.num_tokens, self.num_prompt_tokens, self.num_cached_tokens, self.block_table, + self.token_ids if self.num_completion_tokens == 0 else self.last_token) + + def __setstate__(self, state): + self.num_tokens, self.num_prompt_tokens, self.num_cached_tokens, self.block_table = state[:-1] + if self.num_completion_tokens == 0: + self.token_ids = state[-1] + else: + self.last_token = state[-1] diff --git a/python/icinfer/layers/sampler.py b/python/icinfer/layers/sampler.py new file mode 100644 index 00000000..e4b9816e --- /dev/null +++ b/python/icinfer/layers/sampler.py @@ -0,0 +1,18 @@ +import torch +from torch import nn + + +class Sampler(nn.Module): + + def __init__(self): + super().__init__() + + def forward(self, logits: torch.Tensor, temperatures: torch.Tensor): + logits = logits.to(torch.float) + greedy_tokens = logits.argmax(dim=-1) + logits.div_(temperatures.unsqueeze(dim=1)) + probs = torch.softmax(logits, dim=-1, dtype=torch.float) + # logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float) + epsilon = 1e-10 + sample_tokens = probs.div_(torch.empty_like(probs).exponential_(1) + epsilon).argmax(dim=-1) + return torch.where(temperatures == 0, greedy_tokens, sample_tokens) diff --git a/python/icinfer/llm.py b/python/icinfer/llm.py new file mode 100644 index 00000000..be38ecb0 --- /dev/null +++ b/python/icinfer/llm.py @@ -0,0 +1,5 @@ +from icinfer.engine.llm_engine import InfiniEngine + + +class LLM(InfiniEngine): + pass diff --git a/python/icinfer/models/jiuge.py b/python/icinfer/models/jiuge.py new file mode 100644 index 00000000..7f45f3f0 --- /dev/null +++ b/python/icinfer/models/jiuge.py @@ -0,0 +1,718 @@ +from typing import List, Sequence + +from sympy import true + +from ctypes import POINTER, c_float, c_int, c_uint, c_void_p, byref +import os +from pathlib import Path +import safetensors +import sys +import time +import json +import math +import torch +import transformers + +from icinfer.engine.libinfinicore_infer import ( + JiugeMetaCStruct, + JiugeWeightsCStruct, + KVCacheCStruct, + DataType, + DeviceType, + create_jiuge_model, + destroy_jiuge_model, + create_kv_cache, + drop_kv_cache, + infer_batch, + forward_batch, +) +from icinfer.engine.infer_task import InferTask, KVCache + +import logging +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' +) +logger = logging.getLogger(__name__) + + +torch.set_default_device("cpu") + + +class LlamaWeightsNaming: + def input_embd(self): + return "model.embed_tokens.weight" + + def output_norm(self): + return "model.norm.weight" + + def output_embd(self): + return "lm_head.weight" + + def attn_norm(self, i): + return f"model.layers.{i}.input_layernorm.weight" + + def attn_q(self, i): + return f"model.layers.{i}.self_attn.q_proj.weight" + + def attn_k(self, i): + return f"model.layers.{i}.self_attn.k_proj.weight" + + def attn_v(self, i): + return f"model.layers.{i}.self_attn.v_proj.weight" + + def attn_o(self, i): + return f"model.layers.{i}.self_attn.o_proj.weight" + + def attn_q_b(self, i): + return f"model.layers.{i}.self_attn.q_proj.bias" + + def attn_k_b(self, i): + return f"model.layers.{i}.self_attn.k_proj.bias" + + def attn_v_b(self, i): + return f"model.layers.{i}.self_attn.v_proj.bias" + + def ffn_norm(self, i): + return f"model.layers.{i}.post_attention_layernorm.weight" + + def gate(self, i): + return f"model.layers.{i}.mlp.gate_proj.weight" + + def up(self, i): + return f"model.layers.{i}.mlp.up_proj.weight" + + def down(self, i): + return f"model.layers.{i}.mlp.down_proj.weight" + + def match(state_dict): + return ( + "model.norm.weight" in state_dict + and "model.layers.0.self_attn.q_proj.weight" in state_dict + ) + + +class JiugeMetaFromLlama(JiugeMetaCStruct): + def __init__(self, config, dtype=torch.float16, max_tokens=None): + if dtype == torch.float16: + dt_ = DataType.INFINI_DTYPE_F16 + elif dtype == torch.float32: + dt_ = DataType.INFINI_DTYPE_F32 + elif dtype == torch.bfloat16: + dt_ = DataType.INFINI_DTYPE_BF16 + else: + dt_ = DataType.INFINI_DTYPE_F16 + + self.scale_input = 1.0 + self.scale_output = 1.0 + self.scale_o = 1.0 + self.scale_down = 1.0 + if ( + config["model_type"] in ["fm9g", "minicpm"] + and "scale_emb" in config + and "scale_depth" in config + and "dim_model_base" in config + ): + self.scale_input = config["scale_emb"] + self.scale_output = config["hidden_size"] // config["dim_model_base"] + self.scale_o = config["scale_depth"] / math.sqrt( + config["num_hidden_layers"] + ) + self.scale_down = config["scale_depth"] / math.sqrt( + config["num_hidden_layers"] + ) + + super().__init__( + dt_logits=dt_, + nlayer=config["num_hidden_layers"], + d=config["hidden_size"], + nh=config["num_attention_heads"], + nkvh=( + config["num_key_value_heads"] + if "num_key_value_heads" in config + else config["num_attention_heads"] + ), + dh=config["hidden_size"] // config["num_attention_heads"], + di=config["intermediate_size"], + dctx=( + config["max_position_embeddings"] if max_tokens is None else max_tokens + ), + dvoc=config["vocab_size"], + block_size=config["block_size"], + epsilon=config["rms_norm_eps"], + theta=(config["rope_theta"] if "rope_theta" in config else 100000.0), + end_token=2, + ) + self.torch_dtype_logits = dtype + + +class JiugeWeightsImpl(JiugeWeightsCStruct): + def __init__( + self, + meta, + naming, + state_dict, + torch_dt_mat=torch.float16, + torch_dt_norm=torch.float32, + ndev=1, + transpose_weight=True, + ): + nlayer = meta.nlayer + nh = meta.nh + nkvh = meta.nkvh + dh = meta.dh + d = meta.d + di = meta.di + scale_input = meta.scale_input + scale_output = meta.scale_output + scale_o = meta.scale_o + scale_down = meta.scale_down + assert nh % nkvh == 0 + assert nh % ndev == 0 + assert nkvh % ndev == 0 + assert di % ndev == 0 + torch_dt_logits = meta.torch_dtype_logits + if torch_dt_mat == torch.float16: + self.dt_mat = DataType.INFINI_DTYPE_F16 + elif torch_dt_mat == torch.float32: + self.dt_mat = DataType.INFINI_DTYPE_F32 + elif torch_dt_mat == torch.bfloat16: + self.dt_mat = DataType.INFINI_DTYPE_BF16 + else: + raise ValueError("Unsupported proj weight data type") + if torch_dt_norm == torch.float16: + self.dt_norm = DataType.INFINI_DTYPE_F16 + elif torch_dt_norm == torch.float32: + self.dt_norm = DataType.INFINI_DTYPE_F32 + elif torch_dt_norm == torch.bfloat16: + self.dt_norm = DataType.INFINI_DTYPE_BF16 + else: + raise ValueError("Unsupported norm weight data type") + + input_embd_naming = ( + naming.input_embd() + if naming.input_embd() in state_dict + else naming.output_embd() + ) + output_embd_naming = ( + naming.output_embd() + if naming.output_embd() in state_dict + else naming.input_embd() + ) + self.transpose_linear_weights = 1 if transpose_weight else 0 + self.nlayer = nlayer + self.input_embd_tensor = ( + state_dict[input_embd_naming].to(torch_dt_logits) * scale_input + ) + self.input_embd = self.input_embd_tensor.data_ptr() + self.output_norm_tensor = ( + state_dict[naming.output_norm()].to(torch_dt_norm) * scale_output + ) + self.output_norm = self.output_norm_tensor.data_ptr() + self.output_embd_tensor = state_dict[output_embd_naming].to(torch_dt_mat) + if not transpose_weight: + self.output_embd_tensor = self.output_embd_tensor.transpose( + 0, 1 + ).contiguous() + self.output_embd = self.output_embd_tensor.data_ptr() + + self.attn_norm_tensors = [ + state_dict[naming.attn_norm(i)].to(torch_dt_norm) for i in range(nlayer) + ] + self.attn_norm_ptrs = [ + self.attn_norm_tensors[i].data_ptr() for i in range(nlayer) + ] + self.attn_norm = (c_void_p * nlayer)(*self.attn_norm_ptrs) + + def qkv_slices(_i): + _Q = ( + state_dict[naming.attn_q(_i)] + .reshape([nh, 2, dh // 2, d]) + .transpose(1, 2) + ) + _K = ( + state_dict[naming.attn_k(_i)] + .reshape([nkvh, 2, dh // 2, d]) + .transpose(1, 2) + ) + _V = state_dict[naming.attn_v(_i)].reshape([nkvh, dh // 2, 2, d]) + _result = [] + _nh = nh // ndev + _nkvh = nkvh // ndev + for _idev in range(ndev): + _result.append(_Q[_idev * _nh : (_idev + 1) * _nh, :, :, :]) + _result.append(_K[_idev * _nkvh : (_idev + 1) * _nkvh, :, :, :]) + _result.append(_V[_idev * _nkvh : (_idev + 1) * _nkvh, :, :]) + return _result + + self.qkv_tensor = [ + torch.concat(qkv_slices(i)).to(torch_dt_mat) for i in range(nlayer) + ] + if not transpose_weight: + for i in range(nlayer): + self.qkv_tensor[i] = ( + self.qkv_tensor[i] + .reshape(ndev, (nh + 2 * nkvh) // ndev * dh, d) + .transpose(1, 2) + .contiguous() + ) + self.qkv_tensor_ptrs = [self.qkv_tensor[i].data_ptr() for i in range(nlayer)] + self.attn_qkv = (c_void_p * nlayer)(*self.qkv_tensor_ptrs) + + def qkv_b_slices(_i): + _QB = ( + state_dict[naming.attn_q_b(_i)] + .reshape([nh, 2, dh // 2]) + .transpose(1, 2) + ) + _KB = ( + state_dict[naming.attn_k_b(_i)] + .reshape([nkvh, 2, dh // 2]) + .transpose(1, 2) + ) + _VB = state_dict[naming.attn_v_b(_i)].reshape([nkvh, dh // 2, 2]) + _result = [] + _nh = nh // ndev + _nkvh = nkvh // ndev + for _idev in range(ndev): + _result.append(_QB[_idev * _nh : (_idev + 1) * _nh, :, :].flatten()) + _result.append(_KB[_idev * _nkvh : (_idev + 1) * _nkvh, :, :].flatten()) + _result.append(_VB[_idev * _nkvh : (_idev + 1) * _nkvh, :, :].flatten()) + return _result + + if naming.attn_q_b(0) in state_dict: + self.qkv_b_tensors = [ + torch.concat(qkv_b_slices(i)).to(torch_dt_logits) for i in range(nlayer) + ] + self.qkv_b_tensor_ptrs = [ + self.qkv_b_tensors[i].data_ptr() for i in range(nlayer) + ] + self.attn_qkv_b = (c_void_p * nlayer)(*self.qkv_b_tensor_ptrs) + else: + self.attn_qkv_b = None + + self.attn_o_tensor = [ + ( + state_dict[naming.attn_o(i)] + .to(torch_dt_mat) + .reshape([d, ndev, nh // ndev * dh]) + .transpose(0, 1) + .contiguous() + if transpose_weight + else state_dict[naming.attn_o(i)] + .transpose(0, 1) + .to(torch_dt_mat) + .contiguous() + ) + * scale_o + for i in range(nlayer) + ] + self.attn_o_ptrs = [self.attn_o_tensor[i].data_ptr() for i in range(nlayer)] + self.attn_o = (c_void_p * nlayer)(*self.attn_o_ptrs) + + self.ffn_norm_tensors = [ + state_dict[naming.ffn_norm(i)].to(torch_dt_norm) for i in range(nlayer) + ] + self.ffn_norm_ptrs = [ + self.ffn_norm_tensors[i].data_ptr() for i in range(nlayer) + ] + self.ffn_norm = (c_void_p * nlayer)(*self.ffn_norm_ptrs) + + def gate_up_slices(_i): + _result = [] + _di = di // ndev + for _idev in range(ndev): + _start = _idev * _di + _end = (_idev + 1) * _di + _result.append(state_dict[naming.gate(_i)][_start:_end, :]) + _result.append(state_dict[naming.up(_i)][_start:_end, :]) + return _result + + self.gate_up_tensors = [ + torch.concat(gate_up_slices(i)).to(torch_dt_mat) for i in range(nlayer) + ] + if not transpose_weight: + for i in range(nlayer): + self.gate_up_tensors[i] = ( + self.gate_up_tensors[i] + .reshape(ndev, 2 * di // ndev, d) + .transpose(1, 2) + .contiguous() + ) + self.gate_up_ptrs = [self.gate_up_tensors[i].data_ptr() for i in range(nlayer)] + self.ffn_gate_up = (c_void_p * nlayer)(*self.gate_up_ptrs) + + self.ffn_down_tensor = [ + ( + state_dict[naming.down(i)] + .to(torch_dt_mat) + .reshape([d, ndev, di // ndev]) + .transpose(0, 1) + .contiguous() + if transpose_weight + else state_dict[naming.down(i)] + .transpose(0, 1) + .to(torch_dt_mat) + .contiguous() + ) + * scale_down + for i in range(nlayer) + ] + self.ffn_down_ptrs = [self.ffn_down_tensor[i].data_ptr() for i in range(nlayer)] + self.ffn_down = (c_void_p * nlayer)(*self.ffn_down_ptrs) + + +class JiugeBatchedTask: + def __init__(self, tasks: List[InferTask]): + self.tasks = tasks + self.nreq = len(tasks) + self.is_prefill = 1 + + # Precompute fields + token_lists = [t.tokens for t in tasks] + self.req_lens_list = [len(toks) for toks in token_lists] + self.req_pos_list = [t.pos for t in tasks] + self.kv_cache_ptrs = [t.kvcache().data() for t in tasks] + self.temperaturas_list = [t.temperature for t in tasks] + self.topks_list = [t.topk for t in tasks] + self.topps_list = [t.topp for t in tasks] + + # Flatten token lists + flat_tokens = [tok for toks in token_lists for tok in toks] + self.ntok = len(flat_tokens) + + # Convert to ctypes arrays in one pass + self.tokens = (c_uint * self.ntok)(*flat_tokens) + self.req_lens = (c_uint * self.nreq)(*self.req_lens_list) + self.req_pos = (c_uint * self.nreq)(*self.req_pos_list) + self.kv_caches = (POINTER(KVCacheCStruct) * self.nreq)(*self.kv_cache_ptrs) + self.temperaturas = (c_float * self.nreq)(*self.temperaturas_list) + self.topks = (c_uint * self.nreq)(*self.topks_list) + self.topps = (c_float * self.nreq)(*self.topps_list) + print(list(self.tokens)) + + def input_args(self): + return ( + self.tokens, + self.ntok, + self.req_lens, + self.nreq, + self.req_pos, + self.kv_caches, + self.temperaturas, + self.topks, + self.topps, + self.is_prefill, + ) + + +class JiugeForCausalLM: + def __init__( + self, model_dir_path, device=DeviceType.DEVICE_TYPE_CPU, ndev=1, max_tokens=None + ): + def load_all_safetensors_from_dir(dir_path_: str): + tensors_ = {} + dir_path_ = Path(dir_path_) + for file in sorted(dir_path_.glob("*.safetensors")): + data_ = safetensors.safe_open(file, "pt") + for name_ in data_.keys(): + tensors_[name_] = data_.get_tensor(name_) + return tensors_ + + print("Loading model weights to host...") + load_start_time = time.time() + + with open(os.path.join(model_dir_path, "config.json"), "r") as f: + config = json.load(f) + self.config = config + eos_token_id = self.config["eos_token_id"] + self.eos_token_id = ( + [eos_token_id] if type(eos_token_id) == int else eos_token_id + ) + transpose_weight = ( + device != DeviceType.DEVICE_TYPE_ASCEND + ) # y = xW is faster than y=xW^T on Ascend + if "llama" == config["model_type"]: + model = ( + transformers.LlamaForCausalLM.from_pretrained(model_dir_path, device_map="cpu", torch_dtype=torch.bfloat16, trust_remote_code=True) + ) + load_statets_time = time.time() + self.meta = JiugeMetaFromLlama(config, max_tokens=max_tokens) + self.tokenizer = transformers.AutoTokenizer.from_pretrained(model_dir_path, trust_remote_code=True) + self.weights = JiugeWeightsImpl( + self.meta, + LlamaWeightsNaming(), + model.state_dict(), + ndev=ndev, + transpose_weight=transpose_weight, + ) + elif "fm9g" == config["model_type"]: + logger.info(f"fm9g load start.") + # ) + model = ( + transformers.AutoModelForCausalLM.from_pretrained(model_dir_path, device_map="cpu", torch_dtype=torch.bfloat16, trust_remote_code=True) + ) + + logger.info(f"load over.") + load_statets_time = time.time() + self.meta = JiugeMetaFromLlama(config, max_tokens=max_tokens) + self.weights = JiugeWeightsImpl( + self.meta, + LlamaWeightsNaming(), + model.state_dict(), + ndev=ndev, + transpose_weight=transpose_weight, + ) + self.tokenizer = transformers.AutoTokenizer.from_pretrained( + model_dir_path, trust_remote_code=True) + elif "fm9g7b" == config["model_type"]: + logger.info(f"fm9g7b load start.") + model = ( + transformers.AutoModelForCausalLM.from_pretrained(model_dir_path, device_map="cpu", torch_dtype=torch.bfloat16, trust_remote_code=True) + ) + logger.info(f"load over.") + load_statets_time = time.time() + self.meta = JiugeMetaFromLlama(config, max_tokens=max_tokens) + self.tokenizer = transformers.AutoTokenizer.from_pretrained(model_dir_path, trust_remote_code=True) + self.weights = JiugeWeightsImpl( + self.meta, + LlamaWeightsNaming(), + model.state_dict(), + ndev=ndev, + transpose_weight=transpose_weight, + ) + elif "qwen2" == config["model_type"]: + state_dict = load_all_safetensors_from_dir(model_dir_path) + if LlamaWeightsNaming.match(state_dict): + self.meta = JiugeMetaFromLlama(config, max_tokens=max_tokens) + self.weights = JiugeWeightsImpl( + self.meta, + LlamaWeightsNaming(), + state_dict, + ndev=ndev, + transpose_weight=transpose_weight, + ) + self.tokenizer = transformers.AutoTokenizer.from_pretrained( + model_dir_path, trust_remote_code=True + ) + else: + raise ValueError("Unsupported model architecture") + + load_end_time = time.time() + logger.info(f"Time overall used: {load_end_time - load_start_time:.3f}s, " + f"load_states_time: {load_statets_time - load_start_time:.3f}s, " + f"load_weights_impl_time: {load_end_time - load_statets_time:.3f}s") + + logger.info(f"Creating model on {ndev} devices...") + load_start_time = time.time() + dev_ids = (c_int * ndev)(*[i for i in range(ndev)]) + self.model_instance = create_jiuge_model( + byref(self.meta), + byref(self.weights), + device, + ndev, + dev_ids, + ) + load_end_time = time.time() + print(f"Time used: {load_end_time - load_start_time:.3f}s") + + def max_context_len(self): + return self.meta.dctx + + def create_kv_cache(self): + return create_kv_cache(self.model_instance) + + def drop_kv_cache(self, kv_cache): + drop_kv_cache(self.model_instance, kv_cache) + + def batch_infer_one_round(self, tasks: List[InferTask]): + output = (c_uint * len(tasks))() + batch_inputs = JiugeBatchedTask(tasks) + infer_batch( + self.model_instance, + *(batch_inputs.input_args()), + output, + ) + return list(output) + + def generate(self, input_content, max_steps, topp_=1.0, topk_=1, temperature_=1.0): + input_content = self.tokenizer.apply_chat_template( + conversation=[{"role": "user", "content": input_content}], + add_generation_prompt=True, + tokenize=False, + ) + print(input_content, end="", flush=True) + tokens = self.tokenizer.encode(input_content) + infer_task = InferTask( + 0, + tokens, + self.max_context_len(), + temperature_, + topk_, + topp_, + self.eos_token_id, + ) + infer_task.bind_kvcache(KVCache(self)) + + steps = 0 + total_time = 0 + output_content = "" + + for step_i in range(max_steps): + start_time = time.time() + output_tokens = self.batch_infer_one_round([infer_task]) + end_time = time.time() + steps += 1 + output_str = ( + self.tokenizer._tokenizer.id_to_token(output_tokens[0]) + .replace("▁", " ") + .replace("<0x0A>", "\n") + ) + output_content += output_str + print(output_str, end="", flush=True) + if output_tokens[0] in self.eos_token_id: + break + infer_task.next(output_tokens[0]) + + if step_i > 0: + total_time += end_time - start_time + + print("\n") + avg_time = total_time * 1000 / (steps - 1) + print(f"Time per step: {avg_time:.3f}ms") + + # infer_task._kv_cache.drop(self) + # infer_task.release_kvcache().drop(self) + infer_task.release_kvcache().drop(self) + + return output_content, avg_time + + # def generate(self, input_content, max_steps, topp_=1.0, topk_=1, temperature_=1.0): + # input_content = self.tokenizer.apply_chat_template( + # conversation=[{"role": "user", "content": "山东最高的山是?"}], + # add_generation_prompt=True, + # tokenize=False, + # ) + # print(input_content, end="", flush=True) + # tokens = self.tokenizer.encode(input_content) + # infer_task = InferTask( + # 0, + # tokens, + # self.max_context_len(), + # temperature_, + # topk_, + # topp_, + # self.eos_token_id, + # ) + # infer_task.bind_kvcache(KVCache(self)) + + # input_content1 = self.tokenizer.apply_chat_template( + # conversation=[{"role": "user", "content": "中国最高的山和最长的河是?"}], + # add_generation_prompt=True, + # tokenize=False, + # ) + # tokens1 = self.tokenizer.encode(input_content1) + # infer_task1 = InferTask( + # 1, + # tokens1, + # self.max_context_len(), + # temperature_, + # topk_, + # topp_, + # self.eos_token_id, + # ) + # infer_task1.bind_kvcache(KVCache(self)) + + + # steps = 0 + # total_time = 0 + # output_content = "" + + # for step_i in range(max_steps): + # start_time = time.time() + # output_tokens = self.batch_infer_one_round([infer_task, infer_task1]) + # end_time = time.time() + # steps += 1 + # output_str = ( + # self.tokenizer._tokenizer.id_to_token(output_tokens[0]) + # .replace("▁", " ") + # .replace("<0x0A>", "\n") + # ) + # output_content += output_str + # print(output_str, end="", flush=True) + # if output_tokens[0] in self.eos_token_id: + # break + # infer_task.next(output_tokens[0]) + # infer_task1.next(output_tokens[1]) + + # if step_i > 0: + # total_time += end_time - start_time + + # print("\n") + # avg_time = total_time * 1000 / (steps - 1) + # print(f"Time per step: {avg_time:.3f}ms") + + # # infer_task._kv_cache.drop(self) + # # infer_task.release_kvcache().drop(self) + # infer_task.release_kvcache().drop(self) + # infer_task1.release_kvcache().drop(self) + + # return output_content, avg_time + + def perplexity(self, test_sequences: List[Sequence[int]], batch_size=10): + tasks = [ + InferTask(i, [], self.max_context_len(), 1.0, 1, 1.0, self.eos_token_id) + for i in range(batch_size) + ] + kv_caches = [KVCache(self) for _ in range(batch_size)] + + nll = 0.0 + total_len = 0 + + for i in range(0, len(test_sequences), batch_size): + batch_id = 0 + true_tokens = [] + while batch_id < batch_size and batch_id + i < len(test_sequences): + input_tokens = test_sequences[i + batch_id][:-1] + true_tokens.extend(test_sequences[i + batch_id][1:]) + tasks[batch_id].tokens = input_tokens + tasks[batch_id].bind_kvcache(kv_caches[batch_id]) + batch_id += 1 + + batch_inputs = JiugeBatchedTask(tasks[:batch_id]) + logits = torch.zeros( + (batch_inputs.ntok, self.meta.dvoc), dtype=self.meta.torch_dtype_logits + ) + forward_batch( + self.model_instance, + batch_inputs.tokens, + batch_inputs.ntok, + batch_inputs.req_lens, + batch_inputs.nreq, + batch_inputs.req_pos, + batch_inputs.kv_caches, + logits.data_ptr(), + ) + + logits = logits.float() + token_ids = torch.tensor(true_tokens, dtype=torch.int64) # [ntok,] + log_probs = torch.nn.functional.log_softmax(logits, dim=-1) # (ntok, vocab) + token_logprobs = log_probs[ + torch.arange(batch_inputs.ntok), token_ids + ] # (ntok,) + + start = 0 + for l in batch_inputs.req_lens_list: + nll += -token_logprobs[start : start + l].sum().item() + start += l + total_len += token_logprobs.numel() + + for task in tasks: + task.release_kvcache() + + return math.exp(nll / total_len) + + def destroy_model_instance(self): + destroy_jiuge_model(self.model_instance) + print("Model destroyed") + diff --git a/python/icinfer/sampling_params.py b/python/icinfer/sampling_params.py new file mode 100644 index 00000000..38733f2e --- /dev/null +++ b/python/icinfer/sampling_params.py @@ -0,0 +1,10 @@ +from dataclasses import dataclass + + +@dataclass +class SamplingParams: + temperature: float = 1.0 + topp: float = 1.0 + topk: int = 1 + max_tokens: int = 64 + ignore_eos: bool = False diff --git a/python/icinfer/utils/context.py b/python/icinfer/utils/context.py new file mode 100644 index 00000000..2281888f --- /dev/null +++ b/python/icinfer/utils/context.py @@ -0,0 +1,27 @@ +from dataclasses import dataclass +import torch + + +@dataclass +class Context: + is_prefill: bool = False + cu_seqlens_q: torch.Tensor | None = None + cu_seqlens_k: torch.Tensor | None = None + max_seqlen_q: int = 0 + max_seqlen_k: int = 0 + slot_mapping: torch.Tensor | None = None + context_lens: torch.Tensor | None = None + block_tables: torch.Tensor | None = None + +_CONTEXT = Context() + +def get_context(): + return _CONTEXT + +def set_context(is_prefill, cu_seqlens_q=None, cu_seqlens_k=None, max_seqlen_q=0, max_seqlen_k=0, slot_mapping=None, context_lens=None, block_tables=None): + global _CONTEXT + _CONTEXT = Context(is_prefill, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, slot_mapping, context_lens, block_tables) + +def reset_context(): + global _CONTEXT + _CONTEXT = Context() diff --git a/python/icinfer/utils/jiuge_weights_loader.py b/python/icinfer/utils/jiuge_weights_loader.py new file mode 100644 index 00000000..28307717 --- /dev/null +++ b/python/icinfer/utils/jiuge_weights_loader.py @@ -0,0 +1,467 @@ +# 文件路径: icinfer/engine/weights_loader.py + +import os +import json +import torch +import transformers +from typing import Tuple +import math +from ctypes import POINTER, c_float, c_int, c_uint, c_void_p, byref +import os +from pathlib import Path +import safetensors +import sys +import time +import json +import math +import torch +import transformers + +from icinfer.engine.libinfinicore_infer import ( + JiugeMetaCStruct, + JiugeWeightsCStruct, + DataType, + create_jiuge_model, + DeviceType +) +from icinfer.config import Config + +import logging +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' +) +logger = logging.getLogger(__name__) + +class LlamaWeightsNaming: + def input_embd(self): + return "model.embed_tokens.weight" + + def output_norm(self): + return "model.norm.weight" + + def output_embd(self): + return "lm_head.weight" + + def attn_norm(self, i): + return f"model.layers.{i}.input_layernorm.weight" + + def attn_q(self, i): + return f"model.layers.{i}.self_attn.q_proj.weight" + + def attn_k(self, i): + return f"model.layers.{i}.self_attn.k_proj.weight" + + def attn_v(self, i): + return f"model.layers.{i}.self_attn.v_proj.weight" + + def attn_o(self, i): + return f"model.layers.{i}.self_attn.o_proj.weight" + + def attn_q_b(self, i): + return f"model.layers.{i}.self_attn.q_proj.bias" + + def attn_k_b(self, i): + return f"model.layers.{i}.self_attn.k_proj.bias" + + def attn_v_b(self, i): + return f"model.layers.{i}.self_attn.v_proj.bias" + + def ffn_norm(self, i): + return f"model.layers.{i}.post_attention_layernorm.weight" + + def gate(self, i): + return f"model.layers.{i}.mlp.gate_proj.weight" + + def up(self, i): + return f"model.layers.{i}.mlp.up_proj.weight" + + def down(self, i): + return f"model.layers.{i}.mlp.down_proj.weight" + + def match(state_dict): + return ( + "model.norm.weight" in state_dict + and "model.layers.0.self_attn.q_proj.weight" in state_dict + ) + + +class JiugeMetaFromLlama(JiugeMetaCStruct): + def __init__(self, config, dtype=torch.float16, max_tokens=None): + if dtype == torch.float16: + dt_ = DataType.INFINI_DTYPE_F16 + elif dtype == torch.float32: + dt_ = DataType.INFINI_DTYPE_F32 + elif dtype == torch.bfloat16: + dt_ = DataType.INFINI_DTYPE_BF16 + else: + dt_ = DataType.INFINI_DTYPE_F16 + + self.scale_input = 1.0 + self.scale_output = 1.0 + self.scale_o = 1.0 + self.scale_down = 1.0 + if ( + config.model_type in ["fm9g", "minicpm"] + and hasattr(config, "scale_emb") + and hasattr(config, "scale_depth") + and hasattr(config, "dim_model_base") + ): + self.scale_input = config.scale_emb + self.scale_output = config.hidden_size // config.dim_model_base + self.scale_o = config.scale_depth / math.sqrt( + config.num_hidden_layers + ) + self.scale_down = config.scale_depth / math.sqrt( + config.num_hidden_layers + ) + + super().__init__( + dt_logits=dt_, + nlayer=config.num_hidden_layers, + d=config.hidden_size, + nh=config.num_attention_heads, + nkvh=( + config.num_key_value_heads + if hasattr(config, "num_key_value_heads") + else config.num_attention_heads + ), + dh=config.hidden_size // config.num_attention_heads, + di=config.intermediate_size, + dctx=( + config.max_position_embeddings if max_tokens is None else max_tokens + ), + dvoc=config.vocab_size, + kvcache_block_size=config.kvcache_block_size, + epsilon=config.rms_norm_eps, + theta=(config.rope_theta if hasattr(config, "rope_theta") else 100000.0), + end_token=2, + ) + self.torch_dtype_logits = dtype + + +class JiugeWeightsImpl(JiugeWeightsCStruct): + def __init__( + self, + meta, + naming, + state_dict, + torch_dt_mat=torch.float16, + torch_dt_norm=torch.float32, + ndev=1, + transpose_weight=True, + ): + nlayer = meta.nlayer + nh = meta.nh + nkvh = meta.nkvh + dh = meta.dh + d = meta.d + di = meta.di + scale_input = meta.scale_input + scale_output = meta.scale_output + scale_o = meta.scale_o + scale_down = meta.scale_down + assert nh % nkvh == 0 + assert nh % ndev == 0 + assert nkvh % ndev == 0 + assert di % ndev == 0 + torch_dt_logits = meta.torch_dtype_logits + if torch_dt_mat == torch.float16: + self.dt_mat = DataType.INFINI_DTYPE_F16 + elif torch_dt_mat == torch.float32: + self.dt_mat = DataType.INFINI_DTYPE_F32 + elif torch_dt_mat == torch.bfloat16: + self.dt_mat = DataType.INFINI_DTYPE_BF16 + else: + raise ValueError("Unsupported proj weight data type") + if torch_dt_norm == torch.float16: + self.dt_norm = DataType.INFINI_DTYPE_F16 + elif torch_dt_norm == torch.float32: + self.dt_norm = DataType.INFINI_DTYPE_F32 + elif torch_dt_norm == torch.bfloat16: + self.dt_norm = DataType.INFINI_DTYPE_BF16 + else: + raise ValueError("Unsupported norm weight data type") + + input_embd_naming = ( + naming.input_embd() + if naming.input_embd() in state_dict + else naming.output_embd() + ) + output_embd_naming = ( + naming.output_embd() + if naming.output_embd() in state_dict + else naming.input_embd() + ) + self.transpose_linear_weights = 1 if transpose_weight else 0 + self.nlayer = nlayer + self.input_embd_tensor = ( + state_dict[input_embd_naming].to(torch_dt_logits) * scale_input + ) + self.input_embd = self.input_embd_tensor.data_ptr() + self.output_norm_tensor = ( + state_dict[naming.output_norm()].to(torch_dt_norm) * scale_output + ) + self.output_norm = self.output_norm_tensor.data_ptr() + self.output_embd_tensor = state_dict[output_embd_naming].to(torch_dt_mat) + if not transpose_weight: + self.output_embd_tensor = self.output_embd_tensor.transpose( + 0, 1 + ).contiguous() + self.output_embd = self.output_embd_tensor.data_ptr() + + self.attn_norm_tensors = [ + state_dict[naming.attn_norm(i)].to(torch_dt_norm) for i in range(nlayer) + ] + self.attn_norm_ptrs = [ + self.attn_norm_tensors[i].data_ptr() for i in range(nlayer) + ] + self.attn_norm = (c_void_p * nlayer)(*self.attn_norm_ptrs) + + def qkv_slices(_i): + _Q = ( + state_dict[naming.attn_q(_i)] + .reshape([nh, 2, dh // 2, d]) + .transpose(1, 2) + ) + _K = ( + state_dict[naming.attn_k(_i)] + .reshape([nkvh, 2, dh // 2, d]) + .transpose(1, 2) + ) + _V = state_dict[naming.attn_v(_i)].reshape([nkvh, dh // 2, 2, d]) + _result = [] + _nh = nh // ndev + _nkvh = nkvh // ndev + for _idev in range(ndev): + _result.append(_Q[_idev * _nh : (_idev + 1) * _nh, :, :, :]) + _result.append(_K[_idev * _nkvh : (_idev + 1) * _nkvh, :, :, :]) + _result.append(_V[_idev * _nkvh : (_idev + 1) * _nkvh, :, :]) + return _result + + self.qkv_tensor = [ + torch.concat(qkv_slices(i)).to(torch_dt_mat) for i in range(nlayer) + ] + if not transpose_weight: + for i in range(nlayer): + self.qkv_tensor[i] = ( + self.qkv_tensor[i] + .reshape(ndev, (nh + 2 * nkvh) // ndev * dh, d) + .transpose(1, 2) + .contiguous() + ) + self.qkv_tensor_ptrs = [self.qkv_tensor[i].data_ptr() for i in range(nlayer)] + self.attn_qkv = (c_void_p * nlayer)(*self.qkv_tensor_ptrs) + + def qkv_b_slices(_i): + _QB = ( + state_dict[naming.attn_q_b(_i)] + .reshape([nh, 2, dh // 2]) + .transpose(1, 2) + ) + _KB = ( + state_dict[naming.attn_k_b(_i)] + .reshape([nkvh, 2, dh // 2]) + .transpose(1, 2) + ) + _VB = state_dict[naming.attn_v_b(_i)].reshape([nkvh, dh // 2, 2]) + _result = [] + _nh = nh // ndev + _nkvh = nkvh // ndev + for _idev in range(ndev): + _result.append(_QB[_idev * _nh : (_idev + 1) * _nh, :, :].flatten()) + _result.append(_KB[_idev * _nkvh : (_idev + 1) * _nkvh, :, :].flatten()) + _result.append(_VB[_idev * _nkvh : (_idev + 1) * _nkvh, :, :].flatten()) + return _result + + if naming.attn_q_b(0) in state_dict: + self.qkv_b_tensors = [ + torch.concat(qkv_b_slices(i)).to(torch_dt_logits) for i in range(nlayer) + ] + self.qkv_b_tensor_ptrs = [ + self.qkv_b_tensors[i].data_ptr() for i in range(nlayer) + ] + self.attn_qkv_b = (c_void_p * nlayer)(*self.qkv_b_tensor_ptrs) + else: + self.attn_qkv_b = None + + self.attn_o_tensor = [ + ( + state_dict[naming.attn_o(i)] + .to(torch_dt_mat) + .reshape([d, ndev, nh // ndev * dh]) + .transpose(0, 1) + .contiguous() + if transpose_weight + else state_dict[naming.attn_o(i)] + .transpose(0, 1) + .to(torch_dt_mat) + .contiguous() + ) + * scale_o + for i in range(nlayer) + ] + self.attn_o_ptrs = [self.attn_o_tensor[i].data_ptr() for i in range(nlayer)] + self.attn_o = (c_void_p * nlayer)(*self.attn_o_ptrs) + + self.ffn_norm_tensors = [ + state_dict[naming.ffn_norm(i)].to(torch_dt_norm) for i in range(nlayer) + ] + self.ffn_norm_ptrs = [ + self.ffn_norm_tensors[i].data_ptr() for i in range(nlayer) + ] + self.ffn_norm = (c_void_p * nlayer)(*self.ffn_norm_ptrs) + + def gate_up_slices(_i): + _result = [] + _di = di // ndev + for _idev in range(ndev): + _start = _idev * _di + _end = (_idev + 1) * _di + _result.append(state_dict[naming.gate(_i)][_start:_end, :]) + _result.append(state_dict[naming.up(_i)][_start:_end, :]) + return _result + + self.gate_up_tensors = [ + torch.concat(gate_up_slices(i)).to(torch_dt_mat) for i in range(nlayer) + ] + if not transpose_weight: + for i in range(nlayer): + self.gate_up_tensors[i] = ( + self.gate_up_tensors[i] + .reshape(ndev, 2 * di // ndev, d) + .transpose(1, 2) + .contiguous() + ) + self.gate_up_ptrs = [self.gate_up_tensors[i].data_ptr() for i in range(nlayer)] + self.ffn_gate_up = (c_void_p * nlayer)(*self.gate_up_ptrs) + + self.ffn_down_tensor = [ + ( + state_dict[naming.down(i)] + .to(torch_dt_mat) + .reshape([d, ndev, di // ndev]) + .transpose(0, 1) + .contiguous() + if transpose_weight + else state_dict[naming.down(i)] + .transpose(0, 1) + .to(torch_dt_mat) + .contiguous() + ) + * scale_down + for i in range(nlayer) + ] + self.ffn_down_ptrs = [self.ffn_down_tensor[i].data_ptr() for i in range(nlayer)] + self.ffn_down = (c_void_p * nlayer)(*self.ffn_down_ptrs) + + + +def load_weights_to_cpu(config: Config, device: DeviceType) -> Tuple[JiugeMetaCStruct, JiugeWeightsImpl]: + """ + 复用旧 infiniinfer 的权重加载逻辑。 + 在 CPU 上加载模型权重和配置,并将其转换为 C++ 兼容的结构体。 + """ + + def load_all_safetensors_from_dir(dir_path_: str): + tensors_ = {} + dir_path_ = Path(dir_path_) + for file in sorted(dir_path_.glob("*.safetensors")): + data_ = safetensors.safe_open(file, "pt") + for name_ in data_.keys(): + tensors_[name_] = data_.get_tensor(name_) + return tensors_ + + max_tokens = config.max_model_len + model_dir_path = config.model_path + ndev = config.tensor_parallel_size + hf_config = config.hf_config + + print("Loading model weights to host...") + load_start_time = time.time() + + transpose_weight = ( + device != DeviceType.DEVICE_TYPE_ASCEND + ) # y = xW is faster than y=xW^T on Ascend + if "llama" == hf_config.model_type: + model = ( + transformers.LlamaForCausalLM.from_pretrained(model_dir_path, device_map="cpu", torch_dtype=torch.bfloat16, trust_remote_code=True) + ) + load_statets_time = time.time() + meta = JiugeMetaFromLlama(hf_config, max_tokens=max_tokens) + weights = JiugeWeightsImpl( + meta, + LlamaWeightsNaming(), + model.state_dict(), + ndev=ndev, + transpose_weight=transpose_weight, + ) + elif "fm9g" == hf_config.model_type: + logger.info(f"fm9g load start.") + # ) + model = ( + transformers.AutoModelForCausalLM.from_pretrained(model_dir_path, device_map="cpu", torch_dtype=torch.bfloat16, trust_remote_code=True) + ) + + logger.info(f"load over.") + load_statets_time = time.time() + meta = JiugeMetaFromLlama(hf_config, max_tokens=max_tokens) + weights = JiugeWeightsImpl( + meta, + LlamaWeightsNaming(), + model.state_dict(), + ndev=ndev, + transpose_weight=transpose_weight, + ) + elif "fm9g7b" == hf_config.model_type: + logger.info(f"fm9g7b load start.") + model = ( + transformers.AutoModelForCausalLM.from_pretrained(model_dir_path, device_map="cpu", torch_dtype=torch.bfloat16, trust_remote_code=True) + ) + logger.info(f"load over.") + load_statets_time = time.time() + meta = JiugeMetaFromLlama(hf_config, max_tokens=max_tokens) + weights = JiugeWeightsImpl( + meta, + LlamaWeightsNaming(), + model.state_dict(), + ndev=ndev, + transpose_weight=transpose_weight, + ) + elif "qwen2" == hf_config.model_type: + state_dict = load_all_safetensors_from_dir(model_dir_path) + if LlamaWeightsNaming.match(state_dict): + meta = JiugeMetaFromLlama(hf_config, max_tokens=max_tokens) + weights = JiugeWeightsImpl( + meta, + LlamaWeightsNaming(), + state_dict, + ndev=ndev, + transpose_weight=transpose_weight, + ) + else: + raise ValueError("Unsupported model architecture") + + load_end_time = time.time() + logger.info(f"Time overall used: {load_end_time - load_start_time:.3f}s, " + f"load_states_time: {load_statets_time - load_start_time:.3f}s, " + f"load_weights_impl_time: {load_end_time - load_statets_time:.3f}s") + + logger.info(f"Creating model on {ndev} devices...") + load_start_time = time.time() + + print("Weights loaded to CPU successfully.") + return meta, weights + +def load_model(config: Config, device: DeviceType): + ndev = config.tensor_parallel_size + meta, weights = load_weights_to_cpu(config, device) + dev_ids = (c_int * ndev)(*[i for i in range(ndev)]) + model = create_jiuge_model( + byref(meta), + byref(weights), + device, + ndev, + dev_ids, + ) + return model, meta \ No newline at end of file diff --git a/python/icinfer/utils/loader.py b/python/icinfer/utils/loader.py new file mode 100644 index 00000000..1499e099 --- /dev/null +++ b/python/icinfer/utils/loader.py @@ -0,0 +1,92 @@ +import os +from glob import glob +import torch +from torch import nn +from safetensors import safe_open + + +def default_weight_loader(param: nn.Parameter, loaded_weight: torch.Tensor): + param.data.copy_(loaded_weight) + + +# def load_model(model: nn.Module, path: str): +# packed_modules_mapping = getattr(model, "packed_modules_mapping", {}) +# for file in glob(os.path.join(path, "*.safetensors")): +# with safe_open(file, "pt", "cpu") as f: +# for weight_name in f.keys(): +# for k in packed_modules_mapping: +# if k in weight_name: +# v, shard_id = packed_modules_mapping[k] +# param_name = weight_name.replace(k, v) +# param = model.get_parameter(param_name) +# weight_loader = getattr(param, "weight_loader") +# weight_loader(param, f.get_tensor(weight_name), shard_id) +# break +# else: +# param = model.get_parameter(weight_name) +# weight_loader = getattr(param, "weight_loader", default_weight_loader) +# weight_loader(param, f.get_tensor(weight_name)) + + +def load_model(model: nn.Module, path: str): + """ + 智能加载模型权重。 + 优先查找并使用 .safetensors 文件。如果找不到,则回退到查找并使用 .bin 文件。 + """ + packed_modules_mapping = getattr(model, "packed_modules_mapping", {}) + + # 优先尝试加载 .safetensors 文件 + model_files = glob(os.path.join(path, "*.safetensors")) + is_safetensors = True + + # 如果没有找到 .safetensors,则回退到加载 .bin 文件 + if not model_files: + model_files = glob(os.path.join(path, "*.bin")) + is_safetensors = False + + # 如果两种文件都找不到,则报错 + if not model_files: + raise FileNotFoundError(f"No model weights found in {path}. Looked for .safetensors and .bin files.") + + # 核心加载逻辑 + for file_path in model_files: + if is_safetensors: + with safe_open(file_path, "pt", "cpu") as f: + for weight_name in f.keys(): + tensor = f.get_tensor(weight_name) + _load_and_dispatch(model, weight_name, tensor, packed_modules_mapping) + else: # .bin format + state_dict = torch.load(file_path, map_location="cpu") + for weight_name, tensor in state_dict.items(): + _load_and_dispatch(model, weight_name, tensor, packed_modules_mapping) + +def _load_and_dispatch(model, weight_name, tensor, packed_modules_mapping): + """ + 一个辅助函数,用于分派权重到正确的加载器。 + """ + is_packed = False + for packed_key, (target_name, shard_id) in packed_modules_mapping.items(): + if packed_key in weight_name: + # 替换名称以匹配模型中的合并层参数 + param_name = weight_name.replace(packed_key, target_name) + try: + param = model.get_parameter(param_name) + # 调用参数上附加的专用加载器 (例如 QKVParallelLinear.weight_loader) + getattr(param, "weight_loader")(param, tensor, shard_id) + except AttributeError: + print(f"Warning: Could not find parameter '{param_name}' for packed weight '{weight_name}'. Skipping.") + is_packed = True + break + + if not is_packed: + try: + param = model.get_parameter(weight_name) + # 调用参数上附加的加载器,或使用默认加载器 + loader = getattr(param, "weight_loader", default_weight_loader) + loader(param, tensor) + except AttributeError: + # 某些权重(如_pre_transformer_block.0.norm_1.weight)可能不存在于我们的模型中,可以安全地忽略 + pass + + + diff --git a/python/pyproject.toml b/python/pyproject.toml new file mode 100644 index 00000000..5dad46e6 --- /dev/null +++ b/python/pyproject.toml @@ -0,0 +1,26 @@ +[build-system] +requires = ["setuptools>=61"] +build-backend = "setuptools.build_meta" + +[project] +name = "icinfer" +version = "0.1.0" +authors = [{ name = "" }] +license = "MIT" +license-files = ["LICENSE"] +readme = "README.md" +description = "a lightweight, hardware-agnostic, unified inference engine implementation built from scratch, based on InfiniCore" +requires-python = ">=3.10,<3.13" +dependencies = [ + "torch>=2.4.0", + "triton>=3.0.0", + "transformers>=4.51.0", + "xxhash", +] + +[project.urls] +Homepage="https://github.com/InfiniTensor/InfiniLM" + +[tool.setuptools.packages.find] +where = ["."] +include = ["icinfer*"] diff --git a/python/tests/test_attention.py b/python/tests/test_attention.py new file mode 100644 index 00000000..16590ab3 --- /dev/null +++ b/python/tests/test_attention.py @@ -0,0 +1,522 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import random +from typing import Optional + +import pytest +import torch + +from tests.kernels.allclose_default import get_default_atol, get_default_rtol +from tests.kernels.utils import opcheck +from vllm import _custom_ops as ops +from vllm.attention.layer import Attention, MultiHeadAttention +from vllm.platforms import current_platform +from vllm.utils import get_max_shared_memory_bytes + +if not current_platform.is_rocm(): + from xformers import ops as xops + from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask + + from vllm.attention.backends.xformers import _make_alibi_bias + +FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 +# This will change depending on the compute capability. +# - 512 as a buffer +MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512 +# There may not be enough gpu memory due to large NUM_BLOCKS. +# Reduce NUM_BLOCKS when it happens. +NUM_BLOCKS = 4321 # Arbitrary values for testing +PARTITION_SIZE = 512 +PARTITION_SIZE_ROCM = 256 +DTYPES = [torch.bfloat16] +NUM_GEN_SEQS = [7] # Arbitrary values for testing +NUM_PREFILL_SEQS = [3] # Arbitrary values for testing +NUM_HEADS = [(40, 40), (64, 8)] # Arbitrary values for testing + +# This should be sync with get_supported_head_sizes() in +# vllm.attention.ops.paged_attn.PagedAttention +HEAD_SIZES = [32, 80, 128, 256] + +BLOCK_SIZES = [16, 32] +USE_ALIBI = [False, True] +KV_CACHE_DTYPE = ["auto", "fp8"] +SEEDS = [0] +CUDA_DEVICES = [ + f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) +] + + +def ref_masked_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + scale: float, + attn_mask: Optional[torch.Tensor] = None, +) -> torch.Tensor: + attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float() + if attn_mask is not None: + attn_weights = attn_weights + attn_mask.float() + attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype) + out = torch.einsum("hqk,khd->qhd", attn_weights, value) + return out + + +def ref_single_query_cached_kv_attention( + output: torch.Tensor, + query: torch.Tensor, + num_queries_per_kv: int, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + scale: float, + alibi_slopes: Optional[torch.Tensor], +) -> None: + num_query_heads = query.shape[1] + num_kv_heads = value_cache.shape[1] + head_size = value_cache.shape[2] + block_size = value_cache.shape[3] + num_seqs = query.shape[0] + + block_tables_lst = block_tables.cpu().tolist() + seq_lens_lst = seq_lens.cpu().tolist() + for i in range(num_seqs): + q = query[i].unsqueeze(0) + block_table = block_tables_lst[i] + seq_len = int(seq_lens_lst[i]) + + keys_lst: list[torch.Tensor] = [] + values_lst: list[torch.Tensor] = [] + for j in range(seq_len): + block_number = int(block_table[j // block_size]) + block_offset = j % block_size + + k = key_cache[block_number, :, :, block_offset, :] + k = k.reshape(num_kv_heads, head_size) + keys_lst.append(k) + + v = value_cache[block_number, :, :, block_offset] + values_lst.append(v) + keys = torch.stack(keys_lst, dim=0) + values = torch.stack(values_lst, dim=0) + if num_queries_per_kv > 1: + # Handle MQA and GQA + keys = torch.repeat_interleave(keys, num_queries_per_kv, dim=1) + values = torch.repeat_interleave(values, num_queries_per_kv, dim=1) + + alibi_bias = None + if alibi_slopes is not None: + # Create the ALiBi bias used in the paged attention kernel. + position_ids = torch.arange(seq_len).int() + alibi_bias = (position_ids - seq_len + 1).float() + alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view( + 1, 1, -1) + + out = ref_masked_attention(q, keys, values, scale, alibi_bias) + out = out.view(num_query_heads, head_size) + output[i].copy_(out, non_blocking=True) + + +@pytest.mark.parametrize( + "version", + ["v1", "v2"] if not current_platform.is_rocm() else ["v1", "v2", "rocm"]) +@pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("use_alibi", USE_ALIBI) +@pytest.mark.parametrize("block_size", BLOCK_SIZES) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("device", CUDA_DEVICES) +def test_paged_attention( + kv_cache_factory, + version: str, + num_seqs: int, + num_heads: tuple[int, int], + head_size: int, + use_alibi: bool, + block_size: int, + dtype: torch.dtype, + kv_cache_dtype: str, + seed: int, + device: str, +) -> None: + if ((kv_cache_dtype == "fp8" and head_size % 16) + or (version == "rocm" and head_size not in (64, 128))): + pytest.skip() + + if (version == "rocm" and current_platform.is_navi() + and (kv_cache_dtype == "fp8" or head_size != 128 + or block_size != 16 or use_alibi)): + pytest.skip() + + global PARTITION_SIZE + + current_platform.seed_everything(seed) + torch.set_default_device(device) + scale = float(1.0 / (head_size**0.5)) + num_query_heads, num_kv_heads = num_heads + query = torch.empty(num_seqs, num_query_heads, head_size, dtype=dtype) + query.uniform_(-scale, scale) + + assert num_query_heads % num_kv_heads == 0 + num_queries_per_kv = num_query_heads // num_kv_heads + alibi_slopes = None + if use_alibi: + alibi_slopes = torch.randn(num_query_heads, dtype=torch.float) + + seq_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)] + seq_lens[-1] = MAX_SEQ_LEN + max_seq_len = max(seq_lens) + seq_lens = torch.tensor(seq_lens, dtype=torch.int) + + # Create the block tables. + max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size + block_tables_lst: list[list[int]] = [] + for _ in range(num_seqs): + block_table = [ + random.randint(0, NUM_BLOCKS - 1) + for _ in range(max_num_blocks_per_seq) + ] + block_tables_lst.append(block_table) + + block_tables = torch.tensor(block_tables_lst, dtype=torch.int) + + # Create the KV caches. + key_caches, value_caches = kv_cache_factory(NUM_BLOCKS, block_size, 1, + num_kv_heads, head_size, + kv_cache_dtype, dtype, seed, + device) + key_cache, value_cache = key_caches[0], value_caches[0] + + # Using default kv_scale + k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device) + + # Call the paged attention kernel. + output = torch.empty_like(query) + if version == "v1": + ops.paged_attention( + output, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + ) + + opcheck(torch.ops._C.paged_attention, + (output, query, key_cache, value_cache, num_kv_heads, scale, + block_tables, seq_lens, block_size, max_seq_len, alibi_slopes, + kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0), + cond=(head_size == HEAD_SIZES[0] + and block_size == BLOCK_SIZES[0])) + + elif version in ("v2", "rocm"): + if current_platform.is_rocm() and version == "rocm": + PARTITION_SIZE = PARTITION_SIZE_ROCM + + num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE) + assert PARTITION_SIZE % block_size == 0 + num_seqs, num_heads, head_size = output.shape + tmp_output = torch.empty( + size=(num_seqs, num_heads, num_partitions, head_size), + dtype=output.dtype, + ) + exp_sums = torch.empty( + size=(num_seqs, num_heads, num_partitions), + dtype=torch.float32, + ) + max_logits = torch.empty_like(exp_sums) + if version == "v2": + ops.paged_attention_v2( + output, + exp_sums, + max_logits, + tmp_output, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + ) + + opcheck(torch.ops._C.paged_attention_v2, + (output, exp_sums, max_logits, tmp_output, query, + key_cache, value_cache, num_kv_heads, scale, block_tables, + seq_lens, block_size, max_seq_len, alibi_slopes, + kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0), + cond=(head_size == HEAD_SIZES[0] + and block_size == BLOCK_SIZES[0])) + + else: + ops.paged_attention_rocm( + output, + exp_sums, + max_logits, + tmp_output, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + None, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + ) + + opcheck(torch.ops._rocm_C.paged_attention, + (output, exp_sums, max_logits, tmp_output, query, + key_cache, value_cache, num_kv_heads, scale, block_tables, + seq_lens, None, block_size, max_seq_len, alibi_slopes, + kv_cache_dtype, k_scale, v_scale), + cond=(head_size == HEAD_SIZES[0] + and block_size == BLOCK_SIZES[0])) + + else: + raise AssertionError(f"Unknown version: {version}") + + # Run the reference implementation. + if kv_cache_dtype == "fp8": + # Convert cache data back to dtype. + x = 16 // torch.tensor([], dtype=dtype).element_size() + key_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size // x, + block_size, x) + dequantized_key_cache = torch.empty(size=key_cache_shape, + dtype=dtype, + device=device) + ops.convert_fp8(dequantized_key_cache, key_cache) + key_cache = dequantized_key_cache + + value_cache_shape = value_cache.shape + dequantized_value_cache = torch.empty(size=value_cache_shape, + dtype=dtype, + device=device) + ops.convert_fp8(dequantized_value_cache, value_cache) + value_cache = dequantized_value_cache + + ref_output = torch.empty_like(query) + ref_single_query_cached_kv_attention( + ref_output, + query, + num_queries_per_kv, + key_cache, + value_cache, + block_tables, + seq_lens, + scale, + alibi_slopes, + ) + + # NOTE(woosuk): Due to the kernel-level differences in the two + # implementations, there is a small numerical difference in the two + # outputs. Thus, we use a relaxed tolerance for the test. + atol = get_default_atol(output) if current_platform.is_rocm() else 1e-3 + rtol = get_default_rtol(output) if current_platform.is_rocm() else 1e-5 + + # NOTE(zhaoyang): FP8 KV Cache will introduce quantization error, + # so we use a relaxed tolerance for the test. + atol, rtol = 1e-3, 1e-5 + if kv_cache_dtype == "fp8": + atol, rtol = 1e-2, 1e-5 + torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol) + + +def ref_multi_query_kv_attention( + cu_seq_lens: list[int], + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + scale: float, + alibi_bias: Optional[list[torch.Tensor]], + dtype: torch.dtype, +) -> torch.Tensor: + num_seqs = len(cu_seq_lens) - 1 + ref_outputs: list[torch.Tensor] = [] + if alibi_bias: + assert len(alibi_bias) == num_seqs + for i in range(num_seqs): + start_idx = cu_seq_lens[i] + end_idx = cu_seq_lens[i + 1] + seq_len = end_idx - start_idx + + # Create attention mask. ALiBi already includes a tril causal mask. + if alibi_bias: + attn_mask = alibi_bias[i] + else: + attn_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=dtype), + diagonal=1) + attn_mask = attn_mask * torch.finfo(dtype).min + attn_mask = attn_mask.to(dtype=dtype) + + ref_output = ref_masked_attention( + query[start_idx:end_idx], + key[start_idx:end_idx], + value[start_idx:end_idx], + scale, + attn_mask=attn_mask, + ) + ref_outputs.append(ref_output) + + return torch.cat(ref_outputs, dim=0) + + +@pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.skipif(current_platform.is_rocm(), + reason="Xformers backend is not supported on ROCm.") +@torch.inference_mode() +def test_multi_query_kv_attention( + num_seqs: int, + num_heads: tuple[int, int], + head_size: int, + dtype: torch.dtype, + seed: int, + device: str, + use_alibi: bool = False, +) -> None: + current_platform.seed_everything(seed) + torch.set_default_device(device) + # MAX_SEQ_LEN sometimes causes OOM in the reference implementation. + # As the xformers library is already tested with its own tests, we can use + # a smaller MAX_SEQ_LEN here. + max_len = min(MAX_SEQ_LEN, 4096) + seq_lens = random.sample(range(1, max_len), num_seqs) + num_tokens = sum(seq_lens) + + scale = float(1.0 / (head_size**0.5)) + num_query_heads, num_kv_heads = num_heads + qkv = torch.empty(num_tokens, + num_query_heads + 2 * num_kv_heads, + head_size, + dtype=dtype) + qkv.uniform_(-scale, scale) + query, key, value = qkv.split( + [num_query_heads, num_kv_heads, num_kv_heads], dim=1) + + num_queries_per_kv = num_query_heads // num_kv_heads + if num_queries_per_kv > 1: + # Handle MQA and GQA + key = torch.repeat_interleave(key, num_queries_per_kv, dim=1) + value = torch.repeat_interleave(value, num_queries_per_kv, dim=1) + alibi_bias = None + if use_alibi: + alibi_slopes = torch.randn(num_query_heads, dtype=torch.float) + attn_bias = _make_alibi_bias(alibi_slopes, num_kv_heads, dtype, + seq_lens) + output = torch.empty_like(query) + start = 0 + # Dynamic sequence length not supported with custom attn_bias. + for i, seq_len in enumerate(seq_lens): + end = start + seq_len + out = xops.memory_efficient_attention_forward( + query[None, start:end], + key[None, start:end], + value[None, start:end], + attn_bias=attn_bias[i], + p=0.0, + scale=scale) + output[start:end].copy_(out.view_as(query[start:end])) + start += seq_len + # xformers.AttentionBias to Tensor for use in reference impl. + alibi_bias = [ + b.materialize((1, num_query_heads, i, i), device=device).squeeze() + for b, i in zip(attn_bias, seq_lens) + ] + else: + attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_lens) + output = xops.memory_efficient_attention_forward( + query.unsqueeze(0), + key.unsqueeze(0), + value.unsqueeze(0), + attn_bias=attn_bias, + p=0.0, + scale=scale, + ) + output = output.squeeze(0) + + cu_seq_lens = [0] + for seq_len in seq_lens: + cu_seq_lens.append(cu_seq_lens[-1] + seq_len) + ref_output = ref_multi_query_kv_attention( + cu_seq_lens, + query, + key, + value, + scale, + alibi_bias, + dtype, + ) + atol = get_default_atol(output) if current_platform.is_rocm() else 1e-3 + rtol = get_default_rtol(output) if current_platform.is_rocm() else 1e-5 + torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol) + + +@pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", [64]) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.skipif(current_platform.is_rocm(), + reason="Xformers backend is not supported on ROCm.") +@torch.inference_mode() +def test_multi_query_kv_attention_with_alibi( + num_seqs: int, + num_heads: tuple[int, int], + head_size: int, + dtype: torch.dtype, + seed: int, + device: str, +) -> None: + return test_multi_query_kv_attention( + num_seqs, + num_heads, + head_size, + dtype, + seed, + device, + use_alibi=True, + ) + + +@pytest.mark.parametrize("attention_cls", [Attention, MultiHeadAttention]) +def test_num_heads_not_divisble_by_num_kv_heads(attention_cls: type) -> None: + head_size = 64 + scale = float(1.0 / (head_size**0.5)) + num_heads = 16 + num_kv_heads = 5 + with pytest.raises(AssertionError): + _ = attention_cls( + num_heads=num_heads, + head_size=head_size, + scale=scale, + num_kv_heads=num_kv_heads, + ) diff --git a/scripts/jiuge.py b/scripts/jiuge.py index a2e591f8..24a73355 100644 --- a/scripts/jiuge.py +++ b/scripts/jiuge.py @@ -27,6 +27,14 @@ import torch import transformers +import logging +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' +) +logger = logging.getLogger(__name__) + + torch.set_default_device("cpu") @@ -393,7 +401,7 @@ def input_args(self): ) -class JiugeForCauslLM: +class JiugeForCausalLM: def __init__( self, model_dir_path, device=DeviceType.DEVICE_TYPE_CPU, ndev=1, max_tokens=None ): @@ -421,12 +429,11 @@ def load_all_safetensors_from_dir(dir_path_: str): ) # y = xW is faster than y=xW^T on Ascend if "llama" == config["model_type"]: model = ( - transformers.LlamaForCausalLM.from_pretrained(model_dir_path) - .cpu() - .half() + transformers.LlamaForCausalLM.from_pretrained(model_dir_path, device_map="cpu", torch_dtype=torch.bfloat16, trust_remote_code=True) ) + load_statets_time = time.time() self.meta = JiugeMetaFromLlama(config, max_tokens=max_tokens) - self.tokenizer = transformers.AutoTokenizer.from_pretrained(model_dir_path) + self.tokenizer = transformers.AutoTokenizer.from_pretrained(model_dir_path, trust_remote_code=True) self.weights = JiugeWeightsImpl( self.meta, LlamaWeightsNaming(), @@ -434,56 +441,41 @@ def load_all_safetensors_from_dir(dir_path_: str): ndev=ndev, transpose_weight=transpose_weight, ) - elif "fm9g" == config["model_type"] or "minicpm" == config["model_type"]: - if any( - file.suffix == ".safetensors" for file in Path(model_dir_path).iterdir() - ): - state_dict = load_all_safetensors_from_dir(model_dir_path) - else: - state_dict = torch.load( - os.path.join(model_dir_path, "pytorch_model.bin"), - weights_only=True, - map_location="cpu", - ) - if LlamaWeightsNaming.match(state_dict): - self.meta = JiugeMetaFromLlama(config, max_tokens=max_tokens) - self.weights = JiugeWeightsImpl( - self.meta, - LlamaWeightsNaming(), - state_dict, - ndev=ndev, - transpose_weight=transpose_weight, - ) - self.tokenizer = transformers.AutoTokenizer.from_pretrained( - model_dir_path, trust_remote_code=True - ) - else: - raise ValueError("Unsupported weight naming") + elif "fm9g" == config["model_type"]: + logger.info(f"fm9g load start.") + # ) + model = ( + transformers.AutoModelForCausalLM.from_pretrained(model_dir_path, device_map="cpu", torch_dtype=torch.bfloat16, trust_remote_code=True) + ) + + logger.info(f"load over.") + load_statets_time = time.time() + self.meta = JiugeMetaFromLlama(config, max_tokens=max_tokens) + self.weights = JiugeWeightsImpl( + self.meta, + LlamaWeightsNaming(), + model.state_dict(), + ndev=ndev, + transpose_weight=transpose_weight, + ) + self.tokenizer = transformers.AutoTokenizer.from_pretrained( + model_dir_path, trust_remote_code=True) elif "fm9g7b" == config["model_type"]: - if any( - file.suffix == ".safetensors" for file in Path(model_dir_path).iterdir() - ): - state_dict = load_all_safetensors_from_dir(model_dir_path) - else: - state_dict = torch.load( - os.path.join(model_dir_path, "pytorch_model.bin"), - weights_only=True, - map_location="cpu", - ) - if LlamaWeightsNaming.match(state_dict): - self.meta = JiugeMetaFromLlama(config, max_tokens=max_tokens) - self.weights = JiugeWeightsImpl( - self.meta, - LlamaWeightsNaming(), - state_dict, - ndev=ndev, - transpose_weight=transpose_weight, - ) - self.tokenizer = transformers.AutoTokenizer.from_pretrained( - model_dir_path, trust_remote_code=True - ) - else: - raise ValueError("Unsupported weight naming") + logger.info(f"fm9g7b load start.") + model = ( + transformers.AutoModelForCausalLM.from_pretrained(model_dir_path, device_map="cpu", torch_dtype=torch.bfloat16, trust_remote_code=True) + ) + logger.info(f"load over.") + load_statets_time = time.time() + self.meta = JiugeMetaFromLlama(config, max_tokens=max_tokens) + self.tokenizer = transformers.AutoTokenizer.from_pretrained(model_dir_path, trust_remote_code=True) + self.weights = JiugeWeightsImpl( + self.meta, + LlamaWeightsNaming(), + model.state_dict(), + ndev=ndev, + transpose_weight=transpose_weight, + ) elif "qwen2" == config["model_type"]: state_dict = load_all_safetensors_from_dir(model_dir_path) if LlamaWeightsNaming.match(state_dict): @@ -496,15 +488,17 @@ def load_all_safetensors_from_dir(dir_path_: str): transpose_weight=transpose_weight, ) self.tokenizer = transformers.AutoTokenizer.from_pretrained( - model_dir_path + model_dir_path, trust_remote_code=True ) else: raise ValueError("Unsupported model architecture") load_end_time = time.time() - print(f"Time used: {load_end_time - load_start_time:.3f}s") + logger.info(f"Time overall used: {load_end_time - load_start_time:.3f}s, " + f"load_states_time: {load_statets_time - load_start_time:.3f}s, " + f"load_weights_impl_time: {load_end_time - load_statets_time:.3f}s") - print(f"Creating model on {ndev} devices...") + logger.info(f"Creating model on {ndev} devices...") load_start_time = time.time() dev_ids = (c_int * ndev)(*[i for i in range(ndev)]) self.model_instance = create_jiuge_model( @@ -642,40 +636,3 @@ def destroy_model_instance(self): destroy_jiuge_model(self.model_instance) print("Model destroyed") - -def test(): - if len(sys.argv) < 3: - print( - "Usage: python jiuge.py [--cpu | --nvidia| --cambricon | --ascend | --metax | --moore] [n_device]" - ) - sys.exit(1) - model_path = sys.argv[2] - device_type = DeviceType.DEVICE_TYPE_CPU - if sys.argv[1] == "--cpu": - device_type = DeviceType.DEVICE_TYPE_CPU - elif sys.argv[1] == "--nvidia": - device_type = DeviceType.DEVICE_TYPE_NVIDIA - elif sys.argv[1] == "--cambricon": - device_type = DeviceType.DEVICE_TYPE_CAMBRICON - elif sys.argv[1] == "--ascend": - device_type = DeviceType.DEVICE_TYPE_ASCEND - elif sys.argv[1] == "--metax": - device_type = DeviceType.DEVICE_TYPE_METAX - elif sys.argv[1] == "--moore": - device_type = DeviceType.DEVICE_TYPE_MOORE - elif sys.argv[1] == "--iluvatar": - device_type = DeviceType.DEVICE_TYPE_ILUVATAR - else: - print( - "Usage: python jiuge.py [--cpu | --nvidia| --cambricon | --ascend | --metax | --moore] [n_device]" - ) - sys.exit(1) - - ndev = int(sys.argv[3]) if len(sys.argv) > 3 else 1 - model = JiugeForCauslLM(model_path, device_type, ndev) - model.generate("山东最高的山是?", 500) - model.destroy_model_instance() - - -if __name__ == "__main__": - test() diff --git a/scripts/jiuge_ppl.py b/scripts/jiuge_ppl.py index 67dc2326..f836871d 100644 --- a/scripts/jiuge_ppl.py +++ b/scripts/jiuge_ppl.py @@ -1,7 +1,7 @@ import torch from transformers import AutoTokenizer, AutoModelForCausalLM from datasets import load_dataset -from jiuge import JiugeForCauslLM +from jiuge import JiugeForCausalLM from libinfinicore_infer import DeviceType DEVICE_TYPE_MAP = { @@ -25,7 +25,7 @@ def test_torch(input_ids_list, device_): device = TORCH_DEVICE_TYPE_MAP[device_] - model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True).to( + model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.bfloat16).to( device ) model.eval() @@ -59,7 +59,7 @@ def test_torch(input_ids_list, device_): def test_infinicore(input_ids_list, device_, ndev_): device = DEVICE_TYPE_MAP[device_] - model = JiugeForCauslLM( + model = JiugeForCausalLM( model_path, device, max_tokens=len(input_ids_list[0]), ndev=ndev_ ) perplexity = model.perplexity(input_ids_list) @@ -99,9 +99,9 @@ def test_infinicore(input_ids_list, device_, ndev_): for i in range(0, len(ids) - seq_len + 1, seq_len): input_ids_list.append(ids[i : i + seq_len]) - perplexity = test_infinicore(input_ids_list, args.dev, args.ndev) - print(f"InfiniCore Perplexity: {perplexity:.2f}") + InfiniCore_perplexity = test_infinicore(input_ids_list, args.dev, args.ndev) + print(f"InfiniCore Perplexity: {InfiniCore_perplexity:.2f}") if args.ndev == 1: # Todo: support multi-device testing with torch - perplexity = test_torch(input_ids_list, args.dev) - print(f"Torch Perplexity: {perplexity.item():.2f}") + Torch_perplexity = test_torch(input_ids_list, args.dev) + print(f"Torch Perplexity: {Torch_perplexity.item():.2f}") diff --git a/scripts/launch_server.py b/scripts/launch_server.py index 4847a477..575eb600 100644 --- a/scripts/launch_server.py +++ b/scripts/launch_server.py @@ -1,4 +1,4 @@ -from jiuge import JiugeForCauslLM +from jiuge import JiugeForCausalLM from libinfinicore_infer import DeviceType from infer_task import InferTask from kvcache_pool import KVCachePool @@ -109,7 +109,7 @@ def output(self, out_token): @contextlib.asynccontextmanager async def lifespan(app: FastAPI): # Startup - app.state.model = JiugeForCauslLM(model_path, device_type, ndev, max_tokens=max_tokens) + app.state.model = JiugeForCausalLM(model_path, device_type, ndev, max_tokens=max_tokens) app.state.kv_cache_pool = KVCachePool(app.state.model, MAX_BATCH) app.state.request_queue = janus.Queue() worker_thread = threading.Thread(target=worker_loop, args=(app,), daemon=True) diff --git a/scripts/test_jiuge.py b/scripts/test_jiuge.py new file mode 100644 index 00000000..9dbb7f6c --- /dev/null +++ b/scripts/test_jiuge.py @@ -0,0 +1,55 @@ +from jiuge import JiugeForCausalLM +import sys +import logging +import argparse + +from libinfinicore_infer import DeviceType +logger = logging.getLogger(__name__) + + + +def parse_args(): + parser = argparse.ArgumentParser() + # parser.add_argument("--model-path", type=str, default="/home/wanghaojie/vllm/huggingface/Llama-2-7b-chat-hf") + # parser.add_argument("--model-path", type=str, default="/home/wanghaojie/vllm/huggingface/FM9G_70B_SFT_MHA/") + parser.add_argument("--model-path", type=str, default="/home/wanghaojie/vllm/huggingface/9G7B_MHA/") + parser.add_argument("--device-type", type=str, default="nvidia") + parser.add_argument("--ndev", type=int, default=4) + args = parser.parse_args() + return args + +def test(): + args = parse_args() + model_path = args.model_path + device_type = DeviceType.DEVICE_TYPE_CPU + if args.device_type == "cpu": + device_type = DeviceType.DEVICE_TYPE_CPU + elif args.device_type == "nvidia": + device_type = DeviceType.DEVICE_TYPE_NVIDIA + elif args.device_type == "cambricon": + device_type = DeviceType.DEVICE_TYPE_CAMBRICON + elif args.device_type == "ascend": + device_type = DeviceType.DEVICE_TYPE_ASCEND + elif args.device_type == "metax": + device_type = DeviceType.DEVICE_TYPE_METAX + elif args.device_type == "moore": + device_type = DeviceType.DEVICE_TYPE_MOORE + elif args.device_type == "iluvatar": + device_type = DeviceType.DEVICE_TYPE_ILUVATAR + else: + logger.info( + # "Usage: python jiuge.py [--cpu | --nvidia| --cambricon | --ascend | --metax | --moore] [n_device]" + "Usage: python jiuge.py [cpu | nvidia| cambricon | ascend | metax | moore] [n_device]" + ) + sys.exit(1) + + ndev = args.ndev + model = JiugeForCausalLM(model_path, device_type, ndev) + # model.generate(["山东最高的山是?", "中国面积最大的省是?"], 500) + # model.generate(["山东最高的山是?"], 500) + model.generate("山东最高的山是?", 500) + model.destroy_model_instance() + + +if __name__ == "__main__": + test() diff --git a/scripts/test_perf.py b/scripts/test_perf.py index a6b26f3b..5ef34a31 100644 --- a/scripts/test_perf.py +++ b/scripts/test_perf.py @@ -28,8 +28,8 @@ "想象一下,如果每个人都能读懂他人的思想。" ] -NUM_REQUESTS = 10 -CONCURRENCY = 5 +NUM_REQUESTS = 20 +CONCURRENCY = 20 API_URL = "http://127.0.0.1:8000" MODEL = "FM9G-7B" @@ -122,6 +122,7 @@ async def run_benchmark(verbose=False): successful_requests = len(results) requests_per_second = successful_requests / total_elapsed_time if total_elapsed_time > 0 else 0 + throughput = sum(tokens_list) / total_elapsed_time if total_elapsed_time > 0 else 0 avg_latency = sum(latencies) / len(latencies) if latencies else 0 avg_tokens_per_second = sum(tokens_per_second_list) / len(tokens_per_second_list) if tokens_per_second_list else 0 avg_ttft = sum(ttft_list) / len(ttft_list) if ttft_list else 0 @@ -139,6 +140,7 @@ async def run_benchmark(verbose=False): print(f"{'总输出token数':<{width_label}}: {sum(tokens_list)}") print(f"{'请求速率 (RPS)':<{width_label}}: {requests_per_second:.2f} requests/s") print(sep) + print(f"{'吞吐量':<{width_label}}: {throughput:.2f} tok/s") print(f"{'Average latency':<{width_label}}: {avg_latency:.2f} s") print(f"{'Average TTFT':<{width_label}}: {avg_ttft:.2f} s") print(f"{'Avg time per token':<{width_label}}: {avg_ms_per_token:.2f} ms/token") diff --git a/scripts/test_ppl.py b/scripts/test_ppl.py index 268a9f7d..1278c569 100644 --- a/scripts/test_ppl.py +++ b/scripts/test_ppl.py @@ -11,17 +11,25 @@ parser = argparse.ArgumentParser() parser.add_argument("--model-path", type=str, required=True) parser.add_argument("--port", type=int, default=8000) - parser.add_argument("--endpoint", type=str, default="/completions") + parser.add_argument("--endpoint", type=str, default="/chat/completions") parser.add_argument("--chunk", type=int, default=512) args = parser.parse_args() API_URL = "http://localhost:" + str(args.port) + args.endpoint CHUNK_SIZE = args.chunk - - dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test") + + print("Loading dataset...") + local_file_paths = { + # "train": "/home/wanghaojie/vllm/huggingface/wikitext/wikitext_local_parquet/train.parquet", + # "validation": "/home/wanghaojie/vllm/huggingface/wikitext/wikitext_local_parquet/validation.parquet", + "test": "/home/wanghaojie/vllm/huggingface/wikitext/wikitext-2-raw-v1/test-00000-of-00001.parquet" + } + dataset = load_dataset("parquet", data_files=local_file_paths, split="test") + print("Dataset loaded.") + # dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test") # Local tokenizer used for chunking - tokenizer = AutoTokenizer.from_pretrained(args.model_path) + tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True) total_neg_log_likelihood = 0.0 total_tokens = 0 @@ -41,8 +49,10 @@ API_URL, headers={"Content-Type": "application/json"}, json={ + "messages": [ + {"role": "user", "content": chunk_text} + ], "model": "", - "prompt": chunk_text, "max_tokens": 0, "temperature": 1.0, "echo": True, diff --git a/setup.sh b/setup.sh new file mode 100755 index 00000000..a06bc889 --- /dev/null +++ b/setup.sh @@ -0,0 +1,589 @@ +#!/bin/bash + +set -e + +echo "===================================================================" +echo "Anthropic API 环境变量配置脚本" +echo "注意:本脚本需要在bash环境中运行" +echo "Windows用户请在git bash终端环境下使用" +echo "Mac/Linux用户可直接在终端中运行" +echo "===================================================================" + +# 1. 检查终端环境 +echo "正在检查运行环境..." + +# 检查是否为bash环境 +if [ -z "$BASH_VERSION" ]; then + echo "❌ 错误: 当前不是bash环境" + echo "请在bash终端中运行此脚本:" + echo " - Windows: 使用 Git Bash 或 WSL" + echo " - Mac/Linux: 使用系统终端" + exit 1 +fi + +# 检测操作系统 +OS_TYPE="unknown" +case "$(uname -s)" in + Linux*) OS_TYPE="Linux";; + Darwin*) OS_TYPE="Mac";; + CYGWIN*|MINGW*|MSYS*) OS_TYPE="Windows";; + *) OS_TYPE="unknown";; +esac + +echo "✓ 检测到操作系统: $OS_TYPE" +echo "✓ bash环境检查通过 (版本: $BASH_VERSION)" + +# Node.js 安装函数 +install_nodejs() { + local platform=$(uname -s) + + case "$platform" in + Linux|Darwin) + echo "🚀 正在安装 Node.js..." + + echo "📥 下载并安装 nvm..." + curl -o- https://raw.githubusercontent.com/nvm-sh/nvm/v0.40.3/install.sh | bash + + echo "🔄 加载 nvm 环境..." + \. "$HOME/.nvm/nvm.sh" + + echo "📦 下载并安装 Node.js v22..." + nvm install 22 + + echo -n "✅ Node.js 安装完成!版本: " + node -v + echo -n "✅ npm 版本: " + npm -v + ;; + *) + echo "❌ 不支持的平台: $platform" + echo "请手动安装 Node.js: https://nodejs.org/" + exit 1 + ;; + esac +} + +# 检查 Node.js 环境 +echo "检查 Node.js 环境..." +if command -v node >/dev/null 2>&1; then + current_version=$(node -v | sed 's/v//') + major_version=$(echo $current_version | cut -d. -f1) + + if [ "$major_version" -ge 18 ]; then + echo "✓ Node.js 已安装: v$current_version" + else + echo "⚠️ Node.js v$current_version 版本过低 (需要 >= 18),正在升级..." + install_nodejs + fi +else + echo "📦 Node.js 未安装,正在安装..." + install_nodejs +fi + +# 检查 npm 环境 +if command -v npm >/dev/null 2>&1; then + echo "✓ npm 已安装: $(npm -v)" +else + echo "❌ npm 未找到,Node.js 安装可能有问题" + exit 1 +fi + +# 2. 确定环境变量配置文件 +echo "正在扫描所有可能的环境变量配置文件..." + +# 初始化配置文件数组 +CONFIG_FILES=() + +# 检测当前shell类型 +current_shell=$(basename "$SHELL") + +# 根据shell类型和操作系统,列出所有可能的配置文件 +case "$current_shell" in + bash) + # Bash配置文件优先级顺序 + if [ "$OS_TYPE" = "Mac" ]; then + # macOS上bash配置文件 + potential_files=( + "$HOME/.bash_profile" + "$HOME/.bashrc" + "$HOME/.profile" + ) + else + # Linux/Windows上bash配置文件 + potential_files=( + "$HOME/.bashrc" + "$HOME/.bash_profile" + "$HOME/.profile" + ) + fi + ;; + zsh) + # Zsh配置文件优先级顺序 + potential_files=( + "$HOME/.zshrc" + "$HOME/.zprofile" + "$HOME/.zshenv" + "$HOME/.profile" + ) + + # 检查是否使用Oh My Zsh,避免冲突 + if [ -n "$ZSH" ] && [ -d "$ZSH" ]; then + echo "⚠️ 检测到Oh My Zsh环境,将在配置文件末尾添加变量" + fi + ;; + fish) + # Fish shell配置文件 + potential_files=( + "$HOME/.config/fish/config.fish" + ) + + # 创建fish配置目录(如果不存在) + if [ ! -d "$HOME/.config/fish" ]; then + mkdir -p "$HOME/.config/fish" + echo "创建fish配置目录: ~/.config/fish/" + fi + ;; + *) + # 其他shell的通用配置文件 + potential_files=( + "$HOME/.profile" + "$HOME/.bashrc" + ) + ;; +esac + +# 检查每个可能的配置文件 +echo "检查以下配置文件:" +for file in "${potential_files[@]}"; do + if [ -f "$file" ]; then + CONFIG_FILES+=("$file") + echo " ✓ 找到: ${file/#$HOME/~}" + else + echo " × 不存在: ${file/#$HOME/~}" + fi +done + +# 如果没有找到任何配置文件,创建默认的 +if [ ${#CONFIG_FILES[@]} -eq 0 ]; then + # 根据shell类型创建默认配置文件 + case "$current_shell" in + bash) + if [ "$OS_TYPE" = "Mac" ]; then + DEFAULT_FILE="$HOME/.bash_profile" + else + DEFAULT_FILE="$HOME/.bashrc" + fi + ;; + zsh) + DEFAULT_FILE="$HOME/.zshrc" + ;; + fish) + DEFAULT_FILE="$HOME/.config/fish/config.fish" + ;; + *) + DEFAULT_FILE="$HOME/.profile" + ;; + esac + + touch "$DEFAULT_FILE" + CONFIG_FILES+=("$DEFAULT_FILE") + echo "创建新的配置文件: ${DEFAULT_FILE/#$HOME/~}" +fi + +echo "" +echo "✓ 将更新 ${#CONFIG_FILES[@]} 个配置文件" + +# 3. 检查现有配置(支持不同shell语法) +echo "" +echo "检查现有Anthropic配置..." +EXISTING_CONFIGS=() +BACKUP_FILES=() + +# 检查每个配置文件中的现有配置 +for config_file in "${CONFIG_FILES[@]}"; do + has_config=false + + # 根据文件名判断语法类型 + if [[ "$config_file" == *"fish"* ]]; then + # fish shell 语法: set -x ANTHROPIC_AUTH_TOKEN + if grep -q "set -x ANTHROPIC_AUTH_TOKEN\|set -x ANTHROPIC_BASE_URL" "$config_file" 2>/dev/null; then + has_config=true + fi + else + # bash/zsh 语法: export ANTHROPIC_AUTH_TOKEN + if grep -q "ANTHROPIC_AUTH_TOKEN\|ANTHROPIC_BASE_URL" "$config_file" 2>/dev/null; then + has_config=true + fi + fi + + if [ "$has_config" = true ]; then + EXISTING_CONFIGS+=("$config_file") + echo "⚠️ 在 ${config_file/#$HOME/~} 中检测到已存在的Anthropic配置:" + if [[ "$config_file" == *"fish"* ]]; then + grep -n "set -x ANTHROPIC_" "$config_file" | sed 's/^/ /' || true + else + grep -n "ANTHROPIC_" "$config_file" | sed 's/^/ /' || true + fi + fi +done + +# 如果有现有配置,询问是否覆盖 +if [ ${#EXISTING_CONFIGS[@]} -gt 0 ]; then + echo "" + echo "📋 在 ${#EXISTING_CONFIGS[@]} 个文件中发现现有配置" + read -p "是否要覆盖所有现有配置?(y/N): " overwrite + if [[ ! "$overwrite" =~ ^[Yy]$ ]]; then + echo "操作已取消" + exit 0 + fi + + # 备份所有包含配置的文件 + echo "" + echo "正在备份现有配置文件..." + for config_file in "${EXISTING_CONFIGS[@]}"; do + backup_file="${config_file}.backup.$(date +%Y%m%d_%H%M%S)" + cp "$config_file" "$backup_file" + BACKUP_FILES+=("$backup_file") + echo " ✓ 已备份: ${backup_file/#$HOME/~}" + done +fi + +# 颜色定义 +colorReset='\033[0m' +colorBright='\033[1m' +colorCyan='\033[36m' +colorYellow='\033[33m' +colorMagenta='\033[35m' +colorRed='\033[31m' +colorBlue='\033[34m' +colorWhite='\033[37m' +colorGreen='\033[32m' + +# 显示API密钥获取横幅 +show_api_banner() { + printf "${colorBright}${colorRed} █████╗ ██╗ ${colorBlue}██████╗ ██████╗ ██████╗ ███████╗${colorMagenta} ██╗ ██╗██╗████████╗██╗ ██╗${colorReset}\n" + printf "${colorBright}${colorRed} ██╔══██╗██║ ${colorBlue}██╔════╝██╔═══██╗██╔══██╗██╔════╝${colorMagenta} ██║ ██║██║╚══██╔══╝██║ ██║${colorReset}\n" + printf "${colorBright}${colorRed} ███████║██║ ${colorBlue}██║ ██║ ██║██║ ██║█████╗ ${colorMagenta} ██║ █╗ ██║██║ ██║ ███████║${colorReset}\n" + printf "${colorBright}${colorRed} ██╔══██║██║ ${colorBlue}██║ ██║ ██║██║ ██║██╔══╝ ${colorMagenta} ██║███╗██║██║ ██║ ██╔══██║${colorReset}\n" + printf "${colorBright}${colorRed} ██║ ██║██║ ${colorBlue}╚██████╗╚██████╔╝██████╔╝███████╗${colorMagenta} ╚███╔███╔╝██║ ██║██╗██║ ██║${colorReset}\n" + printf "${colorBright}${colorRed} ╚═╝ ╚═╝╚═╝ ${colorBlue} ╚═════╝ ╚═════╝ ╚═════╝ ╚══════╝${colorMagenta} ╚══╝╚══╝ ╚═╝ ╚═╝╚═╝╚═╝ ╚═╝${colorReset}\n" + printf "\n" + printf "${colorBright}${colorYellow}🌐 请从以下网址获取您的API密钥:${colorReset}\n" + printf "${colorBright}${colorCyan}📋 https://aicodewith.com/dashboard/api-keys${colorReset}\n" + printf "\n" + printf "${colorBright}${colorGreen}📝 API密钥格式: sk-acw-********-****************${colorReset}\n" + printf "\n" +} + +# 4. 获取API密钥 +echo "" +show_api_banner + +# 输入API密钥并验证 +while true; do + read -p "请输入ANTHROPIC_AUTH_TOKEN: " auth_token + echo "" + + # 基本格式验证 + if [[ "$auth_token" =~ ^sk-acw-.{8}-.{16}$ ]]; then + echo "✓ API密钥格式验证通过" + break + else + echo "❌ API密钥格式不正确" + echo " 正确格式: sk-acw-********-****************" + echo " 请重新输入" + fi +done + +# 5. 更新配置文件 +echo "" +echo "正在更新配置文件..." +UPDATE_COUNT=0 +FAILED_FILES=() + +# 处理每个配置文件 +for config_file in "${CONFIG_FILES[@]}"; do + echo " 📝 处理: ${config_file/#$HOME/~}" + + # 判断文件类型和语法 + is_fish=false + if [[ "$config_file" == *"fish"* ]]; then + is_fish=true + fi + + # 移除旧的Anthropic配置 + if grep -q "ANTHROPIC_AUTH_TOKEN\|ANTHROPIC_BASE_URL" "$config_file" 2>/dev/null || \ + grep -q "set -x ANTHROPIC_AUTH_TOKEN\|set -x ANTHROPIC_BASE_URL" "$config_file" 2>/dev/null; then + + # 创建临时文件,移除旧配置 + temp_file=$(mktemp) + if [ "$is_fish" = true ]; then + # 移除fish语法的配置行 + grep -v "set -x ANTHROPIC_AUTH_TOKEN\|set -x ANTHROPIC_BASE_URL" "$config_file" > "$temp_file" + else + # 移除bash/zsh语法的配置行 + grep -v "ANTHROPIC_AUTH_TOKEN\|ANTHROPIC_BASE_URL" "$config_file" > "$temp_file" + fi + mv "$temp_file" "$config_file" + fi + + # 添加新配置 + if [ "$is_fish" = true ]; then + # fish shell 语法 + { + echo "" + echo "# Anthropic API Configuration - $(date '+%Y-%m-%d %H:%M:%S')" + echo "set -x ANTHROPIC_AUTH_TOKEN $auth_token" + echo "set -x ANTHROPIC_BASE_URL https://api.jiuwanliguoxue.com/" + } >> "$config_file" + else + # bash/zsh 语法 + { + echo "" + echo "# Anthropic API Configuration - $(date '+%Y-%m-%d %H:%M:%S')" + echo "export ANTHROPIC_AUTH_TOKEN=$auth_token" + echo "export ANTHROPIC_BASE_URL=https://api.jiuwanliguoxue.com/" + } >> "$config_file" + fi + + # 验证是否写入成功 + if [ "$is_fish" = true ]; then + if grep -q "set -x ANTHROPIC_AUTH_TOKEN $auth_token" "$config_file" && \ + grep -q "set -x ANTHROPIC_BASE_URL" "$config_file"; then + echo " ✓ 配置成功写入" + ((UPDATE_COUNT++)) + else + echo " ❌ 配置写入失败" + FAILED_FILES+=("$config_file") + fi + else + if grep -q "ANTHROPIC_AUTH_TOKEN=$auth_token" "$config_file" && \ + grep -q "ANTHROPIC_BASE_URL=" "$config_file"; then + echo " ✓ 配置成功写入" + ((UPDATE_COUNT++)) + else + echo " ❌ 配置写入失败" + FAILED_FILES+=("$config_file") + fi + fi +done + +echo "" +echo "✓ 成功更新 $UPDATE_COUNT/${#CONFIG_FILES[@]} 个配置文件" + +# 如果有失败的文件,显示错误信息 +if [ ${#FAILED_FILES[@]} -gt 0 ]; then + echo "" + echo "❌ 以下文件更新失败:" + for failed_file in "${FAILED_FILES[@]}"; do + echo " - ${failed_file/#$HOME/~}" + done +fi + +# 6. 加载环境变量并验证 +echo "" +echo "正在加载和验证环境变量..." + +# 尝试从非fish配置文件加载环境变量 +if [[ "$current_shell" != "fish" ]]; then + # 从所有非fish配置文件中提取并加载Anthropic环境变量 + for config_file in "${CONFIG_FILES[@]}"; do + if [[ "$config_file" != *"fish"* ]]; then + eval "$(grep "^export ANTHROPIC_" "$config_file" 2>/dev/null || true)" + fi + done +else + echo "⚠️ Fish shell配置文件不兼容bash,跳过自动加载" +fi + +# 手动设置环境变量用于当前会话 +export ANTHROPIC_AUTH_TOKEN=$auth_token +export ANTHROPIC_BASE_URL=https://api.jiuwanliguoxue.com/ + +# 验证配置是否成功 +if [ "$UPDATE_COUNT" -eq "${#CONFIG_FILES[@]}" ]; then + echo "✅ 所有配置文件更新成功!" + echo "" + echo "📊 当前配置:" + echo " ANTHROPIC_BASE_URL: $ANTHROPIC_BASE_URL" + echo " ANTHROPIC_AUTH_TOKEN: ${ANTHROPIC_AUTH_TOKEN:0:12}...(已隐藏)" + echo "" + + # 显示更新的配置文件列表 + echo "📋 已更新的配置文件:" + for config_file in "${CONFIG_FILES[@]}"; do + echo " - ${config_file/#$HOME/~}" + done + echo "" + echo "🎉 配置完成!" + echo "" + + # 7. 检查并安装/更新Claude Code客户端 + echo "🔍 检查Claude Code客户端..." + if command -v claude >/dev/null 2>&1; then + echo "✓ Claude Code已安装: $(claude --version)" + echo "" + echo "🚀 是否要更新Claude Code客户端到最新版本?" + read -p "这将执行: npm uninstall/install -g @anthropic-ai/claude-code (y/N): " update_claude + + if [[ "$update_claude" =~ ^[Yy]$ ]]; then + echo "🔄 正在更新Claude Code客户端..." + + echo "步骤1: 卸载旧版本..." + npm uninstall -g @anthropic-ai/claude-code + + echo "步骤2: 安装最新版本..." + if npm install -g @anthropic-ai/claude-code --registry=https://registry.npmmirror.com; then + echo "✅ Claude Code客户端更新成功!" + else + echo "❌ Claude Code客户端安装失败,请手动执行:" + echo " npm install -g @anthropic-ai/claude-code --registry=https://registry.npmmirror.com" + fi + fi + else + echo "📦 Claude Code未安装,正在安装..." + if npm install -g @anthropic-ai/claude-code --registry=https://registry.npmmirror.com; then + echo "✅ Claude Code客户端安装成功!" + else + echo "❌ Claude Code客户端安装失败,请手动执行:" + echo " npm install -g @anthropic-ai/claude-code --registry=https://registry.npmmirror.com" + exit 1 + fi + fi + + # 8. 配置Claude Code跳过引导 + echo "" + echo "🔧 配置Claude Code跳过引导..." + node --eval " + const fs = require('fs'); + const os = require('os'); + const path = require('path'); + + const homeDir = os.homedir(); + const filePath = path.join(homeDir, '.claude.json'); + + try { + if (fs.existsSync(filePath)) { + const content = JSON.parse(fs.readFileSync(filePath, 'utf-8')); + fs.writeFileSync(filePath, JSON.stringify({ ...content, hasCompletedOnboarding: true }, null, 2), 'utf-8'); + console.log('✅ 已更新现有Claude配置文件'); + } else { + fs.writeFileSync(filePath, JSON.stringify({ hasCompletedOnboarding: true }, null, 2), 'utf-8'); + console.log('✅ 已创建Claude配置文件并跳过引导'); + } + } catch (error) { + console.log('⚠️ 配置Claude引导跳过时出错:', error.message); + } + " + echo "" + + # 9. 检测并清理Claude配置文件中的代理设置 + echo "" + echo "🔍 检测Claude配置文件中的代理设置..." + # Claude配置文件可能的路径(优先检查settings.json) + CLAUDE_SETTING_FILE="" + if [ -f "$HOME/.claude/settings.json" ]; then + CLAUDE_SETTING_FILE="$HOME/.claude/settings.json" + elif [ -f "$HOME/.claude/settings.local.json" ]; then + CLAUDE_SETTING_FILE="$HOME/.claude/settings.local.json" + elif [ -f "$HOME/.claude/setting.json" ]; then + CLAUDE_SETTING_FILE="$HOME/.claude/setting.json" + fi + + if [ -n "$CLAUDE_SETTING_FILE" ]; then + echo "✓ 找到Claude配置文件: ${CLAUDE_SETTING_FILE/#$HOME/~}" + + # 检测是否存在代理设置 + PROXY_FOUND=false + PROXY_SETTINGS="" + + # 检查是否有HTTP代理设置(不区分大小写) + if grep -iq "http_proxy\|https_proxy\|httpproxy\|httpsproxy" "$CLAUDE_SETTING_FILE" 2>/dev/null; then + PROXY_FOUND=true + echo "" + echo "⚠️ 检测到残留的代理配置:" + PROXY_SETTINGS=$(grep -in "http_proxy\|https_proxy\|httpproxy\|httpsproxy" "$CLAUDE_SETTING_FILE" | sed 's/^/ /') + echo "$PROXY_SETTINGS" + echo "" + echo "📝 这些代理设置可能会影响Claude Code的正常使用" + echo " 建议删除这些设置以避免连接问题" + echo "" + + read -p "是否要删除这些代理设置?(y/N): " remove_proxy + if [[ "$remove_proxy" =~ ^[Yy]$ ]]; then + # 备份原配置文件 + backup_claude_file="${CLAUDE_SETTING_FILE}.backup.$(date +%Y%m%d_%H%M%S)" + cp "$CLAUDE_SETTING_FILE" "$backup_claude_file" + echo "✓ 已备份Claude配置到: ${backup_claude_file/#$HOME/~}" + + # 删除代理设置行(不区分大小写) + # 使用sed删除包含代理相关设置的行 + if [[ "$OS_TYPE" = "Mac" ]]; then + # Mac版本的sed需要备份文件参数 + sed -i '' '/[Hh][Tt][Tt][Pp]_[Pp][Rr][Oo][Xx][Yy]\|[Hh][Tt][Tt][Pp][Ss]_[Pp][Rr][Oo][Xx][Yy]\|[Hh][Tt][Tt][Pp][Pp][Rr][Oo][Xx][Yy]\|[Hh][Tt][Tt][Pp][Ss][Pp][Rr][Oo][Xx][Yy]/d' "$CLAUDE_SETTING_FILE" + else + # Linux/Windows版本的sed + sed -i '/[Hh][Tt][Tt][Pp]_[Pp][Rr][Oo][Xx][Yy]\|[Hh][Tt][Tt][Pp][Ss]_[Pp][Rr][Oo][Xx][Yy]\|[Hh][Tt][Tt][Pp][Pp][Rr][Oo][Xx][Yy]\|[Hh][Tt][Tt][Pp][Ss][Pp][Rr][Oo][Xx][Yy]/d' "$CLAUDE_SETTING_FILE" + fi + + # 验证删除结果(不区分大小写) + if ! grep -iq "http_proxy\|https_proxy\|httpproxy\|httpsproxy" "$CLAUDE_SETTING_FILE" 2>/dev/null; then + echo "✅ 代理设置已成功删除" + echo "📋 Claude Code现在应该能正常使用默认网络连接" + else + echo "❌ 代理设置删除失败" + echo " 请手动编辑文件: $CLAUDE_SETTING_FILE" + echo " 或恢复备份: cp $backup_claude_file $CLAUDE_SETTING_FILE" + fi + else + echo "跳过代理设置清理" + fi + else + echo "✓ 未发现代理设置,配置文件正常" + fi + else + echo "ℹ️ 未找到Claude配置文件(${CLAUDE_SETTING_FILE/#$HOME/~})" + echo " 这是正常的,配置文件会在首次使用Claude Code时自动创建" + fi + echo "" + +# 显示配置完成横幅 +show_complete_banner() { + printf "\n" + printf "${colorBright}${colorRed} █████╗ ██╗ ${colorBlue}██████╗ ██████╗ ██████╗ ███████╗${colorMagenta} ██╗ ██╗██╗████████╗██╗ ██╗${colorReset}\n" + printf "${colorBright}${colorRed} ██╔══██╗██║ ${colorBlue}██╔════╝██╔═══██╗██╔══██╗██╔════╝${colorMagenta} ██║ ██║██║╚══██╔══╝██║ ██║${colorReset}\n" + printf "${colorBright}${colorRed} ███████║██║ ${colorBlue}██║ ██║ ██║██║ ██║█████╗ ${colorMagenta} ██║ █╗ ██║██║ ██║ ███████║${colorReset}\n" + printf "${colorBright}${colorRed} ██╔══██║██║ ${colorBlue}██║ ██║ ██║██║ ██║██╔══╝ ${colorMagenta} ██║███╗██║██║ ██║ ██╔══██║${colorReset}\n" + printf "${colorBright}${colorRed} ██║ ██║██║ ${colorBlue}╚██████╗╚██████╔╝██████╔╝███████╗${colorMagenta} ╚███╔███╔╝██║ ██║██╗██║ ██║${colorReset}\n" + printf "${colorBright}${colorRed} ╚═╝ ╚═╝╚═╝ ${colorBlue} ╚═════╝ ╚═════╝ ╚═════╝ ╚══════╝${colorMagenta} ╚══╝╚══╝ ╚═╝ ╚═╝╚═╝╚═╝ ╚═╝${colorReset}\n" + printf "\n" + printf "${colorBright}${colorYellow}📌 请执行以下命令使配置立即生效:${colorReset}\n" + printf "${colorBright}${colorCyan} source ${CONFIG_FILE/#$HOME/~}${colorReset}\n" + printf "\n" + printf "${colorBright}${colorGreen}🔄 或者重启终端让配置自动生效${colorReset}\n" + printf "\n" +} + + show_complete_banner + echo "" + echo "🔧 如需修改配置,可编辑: ${CONFIG_FILE/#$HOME/~}" +else + # 方案3: 改进错误提示,说明可能的原因 + echo "❌ 配置文件验证失败,可能的原因:" + echo " 1. 配置文件写入过程中出现错误" + echo " 2. 磁盘空间不足或权限问题" + echo " 3. API密钥格式在写入时被意外修改" + echo "" + echo "🔍 调试信息:" + echo " 配置文件路径: $CONFIG_FILE" + echo " API密钥长度: ${#auth_token}" + echo " 配置文件末尾内容:" + tail -5 "$CONFIG_FILE" 2>/dev/null || echo " 无法读取配置文件" + echo "" + echo "💡 建议解决方案:" + echo " 1. 检查磁盘空间: df -h $HOME" + echo " 2. 检查文件权限: ls -la $CONFIG_FILE" + echo " 3. 手动验证配置: cat $CONFIG_FILE | grep ANTHROPIC" + echo " 4. 重新运行脚本" + exit 1 +fi \ No newline at end of file diff --git a/src/models/cache_manager.hpp b/src/models/cache_manager.hpp index 4d1b5aa7..1893b147 100644 --- a/src/models/cache_manager.hpp +++ b/src/models/cache_manager.hpp @@ -142,6 +142,7 @@ class CacheManager { const size_t DEFAULT_CACHE_CAPACITY = 128; LRUDescriptorCache add_cache; + LRUDescriptorCache mul_cache; LRUDescriptorCache rms_norm_cache; LRUDescriptorCache gemm_cache; LRUDescriptorCache rope_cache; @@ -149,17 +150,22 @@ class CacheManager { LRUDescriptorCache causal_softmax_cache; LRUDescriptorCache swiglu_cache; LRUDescriptorCache random_sample_cache; + LRUDescriptorCache paged_caching_cache; + LRUDescriptorCache paged_attention_cache; public: CacheManager(size_t capacity = 100) : add_cache(capacity, infiniopDestroyAddDescriptor), + mul_cache(capacity, infiniopDestroyMulDescriptor), rms_norm_cache(capacity, infiniopDestroyRMSNormDescriptor), gemm_cache(capacity, infiniopDestroyGemmDescriptor), rope_cache(capacity, infiniopDestroyRoPEDescriptor), rearrange_cache(capacity, infiniopDestroyRearrangeDescriptor), causal_softmax_cache(capacity, infiniopDestroyCausalSoftmaxDescriptor), swiglu_cache(capacity, infiniopDestroySwiGLUDescriptor), - random_sample_cache(capacity, infiniopDestroyRandomSampleDescriptor) {} + random_sample_cache(capacity, infiniopDestroyRandomSampleDescriptor), + paged_caching_cache(capacity, infiniopDestroyPagedCachingDescriptor), + paged_attention_cache(capacity, infiniopDestroyPagedAttentionDescriptor) {} // Add operations bool getAddDescriptor(size_t key, infiniopAddDescriptor_t &desc) { @@ -170,6 +176,15 @@ class CacheManager { add_cache.put(key, desc); } + // Mul operations + bool getMulDescriptor(size_t key, infiniopMulDescriptor_t &desc) { + return mul_cache.get(key, desc); + } + + void putMulDescriptor(size_t key, const infiniopMulDescriptor_t &desc) { + mul_cache.put(key, desc); + } + // RMSNorm operations bool getRMSNormDescriptor(size_t key, infiniopRMSNormDescriptor_t &desc) { return rms_norm_cache.get(key, desc); @@ -233,6 +248,24 @@ class CacheManager { random_sample_cache.put(key, desc); } + // Paged Caching operations + bool getPagedCachingDescriptor(size_t key, infiniopPagedCachingDescriptor_t &desc) { + return paged_caching_cache.get(key, desc); + } + + void putPagedCachingDescriptor(size_t key, const infiniopPagedCachingDescriptor_t &desc) { + paged_caching_cache.put(key, desc); + } + + // Paged Attention operations + bool getPagedAttentionDescriptor(size_t key, infiniopPagedAttentionDescriptor_t &desc) { + return paged_attention_cache.get(key, desc); + } + + void putPagedAttentionDescriptor(size_t key, const infiniopPagedAttentionDescriptor_t &desc) { + paged_attention_cache.put(key, desc); + } + template static size_t createDescriptorKey(Tensors... tensors) { size_t seed = 0; diff --git a/src/models/inference_context.cpp b/src/models/inference_context.cpp index fd0dea64..b7f69999 100644 --- a/src/models/inference_context.cpp +++ b/src/models/inference_context.cpp @@ -33,6 +33,27 @@ void InferenceContext::add(std::shared_ptr c, c->data(), a->data(), b->data(), stream)); } +void InferenceContext::mul(std::shared_ptr c, + std::shared_ptr a, + std::shared_ptr b) { + size_t key = CacheManager::createDescriptorKey(c, a, b); + + infiniopMulDescriptor_t desc; + if (!cache_manager->getMulDescriptor(key, desc)) { + RUN_INFINI(infiniopCreateMulDescriptor(rsrc->handle, &desc, c->desc(), a->desc(), b->desc())); + cache_manager->putMulDescriptor(key, desc); + } + + size_t workspace_size = 0; + RUN_INFINI(infiniopGetMulWorkspaceSize(desc, &workspace_size)); + ensure_workspace(workspace_size); + void *workspace = workspace_storage->memory(); + + RUN_INFINI(infiniopMul( + desc, workspace, workspace_size, + c->data(), a->data(), b->data(), stream)); +} + void InferenceContext::rmsnorm(std::shared_ptr y, std::shared_ptr x, std::shared_ptr w, @@ -231,3 +252,71 @@ void InferenceContext::linear(std::shared_ptr c, add(c, c, bias->view_as(c->shape(), strides)); } } + +void InferenceContext::pagedCaching(std::shared_ptr k, + std::shared_ptr v, + std::shared_ptr k_cache, + std::shared_ptr v_cache, + std::shared_ptr slot_mapping) { + size_t key = CacheManager::createDescriptorKey(k, v, k_cache, v_cache, slot_mapping); + + infiniopPagedCachingDescriptor_t desc; + if (!cache_manager->getPagedCachingDescriptor(key, desc)) { + RUN_INFINI(infiniopCreatePagedCachingDescriptor( + rsrc->handle, &desc, k->desc(), v->desc(), + k_cache->desc(), v_cache->desc(), slot_mapping->desc())); + cache_manager->putPagedCachingDescriptor(key, desc); + } + + size_t workspace_size = 0; + RUN_INFINI(infiniopGetPagedCachingWorkspaceSize(desc, &workspace_size)); + ensure_workspace(workspace_size); + void *workspace = workspace_storage->memory(); + + RUN_INFINI(infiniopPagedCaching( + desc, workspace, workspace_size, + k->data(), v->data(), + k_cache->data(), v_cache->data(), + slot_mapping->data(), stream)); +} + +void InferenceContext::pagedAttention(std::shared_ptr out, + std::shared_ptr q, + std::shared_ptr k_cache, + std::shared_ptr v_cache, + std::shared_ptr block_tables, + std::shared_ptr seq_lens, + std::shared_ptr alibi_slopes, // can be nullptr + float scale) { + + size_t key = CacheManager::createDescriptorKey(out, q, k_cache, v_cache, block_tables, seq_lens, alibi_slopes); + + infiniopPagedAttentionDescriptor_t desc; + if (!cache_manager->getPagedAttentionDescriptor(key, desc)) { + infiniopTensorDescriptor_t alibi_desc = alibi_slopes ? alibi_slopes->desc() : nullptr; + RUN_INFINI(infiniopCreatePagedAttentionDescriptor( + rsrc->handle, &desc, out->desc(), q->desc(), + k_cache->desc(), v_cache->desc(), block_tables->desc(), + seq_lens->desc(), alibi_desc, scale)); + cache_manager->putPagedAttentionDescriptor(key, desc); + } + + size_t workspace_size = 0; + RUN_INFINI(infiniopGetPagedAttentionWorkspaceSize(desc, &workspace_size)); + ensure_workspace(workspace_size); + void *workspace = workspace_storage->memory(); + + const void* alibi_data = alibi_slopes ? alibi_slopes->data() : nullptr; + RUN_INFINI(infiniopPagedAttention( + desc, workspace, workspace_size, + out->data(), q->data(), k_cache->data(), v_cache->data(), + block_tables->data(), seq_lens->data(), alibi_data, + stream)); +} + + + + + + + diff --git a/src/models/inference_context.hpp b/src/models/inference_context.hpp index dd5f4b78..a55b5c25 100644 --- a/src/models/inference_context.hpp +++ b/src/models/inference_context.hpp @@ -19,6 +19,9 @@ struct InferenceContext { void add(std::shared_ptr c, std::shared_ptr a, std::shared_ptr b); + void mul(std::shared_ptr c, + std::shared_ptr a, + std::shared_ptr b); void rmsnorm(std::shared_ptr y, std::shared_ptr x, std::shared_ptr w, @@ -49,6 +52,21 @@ struct InferenceContext { float alpha, float beta, std::shared_ptr residual, std::shared_ptr bias); + + void pagedCaching(std::shared_ptr k, + std::shared_ptr v, + std::shared_ptr k_cache, + std::shared_ptr v_cache, + std::shared_ptr slot_mapping); + + void pagedAttention(std::shared_ptr out, + std::shared_ptr q, + std::shared_ptr k_cache, + std::shared_ptr v_cache, + std::shared_ptr block_tables, + std::shared_ptr seq_lens, + std::shared_ptr alibi_slopes, // can be nullptr + float scale); }; namespace { @@ -68,6 +86,11 @@ inline void add(std::shared_ptr c, std::shared_ptr a, std::share getInferenceContext().add(c, a, b); } +inline void mul(std::shared_ptr c, std::shared_ptr a, std::shared_ptr b) { + getInferenceContext().mul(c, a, b); +} + + inline void rmsnorm(std::shared_ptr y, std::shared_ptr x, std::shared_ptr w, float epsilon) { getInferenceContext().rmsnorm(y, x, w, epsilon); @@ -106,4 +129,21 @@ inline void linear(std::shared_ptr c, std::shared_ptr a, std::shared_ptr b, float alpha, float beta, std::shared_ptr residual, std::shared_ptr bias) { getInferenceContext().linear(c, a, b, alpha, beta, residual, bias); + } + +inline void pagedCaching(std::shared_ptr k, std::shared_ptr v, + std::shared_ptr k_cache, std::shared_ptr v_cache, + std::shared_ptr slot_mapping) { + getInferenceContext().pagedCaching(k, v, k_cache, v_cache, slot_mapping); +} + +inline void pagedAttention(std::shared_ptr out, std::shared_ptr q, + std::shared_ptr k_cache, std::shared_ptr v_cache, + std::shared_ptr block_tables, std::shared_ptr seq_lens, + std::shared_ptr alibi_slopes, float scale) { + getInferenceContext().pagedAttention(out, q, k_cache, v_cache, block_tables, seq_lens, alibi_slopes, scale); +} + + + diff --git a/src/models/jiuge/jiuge.cpp b/src/models/jiuge/jiuge.cpp index 6e4e79b4..5c3a8f23 100644 --- a/src/models/jiuge/jiuge.cpp +++ b/src/models/jiuge/jiuge.cpp @@ -9,6 +9,8 @@ #include #include #include +#include +#include void createDeviceResource(DeviceResource *rsrc, const JiugeMeta *meta, const JiugeWeights *weights, @@ -116,7 +118,10 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc, const uint32_t *tokens, uint32_t ntok, const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos, struct KVCache **kv_caches, + const int32_t *block_tables, + const int32_t *slot_mapping, const float *temperature, const uint32_t *topk, const float *topp, + const uint32_t is_prefill, const bool enable_paged_attn, uint32_t *output, void *last_logits) { auto nlayer = meta.nlayer; auto nkvh = meta.nkvh / ndev; @@ -130,13 +135,13 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc, auto dvoc = meta.dvoc; auto stream = rsrc.stream; bool has_qkv_bias = rsrc.b_attn_qkv.size() > 0; - // Allocate buffers auto logits_in = Tensor::buffer(dt_logits, {ntok, d}, rsrc.memory_pool); auto logits_out = Tensor::buffer(dt_logits, {ntok, d}, rsrc.memory_pool); auto qkv_buf = Tensor::buffer(dt_logits, {ntok, (nh + nkvh * 2) * dh}, rsrc.memory_pool); auto gate_up_buf = Tensor::buffer(dt_logits, {ntok, 2 * di}, rsrc.memory_pool); auto o_buf = Tensor::buffer(dt_logits, {ntok, nh * dh}, rsrc.memory_pool); + auto q_buf = Tensor::buffer(dt_logits, {ntok, nh , dh}, rsrc.memory_pool); auto prob_buf = Tensor::buffer(dt_logits, {nreq, dvoc}, rsrc.memory_pool); auto result_buf = Tensor::buffer(INFINI_DTYPE_I64, {nreq}, rsrc.memory_pool); auto result_cpu = std::vector(nreq); @@ -145,11 +150,14 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc, // Prepare inputs auto batch_pos_ids = std::vector(ntok); + auto batch_seq_lens = std::vector(nreq); + size_t req_start = 0; for (uint32_t req = 0; req < nreq; req++) { for (uint32_t i = 0; i < req_lens[req]; i++) { batch_pos_ids[req_start + i] = req_pos[req] + i; } + batch_seq_lens[req] = req_lens[req] + req_pos[req]; req_start += req_lens[req]; } @@ -167,6 +175,27 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc, dsize(dt_logits) * d, INFINIRT_MEMCPY_D2D, stream)); } + std::shared_ptr slot_mapping_buf, block_tables_buf, seq_lens_buf; + size_t max_seq_len_in_batch = 0; + if (enable_paged_attn) { + max_seq_len_in_batch = *std::max_element(batch_seq_lens.begin(), batch_seq_lens.end()); + // Assuming block_size is a known constant, e.g., 16. The max_blocks_per_seq can be calculated. + // Let's assume a reasonable upper bound for simplicity. This might need to be passed in. + // TODO: get block_size from meta + size_t block_size = meta.kvcache_block_size; + size_t max_blocks_per_seq = (max_seq_len_in_batch + block_size - 1) / block_size; + + + slot_mapping_buf = Tensor::buffer(INFINI_DTYPE_I32, {ntok}, rsrc.memory_pool); + block_tables_buf = Tensor::buffer(INFINI_DTYPE_I32, {(uint32_t)nreq, (uint32_t)max_blocks_per_seq}, rsrc.memory_pool); + seq_lens_buf = Tensor::buffer(INFINI_DTYPE_I32, {nreq}, rsrc.memory_pool); + + RUN_INFINI(infinirtMemcpyAsync(slot_mapping_buf->data(), slot_mapping, sizeof(int32_t) * ntok, INFINIRT_MEMCPY_H2D, stream)); + RUN_INFINI(infinirtMemcpyAsync(block_tables_buf->data(), block_tables, sizeof(int32_t) * (nreq * max_blocks_per_seq), INFINIRT_MEMCPY_H2D, stream)); + RUN_INFINI(infinirtMemcpyAsync(seq_lens_buf->data(), batch_seq_lens.data(), sizeof(int32_t) * nreq, INFINIRT_MEMCPY_H2D, stream)); + + } + // Attention // attention inner size_t max_qk_size = 0; @@ -187,11 +216,12 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc, auto attn_val_buf = Tensor::buffer(dt_logits, {nkvh, ngroup * max_seq_len, dh}, rsrc.memory_pool); auto attn_val_gemm = attn_val_buf->view({nkvh, ngroup, max_seq_len, dh}); + // MLP buffers auto gate_buf = gate_up_buf->slice(1, 0, di); auto up_buf = gate_up_buf->slice(1, di, di); - // Compute + for (uint32_t layer = 0; layer < nlayer; layer++) { // 1. Attention // rms norm @@ -202,34 +232,106 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc, rope(qkv_rope->slice(1, 0, nh), qkv_rope->slice(1, 0, nh), pos_ids_buf, rsrc.sin_table, rsrc.cos_table); rope(qkv_rope->slice(1, nh, nkvh), qkv_rope->slice(1, nh, nkvh), pos_ids_buf, rsrc.sin_table, rsrc.cos_table); - size_t token_offset = 0; - for (uint32_t req = 0; req < nreq; req++) { - auto past_len = req_pos[req]; - auto seq_len = req_lens[req]; - auto total_len = past_len + seq_len; - auto o = o_buf->slice({{0, token_offset, seq_len}})->view({seq_len, nkvh, ngroup, dh})->permute({1, 2, 0, 3}); - auto q = qkv_rope->slice({{0, token_offset, seq_len}, {1, 0, nh}})->view({seq_len, nkvh, ngroup, dh})->permute({1, 2, 0, 3}); - auto k = qkv_rope->slice({{0, token_offset, seq_len}, {1, nh, nkvh}}); - auto v = qkv_rope->slice({{0, token_offset, seq_len}, {1, nh + nkvh, nkvh}}); - - // self attention - // concat - rearrange(kv_caches[req]->k[idev][layer]->slice(0, past_len, seq_len), k); - rearrange(kv_caches[req]->v[idev][layer]->slice(0, past_len, seq_len), v); - // qk - rearrange(q_rearrange, q); - auto qk_gemm = qk_buf->view({nkvh, ngroup * seq_len, total_len}); - auto k_gemm = kv_caches[req]->k[idev][layer]->slice(0, 0, total_len)->permute({1, 2, 0}); - linear(qk_gemm, rearrange_q_buf, k_gemm, 1.f / float(sqrt(dh)), 0.f, nullptr, nullptr); - // softmax - auto qk_softmax = qk_buf->view({nh, seq_len, total_len}); - causalSoftmax(qk_softmax, qk_softmax); - auto v_gemm = kv_caches[req]->v[idev][layer]->slice(0, 0, total_len)->permute({1, 0, 2}); - linear(attn_val_buf, qk_gemm, v_gemm, 1.f, 0.f, nullptr, nullptr); - // rearrange attn val - rearrange(o, attn_val_gemm); - - token_offset += seq_len; + if (enable_paged_attn) { + auto k = qkv_rope->slice({ {0, 0, ntok}, {1, nh, nkvh} }); + auto v = qkv_rope->slice({ {0, 0, ntok}, {1, nh + nkvh, nkvh} }); + + // Assuming kv_caches[0] gives access to the entire cache pool for this device. + // This part may need adjustment based on the actual KVCache struct definition. + auto k_cache_pool = kv_caches[0]->k[idev][layer]; + auto v_cache_pool = kv_caches[0]->v[idev][layer]; + pagedCaching(k, v, k_cache_pool, v_cache_pool, slot_mapping_buf); + // printf("o_buf: pass pagedCaching\n"); + + + if (is_prefill) { + size_t token_offset = 0; + for (uint32_t req = 0; req < nreq; req++) { + auto past_len = req_pos[req]; + auto seq_len = req_lens[req]; + auto total_len = past_len + seq_len; + auto o = o_buf->slice({{0, token_offset, seq_len}})->view({seq_len, nkvh, ngroup, dh})->permute({1, 2, 0, 3}); + auto k = qkv_rope->slice({{0, token_offset, seq_len}, {1, nh, nkvh}}); + auto v = qkv_rope->slice({{0, token_offset, seq_len}, {1, nh + nkvh, nkvh}}); + auto q = qkv_rope->slice({{0, token_offset, seq_len}, {1, 0, nh}})->view({seq_len, nkvh, ngroup, dh})->permute({1, 2, 0, 3}); + // qk + // std::cout << "rearrange q" << std::endl; + // std::cout << "q shape: " << q->info() << std::endl; + rearrange(q_rearrange->slice(2, 0, seq_len), q); + // std::cout << "qk_gemm" << std::endl; + // std::cout << "qk_buf: " << qk_buf->info() << std::endl; + auto qk_gemm = qk_buf->slice(1, 0, seq_len * total_len)->view({nkvh, ngroup * seq_len, total_len}); + auto k_gemm = k->permute({1, 2, 0}); + linear(qk_gemm, rearrange_q_buf->slice(1, 0, ngroup * seq_len), k_gemm, 1.f / float(sqrt(dh)), 0.f, nullptr, nullptr); + // softmax + // std::cout << "qk_softmax" << std::endl; + auto qk_softmax = qk_buf->slice(1, 0, seq_len * total_len)->view({nh, seq_len, total_len}); + causalSoftmax(qk_softmax, qk_softmax); + // std::cout << "v_gemm" << std::endl; + auto v_gemm = v->permute({1, 0, 2}); + // std::cout << "attn_val_buf" << std::endl; + linear(attn_val_buf->slice(1, 0, ngroup * seq_len), qk_gemm, v_gemm, 1.f, 0.f, nullptr, nullptr); + // rearrange attn val + rearrange(o, attn_val_gemm->slice(2, 0, seq_len)); + + token_offset += seq_len; + } + } else { + auto o = o_buf->slice({{0, 0, ntok}})->view({ntok, nh, dh}); + auto q_batch = qkv_rope->slice({ {0, 0, ntok}, {1, 0, nh} })->view({ntok, nh, dh}); + // std::cout << "q_batch: " << q_batch->info() << std::endl; + // q_batch->debug + // std::cout << "q_batch: " << q_batch->isContiguous() << std::endl; + // std::cout << "q_batch: " << q_batch->strides() << std::endl; + // q_buf->copyFrom(q_batch, rsrc.handle, stream); + // std::cout << "q_buf: " << q_buf->info() << std::endl; + + float scale = 1.f / float(sqrt(dh)); + // pagedAttention(o, q_buf, k_cache_pool, v_cache_pool, + // block_tables_buf, seq_lens_buf, nullptr /* alibi_slopes */, scale); + pagedAttention(o, q_batch, k_cache_pool, v_cache_pool, + block_tables_buf, seq_lens_buf, nullptr /* alibi_slopes */, scale); + + + } + + } else { + size_t token_offset = 0; + for (uint32_t req = 0; req < nreq; req++) { + auto past_len = req_pos[req]; + auto seq_len = req_lens[req]; + auto total_len = past_len + seq_len; + auto o = o_buf->slice({{0, token_offset, seq_len}})->view({seq_len, nkvh, ngroup, dh})->permute({1, 2, 0, 3}); + auto q = qkv_rope->slice({{0, token_offset, seq_len}, {1, 0, nh}})->view({seq_len, nkvh, ngroup, dh})->permute({1, 2, 0, 3}); + auto k = qkv_rope->slice({{0, token_offset, seq_len}, {1, nh, nkvh}}); + auto v = qkv_rope->slice({{0, token_offset, seq_len}, {1, nh + nkvh, nkvh}}); + + // self attention + // concat + rearrange(kv_caches[req]->k[idev][layer]->slice(0, past_len, seq_len), k); + rearrange(kv_caches[req]->v[idev][layer]->slice(0, past_len, seq_len), v); + // qk + // std::cout << "rearrange q" << std::endl; + // std::cout << "q shape: " << q->info() << std::endl; + rearrange(q_rearrange->slice(2, 0, seq_len), q); + // std::cout << "qk_gemm" << std::endl; + // std::cout << "qk_buf: " << qk_buf->info() << std::endl; + auto qk_gemm = qk_buf->slice(1, 0, seq_len * total_len)->view({nkvh, ngroup * seq_len, total_len}); + auto k_gemm = kv_caches[req]->k[idev][layer]->slice(0, 0, total_len)->permute({1, 2, 0}); + linear(qk_gemm, rearrange_q_buf->slice(1, 0, ngroup * seq_len), k_gemm, 1.f / float(sqrt(dh)), 0.f, nullptr, nullptr); + // softmax + // std::cout << "qk_softmax" << std::endl; + auto qk_softmax = qk_buf->slice(1, 0, seq_len * total_len)->view({nh, seq_len, total_len}); + causalSoftmax(qk_softmax, qk_softmax); + // std::cout << "v_gemm" << std::endl; + auto v_gemm = kv_caches[req]->v[idev][layer]->slice(0, 0, total_len)->permute({1, 0, 2}); + // std::cout << "attn_val_buf" << std::endl; + linear(attn_val_buf->slice(1, 0, ngroup * seq_len), qk_gemm, v_gemm, 1.f, 0.f, nullptr, nullptr); + // rearrange attn val + rearrange(o, attn_val_gemm->slice(2, 0, seq_len)); + + token_offset += seq_len; + } } // o_proj @@ -255,7 +357,10 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc, INFINICCL_SUM, rsrc.comm, stream)); RUN_INFINI(infinirtStreamSynchronize(stream)); } + // printf("o_buf: pass layer %d\n", layer); } + // printf("o_buf: pass all layers\n"); + // Sample and Output if (idev == 0) { if (last_logits != nullptr) { @@ -282,8 +387,8 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc, for (uint32_t req = 0; req < nreq; req++) { auto seq_len = req_lens[req]; float random_val = std::uniform_real_distribution(0, 1)(gen); - randomSample(result_buf->memShare({}, result_buf->dtype()), - prob_buf->view_as({dvoc}, {1}), + randomSample(result_buf->slice(0, req, 1)->view_as({}, {}), + prob_buf->slice(0, req, 1)->view_as({dvoc}, {1}), random_val, topp[req], topk[req], temperature[req]); token_offset += seq_len; } @@ -301,8 +406,11 @@ __C void inferBatch(struct JiugeModel *model, const uint32_t *tokens, uint32_t ntok, const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos, - struct KVCache **kv_caches, + struct KVCache **kv_caches, + const int32_t *block_tables, + const int32_t *slot_mapping, const float *temperature, const uint32_t *topk, const float *topp, + const uint32_t is_prefill, const bool enable_paged_attn, uint32_t *output) { model->req.tokens = tokens; model->req.ntok = ntok; @@ -310,11 +418,15 @@ inferBatch(struct JiugeModel *model, model->req.nreq = nreq; model->req.req_pos = req_pos; model->req.kv_caches = kv_caches; + model->req.block_tables = block_tables; + model->req.slot_mapping = slot_mapping; model->req.output = output; model->req.logits = nullptr; model->req.temperature = temperature; model->req.topk = topk; model->req.topp = topp; + model->req.is_prefill = is_prefill; + model->req.enable_paged_attn = enable_paged_attn; for (size_t idev = 0; idev < model->dev_ids.size(); idev++) { std::unique_lock lock(model->states[idev].mtx); @@ -335,6 +447,9 @@ forwardBatch(struct JiugeModel *model, const uint32_t *tokens, uint32_t ntok, const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos, struct KVCache **kv_caches, + const int32_t *block_tables, + const int32_t *slot_mapping, + const uint32_t is_prefill, const bool enable_paged_attn, void *logits) { model->req.tokens = tokens; model->req.ntok = ntok; @@ -342,11 +457,15 @@ forwardBatch(struct JiugeModel *model, model->req.nreq = nreq; model->req.req_pos = req_pos; model->req.kv_caches = kv_caches; + model->req.block_tables = block_tables; + model->req.slot_mapping = slot_mapping; model->req.output = nullptr; model->req.logits = logits; model->req.temperature = nullptr; model->req.topk = nullptr; model->req.topp = nullptr; + model->req.is_prefill = is_prefill; + model->req.enable_paged_attn = enable_paged_attn; for (size_t idev = 0; idev < model->dev_ids.size(); idev++) { std::unique_lock lock(model->states[idev].mtx); @@ -390,7 +509,10 @@ void launchDevice(const JiugeMeta &meta, const JiugeWeights *weights, DeviceReso inferDeviceBatch(meta, *rsrc, idev, ndev, req.tokens, req.ntok, req.req_lens, req.nreq, req.req_pos, req.kv_caches, - req.temperature, req.topk, req.topp, req.output, req.logits); + req.block_tables, req.slot_mapping, + req.temperature, req.topk, req.topp, + req.is_prefill, req.enable_paged_attn, + req.output, req.logits); state.proceed = false; lock.unlock(); diff --git a/src/models/jiuge/jiuge_impl.hpp b/src/models/jiuge/jiuge_impl.hpp index be05b0e8..c3a03ad2 100644 --- a/src/models/jiuge/jiuge_impl.hpp +++ b/src/models/jiuge/jiuge_impl.hpp @@ -45,10 +45,14 @@ struct InferRequest { uint32_t nreq; const uint32_t *req_pos; struct KVCache **kv_caches; + const int32_t *block_tables; + const int32_t *slot_mapping; const float *temperature; const uint32_t *topk; const float *topp; uint32_t *output; + uint32_t is_prefill; + bool enable_paged_attn; void *logits; }; diff --git a/src/models/jiuge/jiuge_kv_cache.cpp b/src/models/jiuge/jiuge_kv_cache.cpp index db10f94e..bb6d2d44 100644 --- a/src/models/jiuge/jiuge_kv_cache.cpp +++ b/src/models/jiuge/jiuge_kv_cache.cpp @@ -1,4 +1,5 @@ #include "jiuge_impl.hpp" +#include __C struct KVCache *createKVCache(const JiugeModel *model) { KVCache *cache = new KVCache(); @@ -6,6 +7,8 @@ __C struct KVCache *createKVCache(const JiugeModel *model) { auto nkvh = model->meta.nkvh / ndev; auto max_len = model->meta.dctx; auto dh = model->meta.dh; + // auto block_size = model->meta.block_size; + // auto max_num_blocks = model->meta.max_num_blocks; auto shape = std::vector{max_len, nkvh, dh}; for (unsigned int idev = 0; idev < ndev; idev++) { RUN_INFINI(infinirtSetDevice(model->device, model->dev_ids[idev])); @@ -22,6 +25,33 @@ __C struct KVCache *createKVCache(const JiugeModel *model) { return cache; } + +__C struct KVCache *createPagedKVCache(const JiugeModel *model, uint32_t max_kvcache_tokens) { + KVCache *cache = new KVCache(); + auto ndev = model->dev_resources.size(); + auto nkvh = model->meta.nkvh / ndev; + // auto max_len = model->meta.dctx; + auto dh = model->meta.dh; + auto kvcache_block_size = model->meta.kvcache_block_size; + auto max_num_blocks = max_kvcache_tokens / kvcache_block_size; + assert(kvcache_block_size > 0); + auto shape = std::vector{max_num_blocks, nkvh, kvcache_block_size, dh}; + for (unsigned int idev = 0; idev < ndev; idev++) { + RUN_INFINI(infinirtSetDevice(model->device, model->dev_ids[idev])); + auto kcache = std::vector>(); + auto vcache = std::vector>(); + for (unsigned int layer = 0; layer < model->meta.nlayer; layer++) { + kcache.push_back(std::move(Tensor::buffer(model->meta.dt_logits, shape))); + vcache.push_back(std::move(Tensor::buffer(model->meta.dt_logits, shape))); + } + cache->k.push_back(kcache); + cache->v.push_back(vcache); + } + + return cache; +} + + __C struct KVCache *duplicateKVCache(const JiugeModel *model, const KVCache *kv_cache, unsigned int seq_len) { @@ -56,4 +86,4 @@ __C void dropKVCache(JiugeModel const *model, KVCache *kv_cache) { } } delete kv_cache; -} +} \ No newline at end of file diff --git a/src/tensor.hpp b/src/tensor.hpp index 59d56642..070ff733 100644 --- a/src/tensor.hpp +++ b/src/tensor.hpp @@ -136,6 +136,12 @@ class Tensor : public std::enable_shared_from_this { std::shared_ptr view_as(const std::vector &new_shape) const; std::shared_ptr view_as(const std::vector &new_shape, const std::vector &new_strides) const; + // template + // void init_value(T value, infiniopHandle_t handle, infinirtStream_t stream); + + // template + // void init_value_simple(T value, infiniopHandle_t handle, infinirtStream_t stream); + ~Tensor(); }; diff --git a/src/tensor/tensor.cpp b/src/tensor/tensor.cpp index b3a80cca..275912f7 100644 --- a/src/tensor/tensor.cpp +++ b/src/tensor/tensor.cpp @@ -410,3 +410,61 @@ void Tensor::debug(const std::string &filename) const { } void Tensor::debug() const { this->debug(""); } + + +// template +// void Tensor::init_value(T value, infiniopHandle_t handle, +// infinirtStream_t stream) { +// ASSERT_EQ(dsize(this->dtype()), sizeof(T)); + +// size_t numel = 1; +// for (size_t dim : this->shape()) { +// numel *= dim; +// } +// if (numel == 0) { +// return; +// } + +// RUN_INFINI(infinirtMemcpy(this->data(), &value, sizeof(T), +// INFINIRT_MEMCPY_H2D)); + +// auto ndim = this->ndim(); +// auto shape = this->shape(); +// auto bcast_strides = std::vector(ndim, 0); +// auto src_desc = TensorDesc::create(this->dtype(), shape, bcast_strides); + +// infiniopRearrangeDescriptor_t rearrange_desc; +// RUN_INFINI(infiniopCreateRearrangeDescriptor( +// handle, &rearrange_desc, this->desc(), src_desc->desc())); +// RUN_INFINI(infiniopRearrange(rearrange_desc, this->data(), this->data(), +// stream)); + +// RUN_INFINI(infiniopDestroyRearrangeDescriptor(rearrange_desc)); +// } +// template +// void Tensor::init_value_simple(T value, infiniopHandle_t handle, +// infinirtStream_t stream) { +// // 1. 安全检查:确保类型匹配 +// ASSERT_EQ(dsize(this->dtype()), sizeof(T)); + +// // 2. 计算张量元素总数 +// size_t numel = 1; +// for (size_t dim : this->shape()) { +// numel *= dim; +// } +// if (numel == 0) { +// return; +// } + +// // 3. 在 Host (CPU) 上创建一个填满目标值的临时数据源 +// std::vector host_data(numel, value); + +// // 4. 使用 Tensor::weight 功能在设备上创建一个临时的、内容正确的源张量。 +// // 这个源张量的形状与当前张量相同,但内存是连续的。 +// // Tensor::weight 内部会处理从 Host 到 Device 的数据拷贝。 +// auto src_tensor = Tensor::weight(host_data.data(), this->dtype(), this->shape()); + +// // 5. 使用现有的、安全的 copyFrom 函数完成赋值。 +// // copyFrom 会正确处理当前张量(this)可能存在的非连续内存布局(strides)。 +// this->copyFrom(src_tensor, handle, stream); +// }