diff --git a/src/client/App.test.tsx b/src/client/App.test.tsx index d5ffb89..bc027af 100644 --- a/src/client/App.test.tsx +++ b/src/client/App.test.tsx @@ -13,6 +13,7 @@ vi.mock("./api", () => ({ createTask: vi.fn(), updateTask: vi.fn(), createTaskFollowUp: vi.fn(), + forceTaskFollowUp: vi.fn(), deleteTask: vi.fn(), settings: vi.fn(), patchSettings: vi.fn(), diff --git a/src/client/App.tsx b/src/client/App.tsx index 81fb5ab..4ebfa24 100644 --- a/src/client/App.tsx +++ b/src/client/App.tsx @@ -18,6 +18,7 @@ import type { ProviderStatus, Task, TaskCreateInput, + TaskFollowUpMode, TaskStatus, } from "../shared/types"; import { COLUMN_LABELS, TASK_STATUSES } from "../shared/types"; @@ -288,8 +289,12 @@ export function App() { } } - async function askTaskFollowUp(taskId: string, prompt: string) { - const result = await api.createTaskFollowUp(taskId, prompt); + async function askTaskFollowUp( + taskId: string, + prompt: string, + mode: TaskFollowUpMode = "queue", + ) { + const result = await api.createTaskFollowUp(taskId, prompt, mode); setTasks((current) => current.map((task) => (task.id === taskId ? result.task : task))); setEditingTask((current) => (current?.id === taskId ? result.task : current)); } @@ -300,6 +305,12 @@ export function App() { setEditingTask((current) => (current?.id === taskId ? result.task : current)); } + async function forceTaskFollowUp(taskId: string, executionId: string) { + const result = await api.forceTaskFollowUp(taskId, executionId); + setTasks((current) => current.map((task) => (task.id === taskId ? result.task : task))); + setEditingTask((current) => (current?.id === taskId ? result.task : current)); + } + async function createDraftTask(input: TaskCreateInput) { const result = await api.createTask(input); setTasks((current) => [...current, result.task]); @@ -648,8 +659,17 @@ export function App() { } }} onDelete={editingTask ? () => deleteTask(editingTask.id) : undefined} - onAskFollowUp={editingTask ? (prompt) => askTaskFollowUp(editingTask.id, prompt) : undefined} + onAskFollowUp={ + editingTask + ? (prompt, mode) => askTaskFollowUp(editingTask.id, prompt, mode) + : undefined + } onAbortExecution={editingTask ? () => abortTaskRun(editingTask.id) : undefined} + onForceFollowUp={ + editingTask + ? (executionId) => forceTaskFollowUp(editingTask.id, executionId) + : undefined + } onCreateDraft={editingTask ? createDraftTask : undefined} onSave={createOrUpdateTask} /> diff --git a/src/client/api.ts b/src/client/api.ts index 0d4885f..cefe721 100644 --- a/src/client/api.ts +++ b/src/client/api.ts @@ -10,6 +10,8 @@ import type { Task, TaskCreateInput, TaskExecution, + TaskFollowUpInput, + TaskFollowUpMode, TaskUpdateInput, } from "../shared/types"; @@ -56,15 +58,21 @@ export const api = { body: JSON.stringify(input), }); }, - async createTaskFollowUp(id: string, prompt: string) { + async createTaskFollowUp(id: string, prompt: string, mode: TaskFollowUpMode = "queue") { + const input: TaskFollowUpInput = { prompt, mode }; return request(`/api/tasks/${id}/follow-up`, { method: "POST", - body: JSON.stringify({ prompt }), + body: JSON.stringify(input), }); }, async abortTaskRun(id: string) { return request(`/api/tasks/${id}/abort`, { method: "POST" }); }, + async forceTaskFollowUp(id: string, executionId: string) { + return request(`/api/tasks/${id}/follow-up/${executionId}/force`, { + method: "POST", + }); + }, async deleteTask(id: string) { return request(`/api/tasks/${id}`, { method: "DELETE" }); }, diff --git a/src/client/components/TaskModal.test.tsx b/src/client/components/TaskModal.test.tsx index c83016e..5fb9f36 100644 --- a/src/client/components/TaskModal.test.tsx +++ b/src/client/components/TaskModal.test.tsx @@ -142,6 +142,81 @@ describe("TaskModal", () => { expect(document.body.textContent).toContain("Run cancelled."); }); + it("offers force request on a queued follow-up while active AI work is running", async () => { + const onForceFollowUp = vi.fn().mockResolvedValue(undefined); + await renderTaskModal( + { + ...task(), + execution: execution({ + id: "execution-follow-up", + status: "queued", + endedAt: null, + requestKind: "follow_up", + requestPrompt: "Stop and use this correction.", + progressSummary: "Follow-up queued.", + previousExecutions: [ + execution({ + id: "execution-active", + status: "running", + endedAt: null, + progressSummary: "Preparing the task context.", + }), + ], + }), + }, + { onForceFollowUp }, + ); + + expect(findButton("Force request")).toBeTruthy(); + expect(document.body.textContent).toContain("Stop and use this correction."); + await clickButton("Force request"); + await flushReactPromises(); + + expect(onForceFollowUp).toHaveBeenCalledWith("execution-follow-up"); + expect(document.body.textContent).toContain("Forced request queued."); + }); + + it("keeps queued follow-ups forceable when force fails", async () => { + const onForceFollowUp = vi + .fn() + .mockRejectedValueOnce(new Error("Network offline")) + .mockResolvedValueOnce(undefined); + await renderTaskModal( + { + ...task(), + execution: execution({ + id: "execution-follow-up", + status: "queued", + endedAt: null, + requestKind: "follow_up", + requestPrompt: "Force with retry context.", + progressSummary: "Follow-up queued.", + previousExecutions: [ + execution({ + id: "execution-active", + status: "running", + endedAt: null, + progressSummary: "Preparing the task context.", + }), + ], + }), + }, + { onForceFollowUp }, + ); + + await clickButton("Force request"); + await flushReactPromises(); + + expect(document.body.textContent).toContain("Network offline"); + expect(findButton("Force request")).toBeTruthy(); + + await clickButton("Force request"); + await flushReactPromises(); + + expect(onForceFollowUp).toHaveBeenNthCalledWith(2, "execution-follow-up"); + expect(document.body.textContent).toContain("Forced request queued."); + }); + it("shows cancelled runs as terminal task history", async () => { await renderTaskModal({ ...task(), @@ -608,6 +683,12 @@ async function changeField(selector: string, value: string) { }); } +async function flushReactPromises() { + await act(async () => { + await Promise.resolve(); + }); +} + function getButton(label: string) { const button = findButton(label); expect(button).toBeTruthy(); diff --git a/src/client/components/TaskModal.tsx b/src/client/components/TaskModal.tsx index 90bfa2f..bf63a5b 100644 --- a/src/client/components/TaskModal.tsx +++ b/src/client/components/TaskModal.tsx @@ -20,6 +20,7 @@ import { Trash2, UserRound, X, + Zap, } from "lucide-react"; import type { LucideIcon } from "lucide-react"; import { useEffect, useMemo, useRef, useState } from "react"; @@ -35,6 +36,7 @@ import type { TaskCreateInput, TaskExecution, TaskExecutionArtifact, + TaskFollowUpMode, TaskStatus, } from "../../shared/types"; import { COLUMN_LABELS, PRIORITIES, TASK_STATUSES } from "../../shared/types"; @@ -172,6 +174,7 @@ type PendingFollowUp = { id: string; prompt: string; content: string; + mode: TaskFollowUpMode; createdAt: string; status: "sending" | "failed"; }; @@ -183,8 +186,11 @@ type TaskThreadMessage = content: string; createdAt: string; label: string; - pending?: PendingFollowUp["status"]; - prompt?: string; + pending?: PendingFollowUp["status"]; + prompt?: string; + mode: TaskFollowUpMode; + executionId?: string; + canForce?: boolean; } | { type: "agent_result"; id: string; execution: TaskExecution; runIndex: number } | { type: "next_action"; id: string; execution: TaskExecution; content: string } @@ -209,8 +215,9 @@ export function TaskModal(props: { focusAreas: FocusArea[]; onClose: () => void; onDelete?: () => Promise; - onAskFollowUp?: (prompt: string) => Promise; + onAskFollowUp?: (prompt: string, mode?: TaskFollowUpMode) => Promise; onAbortExecution?: () => Promise; + onForceFollowUp?: (executionId: string) => Promise; onCreateDraft?: (input: TaskCreateInput) => Promise; onSave: (input: TaskCreateInput) => Promise; }) { @@ -312,6 +319,7 @@ export function TaskModal(props: { onTitleTouched={() => setTitleTouched(true)} onAskFollowUp={props.onAskFollowUp} onAbortExecution={props.onAbortExecution} + onForceFollowUp={props.onForceFollowUp} onCreateDraft={props.onCreateDraft} /> ); @@ -404,13 +412,15 @@ function TaskDetailWorkspace(props: { onPriorityChange: (value: Priority) => void; onFocusAreaChange: (value: string) => void; onTitleTouched: () => void; - onAskFollowUp?: (prompt: string) => Promise; + onAskFollowUp?: (prompt: string, mode?: TaskFollowUpMode) => Promise; onAbortExecution?: () => Promise; + onForceFollowUp?: (executionId: string) => Promise; onCreateDraft?: (input: TaskCreateInput) => Promise; }) { const [followUp, setFollowUp] = useState(""); const [asking, setAsking] = useState(false); const [aborting, setAborting] = useState(false); + const [forcingExecutionId, setForcingExecutionId] = useState(null); const [creatingDraft, setCreatingDraft] = useState(false); const [nextActionConsumed, setNextActionConsumed] = useState(false); const [pendingFollowUps, setPendingFollowUps] = useState([]); @@ -566,6 +576,7 @@ function TaskDetailWorkspace(props: { promptOverride?: string, displayOverride?: string, pendingId?: string, + mode: TaskFollowUpMode = "queue", ) { const prompt = (promptOverride ?? followUp).trim(); if (!prompt || !props.onAskFollowUp) { @@ -578,20 +589,23 @@ function TaskDetailWorkspace(props: { setError(null); setPendingFollowUps((current) => pendingId - ? current.map((entry) => (entry.id === pendingId ? { ...entry, status: "sending" } : entry)) + ? current.map((entry) => + entry.id === pendingId ? { ...entry, mode, status: "sending" } : entry, + ) : [ ...current, { id, prompt, content, + mode, createdAt: new Date().toISOString(), status: "sending", }, ], ); try { - await props.onAskFollowUp(prompt); + await props.onAskFollowUp(prompt, mode); setPendingFollowUps((current) => current.filter((entry) => entry.id !== id)); if (!promptOverride) { setFollowUp(""); @@ -600,7 +614,9 @@ function TaskDetailWorkspace(props: { setNextActionConsumed(true); } setNotice( - running + mode === "interrupt" + ? "Forced request queued." + : running ? "Follow-up queued. It will start after the current work finishes." : "Follow-up queued.", ); @@ -631,8 +647,25 @@ function TaskDetailWorkspace(props: { } } + async function forceFollowUp(executionId: string) { + if (!props.onForceFollowUp) { + return; + } + setForcingExecutionId(executionId); + setNotice(null); + setError(null); + try { + await props.onForceFollowUp(executionId); + setNotice("Forced request queued."); + } catch (err) { + setError(err instanceof Error ? err.message : String(err)); + } finally { + setForcingExecutionId(null); + } + } + async function retryPendingFollowUp(entry: PendingFollowUp) { - await askFollowUp(entry.prompt, entry.content, entry.id); + await askFollowUp(entry.prompt, entry.content, entry.id, entry.mode); } async function runNextActionAsFollowUp() { @@ -809,8 +842,10 @@ function TaskDetailWorkspace(props: { wide={inspectorCollapsed} nextActionConsumed={nextActionConsumed} creatingDraft={creatingDraft} + forcingExecutionId={forcingExecutionId} onFollowUpChange={setFollowUp} onAskFollowUp={() => void askFollowUp()} + onForceFollowUp={(executionId) => void forceFollowUp(executionId)} onRetryPendingFollowUp={(entry) => void retryPendingFollowUp(entry)} onRunNextAction={nextAction && props.onAskFollowUp ? () => void runNextActionAsFollowUp() : undefined} onCreateDraft={nextAction && props.onCreateDraft ? () => void createDraftFromNextAction() : undefined} @@ -1102,8 +1137,10 @@ function TaskConversation(props: { wide: boolean; nextActionConsumed: boolean; creatingDraft: boolean; + forcingExecutionId: string | null; onFollowUpChange: (value: string) => void; onAskFollowUp: () => void; + onForceFollowUp: (executionId: string) => void; onRetryPendingFollowUp: (entry: PendingFollowUp) => void; onRunNextAction?: () => void; onCreateDraft?: () => void; @@ -1139,7 +1176,9 @@ function TaskConversation(props: { pendingFollowUps={props.pendingFollowUps} nextActionConsumed={props.nextActionConsumed} creatingDraft={props.creatingDraft} + forcingExecutionId={props.forcingExecutionId} onRetryPendingFollowUp={props.onRetryPendingFollowUp} + onForceFollowUp={props.onForceFollowUp} onRunNextAction={props.onRunNextAction} onCreateDraft={props.onCreateDraft} /> @@ -1177,7 +1216,9 @@ function TaskThread(props: { pendingFollowUps: PendingFollowUp[]; nextActionConsumed: boolean; creatingDraft: boolean; + forcingExecutionId: string | null; onRetryPendingFollowUp: (entry: PendingFollowUp) => void; + onForceFollowUp: (executionId: string) => void; onRunNextAction?: () => void; onCreateDraft?: () => void; }) { @@ -1195,7 +1236,13 @@ function TaskThread(props: { message={message} nextActionConsumed={props.nextActionConsumed} creatingDraft={props.creatingDraft} + forcing={Boolean( + message.type === "user_request" && + message.executionId && + props.forcingExecutionId === message.executionId, + )} onRetryPendingFollowUp={props.onRetryPendingFollowUp} + onForceFollowUp={props.onForceFollowUp} onRunNextAction={props.onRunNextAction} onCreateDraft={props.onCreateDraft} /> @@ -1221,7 +1268,9 @@ function TaskThreadMessageRow(props: { message: TaskThreadMessage; nextActionConsumed: boolean; creatingDraft: boolean; + forcing: boolean; onRetryPendingFollowUp: (entry: PendingFollowUp) => void; + onForceFollowUp: (executionId: string) => void; onRunNextAction?: () => void; onCreateDraft?: () => void; }) { @@ -1246,25 +1295,45 @@ function TaskThreadMessageRow(props: {

{message.content}

- {message.pending === "failed" && message.prompt && ( + {((message.pending === "failed" && message.prompt) || message.canForce) && ( - + {message.pending === "failed" && message.prompt && ( + + )} + {message.canForce && message.executionId && ( + + )} )} @@ -1392,19 +1461,21 @@ function TaskThreadComposer(props: { {running ? : } {running ? "Queues behind current work" : "Attached to task"} - - {props.asking ? ( - - ) : ( - - )} - {props.asking ? "Sending..." : "Send"} - +
+ + {props.asking ? ( + + ) : ( + + )} + {props.asking ? "Sending..." : "Send"} + +
{running && This will queue behind the active run.} @@ -1789,12 +1860,18 @@ function buildTaskThreadMessages( content: originalRequest, createdAt: task.createdAt, label: "Original request", + mode: "queue", }); } const executions = getTaskExecutionHistory(task); executions.forEach((execution, index) => { const isLatestExecution = execution === task.execution; + const hasOtherActiveExecution = executions.some( + (otherExecution) => + otherExecution.id !== execution.id && + (otherExecution.status === "queued" || otherExecution.status === "running"), + ); const requestPrompt = displayExecutionRequestPrompt(execution); if (requestPrompt) { messages.push({ @@ -1803,6 +1880,12 @@ function buildTaskThreadMessages( content: requestPrompt, createdAt: execution.createdAt, label: execution.requestKind === "follow_up" ? "Follow-up" : "Request", + mode: "queue", + executionId: execution.id, + canForce: + execution.requestKind === "follow_up" && + execution.status === "queued" && + hasOtherActiveExecution, }); } if (execution.output.trim() || (isLatestExecution && isExecutionActive(execution)) || execution.error) { @@ -1839,9 +1922,10 @@ function buildTaskThreadMessages( id: entry.id, content: entry.content, createdAt: entry.createdAt, - label: "Follow-up", + label: entry.mode === "interrupt" ? "Forced follow-up" : "Follow-up", pending: entry.status, prompt: entry.prompt, + mode: entry.mode, }); }); diff --git a/src/server/routes.test.ts b/src/server/routes.test.ts index 872ca5a..780af5c 100644 --- a/src/server/routes.test.ts +++ b/src/server/routes.test.ts @@ -724,6 +724,88 @@ describe("routes", () => { store.close(); }); + it("forces a queued follow-up without waiting behind the active task run", async () => { + const dir = mkdtempSync(path.join(tmpdir(), "routes-force-follow-up-")); + tempDirs.push(dir); + const store = new BoardStore(path.join(dir, "board.db")); + const providerRouter = new BlockingTaskProviderRouter(store); + const app = buildServer({ + store, + providerRouter, + memoryRoot: dir, + }); + + const created = await app.inject({ + method: "POST", + url: "/api/tasks", + payload: { + title: "Correct active work", + status: "in_progress", + }, + }); + const task = created.json().task as Task; + await waitUntil(() => providerRouter.toolRequests.length === 1); + + const queued = await app.inject({ + method: "POST", + url: `/api/tasks/${task.id}/follow-up`, + payload: { prompt: "Use this correction immediately." }, + }); + expect(queued.statusCode).toBe(201); + expect(queued.json().task.execution.status).toBe("queued"); + const queuedExecutionId = queued.json().task.execution.id; + + const forced = await app.inject({ + method: "POST", + url: `/api/tasks/${task.id}/follow-up/${queuedExecutionId}/force`, + }); + + expect(forced.statusCode).toBe(200); + expect(forced.json().task.execution.requestPrompt).toBe("Use this correction immediately."); + expect(forced.json().task.execution.status).toBe("queued"); + expect( + forced + .json() + .task.execution.events.map((event: { message: string }) => event.message), + ).toContain("Forced follow-up queued."); + expect(providerRouter.toolRequestSignals[0]?.aborted).toBe(true); + await waitUntil(() => providerRouter.toolRequests.length === 2); + expect(store.getTaskExecution(queuedExecutionId)?.status).toBe("running"); + + const forcedPrompt = providerRouter.toolRequests[1]?.context.messages.find( + (message) => message.role === "user", + ); + const forcedPromptContent = + forcedPrompt?.role === "user" && typeof forcedPrompt.content === "string" + ? forcedPrompt.content + : ""; + expect(forcedPromptContent).toContain("Follow-up request: Use this correction immediately."); + + providerRouter.resolveNextToolRequest("Late result that should remain cancelled."); + await flushQueuedPromises(); + expect(store.getTask(task.id)?.status).toBe("in_progress"); + expect(providerRouter.toolRequests).toHaveLength(2); + + providerRouter.resolveNextToolRequest("Forced follow-up result."); + await waitUntil(() => store.getTask(task.id)?.execution?.output === "Forced follow-up result."); + + const detailed = await app.inject({ + method: "GET", + url: `/api/tasks/${task.id}`, + }); + expect(detailed.json().task.status).toBe("done"); + expect(detailed.json().task.execution.status).toBe("succeeded"); + expect(detailed.json().task.execution.id).toBe(queuedExecutionId); + expect(detailed.json().task.execution.previousExecutions).toHaveLength(1); + expect( + detailed + .json() + .task.execution.previousExecutions.map((execution: { status: string }) => execution.status), + ).toEqual(["cancelled"]); + await app.close(); + store.close(); + }); + it("skips queued follow-up work when the task is deleted before it starts", async () => { const dir = mkdtempSync(path.join(tmpdir(), "routes-queue-delete-")); tempDirs.push(dir); diff --git a/src/server/routes.ts b/src/server/routes.ts index 6e84172..ef11143 100644 --- a/src/server/routes.ts +++ b/src/server/routes.ts @@ -32,6 +32,7 @@ import { import { ProviderRouter } from "./providers/router"; import { abortTaskExecutionForTask, + forceQueuedTaskFollowUpExecutionForTask, startTaskExecutionForTask, startTaskFollowUpExecutionForTask, } from "./task-executor"; @@ -58,6 +59,7 @@ const taskUpdateSchema = z.object({ const taskFollowUpSchema = z.object({ prompt: z.string().trim().min(1), + mode: z.enum(["queue", "interrupt"]).optional().default("queue"), }); const settingsPatchSchema = z.object({ @@ -252,6 +254,7 @@ export function buildServer(options: { task: followUpTask, provider: store.getSettings().selectedProvider, followUp: input.prompt, + interruptExisting: input.mode === "interrupt", runInline: options.runTaskExecutionsInline, memoryRoot, memoryLogger: memoryReviewLogger, @@ -280,6 +283,32 @@ export function buildServer(options: { }; }); + app.post<{ Params: { id: string; executionId: string } }>( + "/api/tasks/:id/follow-up/:executionId/force", + async (request, reply) => { + const result = await forceQueuedTaskFollowUpExecutionForTask({ + store, + router: providerRouter, + taskId: request.params.id, + executionId: request.params.executionId, + provider: store.getSettings().selectedProvider, + runInline: options.runTaskExecutionsInline, + memoryRoot, + memoryLogger: memoryReviewLogger, + }); + if (!result.found) { + return reply.status(404).send({ error: result.reason }); + } + if (!result.forced) { + return reply.status(409).send({ error: result.reason ?? "Follow-up cannot be forced." }); + } + return { + task: store.getTask(request.params.id) ?? result.task, + execution: result.execution, + }; + }, + ); + app.delete<{ Params: { id: string } }>("/api/tasks/:id", async (request, reply) => { if (!store.deleteTask(request.params.id)) { return reply.status(404).send({ error: "Task not found" }); diff --git a/src/server/task-executor.ts b/src/server/task-executor.ts index d51063c..e4beda0 100644 --- a/src/server/task-executor.ts +++ b/src/server/task-executor.ts @@ -36,6 +36,7 @@ type StartTaskExecutionInput = { runInline?: boolean; memoryRoot?: string; memoryLogger?: MemoryReviewLogger; + interruptExisting?: boolean; }; export async function startTaskExecutionForTask( @@ -58,7 +59,9 @@ export async function startTaskFollowUpExecutionForTask( buildPrompt: (task) => buildTaskFollowUpPrompt(task, input.followUp), requestKind: "follow_up", requestPrompt: () => input.followUp, - queuedMessage: "Follow-up queued.", + queuedMessage: input.interruptExisting + ? "Forced follow-up queued." + : "Follow-up queued.", }); } @@ -72,6 +75,16 @@ type AbortTaskExecutionResult = execution?: TaskExecution; }; +type ForceQueuedFollowUpResult = + | { found: false; forced: false; reason: string } + | { + found: true; + forced: boolean; + reason?: string; + task: Task; + execution?: TaskExecution; + }; + export function abortTaskExecutionForTask(input: { store: BoardStore; taskId: string; @@ -111,6 +124,94 @@ export function abortTaskExecutionForTask(input: { }; } +export async function forceQueuedTaskFollowUpExecutionForTask(input: { + store: BoardStore; + router: ProviderRouter; + taskId: string; + executionId: string; + provider: ProviderId; + runInline?: boolean; + memoryRoot?: string; + memoryLogger?: MemoryReviewLogger; +}): Promise { + const task = input.store.getTask(input.taskId); + if (!task) { + return { found: false, forced: false, reason: "Task not found." }; + } + const execution = input.store.getTaskExecution(input.executionId); + if (!execution || execution.taskId !== task.id) { + return { found: false, forced: false, reason: "Queued follow-up not found." }; + } + if (execution.requestKind !== "follow_up" || execution.status !== "queued") { + return { + found: true, + forced: false, + reason: "Only queued follow-ups can be forced.", + task, + execution, + }; + } + + cancelAbortableExecutions({ + store: input.store, + task, + exceptExecutionId: execution.id, + }); + const taskForRun = input.store.updateTask(task.id, { status: "in_progress" }) ?? task; + const controller = new AbortController(); + activeTaskExecutionRuns.set(execution.id, { + taskId: task.id, + executionId: execution.id, + controller, + }); + const updatedExecution = + input.store.updateTaskExecution(execution.id, { + progressSummary: "Forced follow-up queued.", + events: [ + ...execution.events, + { + id: "", + kind: "queued", + message: "Forced follow-up queued.", + createdAt: "", + }, + ], + }) ?? execution; + const run = createTaskExecutionRun( + { + store: input.store, + router: input.router, + task: taskForRun, + provider: input.provider, + memoryRoot: input.memoryRoot, + memoryLogger: input.memoryLogger, + buildPrompt: (task) => buildTaskFollowUpPrompt(task, updatedExecution.requestPrompt), + }, + { + task: taskForRun, + execution: updatedExecution, + controller, + shouldMarkTaskInProgress: true, + model: updatedExecution.model ?? input.store.getProviderConfig(input.provider).model, + }, + ); + const queuedRun = enqueueTaskExecution(task.id, run, { replaceExistingQueue: true }); + if (input.runInline) { + await queuedRun; + } else { + void queuedRun.catch((err) => { + console.warn(err instanceof Error ? err.message : String(err)); + }); + } + const latestTask = input.store.getTask(task.id) ?? taskForRun; + return { + found: true, + forced: true, + task: latestTask, + execution: input.store.getTaskExecution(execution.id) ?? updatedExecution, + }; +} + async function startExecution( input: StartTaskExecutionInput & { buildPrompt: (task: Task) => string; @@ -119,38 +220,86 @@ async function startExecution( queuedMessage: string; }, ): Promise { - const shouldMarkTaskInProgress = input.task.status !== "done"; + let task = input.task; + if (input.interruptExisting) { + abortTaskExecutionForTask({ + store: input.store, + taskId: input.task.id, + }); + task = input.store.getTask(input.task.id) ?? input.task; + } + const shouldMarkTaskInProgress = task.status !== "done"; if (shouldMarkTaskInProgress) { - input.store.updateTask(input.task.id, { status: "in_progress" }); + input.store.updateTask(task.id, { status: "in_progress" }); } const config = input.store.getProviderConfig(input.provider); const execution = input.store.createTaskExecution({ - taskId: input.task.id, + taskId: task.id, provider: input.provider, model: config.model, requestKind: input.requestKind, - requestPrompt: input.requestPrompt(input.task), + requestPrompt: input.requestPrompt(task), status: "queued", progressSummary: input.queuedMessage, events: [{ id: "", kind: "queued", message: input.queuedMessage, createdAt: "" }], }); const controller = new AbortController(); activeTaskExecutionRuns.set(execution.id, { - taskId: input.task.id, + taskId: task.id, executionId: execution.id, controller, }); - const run = async () => { - let taskAtStart: Task = input.task; + const run = createTaskExecutionRun(input, { + task, + execution, + controller, + shouldMarkTaskInProgress, + model: config.model, + }); + const queuedRun = enqueueTaskExecution(task.id, run, { + replaceExistingQueue: input.interruptExisting === true, + }); + if (input.runInline) { + await queuedRun; + } else { + void queuedRun.catch((err) => { + console.warn(err instanceof Error ? err.message : String(err)); + }); + } + return input.store.getTaskExecution(execution.id) ?? execution; +} + +function createTaskExecutionRun( + input: StartTaskExecutionInput & { + buildPrompt: (task: Task) => string; + }, + runInput: { + task: Task; + execution: TaskExecution; + controller: AbortController; + shouldMarkTaskInProgress: boolean; + model: string; + }, +): () => Promise { + const { task, execution, controller, shouldMarkTaskInProgress, model } = runInput; + return async () => { + let taskAtStart: Task = task; let events = execution.events; try { - const existingTask = input.store.getTask(input.task.id); + const existingTask = input.store.getTask(task.id); if (!existingTask) { return; } + const existingExecution = input.store.getTaskExecution(execution.id); + if (!existingExecution || existingExecution.status !== "queued") { + return; + } taskAtStart = existingTask; - events = input.store.getTaskExecution(execution.id)?.events ?? execution.events; - if (isTaskRunCancelled(controller.signal) || isTaskExecutionAlreadyCancelled(input.store, execution.id)) { + events = existingExecution.events.length > 0 ? existingExecution.events : execution.events; + if ( + isTaskRunCancelled(controller.signal) || + isTaskExecutionAlreadyCancelled(input.store, execution.id) + ) { markTaskExecutionCancelled({ store: input.store, task: taskAtStart, @@ -160,7 +309,7 @@ async function startExecution( return; } if (shouldMarkTaskInProgress) { - const updatedTask = input.store.updateTask(input.task.id, { status: "in_progress" }); + const updatedTask = input.store.updateTask(task.id, { status: "in_progress" }); if (!updatedTask) { return; } @@ -184,7 +333,10 @@ async function startExecution( }); events = running?.events ?? events; const appendProgressEvent = (kind: TaskExecutionEvent["kind"], message: string) => { - if (isTaskRunCancelled(controller.signal) || isTaskExecutionAlreadyCancelled(input.store, execution.id)) { + if ( + isTaskRunCancelled(controller.signal) || + isTaskExecutionAlreadyCancelled(input.store, execution.id) + ) { return; } events = [ @@ -202,18 +354,21 @@ async function startExecution( }); events = updated?.events ?? events; }; - const boardContext = buildBoardTaskContext(input.store.listTasks(), input.task.id); + const boardContext = buildBoardTaskContext(input.store.listTasks(), task.id); const systemPrompt = buildTaskSystemPrompt(prompt, boardContext, input.memoryRoot); const result = await completeTaskWithLocalTools({ router: input.router, provider: input.provider, - model: config.model, + model, systemPrompt, prompt, signal: controller.signal, onToolProgress: (message) => appendProgressEvent("progress", message), }); - if (isTaskRunCancelled(controller.signal) || isTaskExecutionAlreadyCancelled(input.store, execution.id)) { + if ( + isTaskRunCancelled(controller.signal) || + isTaskExecutionAlreadyCancelled(input.store, execution.id) + ) { markTaskExecutionCancelled({ store: input.store, task: taskAtStart, @@ -237,12 +392,12 @@ async function startExecution( }, ], }); - input.store.updateTask(input.task.id, { status: "done" }); + input.store.updateTask(task.id, { status: "done" }); throwIfTaskRunCancelled(controller.signal); await reviewAndApplyMemory({ providerRouter: input.router, provider: input.provider, - model: config.model, + model, messages: [ { role: "user", content: prompt }, { role: "assistant", content: result || "Finished." }, @@ -257,7 +412,7 @@ async function startExecution( isTaskRunCancelled(controller.signal) || (controller.signal.aborted && isTaskRunCancellationError(err)) ) { - const task = input.store.getTask(input.task.id) ?? taskAtStart; + const task = input.store.getTask(taskAtStart.id) ?? taskAtStart; markTaskExecutionCancelled({ store: input.store, task, @@ -287,23 +442,14 @@ async function startExecution( }, ], }); - input.store.updateTask(input.task.id, { status: "needs_attention" }); + input.store.updateTask(task.id, { status: "needs_attention" }); } finally { const activeRun = activeTaskExecutionRuns.get(execution.id); - if (activeRun?.executionId === execution.id) { + if (activeRun?.executionId === execution.id && activeRun.controller === controller) { activeTaskExecutionRuns.delete(execution.id); } } }; - const queuedRun = enqueueTaskExecution(input.task.id, run); - if (input.runInline) { - await queuedRun; - } else { - void queuedRun.catch((err) => { - console.warn(err instanceof Error ? err.message : String(err)); - }); - } - return input.store.getTaskExecution(execution.id) ?? execution; } function getAbortableExecutions(store: BoardStore, taskId: string): TaskExecution[] { @@ -313,6 +459,27 @@ function getAbortableExecutions(store: BoardStore, taskId: string): TaskExecutio .reverse(); } +function cancelAbortableExecutions(input: { + store: BoardStore; + task: Task; + exceptExecutionId?: string; +}): TaskExecution[] { + return getAbortableExecutions(input.store, input.task.id) + .filter((execution) => execution.id !== input.exceptExecutionId) + .map((execution) => { + activeTaskExecutionRuns + .get(execution.id) + ?.controller.abort(new Error(CANCELLED_BY_USER_MESSAGE)); + return markTaskExecutionCancelled({ + store: input.store, + task: input.task, + executionId: execution.id, + events: execution.events, + }); + }) + .filter((execution): execution is TaskExecution => Boolean(execution)); +} + function isTaskExecutionAlreadyCancelled(store: BoardStore, executionId: string) { return store.getTaskExecution(executionId)?.status === "cancelled"; } @@ -356,11 +523,18 @@ function markTaskExecutionCancelled(input: { } function releaseCancelledTask(store: BoardStore, task: Task) { - if (task.status === "in_progress") { + const latestTask = store.getTask(task.id) ?? task; + if (latestTask.status === "in_progress" && !hasAbortableTaskExecution(store, task.id)) { store.updateTask(task.id, { status: "ready" }); } } +function hasAbortableTaskExecution(store: BoardStore, taskId: string) { + return store + .listTaskExecutions(taskId) + .some((execution) => execution.status === "running" || execution.status === "queued"); +} + function isTaskRunCancelled(signal?: AbortSignal) { return Boolean(signal?.aborted); } @@ -385,8 +559,14 @@ function isTaskRunCancellationError(err: unknown) { return err.name === "AbortError" || /\b(abort|aborted|cancelled|canceled)\b/i.test(err.message); } -function enqueueTaskExecution(taskId: string, run: () => Promise): Promise { - const previous = taskExecutionQueues.get(taskId) ?? Promise.resolve(); +function enqueueTaskExecution( + taskId: string, + run: () => Promise, + options: { replaceExistingQueue?: boolean } = {}, +): Promise { + const previous = options.replaceExistingQueue + ? Promise.resolve() + : taskExecutionQueues.get(taskId) ?? Promise.resolve(); const current = previous.catch(() => undefined).then(run); const tracked = current.finally(() => { if (taskExecutionQueues.get(taskId) === tracked) { diff --git a/src/shared/types.ts b/src/shared/types.ts index 7b79949..5c830ac 100644 --- a/src/shared/types.ts +++ b/src/shared/types.ts @@ -77,6 +77,12 @@ export type TaskCreateInput = { export type TaskUpdateInput = Partial; +export type TaskFollowUpMode = "queue" | "interrupt"; +export type TaskFollowUpInput = { + prompt: string; + mode?: TaskFollowUpMode; +}; + export type TaskExecutionStatus = | "queued" | "running"