@@ -246,36 +246,41 @@ protected virtual void TryReuseMatchingPrefix()
246246 /// Decide whether to continue the loop.
247247 /// </summary>
248248 /// <param name="args"></param>
249+ /// <param name="cancellationToken"></param>
249250 /// <returns></returns>
250- protected abstract Task < bool > GetLoopCondition ( InferStateArgs args ) ;
251+ protected abstract Task < bool > GetLoopCondition ( InferStateArgs args , CancellationToken cancellationToken = default ) ;
251252
252253 /// <summary>
253254 /// Preprocess the inputs before the inference.
254255 /// </summary>
255256 /// <param name="text"></param>
256257 /// <param name="args"></param>
257- protected abstract Task PreprocessInputs ( string ? text , InferStateArgs args ) ;
258+ /// <param name="cancellationToken"></param>
259+ protected abstract Task PreprocessInputs ( string ? text , InferStateArgs args , CancellationToken cancellationToken = default ) ;
258260
259261 /// <summary>
260262 /// Do some post processing after the inference.
261263 /// </summary>
262264 /// <param name="inferenceParams"></param>
263265 /// <param name="args"></param>
266+ /// <param name="cancellationToken"></param>
264267 /// <returns></returns>
265- protected abstract Task < ( bool , IReadOnlyList < string > ) > PostProcess ( IInferenceParams inferenceParams , InferStateArgs args ) ;
268+ protected abstract Task < ( bool , IReadOnlyList < string > ) > PostProcess ( IInferenceParams inferenceParams , InferStateArgs args , CancellationToken cancellationToken = default ) ;
266269
267270 /// <summary>
268271 /// The core inference logic.
269272 /// </summary>
270273 /// <param name="inferenceParams"></param>
271274 /// <param name="args"></param>
272- protected abstract Task InferInternal ( IInferenceParams inferenceParams , InferStateArgs args ) ;
275+ /// <param name="cancellationToken"></param>
276+ protected abstract Task InferInternal ( IInferenceParams inferenceParams , InferStateArgs args , CancellationToken cancellationToken = default ) ;
273277
274278 /// <summary>
275279 /// Save the current state to a file.
276280 /// </summary>
277281 /// <param name="filename"></param>
278- public abstract Task SaveState ( string filename ) ;
282+ /// <param name="cancellationToken"></param>
283+ public abstract Task SaveState ( string filename , CancellationToken cancellationToken = default ) ;
279284
280285 /// <summary>
281286 /// Get the current state data.
@@ -287,13 +292,15 @@ protected virtual void TryReuseMatchingPrefix()
287292 /// Load the state from data.
288293 /// </summary>
289294 /// <param name="data"></param>
290- public abstract Task LoadState ( ExecutorBaseState data ) ;
295+ /// <param name="cancellationToken"></param>
296+ public abstract Task LoadState ( ExecutorBaseState data , CancellationToken cancellationToken = default ) ;
291297
292298 /// <summary>
293299 /// Load the state from a file.
294300 /// </summary>
295301 /// <param name="filename"></param>
296- public abstract Task LoadState ( string filename ) ;
302+ /// <param name="cancellationToken"></param>
303+ public abstract Task LoadState ( string filename , CancellationToken cancellationToken = default ) ;
297304
298305
299306 /// <summary>
@@ -318,17 +325,17 @@ public virtual async IAsyncEnumerable<string> InferAsync(string? text, IInferenc
318325 } ;
319326
320327 AntipromptProcessor . SetAntiprompts ( inferenceParams . AntiPrompts ?? [ ] ) ;
328+ await PreprocessInputs ( text , args , cancellationToken ) ;
321329
322- await PreprocessInputs ( text , args ) ;
323-
324- while ( await GetLoopCondition ( args ) )
330+ while ( await GetLoopCondition ( args , cancellationToken ) )
325331 {
326332 if ( cancellationToken . IsCancellationRequested )
327333 {
328334 break ;
329335 }
336+
330337 args . LastOutput = string . Empty ;
331- await InferInternal ( inferenceParams , args ) ;
338+ await InferInternal ( inferenceParams , args , cancellationToken ) ;
332339
333340 if ( args . ReturnValue )
334341 {
@@ -338,7 +345,7 @@ public virtual async IAsyncEnumerable<string> InferAsync(string? text, IInferenc
338345 yield return decoded ;
339346 }
340347
341- var ( breakGeneration , extraOutputs ) = await PostProcess ( inferenceParams , args ) ;
348+ var ( breakGeneration , extraOutputs ) = await PostProcess ( inferenceParams , args , cancellationToken ) ;
342349 if ( extraOutputs is { Count : > 0 } )
343350 {
344351 foreach ( var item in extraOutputs )
@@ -358,8 +365,9 @@ public virtual async IAsyncEnumerable<string> InferAsync(string? text, IInferenc
358365 /// It could reduce the latency of the first time response if the first input from the user is not immediate.
359366 /// </summary>
360367 /// <param name="prompt">Prompt to process</param>
368+ /// <param name="cancellationToken"></param>
361369 /// <returns></returns>
362- public virtual async Task PrefillPromptAsync ( string prompt )
370+ public virtual async Task PrefillPromptAsync ( string prompt , CancellationToken cancellationToken = default )
363371 {
364372 var inferenceParams = new InferenceParams
365373 {
@@ -374,11 +382,11 @@ public virtual async Task PrefillPromptAsync(string prompt)
374382 NeedToSaveSession = false
375383 } ;
376384
377- await PreprocessInputs ( prompt , args ) ;
385+ await PreprocessInputs ( prompt , args , cancellationToken ) ;
378386 // First run adds the prompt to the _embeds
379- await InferInternal ( inferenceParams , args ) ;
387+ await InferInternal ( inferenceParams , args , cancellationToken ) ;
380388 // Second run puts it through decode
381- await InferInternal ( inferenceParams , args ) ;
389+ await InferInternal ( inferenceParams , args , cancellationToken ) ;
382390 }
383391
384392 /// <summary>
0 commit comments