Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 130 additions & 31 deletions backends/intel_hpu/custom_ops/llama_infer/fused_qkv_rope.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ struct FusedQkvRopeParams {
bool use_neox_style = true;
bool transpose = true;
bool with_qkv_biases = false;
bool use_fp8 = false;
bool fp8_proj = false;
bool fp8_out = false;
};

class FusedQkvRope : public HpuFusedOperator {
Expand All @@ -45,6 +46,8 @@ class FusedQkvRope : public HpuFusedOperator {
int qkv_weights_index = 1;
int rotary_embs_index = 2;
int qkv_biases_index = 3;
int scale_input_index = (params.with_qkv_biases ? (qkv_biases_index + 1)
: (rotary_embs_index + 1));

auto src = createTensorFromCT(&ct, src_index);
auto qkv_weights = createTensorFromCT(&ct, qkv_weights_index);
Expand All @@ -66,7 +69,8 @@ class FusedQkvRope : public HpuFusedOperator {

std::vector<synTensor> reshape_inputs;

if ((!params.use_fp8) && (params.transpose)) { // bfloat16 + transpose=true
if ((!params.fp8_proj) &&
(params.transpose)) { // bfloat16 + transpose=true
if (params.with_qkv_biases) {
linear_inputs.push_back(qkv_biases);
}
Expand All @@ -77,10 +81,7 @@ class FusedQkvRope : public HpuFusedOperator {
gemm_params.transpose_a = false;
gemm_params.transpose_b = params.transpose;

if (params.use_fp8) {
int scale_input_index =
(params.with_qkv_biases ? (qkv_biases_index + 1)
: (rotary_embs_index + 1));
if (params.fp8_proj) {
auto scale_input = createTensorFromCT(&ct, scale_input_index);
auto scale_weight = createTensorFromCT(&ct, scale_input_index + 1);
linear_inputs.push_back(scale_input);
Expand Down Expand Up @@ -183,7 +184,12 @@ class FusedQkvRope : public HpuFusedOperator {
inputs_q.push_back(sin_sq);
inputs_q.push_back(cos_sq);

auto q_states = createTensorFromCT(&ct, 0, false);
synTensor q_states = nullptr;
if (params.fp8_out) {
q_states = createTensorNoPresist("q_states", dtype_, outs[0].dims);
} else {
q_states = createTensorFromCT(&ct, 0, false);
}
outputs_q.push_back(q_states);

ns_RoPESt2::ParamsV2 ropeParams;
Expand All @@ -204,12 +210,42 @@ class FusedQkvRope : public HpuFusedOperator {
AddNodeRope<T>(inputs_k, outputs_k, ropeParams, guid_ + "rope_k");

std::vector<synTensor> inputs_concat;
std::vector<synTensor> outputs_concat;
inputs_concat.push_back(k_rope);
inputs_concat.push_back(v_split);
if (params.fp8_out) {
ns_CastKernel::Params cast_to_fp8_params;
cast_to_fp8_params.round_mode = CAST_ROUND_HALF_NE;
auto scale_q = createTensorFromCT(&ct, scale_input_index + 2);
auto scale_k = createTensorFromCT(&ct, scale_input_index + 3);
auto scale_v = createTensorFromCT(&ct, scale_input_index + 4);

auto q_state_fp8 = createTensorFromCT(&ct, 0, false);
std::vector<synTensor> cast_q_ins = {q_states, scale_q};
std::vector<synTensor> cast_q_outs = {q_state_fp8};
AddNodeConvertToFP8<T>(
cast_q_ins, cast_q_outs, cast_to_fp8_params, guid_ + "cast_q");

auto k_state_fp8 = createTensorNoPresist(
"k_state_fp8", ins[qkv_weights_index].type, kv_dims);
std::vector<synTensor> cast_k_ins = {k_rope, scale_k};
std::vector<synTensor> cast_k_outs = {k_state_fp8};
AddNodeConvertToFP8<T>(
cast_k_ins, cast_k_outs, cast_to_fp8_params, guid_ + "cast_k");

auto v_state_fp8 = createTensorNoPresist(
"v_state_fp8", ins[qkv_weights_index].type, kv_dims);
std::vector<synTensor> cast_v_ins = {v_split, scale_v};
std::vector<synTensor> cast_v_outs = {v_state_fp8};
AddNodeConvertToFP8<T>(
cast_v_ins, cast_v_outs, cast_to_fp8_params, guid_ + "cast_v");
inputs_concat.push_back(k_state_fp8);
inputs_concat.push_back(v_state_fp8);
} else {
inputs_concat.push_back(k_rope);
inputs_concat.push_back(v_split);
}

kv_dims[0] *= 2;
auto kv_concat = createTensorNoPresist("kv_concat", dtype_, kv_dims);
auto kv_concat = createTensorNoPresist("kv_concat", outs[1].type, kv_dims);
std::vector<synTensor> outputs_concat;
outputs_concat.push_back(kv_concat);

synConcatenateParams concatParams;
Expand All @@ -221,7 +257,6 @@ class FusedQkvRope : public HpuFusedOperator {

auto kv_state = createTensorFromCT(&ct, 1, false);
outputs_stack.push_back(kv_state);

AddNodeReshape(outputs_concat, outputs_stack, guid_ + "reshaped_kv");
}

Expand All @@ -237,6 +272,9 @@ void FusedQkvRopeKernel(const Context& dev_ctx,
const phi::DenseTensor& rotary_embs,
const paddle::optional<phi::DenseTensor>& scale_input,
const paddle::optional<phi::DenseTensor>& scale_weight,
const paddle::optional<phi::DenseTensor>& scale_q,
const paddle::optional<phi::DenseTensor>& scale_k,
const paddle::optional<phi::DenseTensor>& scale_v,
phi::DenseTensor* query_states,
phi::DenseTensor* key_value_states,
const phi::Scalar& head_dim,
Expand Down Expand Up @@ -272,23 +310,32 @@ void FusedQkvRopeKernel(const Context& dev_ctx,
ct.Add(query_states, false);
ct.Add(key_value_states, false);

std::string guid_prefix = "fused_qkv_rope_fwd_";
std::string guid_prefix = "fused_qkv_rope";
if (qkv_biases) {
ct.Add(qkv_biases.get());
guid_prefix = "fused_qkv_bias_rope_fwd_";
guid_prefix += "_bias";
}

if (scale_input && scale_weight) {
guid_prefix += "_fp8";
ct.Add(scale_input.get());
ct.Add(scale_weight.get());
guid_prefix = "fused_fp8_qkv_rope_fwd_";
if (qkv_biases) {
guid_prefix = "fused_fp8_qkv_bias_rope_fwd_";
if (scale_q && scale_k && scale_v) {
ct.Add(scale_q.get());
ct.Add(scale_k.get());
ct.Add(scale_v.get());
guid_prefix += "_hf8";
} else if (scale_q || scale_k || scale_v) {
throw std::runtime_error(
"Need all scale_q, scale_k and scale_v for FusedFp8QkvRopeKernel");
} else {
guid_prefix += "_bf16";
}
} else if (scale_input || scale_weight) {
throw std::runtime_error(
"Need both scale_input and scale_weight for FusedFp8QkvRopeKernel");
}
guid_prefix += "_fwd_";

OpCacheOperator op_info;
op_info.prepareOpInfo<T, nullptr_t>(
Expand All @@ -308,7 +355,10 @@ void FusedQkvRopeKernel(const Context& dev_ctx,
params.with_qkv_biases = true;
}
if (scale_input) {
params.use_fp8 = true;
params.fp8_proj = true;
}
if (scale_q) {
params.fp8_out = true;
}

FusedQkvRope op(guid_prefix, op_info.datatype_);
Expand All @@ -335,6 +385,9 @@ void CallFusedQkvRopeKernel(
const phi::DenseTensor& rotary_embs,
const paddle::optional<phi::DenseTensor>& scale_input,
const paddle::optional<phi::DenseTensor>& scale_weight,
const paddle::optional<phi::DenseTensor>& scale_q,
const paddle::optional<phi::DenseTensor>& scale_k,
const paddle::optional<phi::DenseTensor>& scale_v,
phi::DenseTensor* query_states,
phi::DenseTensor* key_value_states,
const phi::Scalar& head_dim,
Expand All @@ -350,6 +403,9 @@ void CallFusedQkvRopeKernel(
rotary_embs,
scale_input,
scale_weight,
scale_q,
scale_k,
scale_v,
query_states,
key_value_states,
head_dim,
Expand All @@ -365,6 +421,9 @@ void CallFusedQkvRopeKernel(
rotary_embs,
scale_input,
scale_weight,
scale_q,
scale_k,
scale_v,
query_states,
key_value_states,
head_dim,
Expand Down Expand Up @@ -428,6 +487,9 @@ std::vector<paddle::Tensor> FusedQkvRopeImpl(
*rotary_embs_tensor,
paddle::optional<phi::DenseTensor>(),
paddle::optional<phi::DenseTensor>(),
paddle::optional<phi::DenseTensor>(),
paddle::optional<phi::DenseTensor>(),
paddle::optional<phi::DenseTensor>(),
query_states.get(),
key_value_states.get(),
phi::Scalar(head_dim),
Expand Down Expand Up @@ -465,7 +527,7 @@ std::vector<paddle::DataType> FusedQkvRopeDtype(
return {src_dtype, src_dtype};
}

PD_BUILD_OP(fused_qkv_rope)
PD_BUILD_OP(fused_qkv_rope_bf16)
.Inputs(
{"src", "qkv_weights", paddle::Optional("qkv_biases"), "rotary_embs"})
.Outputs({"query_states", "key_value_states"})
Expand All @@ -483,8 +545,11 @@ std::vector<paddle::Tensor> FusedFp8QkvRopeImpl(
const paddle::Tensor& qkv_weights,
const paddle::optional<paddle::Tensor>& qkv_biases,
const paddle::Tensor& rotary_embs,
const paddle::Tensor& scale_input,
const paddle::Tensor& scale_weight,
const paddle::optional<paddle::Tensor>& scale_input,
const paddle::optional<paddle::Tensor>& scale_weight,
const paddle::optional<paddle::Tensor>& scale_q,
const paddle::optional<paddle::Tensor>& scale_k,
const paddle::optional<paddle::Tensor>& scale_v,
int head_dim,
int num_head,
int total_batch,
Expand All @@ -505,12 +570,34 @@ std::vector<paddle::Tensor> FusedFp8QkvRopeImpl(
qkv_biases_tensor = paddle::optional<phi::DenseTensor>(*qkv_biases_dt);
}

auto _scale_input =
static_cast<const phi::DenseTensor*>(scale_input.impl().get());
auto scale_input_tensor = paddle::optional<phi::DenseTensor>(*_scale_input);
auto _scale_weight =
static_cast<const phi::DenseTensor*>(scale_weight.impl().get());
auto scale_weight_tensor = paddle::optional<phi::DenseTensor>(*_scale_weight);
auto scale_input_tensor = paddle::optional<phi::DenseTensor>();
auto scale_weight_tensor = paddle::optional<phi::DenseTensor>();
if (scale_input) {
auto scale_input_dt =
static_cast<phi::DenseTensor*>(scale_input->impl().get());
scale_input_tensor = paddle::optional<phi::DenseTensor>(*scale_input_dt);
}
if (scale_weight) {
auto scale_weight_dt =
static_cast<phi::DenseTensor*>(scale_weight->impl().get());
scale_weight_tensor = paddle::optional<phi::DenseTensor>(*scale_weight_dt);
}

auto scale_q_tensor = paddle::optional<phi::DenseTensor>();
auto scale_k_tensor = paddle::optional<phi::DenseTensor>();
auto scale_v_tensor = paddle::optional<phi::DenseTensor>();
if (scale_q) {
auto scale_q_dt = static_cast<phi::DenseTensor*>(scale_q->impl().get());
scale_q_tensor = paddle::optional<phi::DenseTensor>(*scale_q_dt);
}
if (scale_k) {
auto scale_k_dt = static_cast<phi::DenseTensor*>(scale_k->impl().get());
scale_k_tensor = paddle::optional<phi::DenseTensor>(*scale_k_dt);
}
if (scale_v) {
auto scale_v_dt = static_cast<phi::DenseTensor*>(scale_v->impl().get());
scale_v_tensor = paddle::optional<phi::DenseTensor>(*scale_v_dt);
}

// allocate memory on device.
int64_t bsz = src.dims()[0];
Expand All @@ -523,13 +610,19 @@ std::vector<paddle::Tensor> FusedFp8QkvRopeImpl(
std::make_shared<phi::DenseTensor>();
query_states->Resize(
phi::make_ddim({total_batch, seq_len, num_head, head_dim}));
dev_ctx->Alloc(query_states.get(), src_tensor->dtype());

std::shared_ptr<phi::DenseTensor> key_value_states =
std::make_shared<phi::DenseTensor>();
key_value_states->Resize(
phi::make_ddim({2, total_batch, seq_len, kv_num_head, head_dim}));
dev_ctx->Alloc(key_value_states.get(), src_tensor->dtype());

if (scale_q) {
dev_ctx->Alloc(query_states.get(), qkv_weights_tensor->dtype());
dev_ctx->Alloc(key_value_states.get(), qkv_weights_tensor->dtype());
} else {
dev_ctx->Alloc(query_states.get(), src_tensor->dtype());
dev_ctx->Alloc(key_value_states.get(), src_tensor->dtype());
}

CallFusedQkvRopeKernel(*dev_ctx,
*src_tensor,
Expand All @@ -538,6 +631,9 @@ std::vector<paddle::Tensor> FusedFp8QkvRopeImpl(
*rotary_embs_tensor,
scale_input_tensor,
scale_weight_tensor,
scale_q_tensor,
scale_k_tensor,
scale_v_tensor,
query_states.get(),
key_value_states.get(),
phi::Scalar(head_dim),
Expand Down Expand Up @@ -579,13 +675,16 @@ std::vector<paddle::DataType> FusedFp8QkvRopeDtype(
return {src_dtype, src_dtype};
}

PD_BUILD_OP(fused_fp8_qkv_rope)
PD_BUILD_OP(fused_qkv_rope)
.Inputs({"src",
"qkv_weights",
paddle::Optional("qkv_biases"),
"rotary_embs",
"scale_input",
"scale_weight"})
"scale_weight",
"scale_q",
"scale_k",
"scale_v"})
.Outputs({"query_states", "key_value_states"})
.Attrs({"head_dim: int",
"num_head: int",
Expand Down
Loading
Loading