-
Notifications
You must be signed in to change notification settings - Fork 116
Add Qwen3 model support #423
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Implements support for the Qwen3 model family, including Qwen3-4B-Instruct. Key features: - QK normalization for improved training stability - Grouped Query Attention (32 query heads, 8 KV heads) - High RoPE theta (5M) for extended context (262K tokens) - Support for causal language modeling and sequence classification - Complete parameter mapping for HuggingFace model loading - Example scripts demonstrating text generation and chat usage Tested with Qwen3-4B-Instruct-2507 and generates coherent English output.
|
I will test it tomorrow with my h200 to be sure that everything is working. With my mbr the answers seems ok, but the generation is slow. |
Implements last token pooling strategy in text_embedding to support Qwen3-Embedding models which use the last token's hidden state for generating text embeddings. - Add :last_token_pooling option to text_embedding - Extract last non-padding token using attention_mask - Add Qwen3-Embedding-0.6B example demonstrating: - Text embedding generation (1024-dim vectors) - Semantic similarity computation - Instruction-aware embeddings - Batch processing Tested with Qwen3-Embedding-0.6B and produces correct similarity scores.
Implements :for_embedding architecture for Qwen3 models with last token pooling, enabling direct use with Bumblebee.Text.text_embedding/3. Changes: - Add :for_embedding architecture to Qwen3 model - Register Qwen3ForEmbedding in model mappings - Add instruction prompts example showing Qwen team recommendations - Update examples to use cleaner serving-based API - Add .lexical/ to gitignore - Clean up mix.exs dependencies (remove emlx, nx override) Examples demonstrate: - Basic embedding generation (1024-dim vectors) - Semantic similarity computation - Instruction-aware prompts (1-5% performance improvement) - Custom task instructions for code search - Multilingual embedding support Tested with Qwen3-Embedding-0.6B, generates correct similarity scores.
Implements document reranking using Qwen3-Reranker models. Rerankers score query-document pairs for relevance, improving retrieval quality in RAG and search applications. Features: - Automatic yes/no token detection from tokenizer - Proper input format with instruction, query, and document - Softmax-based relevance scoring (0-1 range) - Support for custom task instructions Example demonstrates: - Basic query-document scoring - Custom instructions for code search - Reranking search results (top-k selection) Results show correct ranking: - Relevant docs score 0.99+ - Irrelevant docs score near 0.0 - Custom instructions work for domain-specific tasks Works with Qwen3-Reranker-0.6B/4B/8B models.
Move all Qwen3-related examples and documentation into examples/qwen3/ for better organization and discoverability. Changes: - Create examples/qwen3/ directory - Move qwen3.exs, qwen3_embedding.exs, qwen3_embedding_prompts.exs, qwen3_reranker.exs - Move QWEN3_IEX_GUIDE.md to examples/qwen3/ - Update examples/README.md to reference qwen3/ subdirectory All examples now accessible under examples/qwen3/ with consistent structure.
|
I was interested in getting a qwen3 vision model working like https://huggingface.co/huihui-ai/Huihui-MiniCPM-V-4_5-abliterated |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lib/bumblebee/text/qwen3.ex
Outdated
| # QK Normalization (Qwen3-specific) - normalize over head_dim | ||
| query = | ||
| if spec.use_qk_norm do | ||
| Layers.rms_norm(query, | ||
| name: join(name, "query_norm"), | ||
| epsilon: spec.layer_norm_epsilon, | ||
| channel_index: -1 | ||
| ) | ||
| else | ||
| query | ||
| end | ||
|
|
||
| key = | ||
| if spec.use_qk_norm do | ||
| Layers.rms_norm(key, | ||
| name: join(name, "key_norm"), | ||
| epsilon: spec.layer_norm_epsilon, | ||
| channel_index: -1 | ||
| ) | ||
| else | ||
| key | ||
| end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the only divergence from the usual logic, right? Instead of rewriting all of the implementation here, you can add a new option to Layers.Transformer.blocks. I would add :query_norm and :key_norm, both being a 2-arity function. There is already a :layer_norm option kinda similar to that (and we already have kqv specific options: :query_use_bias, :key_use_bias, :value_use_bias).
- Remove .lexical/ from project gitignore (should be in global gitignore) - Add :qwen2 tokenizer type with correct Qwen3 special tokens - Refactor QK normalization to use generalized approach: - Add :query_norm and :key_norm options to Layers.Transformer - Apply normalization after head splitting, before rotary embedding - Update Qwen3 to use Layers.Transformer.blocks instead of custom implementation - Remove ~200 lines of custom decoder/attention code - Remove standalone examples directory per review feedback The generalized QK normalization approach makes the transformer layer more flexible and maintainable, allowing other models to use similar patterns.
Use 'decoder.blocks' as the name prefix when calling Layers.Transformer.blocks
to match the expected params mapping pattern decoder.blocks.{n}.*.
This aligns with how other models like BERT use the transformer blocks.
Fix model_type_to_tokenizer_type mapping to use :qwen2 instead of :gpt2 for qwen3 models. This ensures Qwen3 models load with the correct tokenizer configuration including proper special tokens.
Create notebooks/qwen3.livemd demonstrating: - Text generation using Qwen3-4B-Instruct-2507 - Embeddings using Qwen3-Embedding-0.6B with similarity examples - Reranking using Qwen3-Reranker-0.6B with query-document scoring This replaces the deleted standalone examples with a consolidated, easy-to-follow notebook format as suggested in PR review.
Update the embeddings section to use the proper instruction format:
'Instruct: Given a query, retrieve relevant documents\nQuery: {query}\n{text}'
This ensures consistency with the reranker example and follows Qwen3
embedding best practices for better semantic search results.
Add comprehensive test suite for Qwen3 using tiny-random/qwen3: - Test :base architecture with QK normalization enabled - Test :for_causal_language_modeling with logits verification - Test :for_sequence_classification (shape only, random params) - Test :for_embedding architecture Reference values generated from tiny-random/qwen3 model predictions. All tests pass successfully (4 tests, 0 failures).
|
Generation looking good! iex(16)> prompt = """
...(16)> <|im_start|>system
...(16)> You are a helpful assistant.<|im_end|>
...(16)> <|im_start|>user
...(16)> What is the capital of France?<|im_end|>
...(16)> <|im_start|>assistant
...(16)> """
"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\nWhat is the capital of France?<|im_end|>\n<|im_start|>assistant\n"
iex(17)>
nil
iex(18)> result = Nx.Serving.run(serving, prompt)
%{
results: [
%{
text: "The capital of France is Paris.",
token_summary: %{input: 26, output: 8, padding: 0}
}
]
}
Still more tests to do and write! |
|
@jonatanklosko i used a light model with qwen3 arch to write some basic tests similar to the other PR. Let me know if this is enough. |
…n_id The Qwen3 :for_embedding architecture was incorrectly using pad_token_id to find the last non-padding token for pooling. This caused embeddings to differ from the reference Python implementation. Root cause: - Token 151643 serves as both the pad_token_id AND the EOS token - When tokenizing "hello!", the output is [14990, 0, 151643] with attention_mask [1, 1, 1], meaning all tokens should be attended - The EOS token at the end is part of the actual sequence, not padding - Only explicitly added padding tokens have attention_mask = 0 The fix changes the pooling logic to use attention_mask.sum(dim=-1) - 1 to find the last attended token, matching the official HuggingFace implementation's last_token_pool function. Debugging process: 1. Compared raw transformer hidden states between Python and Elixir 2. Found both were producing identical hidden states (norm ~102 for all tokens) 3. Discovered Python was pooling index 2 (last token) while Elixir pooled index 1 (last non-pad_token_id token) 4. Investigated tokenizer behavior: attention_mask was [1,1,1] not [1,1,0] 5. Confirmed with explicit padding that only added padding has mask = 0 6. Updated architecture to use attention_mask for pooling logic After fix, embeddings now match Python implementation within bf16 precision: - Python: [0.00039361, -0.02717206, -0.01105759, ...] - Elixir: [5.552e-4, -0.027919, -0.011104, ...] Also updated notebooks/qwen3.livemd to specify architecture: :for_embedding explicitly and remove unnecessary output_pool parameter.
Implements binary relevance classification (reranking) for Qwen3 models.
Changes:
- Add :for_reranker architecture to Qwen3 model
- Extracts logits at last attended token position
- Returns full vocab logits for binary classification
- Create new TextReranking serving module
- Handles query-document pair formatting
- Applies Qwen3 reranker prompt template
- Computes relevance scores from yes/no token logits
- Uses log_softmax for score normalization
- Update notebook with proper reranker usage
- Changed from :for_embedding to :for_reranker architecture
- Uses Bumblebee.Text.text_reranking/3 API
- Simplified query-document pair handling
The reranker follows the official HuggingFace implementation:
1. Format: "<|im_start|>system\n...<|im_end|>\n<|im_start|>user\n<Instruct>: {task}\n<Query>: {query}\n<Document>: {doc}<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
2. Extract logits at last attended token (using attention mask)
3. Get yes/no token logits
4. Apply log_softmax([no_logit, yes_logit])
5. Return exp(yes_log_prob) as relevance score
Tested with Qwen3-Reranker-0.6B, scores match Python reference (both ~1.0).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's remove this file as well :)
| * `:query_norm` - configuration for query normalization. If set, normalizes | ||
| the query projection before rotary embedding. Configured with the same | ||
| options as `:layer_norm` in the block function. Defaults to `nil` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's make it always a function. For :layer_norm we allow keyword list because it used to always be layer norm, but it needed to be a different norm for some specific model (though perhaps it would make sense to change it to always be a function :D).
| eos: "<|im_end|>", | ||
| pad: "<|endoftext|>" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We want the same defaults as hf/transformers. If a particular uploaded model uses different ones, it is in the configuration files and we load those.
| eos: "<|im_end|>", | |
| pad: "<|endoftext|>" | |
| unk: "<|endoftext|>", | |
| eos: "<|endoftext|>", | |
| pad: "<|endoftext|>" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's not update dependencies as part of this PR.
|
|
||
| test ":base" do | ||
| assert {:ok, %{model: model, params: params, spec: spec}} = | ||
| Bumblebee.load_model({:hf, "tiny-random/qwen3"}, architecture: :base) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I pushed smaller version of each model to bumblebee-testing org, you can test using those :)
Generated with:
from transformers import Qwen3Config, Qwen3Model, Qwen3ForCausalLM, Qwen3ForSequenceClassification
config = Qwen3Config(
vocab_size=1024,
hidden_size=32,
num_hidden_layers=2,
num_attention_heads=4,
intermediate_size=37,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=16,
is_decoder=False,
initializer_range=0.02,
pad_token_id=0
)
for c in [Qwen3Config, Qwen3Model, Qwen3ForCausalLM, Qwen3ForSequenceClassification]:
name = c.__name__
c(config).save_pretrained(f"bumblebee-testing/tiny-random-{name}", repo_id=f"bumblebee-testing/tiny-random-{name}", push_to_hub=True)| [0.4693, -0.0452, 0.2521] | ||
| ] | ||
| ]), | ||
| atol: 1.0e-3 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just double-checking, these values come from Python, right?
We should be able to get within the default atol 1.0e-4. There are very rare cases where it's not the case, but so far there were no issues with text transformers.
If the values are not close enough to Python, it may indicate some tiny difference in the implementation (for example a missing/extra/out-of-place normalization layer).
| def model(%__MODULE__{architecture: :for_embedding} = spec) do | ||
| inputs = inputs(spec) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Architectures should match differences in the model layers, but the embedding model is also :for_causal_language_modeling, so it should map to :for_causal_language_modeling.
Instead of having the pooling logic here, we should instead add :last_token_pooling an option to the text embedding serving:
bumblebee/lib/bumblebee/text.ex
Lines 379 to 388 in 79199e0
| * `:output_pool` - pooling to apply on top of the model output, in case | |
| it is not already a pooled embedding. Supported values: | |
| * `:mean_pooling` - performs a mean across all tokens | |
| * `cls_token_pooling` - takes the embedding for the special CLS token. | |
| Note that we currently assume that the CLS token is the first token | |
| in the sequence | |
| By default no pooling is applied |
Then the user loads the model as usual (automatically mapped to :for_causal_language_modeling), and when building the serving they will pass output_pool: :last_token_pooling
| }) | ||
| end | ||
|
|
||
| def model(%__MODULE__{architecture: :for_reranker} = spec) do |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here, the model should be loaded as :for_causal_language_modeling as in the upstream repo.
| * `:yes_token` - the token ID corresponding to "yes" for relevance scoring. | ||
| If not provided, will be inferred from the tokenizer | ||
| * `:no_token` - the token ID corresponding to "no" for relevance scoring. | ||
| If not provided, will be inferred from the tokenizer | ||
| * `:instruction_prefix` - the instruction prefix to use. Defaults to the | ||
| Qwen3 reranker format | ||
| * `:instruction_suffix` - the instruction suffix to use. Defaults to the | ||
| Qwen3 reranker format | ||
| * `:task_description` - the task description to include in prompts. Defaults | ||
| to "Given a web search query, retrieve relevant passages that answer the query" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This re-raking approach has a fair amount of Qwen specifics. I am not entirely sure if we should ship this serving, but at least we should name this Text.text_reranking_qwen3. We do something similar for Whisper:
bumblebee/lib/bumblebee/audio.ex
Lines 179 to 186 in 79199e0
| defdelegate speech_to_text_whisper( | |
| model_info, | |
| featurizer, | |
| tokenizer, | |
| generation_config, | |
| opts \\ [] | |
| ), | |
| to: Bumblebee.Audio.SpeechToTextWhisper |
|
Sorry, worked caught up with me, I will continue the PR this weekend. |
Add Qwen3 Model Family Support
Summary
This PR adds comprehensive support for the Qwen3 model family from Alibaba Cloud, including text generation,
embeddings, and reranking models. Qwen3 is a state-of-the-art multilingual language model with advanced features like
QK normalization and support for up to 262K context length.
What's New
Architectures:
Key Features:
innovation)
Files Changed
Core Implementation:
Examples:
Documentation:
Testing
Text Generation (Qwen3-4B-Instruct)
{:ok, model} = Bumblebee.load_model({:hf, "Qwen/Qwen3-4B-Instruct-2507"})
{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "Qwen/Qwen3-4B-Instruct-2507"})
{:ok, config} = Bumblebee.load_generation_config({:hf, "Qwen/Qwen3-4B-Instruct-2507"})
serving = Bumblebee.Text.generation(model, tokenizer, config)
Nx.Serving.run(serving, "The future of AI")
Results: Generates coherent English text, answers questions correctly, creates stories and code.
Text Embeddings (Qwen3-Embedding-0.6B)
{:ok, model} = Bumblebee.load_model({:hf, "Qwen/Qwen3-Embedding-0.6B"},
architecture: :for_embedding
)
{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "Qwen/Qwen3-Embedding-0.6B"})
serving = Bumblebee.Text.text_embedding(model, tokenizer,
output_attribute: :embedding,
embedding_processor: :l2_norm
)
e1 = Nx.Serving.run(serving, "The cat sat on the mat")
e2 = Nx.Serving.run(serving, "A feline rested on the rug")
Nx.dot(e1.embedding, e2.embedding) |> Nx.to_number() # 0.73 (similar)
Results:
Reranking (Qwen3-Reranker-0.6B)
{:ok, model} = Bumblebee.load_model({:hf, "Qwen/Qwen3-Reranker-0.6B"})
Score query-document relevance
Relevant: 0.99+, Irrelevant: ~0.0
Results: Correctly ranks documents by relevance to queries.
Compatible Models
Text Generation:
Embeddings:
Reranking:
Technical Implementation
QK Normalization
Unlike standard transformers, Qwen3 applies RMS normalization to query and key states:
hidden -> dense -> split_heads -> rms_norm -> rotary -> attention
Architecture Support
Custom decoder blocks implement QK normalization while maintaining compatibility with Bumblebee's transformer patterns.
Embedding Architecture
New :for_embedding architecture automatically pools the last non-padding token for text embedding tasks.
Reranking
Uses the causal LM architecture with yes/no token logit extraction and softmax scoring.
Breaking Changes
None. This is purely additive.
References