diff --git a/custom_ops/gpu_ops/cpp_extensions.cc b/custom_ops/gpu_ops/cpp_extensions.cc index eca815656ca..bc97a49c394 100644 --- a/custom_ops/gpu_ops/cpp_extensions.cc +++ b/custom_ops/gpu_ops/cpp_extensions.cc @@ -1142,13 +1142,6 @@ PYBIND11_MODULE(fastdeploy_ops, m) { */ m.def("recover_decode_task", &RecoverDecodeTask, "recover decode task for scheduler v1 function"); - /** - * extract_text_token_output.cu - * extract_text_token_output - */ - m.def("extract_text_token_output", &ExtractTextTokenOutput, - "extract_text_token_output function"); - m.def("group_swiglu_with_masked", &GroupSwigluWithMasked, "group_swiglu_with_masked function"); diff --git a/custom_ops/gpu_ops/extract_text_token_output.cu b/custom_ops/gpu_ops/extract_text_token_output.cu deleted file mode 100644 index 4459b967eae..00000000000 --- a/custom_ops/gpu_ops/extract_text_token_output.cu +++ /dev/null @@ -1,101 +0,0 @@ -// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "helper.h" - -template -__global__ void extract_text_token_output_kernel(int *max_seq_len, - int *max_seq_len_index, - int *mm_token_num_len, - int *seq_lens_this_time, - int *cu_seqlens_q, - float *hidden_states, - float *output, - const int bsz, - const int hidden_size) { - int bsz_index = threadIdx.x; - int block_idx = blockIdx.x; - if (bsz_index >= bsz) return; - - int max_seq_len_data = max_seq_len[0]; - int max_seq_len_index_data = max_seq_len_index[0]; - int mm_token_num_len_data = mm_token_num_len[0]; - int true_bsz = cu_seqlens_q[bsz_index + 1] - 1; - if (max_seq_len_data == mm_token_num_len_data && bsz_index == max_seq_len_index_data) { - output[bsz_index * hidden_size + block_idx] = 0.0; - } else { - if (seq_lens_this_time[bsz_index] != 0) { - output[bsz_index * hidden_size + block_idx] = hidden_states[true_bsz * hidden_size + block_idx]; - } - } - __syncthreads(); -} - -std::vector ExtractTextTokenOutput( - const paddle::Tensor& max_seq_len, - const paddle::Tensor& max_seq_len_index, - const paddle::Tensor& mm_token_num_len, - const paddle::Tensor& seq_lens_this_time, - const paddle::Tensor& cu_seqlens_q, - const paddle::Tensor& hidden_states) { - - const int bsz = seq_lens_this_time.shape()[0]; - const int hidden_size = hidden_states.shape()[1]; - paddle::Tensor output = paddle::full({bsz, hidden_size}, 1, paddle::DataType::FLOAT32, hidden_states.place()); - - extract_text_token_output_kernel<1024><<>>( - const_cast(max_seq_len.data()), - const_cast(max_seq_len_index.data()), - const_cast(mm_token_num_len.data()), - const_cast(seq_lens_this_time.data()), - const_cast(cu_seqlens_q.data()), - const_cast(hidden_states.data()), - output.data(), - bsz, - hidden_size - ); - return {output}; -} - -std::vector> ExtractTextTokenOutputInferShape(const std::vector& max_seq_len_shape, - const std::vector& max_seq_len_index_shape, - const std::vector& mm_token_num_len_shape, - const std::vector& seq_lens_this_time_shape, - const std::vector& cu_seqlens_q_shape, - const std::vector& hidden_states_shape) { - const int bsz = seq_lens_this_time_shape[0]; - const int hidden_size = hidden_states_shape[1]; - return {{bsz, hidden_size}}; -} - -std::vector ExtractTextTokenOutputInferDtype(const paddle::DataType& max_seq_len_dtype, - const paddle::DataType& max_seq_len_index_dtype, - const paddle::DataType& mm_token_num_len_dtype, - const paddle::DataType& seq_lens_this_time_dtype, - const paddle::DataType& cu_seqlens_q_dtype, - const paddle::DataType& hidden_states_dtype) { - return {hidden_states_dtype}; -} - -PD_BUILD_STATIC_OP(extract_text_token_output) - .Inputs({"max_seq_len", - "max_seq_len_index", - "mm_token_num_len", - "seq_lens_this_time", - "cu_seqlens_q", - "hidden_states"}) - .Outputs({"output"}) - .SetKernelFn(PD_KERNEL(ExtractTextTokenOutput)) - .SetInferShapeFn(PD_INFER_SHAPE(ExtractTextTokenOutputInferShape)) - .SetInferDtypeFn(PD_INFER_DTYPE(ExtractTextTokenOutputInferDtype)); diff --git a/custom_ops/setup_ops.py b/custom_ops/setup_ops.py index 8636c3de440..15832eb079f 100644 --- a/custom_ops/setup_ops.py +++ b/custom_ops/setup_ops.py @@ -290,7 +290,6 @@ def find_end_files(directory, end_str): "gpu_ops/cpp_extensions.cc", "gpu_ops/share_external_data.cu", "gpu_ops/per_token_quant_fp8.cu", - "gpu_ops/extract_text_token_output.cu", "gpu_ops/update_split_fuse_input.cu", "gpu_ops/text_image_index_out.cu", "gpu_ops/text_image_gather_scatter.cu", @@ -536,6 +535,9 @@ def find_end_files(directory, end_str): "gpu_ops/token_penalty_multi_scores.cu", "gpu_ops/sample_kernels/rejection_top_p_sampling.cu", "gpu_ops/sample_kernels/top_k_renorm_probs.cu", + "gpu_ops/text_image_index_out.cu", + "gpu_ops/text_image_gather_scatter.cu", + "gpu_ops/set_data_ipc.cu", "iluvatar_ops/moe_dispatch.cu", "iluvatar_ops/moe_reduce.cu", "iluvatar_ops/paged_attn.cu", @@ -594,7 +596,6 @@ def find_end_files(directory, end_str): "gpu_ops/read_data_ipc.cu", "gpu_ops/dequant_int8.cu", "gpu_ops/share_external_data.cu", - "gpu_ops/extract_text_token_output.cu", "gpu_ops/moe/tritonmoe_preprocess.cu", "gpu_ops/moe/moe_topk_select.cu", "gpu_ops/recover_decode_task.cu", diff --git a/docs/get_started/installation/iluvatar_gpu.md b/docs/get_started/installation/iluvatar_gpu.md index 393f250a1ca..80e969b831e 100644 --- a/docs/get_started/installation/iluvatar_gpu.md +++ b/docs/get_started/installation/iluvatar_gpu.md @@ -409,3 +409,148 @@ Accuracy: 0.962 Invaild: 0.000 Latency: 17332.728 s ``` + +# Run ERNIE-4.5-VL-28B-A3B-Paddle model on iluvatar machine + +## Machine Preparation +First, the `TP=2` when running the ERNIE-4.5-VL-28B-A3B-Paddle model and so you need to prepare a machine with the following configurations: + +| CPU | Memory | Card | Hard Disk| +| :---: | :---: | :---: | :---: | +| x86 | 1TB| 2xBI150| 1TB| + +## Image Preparation +Pull the Docker image + +```bash +docker pull ccr-2vdh3abv-pub.cnc.bj.baidubce.com/device/paddle-ixuca:latest +``` + +## Container Preparation +### Start Container + +```bash +docker run -itd --name paddle_infer -v /usr/src:/usr/src -v /lib/modules:/lib/modules -v /dev:/dev -v /home/paddle:/home/paddle --privileged --cap-add=ALL --pid=host ccr-2vdh3abv-pub.cnc.bj.baidubce.com/device/paddle-ixuca:latest +docker exec -it paddle_infer bash +``` + +/home/paddle contains the model files, *.whl packages, and scripts. + +### Install paddle + +```bash +pip3 install paddlepaddle==3.2.0 -i https://www.paddlepaddle.org.cn/packages/stable/cpu/ +pip3 install paddle-iluvatar-gpu==3.0.0.dev20250926 -i https://www.paddlepaddle.org.cn/packages/nightly/ixuca/ +``` +For latest paddle version on iluvatar. Refer to [PaddlePaddle Installation](https://www.paddlepaddle.org.cn/) + +### Install FastDeploy +```bash +pip3 install fastdeploy_iluvatar_gpu==2.3.0.dev0 -i https://www.paddlepaddle.org.cn/packages/stable/ixuca/ --extra-index-url https://mirrors.aliyun.com/pypi/simple/ +``` + +## Prepare the inference demo script + +script list below: + +`run_demo_vl.sh`: + +```bash +#!/bin/bash +export PADDLE_XCCL_BACKEND=iluvatar_gpu +export INFERENCE_MSG_QUEUE_ID=232132 +export LD_PRELOAD=/usr/local/corex/lib64/libcuda.so.1 +export FD_SAMPLING_CLASS=rejection +export FD_DEBUG=1 +python3 run_demo_vl.py +``` + +`run_demo_vl.py`: + +```python +import io +import requests +from PIL import Image + +from fastdeploy.entrypoints.llm import LLM +from fastdeploy.engine.sampling_params import SamplingParams +from fastdeploy.input.ernie4_5_tokenizer import Ernie4_5Tokenizer + + +PATH = "/home/paddle/ERNIE-4.5-VL-28B-A3B-Paddle" +tokenizer = Ernie4_5Tokenizer.from_pretrained(PATH) + +messages = [ + { + "role": "user", + "content": [ + {"type":"image_url", "image_url": {"url":"https://paddlenlp.bj.bcebos.com/datasets/paddlemix/demo_images/example2.jpg"}}, + {"type":"text", "text":"图中的文物属于哪个年代"} + ] + } +] +prompt = tokenizer.apply_chat_template(messages, tokenize=False) +images, videos = [], [] +for message in messages: + content = message["content"] + if not isinstance(content, list): + continue + for part in content: + if part["type"] == "image_url": + url = part["image_url"]["url"] + image_bytes = requests.get(url).content + img = Image.open(io.BytesIO(image_bytes)) + images.append(img) + elif part["type"] == "video_url": + url = part["video_url"]["url"] + video_bytes = requests.get(url).content + videos.append({ + "video": video_bytes, + "max_frames": 30 + }) + +sampling_params = SamplingParams(temperature=0.1, max_tokens=6400) +llm = LLM(model=PATH, tensor_parallel_size=2, max_model_len=32768, block_size=16, quantization="wint8", limit_mm_per_prompt={"image": 100}, reasoning_parser="ernie-45-vl") +outputs = llm.generate(prompts={ + "prompt": prompt, + "multimodal_data": { + "image": images, + "video": videos + } +}, sampling_params=sampling_params) +# Output results +for output in outputs: + prompt = output.prompt + generated_text = output.outputs.text + reasoning_text = output.outputs.reasoning_content + print(f"generated_text={generated_text}") +``` + +## run demo + +```bash +./run_demo_vl.sh +``` + +The following logs will be printed: + +``` +[2025-09-23 10:13:10,844] [ INFO] - Using download source: huggingface +[2025-09-23 10:13:10,844] [ INFO] - loading configuration file /home/paddle/ERNIE-4.5-VL-28B-A3B-Paddle/preprocessor_config.json +[2025-09-23 10:13:10,845] [ INFO] - Using download source: huggingface +[2025-09-23 10:13:10,845] [ INFO] - Loading configuration file /home/paddle/ERNIE-4.5-VL-28B-A3B-Paddle/generation_config.json +/usr/local/lib/python3.10/site-packages/paddleformers/generation/configuration_utils.py:250: UserWarning: using greedy search strategy. However, `temperature` is set to `0.2` -- this flag is only used in sample-based generation modes. You should set `decode_strategy="greedy_search" ` or +unset `temperature`. This was detected when initializing the generation config instance, which means the corresponding file may hold incorrect parameterization and should be fixed. + warnings.warn( +/usr/local/lib/python3.10/site-packages/paddleformers/generation/configuration_utils.py:255: UserWarning: using greedy search strategy. However, `top_p` is set to `0.8` -- this flag is only used in sample-based generation modes. You should set `decode_strategy="greedy_search" ` or unset +`top_p`. This was detected when initializing the generation config instance, which means the corresponding file may hold incorrect parameterization and should be fixed. warnings.warn( +INFO 2025-09-23 10:13:11,969 3880245 engine.py[line:136] Waiting worker processes ready... +Loading Weights: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [02:21<00:00, 1.41s/it] +Loading Layers: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:15<00:00, 6.65it/s] +INFO 2025-09-23 10:15:53,672 3880245 engine.py[line:173] Worker processes are launched with 181.2426426410675 seconds. +prompts: 100%|███████████████████████████████████| 1/1 [01:52<00:00, 112.74s/it, est. speed input: 0.00 toks/s, output: 0.00 toks/s] +generated_text= +图中的文物是**北齐释迦牟尼佛像**,属于**北齐(公元550年-577年)**的文物。 + +这件佛像具有典型的北齐风格,佛像结跏趺坐于莲花座上,身披通肩袈裟,面部圆润,神态安详,体现了北齐佛教艺术的独特魅力。 +``` diff --git a/docs/zh/get_started/installation/iluvatar_gpu.md b/docs/zh/get_started/installation/iluvatar_gpu.md index 1ece14ea216..72d683885c2 100644 --- a/docs/zh/get_started/installation/iluvatar_gpu.md +++ b/docs/zh/get_started/installation/iluvatar_gpu.md @@ -409,3 +409,148 @@ Accuracy: 0.962 Invaild: 0.000 Latency: 17332.728 s ``` + +# 如何在天数机器上运行ERNIE-4.5-VL-28B-A3B-Paddle model + +## 准备机器 +首先运行ERNIE-4.5-VL-28B-A3B-Paddle模型需要`TP=2`, 所以您需要准备以下配置的机器:: + +| CPU | Memory | Card | Hard Disk| +| :---: | :---: | :---: | :---: | +| x86 | 1TB| 2xBI150| 1TB| + +## 准备镜像 +拉取镜像: + +```bash +docker pull ccr-2vdh3abv-pub.cnc.bj.baidubce.com/device/paddle-ixuca:latest +``` + +## 准备容器 +### 启动容器 + +```bash +docker run -itd --name paddle_infer -v /usr/src:/usr/src -v /lib/modules:/lib/modules -v /dev:/dev -v /home/paddle:/home/paddle --privileged --cap-add=ALL --pid=host ccr-2vdh3abv-pub.cnc.bj.baidubce.com/device/paddle-ixuca:latest +docker exec -it paddle_infer bash +``` + +/home/paddle 为模型文件、whl包、脚本所在目录。 + +### Install paddle + +```bash +pip3 install paddlepaddle==3.2.0 -i https://www.paddlepaddle.org.cn/packages/stable/cpu/ +pip3 install paddle-iluvatar-gpu==3.0.0.dev20250926 -i https://www.paddlepaddle.org.cn/packages/nightly/ixuca/ +``` +获取Paddle的最新安装版本: [PaddlePaddle Installation](https://www.paddlepaddle.org.cn/) + +### 安装FastDeploy +```bash +pip3 install fastdeploy_iluvatar_gpu==2.3.0.dev0 -i https://www.paddlepaddle.org.cn/packages/stable/ixuca/ --extra-index-url https://mirrors.aliyun.com/pypi/simple/ +``` + +## 准备推理demo脚本 + +脚本列表如下所示: + +`run_demo_vl.sh`: + +```bash +#!/bin/bash +export PADDLE_XCCL_BACKEND=iluvatar_gpu +export INFERENCE_MSG_QUEUE_ID=232132 +export LD_PRELOAD=/usr/local/corex/lib64/libcuda.so.1 +export FD_SAMPLING_CLASS=rejection +export FD_DEBUG=1 +python3 run_demo_vl.py +``` + +`run_demo_vl.py`: + +```python +import io +import requests +from PIL import Image + +from fastdeploy.entrypoints.llm import LLM +from fastdeploy.engine.sampling_params import SamplingParams +from fastdeploy.input.ernie4_5_tokenizer import Ernie4_5Tokenizer + + +PATH = "/home/paddle/ERNIE-4.5-VL-28B-A3B-Paddle" +tokenizer = Ernie4_5Tokenizer.from_pretrained(PATH) + +messages = [ + { + "role": "user", + "content": [ + {"type":"image_url", "image_url": {"url":"https://paddlenlp.bj.bcebos.com/datasets/paddlemix/demo_images/example2.jpg"}}, + {"type":"text", "text":"图中的文物属于哪个年代"} + ] + } +] +prompt = tokenizer.apply_chat_template(messages, tokenize=False) +images, videos = [], [] +for message in messages: + content = message["content"] + if not isinstance(content, list): + continue + for part in content: + if part["type"] == "image_url": + url = part["image_url"]["url"] + image_bytes = requests.get(url).content + img = Image.open(io.BytesIO(image_bytes)) + images.append(img) + elif part["type"] == "video_url": + url = part["video_url"]["url"] + video_bytes = requests.get(url).content + videos.append({ + "video": video_bytes, + "max_frames": 30 + }) + +sampling_params = SamplingParams(temperature=0.1, max_tokens=6400) +llm = LLM(model=PATH, tensor_parallel_size=2, max_model_len=32768, block_size=16, quantization="wint8", limit_mm_per_prompt={"image": 100}, reasoning_parser="ernie-45-vl") +outputs = llm.generate(prompts={ + "prompt": prompt, + "multimodal_data": { + "image": images, + "video": videos + } +}, sampling_params=sampling_params) +# Output results +for output in outputs: + prompt = output.prompt + generated_text = output.outputs.text + reasoning_text = output.outputs.reasoning_content + print(f"generated_text={generated_text}") +``` + +## 运行demo + +```bash +./run_demo_vl.sh +``` + +打印如下log: + +``` +[2025-09-23 10:13:10,844] [ INFO] - Using download source: huggingface +[2025-09-23 10:13:10,844] [ INFO] - loading configuration file /home/paddle/ERNIE-4.5-VL-28B-A3B-Paddle/preprocessor_config.json +[2025-09-23 10:13:10,845] [ INFO] - Using download source: huggingface +[2025-09-23 10:13:10,845] [ INFO] - Loading configuration file /home/paddle/ERNIE-4.5-VL-28B-A3B-Paddle/generation_config.json +/usr/local/lib/python3.10/site-packages/paddleformers/generation/configuration_utils.py:250: UserWarning: using greedy search strategy. However, `temperature` is set to `0.2` -- this flag is only used in sample-based generation modes. You should set `decode_strategy="greedy_search" ` or +unset `temperature`. This was detected when initializing the generation config instance, which means the corresponding file may hold incorrect parameterization and should be fixed. + warnings.warn( +/usr/local/lib/python3.10/site-packages/paddleformers/generation/configuration_utils.py:255: UserWarning: using greedy search strategy. However, `top_p` is set to `0.8` -- this flag is only used in sample-based generation modes. You should set `decode_strategy="greedy_search" ` or unset +`top_p`. This was detected when initializing the generation config instance, which means the corresponding file may hold incorrect parameterization and should be fixed. warnings.warn( +INFO 2025-09-23 10:13:11,969 3880245 engine.py[line:136] Waiting worker processes ready... +Loading Weights: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [02:21<00:00, 1.41s/it] +Loading Layers: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:15<00:00, 6.65it/s] +INFO 2025-09-23 10:15:53,672 3880245 engine.py[line:173] Worker processes are launched with 181.2426426410675 seconds. +prompts: 100%|███████████████████████████████████| 1/1 [01:52<00:00, 112.74s/it, est. speed input: 0.00 toks/s, output: 0.00 toks/s] +generated_text= +图中的文物是**北齐释迦牟尼佛像**,属于**北齐(公元550年-577年)**的文物。 + +这件佛像具有典型的北齐风格,佛像结跏趺坐于莲花座上,身披通肩袈裟,面部圆润,神态安详,体现了北齐佛教艺术的独特魅力。 +``` diff --git a/fastdeploy/model_executor/layers/attention/iluvatar_attn_backend.py b/fastdeploy/model_executor/layers/attention/iluvatar_attn_backend.py index db3a09ce8fc..a88ea4b4a7e 100644 --- a/fastdeploy/model_executor/layers/attention/iluvatar_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/iluvatar_attn_backend.py @@ -86,11 +86,19 @@ def __init__(self, fd_config: FDConfig, kv_num_heads: int, num_heads: int, head_ self.scale = 1.0 / sqrt(head_dim) self.num_layers = fd_config.model_config.num_hidden_layers self.dtype = paddle.get_default_dtype() + self.enable_mm = fd_config.model_config.enable_mm def init_attention_metadata(self, forward_meta: ForwardMeta): """Initialize attntion metadata hence all layers in the forward pass can reuse it.""" - self.rope_cos = forward_meta.rotary_embs[0, 0, :, :, :] - self.rope_sin = forward_meta.rotary_embs[1, 0, :, :, :] + if self.enable_mm: + # VL: TODO: The first 0 may need to be replaced with batch_id + # of max_num_seqs when running multiple batch case later + self.rope_cos = forward_meta.rotary_embs[0, 0, 0, :, :, :] + self.rope_sin = forward_meta.rotary_embs[0, 1, 0, :, :, :] + else: + # text + self.rope_cos = forward_meta.rotary_embs[0, 0, :, :, :] + self.rope_sin = forward_meta.rotary_embs[1, 0, :, :, :] self.prefill_info_dict = {} self.decode_info_dict = {} self.prefill_info_dict["batch_ids"] = paddle.where(forward_meta.seq_lens_encoder)[0] @@ -115,7 +123,10 @@ def init_attention_metadata(self, forward_meta: ForwardMeta): self.prefill_info_dict["cu_seqlens_q"][1:] = forward_meta.seq_lens_encoder[ self.prefill_info_dict["batch_ids"], 0 ] - self.prefill_info_dict["cu_seqlens_q"] = paddle.cumsum(self.prefill_info_dict["cu_seqlens_q"]) + # NOTE: The explicit dtype='int32' is required for Iluvatar hardware compatibility. + self.prefill_info_dict["cu_seqlens_q"] = paddle.cumsum( + self.prefill_info_dict["cu_seqlens_q"], dtype="int32" + ) self.tmp_buffer = paddle.zeros( [self.prefill_num_tokens + self.decode_len, self.hidden_dim], dtype=self.dtype diff --git a/fastdeploy/model_executor/layers/rotary_embedding.py b/fastdeploy/model_executor/layers/rotary_embedding.py index eb1d65695a3..9859ea91544 100644 --- a/fastdeploy/model_executor/layers/rotary_embedding.py +++ b/fastdeploy/model_executor/layers/rotary_embedding.py @@ -411,6 +411,9 @@ def __call__(self, position_ids): rot_emb[0] = cos_thw rot_emb[1] = sin_thw + if current_platform.is_iluvatar(): + rot_emb = paddle.stack([rot_emb, rot_emb], axis=-1).reshape([2, 1, self.max_position, 1, self.rotary_dim]) + return rot_emb diff --git a/fastdeploy/model_executor/models/ernie4_5_vl/dfnrope/modeling.py b/fastdeploy/model_executor/models/ernie4_5_vl/dfnrope/modeling.py index 1933db2b0fd..de5366e67b0 100644 --- a/fastdeploy/model_executor/models/ernie4_5_vl/dfnrope/modeling.py +++ b/fastdeploy/model_executor/models/ernie4_5_vl/dfnrope/modeling.py @@ -35,6 +35,7 @@ from fastdeploy.model_executor.layers.utils import divide, get_tensor from fastdeploy.model_executor.utils import set_weight_attrs +from fastdeploy.platforms import current_platform from .activation import ACT2FN from .configuration import DFNRopeVisionTransformerConfig @@ -174,7 +175,7 @@ def __init__( mp_group=fleet.get_hybrid_communicate_group().get_model_parallel_group(), weight_attr=None, has_bias=True, - fuse_matmul_bias=True, + fuse_matmul_bias=False if current_platform.is_iluvatar() else True, gather_output=False, ) self.proj = RowParallelLinear( diff --git a/fastdeploy/model_executor/models/ernie4_5_vl/image_op.py b/fastdeploy/model_executor/models/ernie4_5_vl/image_op.py index 4324d921f34..e5fbb3be3f9 100644 --- a/fastdeploy/model_executor/models/ernie4_5_vl/image_op.py +++ b/fastdeploy/model_executor/models/ernie4_5_vl/image_op.py @@ -26,6 +26,11 @@ text_image_gather_scatter, text_image_index_out, ) +elif current_platform.is_iluvatar(): + from fastdeploy.model_executor.ops.iluvatar import ( + text_image_gather_scatter, + text_image_index_out, + ) else: raise ImportError("Unsupported platform, only support CUDA and XPU") diff --git a/fastdeploy/model_executor/models/ernie4_5_vl/modeling_resampler.py b/fastdeploy/model_executor/models/ernie4_5_vl/modeling_resampler.py index 90ab12c7126..8a54b0d2cff 100644 --- a/fastdeploy/model_executor/models/ernie4_5_vl/modeling_resampler.py +++ b/fastdeploy/model_executor/models/ernie4_5_vl/modeling_resampler.py @@ -31,6 +31,7 @@ scatter_axis, ) from fastdeploy.model_executor.utils import set_weight_attrs +from fastdeploy.platforms import current_platform class ScatterOp(PyLayer): @@ -172,7 +173,7 @@ def __init__( self.spatial_dim, input_is_parallel=True, has_bias=True, - fuse_matmul_bias=True, + fuse_matmul_bias=False if current_platform.is_iluvatar() else True, ) if self.tensor_parallel_degree > 1 else nn.Linear(self.spatial_dim, self.spatial_dim) diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index bd39d6efb53..092e981dc33 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -44,7 +44,10 @@ from fastdeploy.platforms import current_platform if current_platform.is_iluvatar(): - from fastdeploy.model_executor.ops.iluvatar import set_value_by_flags_and_idx + from fastdeploy.model_executor.ops.iluvatar import ( + set_data_ipc, + set_value_by_flags_and_idx, + ) recover_decode_task = None share_external_data = None diff --git a/fastdeploy/worker/iluvatar_model_runner.py b/fastdeploy/worker/iluvatar_model_runner.py index 5ea6408be7e..e8ef1b69cb5 100644 --- a/fastdeploy/worker/iluvatar_model_runner.py +++ b/fastdeploy/worker/iluvatar_model_runner.py @@ -14,6 +14,8 @@ # limitations under the License. """ +import paddle + from fastdeploy import envs from fastdeploy.config import FDConfig from fastdeploy.model_executor.layers.attention import IluvatarAttnBackend @@ -36,8 +38,23 @@ def __init__( assert self.guided_backend is None, "Iluvatar does not support guided decoding" assert not envs.ENABLE_V1_KVCACHE_SCHEDULER, "Iluvatar does not support v1 kvcache scheduler" assert not self.cache_config.enable_prefix_caching, "Iluvatar does not support prefix caching" + self.mla_cache = envs.FD_ATTENTION_BACKEND == "MLA_ATTN" + assert not self.mla_cache, "Iluvatar does not support MLA" + if self.enable_mm: + assert ( + not self.cache_config.enable_chunked_prefill + ), "Iluvatar does not support chunked prefill for VL model" + # VL neox style = True + if self.enable_mm: + emb_shape = self.share_inputs["rope_emb"].shape + emb_shape[-1] *= 2 + self.share_inputs["rope_emb"] = paddle.full( + shape=emb_shape, + fill_value=0, + dtype="float32", + ) - def initialize_attn_backend(self) -> None: + def _initialize_attn_backend(self) -> None: """ Initialize attention backends """ diff --git a/fastdeploy/worker/iluvatar_worker.py b/fastdeploy/worker/iluvatar_worker.py index c1b06058875..f8501124db1 100644 --- a/fastdeploy/worker/iluvatar_worker.py +++ b/fastdeploy/worker/iluvatar_worker.py @@ -40,6 +40,8 @@ def __init__( local_rank: int, rank: int, ): + if fd_config.model_config.enable_mm: + paddle.set_flags({"FLAGS_enable_ixattnbkd": True, "FLAGS_enable_ixdnn_attn": False}) super(IluvatarWorker, self).__init__( fd_config=fd_config, local_rank=local_rank, diff --git a/requirements_iluvatar.txt b/requirements_iluvatar.txt index 7983b3b5843..188ba4ae4a2 100644 --- a/requirements_iluvatar.txt +++ b/requirements_iluvatar.txt @@ -10,7 +10,7 @@ tqdm pynvml uvicorn==0.29.0 fastapi -paddleformers +paddleformers==0.3.1 redis etcd3 httpx @@ -38,3 +38,4 @@ opentelemetry-distro opentelemetry-exporter-otlp opentelemetry-instrumentation-fastapi partial_json_parser +msgspec diff --git a/tests/operators/test_extract_text_token_output.py b/tests/operators/test_extract_text_token_output.py deleted file mode 100644 index ef180460e73..00000000000 --- a/tests/operators/test_extract_text_token_output.py +++ /dev/null @@ -1,110 +0,0 @@ -""" -# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License" -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" - -import unittest - -import numpy as np -import paddle - -from fastdeploy.model_executor.ops.gpu import extract_text_token_output - - -class TestExtractTextTokenOutput(unittest.TestCase): - def setUp(self): - paddle.set_device("gpu") - np.random.seed(42) - - def _run_and_check( - self, - bsz, - hidden_size, - max_seq_len_v, - max_seq_len_index_v, - mm_token_num_len_v, - seq_lens_this_time_v, - cu_seqlens_q_v, - hidden_states_v, - ): - - max_seq_len = paddle.to_tensor([max_seq_len_v], dtype="int32") - max_seq_len_index = paddle.to_tensor([max_seq_len_index_v], dtype="int32") - mm_token_num_len = paddle.to_tensor([mm_token_num_len_v], dtype="int32") - seq_lens_this_time = paddle.to_tensor(seq_lens_this_time_v, dtype="int32") - cu_seqlens_q = paddle.to_tensor(cu_seqlens_q_v, dtype="int32") - hidden_states = paddle.to_tensor(hidden_states_v, dtype="float32") - - out = extract_text_token_output( - max_seq_len, max_seq_len_index, mm_token_num_len, seq_lens_this_time, cu_seqlens_q, hidden_states - )[0] - out_np = out.numpy() - - expect = np.ones((bsz, hidden_size), dtype="float32") - for i in range(bsz): - true_bsz = cu_seqlens_q_v[i + 1] - 1 - if (max_seq_len_v == mm_token_num_len_v) and (i == max_seq_len_index_v): - expect[i, :] = 0.0 - else: - if seq_lens_this_time_v[i] != 0: - expect[i, :] = hidden_states_v[true_bsz, :] - - if out_np.ndim == 1: - np.testing.assert_allclose(out_np, expect[0], rtol=1e-5, atol=1e-5) - else: - np.testing.assert_allclose(out_np, expect, rtol=1e-5, atol=1e-5) - - def test_basic_case(self): - bsz, hidden_size = 2, 4 - max_seq_len_v = 3 - max_seq_len_index_v = 0 - mm_token_num_len_v = 2 - seq_lens_this_time_v = [2, 1] - cu_seqlens_q_v = [0, 2, 3] - hidden_states_v = np.arange(12).reshape(3, 4).astype("float32") - - self._run_and_check( - bsz, - hidden_size, - max_seq_len_v, - max_seq_len_index_v, - mm_token_num_len_v, - seq_lens_this_time_v, - cu_seqlens_q_v, - hidden_states_v, - ) - - def test_zero_case(self): - bsz, hidden_size = 2, 4 - max_seq_len_v = 5 - max_seq_len_index_v = 1 - mm_token_num_len_v = 5 - seq_lens_this_time_v = [1, 1] - cu_seqlens_q_v = [0, 1, 2] - hidden_states_v = np.random.randn(2, hidden_size).astype("float32") - - self._run_and_check( - bsz, - hidden_size, - max_seq_len_v, - max_seq_len_index_v, - mm_token_num_len_v, - seq_lens_this_time_v, - cu_seqlens_q_v, - hidden_states_v, - ) - - -if __name__ == "__main__": - unittest.main()