-
Notifications
You must be signed in to change notification settings - Fork 470
Open
Labels
do not closeProtect this issue from auto closingProtect this issue from auto closingstaleStale issue will be autoclosed soonStale issue will be autoclosed soon
Description
Description
Taking inspiration from the LlamaEmbedder and the multimodal support which has been added to LlamaInteractExecutor, I have been trying to implement a multimodal embedder. The main idea is to support Qwen2-VL related models specialized in screenshot embedding such as:
IMO, it should works as I did i.e.:
- building manually the prompt
- tokenizing separately the prompt before and after the image marker
- feeding the model with the token before the image, then the image using
LlavaWeights.EvalImageEmbed
and finally the tokens after the image - getting the embedding of the last token
<|endoftext|>
and normalize it
But:
- I don't get the same vectors as the one I obtain with python code
- even worse, embedding twice the same image doesn't give me the same vectors
- and even using two different context instances, I don't get the same vector
Does it ring a bell to someone?
Here is my class and a dummy unit test comparing two runs of image embedding computation
using System.Numerics.Tensors;
using System.Text;
using LLama.Common;
using LLama.Extensions;
using LLama.Native;
namespace LLama.Unittest;
public sealed class LlamaMultimodalEmbedder : IDisposable
{
private readonly LLavaWeights _llavaWeights;
private readonly LLamaContext _context;
public LlamaMultimodalEmbedder(LLamaContext context, LLavaWeights llavaWeights)
{
if (context.Params.UBatchSize != context.Params.BatchSize)
throw new ArgumentException("For non-causal models, batch size must be equal to ubatch size");
_llavaWeights = llavaWeights;
_context = context;
NativeApi.llama_set_embeddings(_context.NativeHandle, true);
}
private bool _disposed;
public void Dispose()
{
if (_disposed)
return;
_context.Dispose();
_llavaWeights.Dispose();
_disposed = true;
}
private const string ImageMarker = "<|image_pad|>";
private readonly int _imageMarkerSize = ImageMarker.Length;
private async Task<float[]> GetEmbedding(
string? text,
byte[]? image,
CancellationToken cancellationToken = default)
{
// clear previous kv_cache values
_context.NativeHandle.KvCacheClear();
_context.NativeHandle.KvCacheRemove( LLamaSeqId.Zero, -1, -1 );
var hasText = !string.IsNullOrEmpty(text);
var hasImage = image != null;
if (!hasText && !hasImage)
throw new ArgumentException("At least one of text or image must be provided");
// Even if it implies a loss of genericity, we build manually the prompt for two reasons:
// * history doesn't handle image content
// * we aim to support Qwen2-VL like model
var promptBuilder = new StringBuilder();
promptBuilder
.Append("<|im_start|>system\n")
.Append("You are a helpful assistant.<|im_end|>\n");
promptBuilder.Append("<|im_start|>user\n");
if (hasImage)
promptBuilder.Append("<|vision_start|>").Append(ImageMarker).Append("<|vision_end|>");
if (hasText)
promptBuilder.Append(text);
promptBuilder.Append("<|im_end|>\n");
promptBuilder
.Append("<|im_start|>assistant\n")
.Append("<|endoftext|>");
var prompt = promptBuilder.ToString();
// Compute embeddings of the input image to be fed into the model
using var imageEmbeddingHandle = hasImage ? GetImageEmbeddingHandle(image!) : null;
var tokens = new List<LLamaToken>();
var imageTokenIndex = -1;
if (hasImage)
{
var imageIndexInPrompt = prompt.IndexOf(ImageMarker, StringComparison.Ordinal);
// Tokenize text segment before <|image_pad|> tag
var promptBeforeImage = prompt[..imageIndexInPrompt];
var tokensBeforeImage = _context.Tokenize(promptBeforeImage, addBos: true, special: true);
// Remember the position to add the image embeddings
imageTokenIndex = tokensBeforeImage.Length;
// Tokenize text segment after <|image_pad|> tag
var promptAfterImage = prompt[(imageIndexInPrompt + _imageMarkerSize)..];
var tokensAfterImage = _context.Tokenize(promptAfterImage, addBos: false, special: true);
tokens.AddRange(tokensBeforeImage);
tokens.AddRange(tokensAfterImage);
}
else
{
tokens.AddRange(_context.Tokenize(prompt, addBos: true, special: true));
}
var tokensCount = tokens.Count;
if (tokensCount > _context.ContextSize)
throw new ArgumentException(
$"Embedding prompt is longer than the context window ({tokensCount} > {_context.ContextSize})");
// Check if we should cancel the work, just before doing anything expensive (encode/decode)
cancellationToken.ThrowIfCancellationRequested();
// Evaluate prompt in batch-size chunks
var batch = new LLamaBatch();
var nPast = 0;
var decodeResponse = await _context
.DecodeAsync(tokens.GetRange(0, hasImage ? imageTokenIndex : tokensCount), LLamaSeqId.Zero, batch, nPast)
.ConfigureAwait(false);
nPast = decodeResponse.Item3;
if (hasImage)
{
_llavaWeights.EvalImageEmbed(_context, imageEmbeddingHandle!, ref nPast);
decodeResponse = await _context
.DecodeAsync(tokens.GetRange(imageTokenIndex, tokensCount - imageTokenIndex), LLamaSeqId.Zero, batch,
nPast)
.ConfigureAwait(false);
nPast = decodeResponse.Item3;
}
var poolingType = _context.NativeHandle.PoolingType;
if (poolingType != LLamaPoolingType.None)
throw new NotSupportedException("Unsupported pooling type");
var positions = batch.GetLogitPositions();
if (positions == null)
throw new InvalidOperationException("GetLogitPositions returned null");
var embedding = _context.NativeHandle.GetEmbeddingsIth(positions[^1].Item2).ToArray();
embedding.EuclideanNormalization();
return embedding;
}
private SafeLlavaImageEmbedHandle GetImageEmbeddingHandle(byte[] imageBytes)
{
if (_llavaWeights == null)
throw new InvalidOperationException("LLavaWeights is not loaded.");
var embeddingsHandle = _llavaWeights.CreateImageEmbeddings(imageBytes);
if (embeddingsHandle.IsInvalid)
throw new InvalidOperationException(
"Failed to create embedding handle, make sure that the image is a valid base 64 encoded string.");
return embeddingsHandle;
}
public async Task<float[]> GetTextEmbedding(string text, CancellationToken cancellationToken) =>
await GetEmbedding(text, null, cancellationToken).ConfigureAwait(false);
public async Task<float[]> GetImageEmbedding(byte[] imageBytes, CancellationToken cancellationToken) =>
await GetEmbedding(null, imageBytes, cancellationToken).ConfigureAwait(false);
}
public sealed class LLamaMultimodalEmbedderTests
{
private const string ModelPath = "path\to\model.gguf";
private const string MmprojPath = "path\to\mmproj.gguf";
private const string ImagePath = "path\to\image.png";
[Fact]
public async Task TestBasic()
{
var parameters = new ModelParams(ModelPath)
{
GpuLayerCount = 5
};
var model = await LLamaWeights.LoadFromFileAsync(parameters);
var llavaWeights = await LLavaWeights.LoadFromFileAsync(MmprojPath);
var context = model.CreateContext(parameters);
var multimodalEmbedder = new LlamaMultimodalEmbedder(context, llavaWeights);
var embedding1 = await multimodalEmbedder.GetImageEmbedding(
await File.ReadAllBytesAsync(ImagePath),
CancellationToken.None);
var embedding2 = await multimodalEmbedder.GetImageEmbedding(
await File.ReadAllBytesAsync(ImagePath),
CancellationToken.None);
var diff = TensorPrimitives.Norm(
embedding1.Zip(embedding2, (a, b) => a - b).ToArray());
Assert.True(diff < 10e-1);
}
}
arnaud-duvieusart, basilevc, claeyzre and clarinevongMathVast, sangyuxiaowu and claeyzre
Metadata
Metadata
Assignees
Labels
do not closeProtect this issue from auto closingProtect this issue from auto closingstaleStale issue will be autoclosed soonStale issue will be autoclosed soon