diff --git a/agents/src/tts/stream_adapter.ts b/agents/src/tts/stream_adapter.ts index e1ab402df..58f6d8165 100644 --- a/agents/src/tts/stream_adapter.ts +++ b/agents/src/tts/stream_adapter.ts @@ -51,7 +51,9 @@ export class StreamAdapterWrapper extends SynthesizeStream { async #run() { const forwardInput = async () => { - for await (const input of this.input) { + while (true) { + const { done, value: input } = await this.inputReader.read(); + if (done) break; if (input === SynthesizeStream.FLUSH_SENTINEL) { this.#sentenceStream.flush(); } else { @@ -65,10 +67,10 @@ export class StreamAdapterWrapper extends SynthesizeStream { const synthesize = async () => { for await (const ev of this.#sentenceStream) { for await (const audio of this.#tts.synthesize(ev.token)) { - this.output.put(audio); + this.outputWriter.write(audio); } } - this.output.put(SynthesizeStream.END_OF_STREAM); + this.outputWriter.write(SynthesizeStream.END_OF_STREAM); }; Promise.all([forwardInput(), synthesize()]); diff --git a/agents/src/tts/tts.ts b/agents/src/tts/tts.ts index 7826b446a..da66799ed 100644 --- a/agents/src/tts/tts.ts +++ b/agents/src/tts/tts.ts @@ -4,8 +4,12 @@ import type { AudioFrame } from '@livekit/rtc-node'; import type { TypedEventEmitter as TypedEmitter } from '@livekit/typed-emitter'; import { EventEmitter } from 'node:events'; +import type { ReadableStream } from 'node:stream/web'; +import { log } from '../log.js'; import type { TTSMetrics } from '../metrics/base.js'; -import { AsyncIterableQueue, mergeFrames } from '../utils.js'; +import { DeferredReadableStream } from '../stream/deferred_stream.js'; +import { IdentityTransform } from '../stream/identity_transform.js'; +import { mergeFrames } from '../utils.js'; /** SynthesizedAudio is a packet of speech synthesis as returned by the TTS. */ export interface SynthesizedAudio { @@ -105,22 +109,73 @@ export abstract class SynthesizeStream { protected static readonly FLUSH_SENTINEL = Symbol('FLUSH_SENTINEL'); static readonly END_OF_STREAM = Symbol('END_OF_STREAM'); - protected input = new AsyncIterableQueue(); - protected queue = new AsyncIterableQueue< + protected inputReader: ReadableStreamDefaultReader< + string | typeof SynthesizeStream.FLUSH_SENTINEL + >; + protected outputWriter: WritableStreamDefaultWriter< SynthesizedAudio | typeof SynthesizeStream.END_OF_STREAM - >(); - protected output = new AsyncIterableQueue< - SynthesizedAudio | typeof SynthesizeStream.END_OF_STREAM - >(); + >; protected closed = false; abstract label: string; #tts: TTS; #metricsPendingTexts: string[] = []; #metricsText = ''; - #monitorMetricsTask?: Promise; + + private deferredInputStream: DeferredReadableStream< + string | typeof SynthesizeStream.FLUSH_SENTINEL + >; + private metricsStream: ReadableStream; + private input = new IdentityTransform(); + private output = new IdentityTransform< + SynthesizedAudio | typeof SynthesizeStream.END_OF_STREAM + >(); + private inputWriter: WritableStreamDefaultWriter; + private outputReader: ReadableStreamDefaultReader< + SynthesizedAudio | typeof SynthesizeStream.END_OF_STREAM + >; + private logger = log(); + private inputClosed = false; constructor(tts: TTS) { this.#tts = tts; + this.deferredInputStream = new DeferredReadableStream(); + + this.inputWriter = this.input.writable.getWriter(); + this.inputReader = this.input.readable.getReader(); + this.outputWriter = this.output.writable.getWriter(); + + const [outputStream, metricsStream] = this.output.readable.tee(); + this.outputReader = outputStream.getReader(); + this.metricsStream = metricsStream; + + this.pumpDeferredStream(); + this.monitorMetrics(); + } + + /** + * Reads from the deferred input stream and forwards chunks to the input writer. + * + * Note: we can't just do this.deferredInputStream.stream.pipeTo(this.input.writable) + * because the inputWriter locks the this.input.writable stream. All writes must go through + * the inputWriter. + */ + private async pumpDeferredStream() { + const reader = this.deferredInputStream.stream.getReader(); + try { + while (true) { + const { done, value } = await reader.read(); + if (done || value === SynthesizeStream.FLUSH_SENTINEL) { + break; + } + this.inputWriter.write(value); + } + } catch (error) { + this.logger.error(error, 'Error reading deferred input stream'); + } finally { + reader.releaseLock(); + this.flush(); + this.endInput(); + } } protected async monitorMetrics() { @@ -148,9 +203,11 @@ export abstract class SynthesizeStream } }; - for await (const audio of this.queue) { - this.output.put(audio); - if (audio === SynthesizeStream.END_OF_STREAM) continue; + const metricsReader = this.metricsStream.getReader(); + + while (true) { + const { done, value: audio } = await metricsReader.read(); + if (done || audio === SynthesizeStream.END_OF_STREAM) break; requestId = audio.requestId; if (!ttfb) { ttfb = process.hrtime.bigint() - startTime; @@ -164,23 +221,24 @@ export abstract class SynthesizeStream if (requestId) { emit(); } - this.output.close(); + } + + updateInputStream(text: ReadableStream) { + this.deferredInputStream.setSource(text); } /** Push a string of text to the TTS */ + /** @deprecated Use `updateInputStream` instead */ pushText(text: string) { - if (!this.#monitorMetricsTask) { - this.#monitorMetricsTask = this.monitorMetrics(); - } this.#metricsText += text; - if (this.input.closed) { + if (this.inputClosed) { throw new Error('Input is closed'); } if (this.closed) { throw new Error('Stream is closed'); } - this.input.put(text); + this.inputWriter.write(text); } /** Flush the TTS, causing it to process all pending text */ @@ -189,34 +247,41 @@ export abstract class SynthesizeStream this.#metricsPendingTexts.push(this.#metricsText); this.#metricsText = ''; } - if (this.input.closed) { + if (this.inputClosed) { throw new Error('Input is closed'); } if (this.closed) { throw new Error('Stream is closed'); } - this.input.put(SynthesizeStream.FLUSH_SENTINEL); + this.inputWriter.write(SynthesizeStream.FLUSH_SENTINEL); } /** Mark the input as ended and forbid additional pushes */ endInput() { - if (this.input.closed) { + if (this.inputClosed) { throw new Error('Input is closed'); } if (this.closed) { throw new Error('Stream is closed'); } - this.input.close(); + this.inputClosed = true; + this.inputWriter.close(); } next(): Promise> { - return this.output.next(); + return this.outputReader.read().then(({ done, value }) => { + if (done) { + return { done: true, value: undefined }; + } + return { done: false, value }; + }); } /** Close both the input and output of the TTS stream */ close() { - this.input.close(); - this.output.close(); + if (!this.inputClosed) { + this.inputWriter.close(); + } this.closed = true; } @@ -240,17 +305,26 @@ export abstract class SynthesizeStream * exports its own child ChunkedStream class, which inherits this class's methods. */ export abstract class ChunkedStream implements AsyncIterableIterator { - protected queue = new AsyncIterableQueue(); - protected output = new AsyncIterableQueue(); + protected outputWriter: WritableStreamDefaultWriter< + SynthesizedAudio | typeof SynthesizeStream.END_OF_STREAM + >; protected closed = false; abstract label: string; #text: string; #tts: TTS; + private output = new IdentityTransform(); + private outputReader: ReadableStreamDefaultReader; + private metricsStream: ReadableStream; constructor(text: string, tts: TTS) { this.#text = text; this.#tts = tts; + this.outputWriter = this.output.writable.getWriter(); + const [outputStream, metricsStream] = this.output.readable.tee(); + this.outputReader = outputStream.getReader(); + this.metricsStream = metricsStream; + this.monitorMetrics(); } @@ -260,15 +334,18 @@ export abstract class ChunkedStream implements AsyncIterableIterator> { - return this.output.next(); + async next(): Promise> { + const { done, value } = await this.outputReader.read(); + if (done) { + return { done: true, value: undefined }; + } + return { done: false, value }; } /** Close both the input and output of the TTS stream */ close() { - this.queue.close(); - this.output.close(); + if (!this.closed) { + this.outputWriter.close(); + } this.closed = true; } diff --git a/agents/src/vad.ts b/agents/src/vad.ts index 2e135df3a..8ff1798ae 100644 --- a/agents/src/vad.ts +++ b/agents/src/vad.ts @@ -84,12 +84,9 @@ export abstract class VAD extends (EventEmitter as new () => TypedEmitter { protected static readonly FLUSH_SENTINEL = Symbol('FLUSH_SENTINEL'); - protected input = new IdentityTransform(); - protected output = new IdentityTransform(); - protected inputWriter: WritableStreamDefaultWriter; + protected inputReader: ReadableStreamDefaultReader; protected outputWriter: WritableStreamDefaultWriter; - protected outputReader: ReadableStreamDefaultReader; protected closed = false; protected inputClosed = false; @@ -97,8 +94,12 @@ export abstract class VADStream implements AsyncIterableIterator { #lastActivityTime = BigInt(0); private logger = log(); private deferredInputStream: DeferredReadableStream; - + private input = new IdentityTransform(); + private output = new IdentityTransform(); private metricsStream: ReadableStream; + private outputReader: ReadableStreamDefaultReader; + private inputWriter: WritableStreamDefaultWriter; + constructor(vad: VAD) { this.#vad = vad; this.deferredInputStream = new DeferredReadableStream(); @@ -207,7 +208,7 @@ export abstract class VADStream implements AsyncIterableIterator { throw new Error('Stream is closed'); } this.inputClosed = true; - this.input.writable.close(); + this.inputWriter.close(); } async next(): Promise> { @@ -220,7 +221,9 @@ export abstract class VADStream implements AsyncIterableIterator { } close() { - this.input.writable.close(); + if (!this.inputClosed) { + this.inputWriter.close(); + } this.closed = true; } diff --git a/plugins/cartesia/src/tts.ts b/plugins/cartesia/src/tts.ts index 28c991d76..c68cc5fb6 100644 --- a/plugins/cartesia/src/tts.ts +++ b/plugins/cartesia/src/tts.ts @@ -107,7 +107,7 @@ export class ChunkedStream extends tts.ChunkedStream { (res) => { res.on('data', (chunk) => { for (const frame of bstream.write(chunk)) { - this.queue.put({ + this.outputWriter.write({ requestId, frame, final: false, @@ -117,14 +117,14 @@ export class ChunkedStream extends tts.ChunkedStream { }); res.on('close', () => { for (const frame of bstream.flush()) { - this.queue.put({ + this.outputWriter.write({ requestId, frame, final: false, segmentId: requestId, }); } - this.queue.close(); + this.close(); }); }, ); @@ -178,7 +178,9 @@ export class SynthesizeStream extends tts.SynthesizeStream { }; const inputTask = async () => { - for await (const data of this.input) { + while (true) { + const { done, value: data } = await this.inputReader.read(); + if (done) break; if (data === SynthesizeStream.FLUSH_SENTINEL) { this.#tokenizer.flush(); continue; @@ -195,7 +197,7 @@ export class SynthesizeStream extends tts.SynthesizeStream { let lastFrame: AudioFrame | undefined; const sendLastFrame = (segmentId: string, final: boolean) => { if (lastFrame) { - this.queue.put({ requestId, segmentId, frame: lastFrame, final }); + this.outputWriter.write({ requestId, segmentId, frame: lastFrame, final }); lastFrame = undefined; } }; @@ -215,7 +217,7 @@ export class SynthesizeStream extends tts.SynthesizeStream { lastFrame = frame; } sendLastFrame(segmentId, true); - this.queue.put(SynthesizeStream.END_OF_STREAM); + this.outputWriter.write(SynthesizeStream.END_OF_STREAM); if (segmentId === requestId) { closing = true; diff --git a/plugins/elevenlabs/src/tts.ts b/plugins/elevenlabs/src/tts.ts index eb80664cc..4d796c35d 100644 --- a/plugins/elevenlabs/src/tts.ts +++ b/plugins/elevenlabs/src/tts.ts @@ -148,7 +148,9 @@ export class SynthesizeStream extends tts.SynthesizeStream { const tokenizeInput = async () => { let stream: tokenize.WordStream | null = null; - for await (const text of this.input) { + while (true) { + const { done, value: text } = await this.inputReader.read(); + if (done) break; if (text === SynthesizeStream.FLUSH_SENTINEL) { stream?.endInput(); stream = null; @@ -166,7 +168,7 @@ export class SynthesizeStream extends tts.SynthesizeStream { const runStream = async () => { for await (const stream of segments) { await this.#runWS(stream); - this.queue.put(SynthesizeStream.END_OF_STREAM); + this.outputWriter.write(SynthesizeStream.END_OF_STREAM); } }; @@ -246,7 +248,7 @@ export class SynthesizeStream extends tts.SynthesizeStream { let lastFrame: AudioFrame | undefined; const sendLastFrame = (segmentId: string, final: boolean) => { if (lastFrame) { - this.queue.put({ requestId, segmentId, frame: lastFrame, final }); + this.outputWriter.write({ requestId, segmentId, frame: lastFrame, final }); lastFrame = undefined; } }; @@ -278,7 +280,7 @@ export class SynthesizeStream extends tts.SynthesizeStream { lastFrame = frame; } sendLastFrame(segmentId, true); - this.queue.put(SynthesizeStream.END_OF_STREAM); + this.outputWriter.write(SynthesizeStream.END_OF_STREAM); if (segmentId === requestId) { ws.close(); diff --git a/plugins/neuphonic/src/tts.ts b/plugins/neuphonic/src/tts.ts index 4636da75d..2b54f3929 100644 --- a/plugins/neuphonic/src/tts.ts +++ b/plugins/neuphonic/src/tts.ts @@ -109,7 +109,7 @@ export class ChunkedStream extends tts.ChunkedStream { if (parsedMessage?.data?.audio) { for (const frame of bstream.write(parsedMessage.data.audio)) { - this.queue.put({ + this.outputWriter.write({ requestId, frame, final: false, @@ -123,14 +123,14 @@ export class ChunkedStream extends tts.ChunkedStream { }); res.on('close', () => { for (const frame of bstream.flush()) { - this.queue.put({ + this.outputWriter.write({ requestId, frame, final: false, segmentId: requestId, }); } - this.queue.close(); + this.close(); }); }, ); @@ -156,7 +156,9 @@ export class SynthesizeStream extends tts.SynthesizeStream { let closing = false; const sendTask = async (ws: WebSocket) => { - for await (const data of this.input) { + while (true) { + const { done, value: data } = await this.inputReader.read(); + if (done) break; if (data === SynthesizeStream.FLUSH_SENTINEL) { ws.send(JSON.stringify({ text: '' })); continue; @@ -172,7 +174,7 @@ export class SynthesizeStream extends tts.SynthesizeStream { let lastFrame: AudioFrame | undefined; const sendLastFrame = (segmentId: string, final: boolean) => { if (lastFrame) { - this.queue.put({ requestId, segmentId, frame: lastFrame, final }); + this.outputWriter.write({ requestId, segmentId, frame: lastFrame, final }); lastFrame = undefined; } }; @@ -195,7 +197,7 @@ export class SynthesizeStream extends tts.SynthesizeStream { lastFrame = frame; } sendLastFrame(requestId, true); - this.queue.put(SynthesizeStream.END_OF_STREAM); + this.outputWriter.write(SynthesizeStream.END_OF_STREAM); closing = true; ws.close(); diff --git a/plugins/openai/src/tts.ts b/plugins/openai/src/tts.ts index 6fd91053c..4a5965acc 100644 --- a/plugins/openai/src/tts.ts +++ b/plugins/openai/src/tts.ts @@ -97,7 +97,7 @@ export class ChunkedStream extends tts.ChunkedStream { let lastFrame: AudioFrame | undefined; const sendLastFrame = (segmentId: string, final: boolean) => { if (lastFrame) { - this.queue.put({ requestId, segmentId, frame: lastFrame, final }); + this.outputWriter.write({ requestId, segmentId, frame: lastFrame, final }); lastFrame = undefined; } }; @@ -108,6 +108,6 @@ export class ChunkedStream extends tts.ChunkedStream { } sendLastFrame(requestId, true); - this.queue.close(); + this.close(); } } diff --git a/plugins/resemble/src/tts.ts b/plugins/resemble/src/tts.ts index 59798049d..35da02054 100644 --- a/plugins/resemble/src/tts.ts +++ b/plugins/resemble/src/tts.ts @@ -116,7 +116,7 @@ export class ChunkedStream extends tts.ChunkedStream { const audioBytes = Buffer.from(audioContentB64, 'base64'); for (const frame of bstream.write(audioBytes)) { - this.queue.put({ + this.outputWriter.write({ requestId, frame, final: false, @@ -125,7 +125,7 @@ export class ChunkedStream extends tts.ChunkedStream { } for (const frame of bstream.flush()) { - this.queue.put({ + this.outputWriter.write({ requestId, frame, final: false, @@ -133,16 +133,16 @@ export class ChunkedStream extends tts.ChunkedStream { }); } - this.queue.close(); + this.close(); } catch (error) { this.#logger.error('Error processing Resemble API response:', error); - this.queue.close(); + this.close(); } }); res.on('error', (error) => { this.#logger.error('Resemble API error:', error); - this.queue.close(); + this.close(); }); }, ); @@ -187,7 +187,9 @@ export class SynthesizeStream extends tts.SynthesizeStream { }; const inputTask = async () => { - for await (const data of this.input) { + while (true) { + const { done, value: data } = await this.inputReader.read(); + if (done) break; if (data === SynthesizeStream.FLUSH_SENTINEL) { this.#tokenizer.flush(); continue; @@ -204,7 +206,7 @@ export class SynthesizeStream extends tts.SynthesizeStream { let lastFrame: AudioFrame | undefined; const sendLastFrame = (segmentId: string, final: boolean) => { if (lastFrame) { - this.queue.put({ requestId, segmentId, frame: lastFrame, final }); + this.outputWriter.write({ requestId, segmentId, frame: lastFrame, final }); lastFrame = undefined; } }; @@ -234,7 +236,7 @@ export class SynthesizeStream extends tts.SynthesizeStream { activeRequests.delete(Number(segmentId)); if (activeRequests.size === 0 && this.#tokenizer.closed) { - this.queue.put(SynthesizeStream.END_OF_STREAM); + this.outputWriter.write(SynthesizeStream.END_OF_STREAM); closing = true; ws.close(); } @@ -245,7 +247,7 @@ export class SynthesizeStream extends tts.SynthesizeStream { closing = true; ws.close(); - this.queue.put(SynthesizeStream.END_OF_STREAM); + this.outputWriter.write(SynthesizeStream.END_OF_STREAM); } } catch (error) { this.#logger.error(`Error parsing WebSocket message: ${error}`); @@ -256,7 +258,7 @@ export class SynthesizeStream extends tts.SynthesizeStream { this.#logger.error(`WebSocket error: ${error}`); if (!closing) { closing = true; - this.queue.put(SynthesizeStream.END_OF_STREAM); + this.outputWriter.write(SynthesizeStream.END_OF_STREAM); ws.close(); } }); @@ -264,7 +266,7 @@ export class SynthesizeStream extends tts.SynthesizeStream { ws.on('close', (code, reason) => { if (!closing) { this.#logger.error(`WebSocket closed with code ${code}: ${reason}`); - this.queue.put(SynthesizeStream.END_OF_STREAM); + this.outputWriter.write(SynthesizeStream.END_OF_STREAM); } ws.removeAllListeners(); });