Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: smooth and combine token output (#936)
* feat: smooth and combine token output * fix: stop generating button not triggering message updates abort
- Loading branch information
Showing
2 changed files
with
264 additions
and
103 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,214 @@ | ||
import type { MessageUpdate, TextStreamUpdate } from "$lib/types/MessageUpdate"; | ||
|
||
type MessageUpdateRequestOptions = { | ||
base: string; | ||
inputs?: string; | ||
messageId?: string; | ||
isRetry: boolean; | ||
isContinue: boolean; | ||
webSearch: boolean; | ||
files?: string[]; | ||
}; | ||
export async function fetchMessageUpdates( | ||
conversationId: string, | ||
opts: MessageUpdateRequestOptions, | ||
abortSignal: AbortSignal | ||
): Promise<AsyncGenerator<MessageUpdate>> { | ||
const abortController = new AbortController(); | ||
abortSignal.addEventListener("abort", () => abortController.abort()); | ||
|
||
const response = await fetch(`${opts.base}/conversation/${conversationId}`, { | ||
method: "POST", | ||
headers: { "Content-Type": "application/json" }, | ||
body: JSON.stringify({ | ||
inputs: opts.inputs, | ||
id: opts.messageId, | ||
is_retry: opts.isRetry, | ||
is_continue: opts.isContinue, | ||
web_search: opts.webSearch, | ||
files: opts.files, | ||
}), | ||
signal: abortController.signal, | ||
}); | ||
|
||
if (!response.ok) { | ||
const errorMessage = await response | ||
.json() | ||
.then((obj) => obj.message) | ||
.catch(() => `Request failed with status code ${response.status}: ${response.statusText}`); | ||
throw Error(errorMessage); | ||
} | ||
if (!response.body) { | ||
throw Error("Body not defined"); | ||
} | ||
return smoothAsyncIterator( | ||
streamMessageUpdatesToFullWords(endpointStreamToIterator(response, abortController)) | ||
); | ||
} | ||
|
||
async function* endpointStreamToIterator( | ||
response: Response, | ||
abortController: AbortController | ||
): AsyncGenerator<MessageUpdate> { | ||
const reader = response.body?.pipeThrough(new TextDecoderStream()).getReader(); | ||
if (!reader) throw Error("Response for endpoint had no body"); | ||
|
||
// Handle any cases where we must abort | ||
reader.closed.then(() => abortController.abort()); | ||
|
||
// Handle logic for aborting | ||
abortController.signal.addEventListener("abort", () => reader.cancel()); | ||
|
||
// ex) If the last response is => {"type": "stream", "token": | ||
// It should be => {"type": "stream", "token": "Hello"} = prev_input_chunk + "Hello"} | ||
let prevChunk = ""; | ||
while (!abortController.signal.aborted) { | ||
const { done, value } = await reader.read(); | ||
if (done) { | ||
abortController.abort(); | ||
break; | ||
} | ||
if (!value) continue; | ||
|
||
const { messageUpdates, remainingText } = parseMessageUpdates(prevChunk + value); | ||
prevChunk = remainingText; | ||
for (const messageUpdate of messageUpdates) yield messageUpdate; | ||
} | ||
} | ||
|
||
function parseMessageUpdates(value: string): { | ||
messageUpdates: MessageUpdate[]; | ||
remainingText: string; | ||
} { | ||
const inputs = value.split("\n"); | ||
const messageUpdates: MessageUpdate[] = []; | ||
for (const input of inputs) { | ||
try { | ||
messageUpdates.push(JSON.parse(input) as MessageUpdate); | ||
} catch (error) { | ||
// in case of parsing error, we return what we were able to parse | ||
if (error instanceof SyntaxError) { | ||
return { | ||
messageUpdates, | ||
remainingText: inputs.at(-1) ?? "", | ||
}; | ||
} | ||
} | ||
} | ||
return { messageUpdates, remainingText: "" }; | ||
} | ||
|
||
/** | ||
* Emits all the message updates immediately that aren't "stream" type | ||
* Emits a concatenated "stream" type message update once it detects a full word | ||
* Example: "what" " don" "'t" => "what" " don't" | ||
* Only supports latin languages, ignores others | ||
*/ | ||
async function* streamMessageUpdatesToFullWords( | ||
iterator: AsyncGenerator<MessageUpdate> | ||
): AsyncGenerator<MessageUpdate> { | ||
let bufferedStreamUpdates: TextStreamUpdate[] = []; | ||
|
||
const endAlphanumeric = /[a-zA-Z0-9À-ž'`]+$/; | ||
const beginnningAlphanumeric = /^[a-zA-Z0-9À-ž'`]+/; | ||
|
||
for await (const messageUpdate of iterator) { | ||
if (messageUpdate.type !== "stream") { | ||
yield messageUpdate; | ||
continue; | ||
} | ||
bufferedStreamUpdates.push(messageUpdate); | ||
|
||
let lastIndexEmitted = 0; | ||
for (let i = 1; i < bufferedStreamUpdates.length; i++) { | ||
const prevEndsAlphanumeric = endAlphanumeric.test(bufferedStreamUpdates[i - 1].token); | ||
const currBeginsAlphanumeric = beginnningAlphanumeric.test(bufferedStreamUpdates[i].token); | ||
const shouldCombine = prevEndsAlphanumeric && currBeginsAlphanumeric; | ||
const combinedTooMany = i - lastIndexEmitted >= 5; | ||
if (shouldCombine && !combinedTooMany) continue; | ||
|
||
// Combine tokens together and emit | ||
yield { | ||
type: "stream", | ||
token: bufferedStreamUpdates | ||
.slice(lastIndexEmitted, i) | ||
.map((_) => _.token) | ||
.join(""), | ||
}; | ||
lastIndexEmitted = i; | ||
} | ||
bufferedStreamUpdates = bufferedStreamUpdates.slice(lastIndexEmitted); | ||
} | ||
for (const messageUpdate of bufferedStreamUpdates) yield messageUpdate; | ||
} | ||
|
||
/** | ||
* Attempts to smooth out the time between values emitted by an async iterator | ||
* by waiting for the average time between values to emit the next value | ||
*/ | ||
async function* smoothAsyncIterator<T>(iterator: AsyncGenerator<T>): AsyncGenerator<T> { | ||
const eventTarget = new EventTarget(); | ||
let done = false; | ||
const valuesBuffer: T[] = []; | ||
const valueTimesMS: number[] = []; | ||
|
||
const next = async () => { | ||
const obj = await iterator.next(); | ||
if (obj.done) { | ||
done = true; | ||
} else { | ||
valuesBuffer.push(obj.value); | ||
valueTimesMS.push(performance.now()); | ||
next(); | ||
} | ||
eventTarget.dispatchEvent(new Event("next")); | ||
}; | ||
next(); | ||
|
||
let timeOfLastEmitMS = performance.now(); | ||
while (!done || valuesBuffer.length > 0) { | ||
// Only consider the last X times between tokens | ||
const sampledTimesMS = valueTimesMS.slice(-30); | ||
|
||
// Get the total time spent in abnormal periods | ||
const anomalyThresholdMS = 2000; | ||
const anomalyDurationMS = sampledTimesMS | ||
.map((time, i, times) => time - times[i - 1]) | ||
.slice(1) | ||
.filter((time) => time > anomalyThresholdMS) | ||
.reduce((a, b) => a + b, 0); | ||
|
||
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion | ||
const totalTimeMSBetweenValues = sampledTimesMS.at(-1)! - sampledTimesMS[0]; | ||
const timeMSBetweenValues = totalTimeMSBetweenValues - anomalyDurationMS; | ||
|
||
const averageTimeMSBetweenValues = Math.min( | ||
200, | ||
timeMSBetweenValues / (sampledTimesMS.length - 1) | ||
); | ||
const timeSinceLastEmitMS = performance.now() - timeOfLastEmitMS; | ||
|
||
// Emit after waiting duration or cancel if "next" event is emitted | ||
const gotNext = await Promise.race([ | ||
sleep(Math.max(5, averageTimeMSBetweenValues - timeSinceLastEmitMS)), | ||
waitForEvent(eventTarget, "next"), | ||
]); | ||
|
||
// Go to next iteration so we can re-calculate when to emit | ||
if (gotNext) continue; | ||
|
||
// Nothing in buffer to emit | ||
if (valuesBuffer.length === 0) continue; | ||
|
||
// Emit | ||
timeOfLastEmitMS = performance.now(); | ||
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion | ||
yield valuesBuffer.shift()!; | ||
} | ||
} | ||
|
||
const sleep = (ms: number) => new Promise((resolve) => setTimeout(resolve, ms)); | ||
const waitForEvent = (eventTarget: EventTarget, eventName: string) => | ||
new Promise<boolean>((resolve) => | ||
eventTarget.addEventListener(eventName, () => resolve(true), { once: true }) | ||
); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters