Skip to content

Commit 935ddb8

Browse files
authored
Merge pull request #1260 from Cryptoc1/feature/executor-cancellation-support
Accept `CancellationToken` in `ChatSession.InitializeSessionFromHistoryAsync`
2 parents c1783d4 + 94273ff commit 935ddb8

File tree

6 files changed

+82
-67
lines changed

6 files changed

+82
-67
lines changed

LLama.Examples/Examples/QuantizeModel.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ namespace LLama.Examples.Examples
22
{
33
public class QuantizeModel
44
{
5-
public static async Task Run()
5+
public static Task Run()
66
{
77
string inputPath = UserSettings.GetModelPath();
88

@@ -21,7 +21,7 @@ public static async Task Run()
2121
Console.WriteLine("Quantization failed!");
2222
}
2323

24-
await Task.CompletedTask;
24+
return Task.CompletedTask;
2525
}
2626
}
2727
}

LLama/ChatSession.cs

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -76,9 +76,10 @@ public class ChatSession
7676
/// <param name="executor">The executor for this session</param>
7777
/// <param name="history">History for this session</param>
7878
/// <param name="transform">History Transform for this session</param>
79+
/// <param name="cancellationToken">A token that cancels the operation</param>
7980
/// <returns>A new chat session.</returns>
8081
public static async Task<ChatSession> InitializeSessionFromHistoryAsync(
81-
ILLamaExecutor executor, ChatHistory history, IHistoryTransform? transform = null)
82+
ILLamaExecutor executor, ChatHistory history, IHistoryTransform? transform = null, CancellationToken cancellationToken = default)
8283
{
8384
if (executor is not StatefulExecutorBase statefulExecutor)
8485
{
@@ -90,7 +91,7 @@ public static async Task<ChatSession> InitializeSessionFromHistoryAsync(
9091
session = session.WithHistoryTransform(transform);
9192
}
9293

93-
await statefulExecutor.PrefillPromptAsync(session.HistoryTransform.HistoryToText(history));
94+
await statefulExecutor.PrefillPromptAsync(session.HistoryTransform.HistoryToText(history), cancellationToken);
9495
return session;
9596
}
9697

@@ -311,13 +312,15 @@ public ChatSession RemoveLastMessage()
311312
/// Compute KV cache for the message and add it to the chat history.
312313
/// </summary>
313314
/// <param name="message"></param>
315+
/// <param name="cancellationToken"></param>
314316
/// <returns></returns>
315-
public async Task<ChatSession> AddAndProcessMessage(ChatHistory.Message message)
317+
public async Task<ChatSession> AddAndProcessMessage(ChatHistory.Message message, CancellationToken cancellationToken = default)
316318
{
317319
if (Executor is not StatefulExecutorBase statefulExecutor)
318320
{
319321
throw new InvalidOperationException("Executor must be a StatefulExecutorBase to support pre-processing of system messages.");
320322
}
323+
321324
AddMessage(message);
322325
var content = message.Content;
323326
if (message.AuthorRole != AuthorRole.Assistant)
@@ -328,27 +331,27 @@ public async Task<ChatSession> AddAndProcessMessage(ChatHistory.Message message)
328331
}
329332
}
330333

331-
await statefulExecutor.PrefillPromptAsync(content);
334+
await statefulExecutor.PrefillPromptAsync(content, cancellationToken);
332335
return this;
333336
}
334337

335338
/// <summary>
336339
/// Compute KV cache for the system message and add it to the chat history.
337340
/// </summary>
338-
public Task<ChatSession> AddAndProcessSystemMessage(string content)
339-
=> AddAndProcessMessage(new ChatHistory.Message(AuthorRole.System, content));
341+
public Task<ChatSession> AddAndProcessSystemMessage(string content, CancellationToken cancellationToken = default)
342+
=> AddAndProcessMessage(new ChatHistory.Message(AuthorRole.System, content), cancellationToken);
340343

341344
/// <summary>
342345
/// Compute KV cache for the user message and add it to the chat history.
343346
/// </summary>
344-
public Task<ChatSession> AddAndProcessUserMessage(string content)
345-
=> AddAndProcessMessage(new ChatHistory.Message(AuthorRole.User, content));
347+
public Task<ChatSession> AddAndProcessUserMessage(string content, CancellationToken cancellationToken = default)
348+
=> AddAndProcessMessage(new ChatHistory.Message(AuthorRole.User, content), cancellationToken);
346349

347350
/// <summary>
348351
/// Compute KV cache for the assistant message and add it to the chat history.
349352
/// </summary>
350-
public Task<ChatSession> AddAndProcessAssistantMessage(string content)
351-
=> AddAndProcessMessage(new ChatHistory.Message(AuthorRole.Assistant, content));
353+
public Task<ChatSession> AddAndProcessAssistantMessage(string content, CancellationToken cancellationToken = default)
354+
=> AddAndProcessMessage(new ChatHistory.Message(AuthorRole.Assistant, content), cancellationToken);
352355

353356
/// <summary>
354357
/// Replace a user message with a new message and remove all messages after the new message.

LLama/LLamaContext.cs

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
1-
using LLama.Native;
21
using System;
32
using System.Collections.Generic;
43
using System.Diagnostics;
5-
using System.Text;
64
using System.IO;
75
using System.IO.MemoryMappedFiles;
6+
using System.Text;
7+
using System.Threading;
88
using System.Threading.Tasks;
99
using LLama.Abstractions;
10+
using LLama.Native;
1011
using Microsoft.Extensions.Logging;
11-
using System.Threading;
1212

1313
namespace LLama
1414
{
@@ -73,7 +73,7 @@ public int BatchThreads
7373
/// Get the special tokens for the model associated with this context
7474
/// </summary>
7575
public SafeLlamaModelHandle.Vocabulary Vocab { get; }
76-
76+
7777
/// <summary>
7878
/// Create a new LLamaContext for the given LLamaWeights
7979
/// </summary>
@@ -396,7 +396,7 @@ public Task<DecodeResult> DecodeAsync(LLamaBatch batch, CancellationToken cancel
396396
{
397397
return Task.Run(() => Decode(batch), cancellationToken);
398398
}
399-
399+
400400
/// <summary>
401401
/// </summary>
402402
/// <param name="batch"></param>
@@ -406,10 +406,10 @@ public DecodeResult Decode(LLamaBatchEmbeddings batch)
406406
return 0;
407407
if (batch.EmbeddingsCount > BatchSize)
408408
throw new ArgumentException("Input contains more tokens than configured batch size", nameof(batch));
409-
409+
410410
return (DecodeResult)NativeHandle.Decode(batch);
411411
}
412-
412+
413413
/// <summary>
414414
/// </summary>
415415
/// <param name="batch"></param>
@@ -425,15 +425,16 @@ public Task<DecodeResult> DecodeAsync(LLamaBatchEmbeddings batch, CancellationTo
425425
/// <param name="id"></param>
426426
/// <param name="batch"></param>
427427
/// <param name="n_past"></param>
428+
/// <param name="cancellationToken"></param>
428429
/// <returns>A tuple, containing the decode result, the number of tokens that have <b>not</b> been decoded yet and the total number of tokens that have been decoded.</returns>
429-
public Task<(DecodeResult, int, int)> DecodeAsync(List<LLamaToken> tokens, LLamaSeqId id, LLamaBatch batch, int n_past)
430+
public Task<(DecodeResult, int, int)> DecodeAsync(List<LLamaToken> tokens, LLamaSeqId id, LLamaBatch batch, int n_past, CancellationToken cancellationToken = default)
430431
{
431432
return Task.Run(() =>
432433
{
433434
var past = n_past;
434435
var res = NativeHandle.Decode(tokens, id, batch, ref past);
435436
return (res.Item1, res.Item2, past);
436-
});
437+
}, cancellationToken);
437438
}
438439
#endregion
439440

LLama/LLamaExecutorBase.cs

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -246,36 +246,41 @@ protected virtual void TryReuseMatchingPrefix()
246246
/// Decide whether to continue the loop.
247247
/// </summary>
248248
/// <param name="args"></param>
249+
/// <param name="cancellationToken"></param>
249250
/// <returns></returns>
250-
protected abstract Task<bool> GetLoopCondition(InferStateArgs args);
251+
protected abstract Task<bool> GetLoopCondition(InferStateArgs args, CancellationToken cancellationToken = default);
251252

252253
/// <summary>
253254
/// Preprocess the inputs before the inference.
254255
/// </summary>
255256
/// <param name="text"></param>
256257
/// <param name="args"></param>
257-
protected abstract Task PreprocessInputs(string? text, InferStateArgs args);
258+
/// <param name="cancellationToken"></param>
259+
protected abstract Task PreprocessInputs(string? text, InferStateArgs args, CancellationToken cancellationToken = default);
258260

259261
/// <summary>
260262
/// Do some post processing after the inference.
261263
/// </summary>
262264
/// <param name="inferenceParams"></param>
263265
/// <param name="args"></param>
266+
/// <param name="cancellationToken"></param>
264267
/// <returns></returns>
265-
protected abstract Task<(bool, IReadOnlyList<string>)> PostProcess(IInferenceParams inferenceParams, InferStateArgs args);
268+
protected abstract Task<(bool, IReadOnlyList<string>)> PostProcess(IInferenceParams inferenceParams, InferStateArgs args, CancellationToken cancellationToken = default);
266269

267270
/// <summary>
268271
/// The core inference logic.
269272
/// </summary>
270273
/// <param name="inferenceParams"></param>
271274
/// <param name="args"></param>
272-
protected abstract Task InferInternal(IInferenceParams inferenceParams, InferStateArgs args);
275+
/// <param name="cancellationToken"></param>
276+
protected abstract Task InferInternal(IInferenceParams inferenceParams, InferStateArgs args, CancellationToken cancellationToken = default);
273277

274278
/// <summary>
275279
/// Save the current state to a file.
276280
/// </summary>
277281
/// <param name="filename"></param>
278-
public abstract Task SaveState(string filename);
282+
/// <param name="cancellationToken"></param>
283+
public abstract Task SaveState(string filename, CancellationToken cancellationToken = default);
279284

280285
/// <summary>
281286
/// Get the current state data.
@@ -287,13 +292,15 @@ protected virtual void TryReuseMatchingPrefix()
287292
/// Load the state from data.
288293
/// </summary>
289294
/// <param name="data"></param>
290-
public abstract Task LoadState(ExecutorBaseState data);
295+
/// <param name="cancellationToken"></param>
296+
public abstract Task LoadState(ExecutorBaseState data, CancellationToken cancellationToken = default);
291297

292298
/// <summary>
293299
/// Load the state from a file.
294300
/// </summary>
295301
/// <param name="filename"></param>
296-
public abstract Task LoadState(string filename);
302+
/// <param name="cancellationToken"></param>
303+
public abstract Task LoadState(string filename, CancellationToken cancellationToken = default);
297304

298305

299306
/// <summary>
@@ -318,17 +325,17 @@ public virtual async IAsyncEnumerable<string> InferAsync(string? text, IInferenc
318325
};
319326

320327
AntipromptProcessor.SetAntiprompts(inferenceParams.AntiPrompts ?? []);
328+
await PreprocessInputs(text, args, cancellationToken);
321329

322-
await PreprocessInputs(text, args);
323-
324-
while (await GetLoopCondition(args))
330+
while (await GetLoopCondition(args, cancellationToken))
325331
{
326332
if (cancellationToken.IsCancellationRequested)
327333
{
328334
break;
329335
}
336+
330337
args.LastOutput = string.Empty;
331-
await InferInternal(inferenceParams, args);
338+
await InferInternal(inferenceParams, args, cancellationToken);
332339

333340
if (args.ReturnValue)
334341
{
@@ -338,7 +345,7 @@ public virtual async IAsyncEnumerable<string> InferAsync(string? text, IInferenc
338345
yield return decoded;
339346
}
340347

341-
var (breakGeneration, extraOutputs) = await PostProcess(inferenceParams, args);
348+
var (breakGeneration, extraOutputs) = await PostProcess(inferenceParams, args, cancellationToken);
342349
if (extraOutputs is { Count: > 0 })
343350
{
344351
foreach (var item in extraOutputs)
@@ -358,8 +365,9 @@ public virtual async IAsyncEnumerable<string> InferAsync(string? text, IInferenc
358365
/// It could reduce the latency of the first time response if the first input from the user is not immediate.
359366
/// </summary>
360367
/// <param name="prompt">Prompt to process</param>
368+
/// <param name="cancellationToken"></param>
361369
/// <returns></returns>
362-
public virtual async Task PrefillPromptAsync(string prompt)
370+
public virtual async Task PrefillPromptAsync(string prompt, CancellationToken cancellationToken = default)
363371
{
364372
var inferenceParams = new InferenceParams
365373
{
@@ -374,11 +382,11 @@ public virtual async Task PrefillPromptAsync(string prompt)
374382
NeedToSaveSession = false
375383
};
376384

377-
await PreprocessInputs(prompt, args);
385+
await PreprocessInputs(prompt, args, cancellationToken);
378386
// First run adds the prompt to the _embeds
379-
await InferInternal(inferenceParams, args);
387+
await InferInternal(inferenceParams, args, cancellationToken);
380388
// Second run puts it through decode
381-
await InferInternal(inferenceParams, args);
389+
await InferInternal(inferenceParams, args, cancellationToken);
382390
}
383391

384392
/// <summary>

LLama/LLamaInstructExecutor.cs

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
1-
using LLama.Abstractions;
2-
using LLama.Common;
3-
using LLama.Native;
41
using System;
52
using System.Collections.Generic;
63
using System.IO;
74
using System.Linq;
85
using System.Text.Json;
96
using System.Text.Json.Serialization;
7+
using System.Threading;
108
using System.Threading.Tasks;
9+
using LLama.Abstractions;
10+
using LLama.Common;
1111
using LLama.Exceptions;
12+
using LLama.Native;
1213
using LLama.Sampling;
1314
using Microsoft.Extensions.Logging;
1415

@@ -65,9 +66,9 @@ public override ExecutorBaseState GetStateData()
6566
return state;
6667
}
6768
/// <inheritdoc />
68-
public override Task LoadState(ExecutorBaseState data)
69+
public override Task LoadState(ExecutorBaseState data, CancellationToken cancellationToken = default)
6970
{
70-
if(data is InstructExecutorState state)
71+
if (data is InstructExecutorState state)
7172
{
7273
_n_session_consumed = state.ConsumedSessionCount;
7374
_embed_inps = state.EmbedInps!.ToList();
@@ -91,35 +92,35 @@ public override Task LoadState(ExecutorBaseState data)
9192
}
9293

9394
/// <inheritdoc />
94-
public override async Task SaveState(string filename)
95+
public override async Task SaveState(string filename, CancellationToken cancellationToken = default)
9596
{
9697
var state = (InstructExecutorState)GetStateData();
9798
using (var fs = new FileStream(filename, FileMode.Create, FileAccess.Write))
9899
{
99-
await JsonSerializer.SerializeAsync(fs, state);
100+
await JsonSerializer.SerializeAsync(fs, state, cancellationToken: cancellationToken);
100101
}
101102
}
102103

103104
/// <inheritdoc />
104-
public override async Task LoadState(string filename)
105+
public override async Task LoadState(string filename, CancellationToken cancellationToken = default)
105106
{
106107
using (var fs = new FileStream(filename, FileMode.Open, FileAccess.Read))
107108
{
108109
var state = await JsonSerializer.DeserializeAsync<InstructExecutorState>(fs);
109-
await LoadState(state!);
110+
await LoadState(state!, cancellationToken);
110111
}
111112
}
112113

113114
/// <inheritdoc />
114-
protected override Task<bool> GetLoopCondition(InferStateArgs args)
115+
protected override Task<bool> GetLoopCondition(InferStateArgs args, CancellationToken cancellationToken)
115116
{
116117
return Task.FromResult(args.RemainedTokens != 0 || _is_prompt_run);
117118
}
118119

119120
/// <inheritdoc />
120-
protected override Task PreprocessInputs(string? text, InferStateArgs args)
121+
protected override Task PreprocessInputs(string? text, InferStateArgs args, CancellationToken cancellationToken)
121122
{
122-
args.Antiprompts ??= [ ];
123+
args.Antiprompts ??= [];
123124
if (!args.Antiprompts.Contains(_instructionPrefix))
124125
args.Antiprompts.Add(_instructionPrefix);
125126

@@ -155,7 +156,7 @@ protected override Task PreprocessInputs(string? text, InferStateArgs args)
155156
}
156157

157158
/// <inheritdoc />
158-
protected override Task<(bool, IReadOnlyList<string>)> PostProcess(IInferenceParams inferenceParams, InferStateArgs args)
159+
protected override Task<(bool, IReadOnlyList<string>)> PostProcess(IInferenceParams inferenceParams, InferStateArgs args, CancellationToken cancellationToken)
159160
{
160161
if (_embed_inps.Count <= _consumedTokensCount)
161162
{
@@ -167,7 +168,7 @@ protected override Task PreprocessInputs(string? text, InferStateArgs args)
167168

168169
if (_pastTokensCount > 0 && args.WaitForInput)
169170
{
170-
return Task.FromResult<(bool, IReadOnlyList<string>)>((true, [ "\n> " ]));
171+
return Task.FromResult<(bool, IReadOnlyList<string>)>((true, ["\n> "]));
171172
}
172173
}
173174

@@ -185,7 +186,7 @@ protected override Task PreprocessInputs(string? text, InferStateArgs args)
185186
}
186187

187188
/// <inheritdoc />
188-
protected override async Task InferInternal(IInferenceParams inferenceParams, InferStateArgs args)
189+
protected override async Task InferInternal(IInferenceParams inferenceParams, InferStateArgs args, CancellationToken cancellationToken)
189190
{
190191
var batch = new LLamaBatch();
191192

@@ -253,7 +254,7 @@ protected override async Task InferInternal(IInferenceParams inferenceParams, In
253254

254255
return;
255256
}
256-
257+
257258
/// <summary>
258259
/// The descriptor of the state of the instruct executor.
259260
/// </summary>

0 commit comments

Comments
 (0)