Skip to content

Commit

Permalink
feat: smooth and combine token output (#936)
Browse files Browse the repository at this point in the history
* feat: smooth and combine token output

* fix: stop generating button not triggering message updates abort
  • Loading branch information
Saghen committed Mar 22, 2024
1 parent 999407a commit 8583cf1
Show file tree
Hide file tree
Showing 2 changed files with 264 additions and 103 deletions.
214 changes: 214 additions & 0 deletions src/lib/utils/messageUpdates.ts
@@ -0,0 +1,214 @@
import type { MessageUpdate, TextStreamUpdate } from "$lib/types/MessageUpdate";

type MessageUpdateRequestOptions = {
base: string;
inputs?: string;
messageId?: string;
isRetry: boolean;
isContinue: boolean;
webSearch: boolean;
files?: string[];
};
export async function fetchMessageUpdates(
conversationId: string,
opts: MessageUpdateRequestOptions,
abortSignal: AbortSignal
): Promise<AsyncGenerator<MessageUpdate>> {
const abortController = new AbortController();
abortSignal.addEventListener("abort", () => abortController.abort());

const response = await fetch(`${opts.base}/conversation/${conversationId}`, {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({
inputs: opts.inputs,
id: opts.messageId,
is_retry: opts.isRetry,
is_continue: opts.isContinue,
web_search: opts.webSearch,
files: opts.files,
}),
signal: abortController.signal,
});

if (!response.ok) {
const errorMessage = await response
.json()
.then((obj) => obj.message)
.catch(() => `Request failed with status code ${response.status}: ${response.statusText}`);
throw Error(errorMessage);
}
if (!response.body) {
throw Error("Body not defined");
}
return smoothAsyncIterator(
streamMessageUpdatesToFullWords(endpointStreamToIterator(response, abortController))
);
}

async function* endpointStreamToIterator(
response: Response,
abortController: AbortController
): AsyncGenerator<MessageUpdate> {
const reader = response.body?.pipeThrough(new TextDecoderStream()).getReader();
if (!reader) throw Error("Response for endpoint had no body");

// Handle any cases where we must abort
reader.closed.then(() => abortController.abort());

// Handle logic for aborting
abortController.signal.addEventListener("abort", () => reader.cancel());

// ex) If the last response is => {"type": "stream", "token":
// It should be => {"type": "stream", "token": "Hello"} = prev_input_chunk + "Hello"}
let prevChunk = "";
while (!abortController.signal.aborted) {
const { done, value } = await reader.read();
if (done) {
abortController.abort();
break;
}
if (!value) continue;

const { messageUpdates, remainingText } = parseMessageUpdates(prevChunk + value);
prevChunk = remainingText;
for (const messageUpdate of messageUpdates) yield messageUpdate;
}
}

function parseMessageUpdates(value: string): {
messageUpdates: MessageUpdate[];
remainingText: string;
} {
const inputs = value.split("\n");
const messageUpdates: MessageUpdate[] = [];
for (const input of inputs) {
try {
messageUpdates.push(JSON.parse(input) as MessageUpdate);
} catch (error) {
// in case of parsing error, we return what we were able to parse
if (error instanceof SyntaxError) {
return {
messageUpdates,
remainingText: inputs.at(-1) ?? "",
};
}
}
}
return { messageUpdates, remainingText: "" };
}

/**
* Emits all the message updates immediately that aren't "stream" type
* Emits a concatenated "stream" type message update once it detects a full word
* Example: "what" " don" "'t" => "what" " don't"
* Only supports latin languages, ignores others
*/
async function* streamMessageUpdatesToFullWords(
iterator: AsyncGenerator<MessageUpdate>
): AsyncGenerator<MessageUpdate> {
let bufferedStreamUpdates: TextStreamUpdate[] = [];

const endAlphanumeric = /[a-zA-Z0-9À-ž'`]+$/;
const beginnningAlphanumeric = /^[a-zA-Z0-9À-ž'`]+/;

for await (const messageUpdate of iterator) {
if (messageUpdate.type !== "stream") {
yield messageUpdate;
continue;
}
bufferedStreamUpdates.push(messageUpdate);

let lastIndexEmitted = 0;
for (let i = 1; i < bufferedStreamUpdates.length; i++) {
const prevEndsAlphanumeric = endAlphanumeric.test(bufferedStreamUpdates[i - 1].token);
const currBeginsAlphanumeric = beginnningAlphanumeric.test(bufferedStreamUpdates[i].token);
const shouldCombine = prevEndsAlphanumeric && currBeginsAlphanumeric;
const combinedTooMany = i - lastIndexEmitted >= 5;
if (shouldCombine && !combinedTooMany) continue;

// Combine tokens together and emit
yield {
type: "stream",
token: bufferedStreamUpdates
.slice(lastIndexEmitted, i)
.map((_) => _.token)
.join(""),
};
lastIndexEmitted = i;
}
bufferedStreamUpdates = bufferedStreamUpdates.slice(lastIndexEmitted);
}
for (const messageUpdate of bufferedStreamUpdates) yield messageUpdate;
}

/**
* Attempts to smooth out the time between values emitted by an async iterator
* by waiting for the average time between values to emit the next value
*/
async function* smoothAsyncIterator<T>(iterator: AsyncGenerator<T>): AsyncGenerator<T> {
const eventTarget = new EventTarget();
let done = false;
const valuesBuffer: T[] = [];
const valueTimesMS: number[] = [];

const next = async () => {
const obj = await iterator.next();
if (obj.done) {
done = true;
} else {
valuesBuffer.push(obj.value);
valueTimesMS.push(performance.now());
next();
}
eventTarget.dispatchEvent(new Event("next"));
};
next();

let timeOfLastEmitMS = performance.now();
while (!done || valuesBuffer.length > 0) {
// Only consider the last X times between tokens
const sampledTimesMS = valueTimesMS.slice(-30);

// Get the total time spent in abnormal periods
const anomalyThresholdMS = 2000;
const anomalyDurationMS = sampledTimesMS
.map((time, i, times) => time - times[i - 1])
.slice(1)
.filter((time) => time > anomalyThresholdMS)
.reduce((a, b) => a + b, 0);

// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
const totalTimeMSBetweenValues = sampledTimesMS.at(-1)! - sampledTimesMS[0];
const timeMSBetweenValues = totalTimeMSBetweenValues - anomalyDurationMS;

const averageTimeMSBetweenValues = Math.min(
200,
timeMSBetweenValues / (sampledTimesMS.length - 1)
);
const timeSinceLastEmitMS = performance.now() - timeOfLastEmitMS;

// Emit after waiting duration or cancel if "next" event is emitted
const gotNext = await Promise.race([
sleep(Math.max(5, averageTimeMSBetweenValues - timeSinceLastEmitMS)),
waitForEvent(eventTarget, "next"),
]);

// Go to next iteration so we can re-calculate when to emit
if (gotNext) continue;

// Nothing in buffer to emit
if (valuesBuffer.length === 0) continue;

// Emit
timeOfLastEmitMS = performance.now();
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
yield valuesBuffer.shift()!;
}
}

const sleep = (ms: number) => new Promise((resolve) => setTimeout(resolve, ms));
const waitForEvent = (eventTarget: EventTarget, eventName: string) =>
new Promise<boolean>((resolve) =>
eventTarget.addEventListener(eventName, () => resolve(true), { once: true })
);
153 changes: 50 additions & 103 deletions src/routes/conversation/[id]/+page.svelte
Expand Up @@ -16,6 +16,7 @@
import file2base64 from "$lib/utils/file2base64";
import { addChildren } from "$lib/utils/tree/addChildren";
import { addSibling } from "$lib/utils/tree/addSibling";
import { fetchMessageUpdates } from "$lib/utils/messageUpdates";
import { createConvTreeStore } from "$lib/stores/convTree";
import type { v4 } from "uuid";
Expand Down Expand Up @@ -181,125 +182,71 @@
messages = [...messages];
const messageToWriteTo = messages.find((message) => message.id === messageToWriteToId);
if (!messageToWriteTo) {
throw new Error("Message to write to not found");
}
// disable websearch if assistant is present
const hasAssistant = !!$page.data.assistant;
const response = await fetch(`${base}/conversation/${$page.params.id}`, {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({
const messageUpdatesAbortController = new AbortController();
const messageUpdatesIterator = await fetchMessageUpdates(
$page.params.id,
{
base,
inputs: prompt,
id: messageId,
is_retry: isRetry,
is_continue: isContinue,
web_search: !hasAssistant && $webSearchParameters.useSearch,
messageId,
isRetry,
isContinue,
webSearch: !hasAssistant && $webSearchParameters.useSearch,
files: isRetry ? undefined : resizedImages,
}),
},
messageUpdatesAbortController.signal
).catch((err) => {
error.set(err.message);
});
if (messageUpdatesIterator === undefined) return;
files = [];
if (!response.body) {
throw new Error("Body not defined");
}
if (!response.ok) {
error.set((await response.json())?.message);
return;
}
// eslint-disable-next-line no-undef
const encoder = new TextDecoderStream();
const reader = response?.body?.pipeThrough(encoder).getReader();
let finalAnswer = "";
const messageUpdates: MessageUpdate[] = [];
// set str queue
// ex) if the last response is => {"type": "stream", "token":
// It should be => {"type": "stream", "token": "Hello"} = prev_input_chunk + "Hello"}
let prev_input_chunk = [""];
// this is a bit ugly
// we read the stream until we get the final answer
let readerClosed = false;
reader.closed.then(() => {
readerClosed = true;
});
while (finalAnswer === "") {
// check for abort
if ($isAborted || $error || readerClosed) {
reader?.cancel();
for await (const update of messageUpdatesIterator) {
if ($isAborted) {
messageUpdatesAbortController.abort();
return;
}
if (update.type === "finalAnswer") {
loading = false;
pending = false;
break;
}
// if there is something to read
await reader?.read().then(async ({ done, value }) => {
// we read, if it's done we cancel
if (done) {
reader.cancel();
}
if (!value) {
return;
}
value = prev_input_chunk.pop() + value;
// if it's not done we parse the value, which contains all messages
const inputs = value.split("\n");
inputs.forEach(async (el: string) => {
try {
const update = JSON.parse(el) as MessageUpdate;
if (update.type !== "stream") {
messageUpdates.push(update);
}
if (update.type === "finalAnswer") {
finalAnswer = update.text;
loading = false;
pending = false;
} else if (update.type === "stream") {
pending = false;
messageToWriteTo.content += update.token;
messages = [...messages];
} else if (update.type === "webSearch") {
messageToWriteTo.updates = [...(messageToWriteTo.updates ?? []), update];
messages = [...messages];
} else if (update.type === "status") {
if (update.status === "title" && update.message) {
const convInData = data.conversations.find(({ id }) => id === $page.params.id);
if (convInData) {
convInData.title = update.message;
$titleUpdate = {
title: update.message,
convId: $page.params.id,
};
}
} else if (update.status === "error") {
$error = update.message ?? "An error has occurred";
}
} else if (update.type === "error") {
error.set(update.message);
reader.cancel();
}
} catch (parseError) {
// in case of parsing error we wait for the next message
if (el === inputs[inputs.length - 1]) {
prev_input_chunk.push(el);
}
return;
messageUpdates.push(update);
if (update.type === "stream") {
pending = false;
messageToWriteTo.content += update.token;
messages = [...messages];
} else if (update.type === "webSearch") {
messageToWriteTo.updates = [...(messageToWriteTo.updates ?? []), update];
messages = [...messages];
} else if (update.type === "status") {
if (update.status === "title" && update.message) {
const convInData = data.conversations.find(({ id }) => id === $page.params.id);
if (convInData) {
convInData.title = update.message;
$titleUpdate = {
title: update.message,
convId: $page.params.id,
};
}
});
});
} else if (update.status === "error") {
$error = update.message ?? "An error has occurred";
}
} else if (update.type === "error") {
error.set(update.message);
messageUpdatesAbortController.abort();
}
}
messageToWriteTo.updates = messageUpdates;
Expand Down

0 comments on commit 8583cf1

Please sign in to comment.