|
1 | 1 | use crate::flash_attn::flash_attn_varlen; |
2 | 2 | use crate::layers::{get_cos_sin, get_inv_freqs, HiddenAct, Linear, RMSNorm}; |
3 | 3 | use crate::models::{Model, Qwen3Config}; |
4 | | -use candle::{DType, Device, IndexOp, Result, Tensor}; |
| 4 | +use candle::{DType, Device, IndexOp, Result, Tensor, D}; |
5 | 5 | use candle_nn::{Embedding, Module, VarBuilder}; |
6 | 6 | use candle_rotary::apply_rotary_inplace; |
7 | 7 | use text_embeddings_backend_core::{Batch, ModelType, Pool}; |
@@ -592,10 +592,13 @@ impl Model for FlashQwen3Model { |
592 | 592 |
|
593 | 593 | let h_last = Tensor::stack(&last_hidden_states, 0)?; // [bs, hidden_size] |
594 | 594 |
|
595 | | - let true_id = 9693u32; |
596 | | - let false_id = 2152u32; |
| 595 | + // Correct token IDs for Qwen3 (verified from tokenizer) |
| 596 | + let yes_id = 9454u32; // "yes" token ID |
| 597 | + let no_id = 2901u32; // "no" token ID |
597 | 598 |
|
598 | | - let ids = Tensor::from_vec(vec![false_id, true_id], 2, &self.device)?; |
| 599 | + tracing::debug!("Using Qwen3 token IDs - yes: {}, no: {}", yes_id, no_id); |
| 600 | + |
| 601 | + let ids = Tensor::from_vec(vec![no_id, yes_id], 2, &self.device)?; |
599 | 602 | let w = self.lm_head_weight.index_select(&ids, 0)?; // [2, hidden_size] |
600 | 603 | let logits = h_last.matmul(&w.t()?)?; // [bs, 2] (no, yes) |
601 | 604 | let log_probs = candle_nn::ops::log_softmax(&logits, D::Minus1)?; |
|
0 commit comments