Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: smooth and combine token output #936

Merged
merged 2 commits into from Mar 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
214 changes: 214 additions & 0 deletions src/lib/utils/messageUpdates.ts
@@ -0,0 +1,214 @@
import type { MessageUpdate, TextStreamUpdate } from "$lib/types/MessageUpdate";

type MessageUpdateRequestOptions = {
base: string;
inputs?: string;
messageId?: string;
isRetry: boolean;
isContinue: boolean;
webSearch: boolean;
files?: string[];
};
export async function fetchMessageUpdates(
conversationId: string,
opts: MessageUpdateRequestOptions,
abortSignal: AbortSignal
): Promise<AsyncGenerator<MessageUpdate>> {
const abortController = new AbortController();
abortSignal.addEventListener("abort", () => abortController.abort());

const response = await fetch(`${opts.base}/conversation/${conversationId}`, {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({
inputs: opts.inputs,
id: opts.messageId,
is_retry: opts.isRetry,
is_continue: opts.isContinue,
web_search: opts.webSearch,
files: opts.files,
}),
signal: abortController.signal,
});

if (!response.ok) {
const errorMessage = await response
.json()
.then((obj) => obj.message)
.catch(() => `Request failed with status code ${response.status}: ${response.statusText}`);
throw Error(errorMessage);
}
if (!response.body) {
throw Error("Body not defined");
}
return smoothAsyncIterator(
streamMessageUpdatesToFullWords(endpointStreamToIterator(response, abortController))
);
}

async function* endpointStreamToIterator(
response: Response,
abortController: AbortController
): AsyncGenerator<MessageUpdate> {
const reader = response.body?.pipeThrough(new TextDecoderStream()).getReader();
if (!reader) throw Error("Response for endpoint had no body");

// Handle any cases where we must abort
reader.closed.then(() => abortController.abort());

// Handle logic for aborting
abortController.signal.addEventListener("abort", () => reader.cancel());

// ex) If the last response is => {"type": "stream", "token":
// It should be => {"type": "stream", "token": "Hello"} = prev_input_chunk + "Hello"}
let prevChunk = "";
while (!abortController.signal.aborted) {
const { done, value } = await reader.read();
if (done) {
abortController.abort();
break;
}
if (!value) continue;

const { messageUpdates, remainingText } = parseMessageUpdates(prevChunk + value);
prevChunk = remainingText;
for (const messageUpdate of messageUpdates) yield messageUpdate;
}
}

function parseMessageUpdates(value: string): {
messageUpdates: MessageUpdate[];
remainingText: string;
} {
const inputs = value.split("\n");
const messageUpdates: MessageUpdate[] = [];
for (const input of inputs) {
try {
messageUpdates.push(JSON.parse(input) as MessageUpdate);
} catch (error) {
// in case of parsing error, we return what we were able to parse
if (error instanceof SyntaxError) {
return {
messageUpdates,
remainingText: inputs.at(-1) ?? "",
};
}
}
}
return { messageUpdates, remainingText: "" };
}

/**
* Emits all the message updates immediately that aren't "stream" type
* Emits a concatenated "stream" type message update once it detects a full word
* Example: "what" " don" "'t" => "what" " don't"
* Only supports latin languages, ignores others
*/
async function* streamMessageUpdatesToFullWords(
iterator: AsyncGenerator<MessageUpdate>
): AsyncGenerator<MessageUpdate> {
let bufferedStreamUpdates: TextStreamUpdate[] = [];

const endAlphanumeric = /[a-zA-Z0-9À-ž'`]+$/;
const beginnningAlphanumeric = /^[a-zA-Z0-9À-ž'`]+/;

for await (const messageUpdate of iterator) {
if (messageUpdate.type !== "stream") {
yield messageUpdate;
continue;
}
bufferedStreamUpdates.push(messageUpdate);

let lastIndexEmitted = 0;
for (let i = 1; i < bufferedStreamUpdates.length; i++) {
const prevEndsAlphanumeric = endAlphanumeric.test(bufferedStreamUpdates[i - 1].token);
const currBeginsAlphanumeric = beginnningAlphanumeric.test(bufferedStreamUpdates[i].token);
const shouldCombine = prevEndsAlphanumeric && currBeginsAlphanumeric;
const combinedTooMany = i - lastIndexEmitted >= 5;
if (shouldCombine && !combinedTooMany) continue;

// Combine tokens together and emit
yield {
type: "stream",
token: bufferedStreamUpdates
.slice(lastIndexEmitted, i)
.map((_) => _.token)
.join(""),
};
lastIndexEmitted = i;
}
bufferedStreamUpdates = bufferedStreamUpdates.slice(lastIndexEmitted);
}
for (const messageUpdate of bufferedStreamUpdates) yield messageUpdate;
}

/**
* Attempts to smooth out the time between values emitted by an async iterator
* by waiting for the average time between values to emit the next value
*/
async function* smoothAsyncIterator<T>(iterator: AsyncGenerator<T>): AsyncGenerator<T> {
const eventTarget = new EventTarget();
let done = false;
const valuesBuffer: T[] = [];
const valueTimesMS: number[] = [];

const next = async () => {
const obj = await iterator.next();
if (obj.done) {
done = true;
} else {
valuesBuffer.push(obj.value);
valueTimesMS.push(performance.now());
next();
}
eventTarget.dispatchEvent(new Event("next"));
};
next();

let timeOfLastEmitMS = performance.now();
while (!done || valuesBuffer.length > 0) {
// Only consider the last X times between tokens
const sampledTimesMS = valueTimesMS.slice(-30);

// Get the total time spent in abnormal periods
const anomalyThresholdMS = 2000;
const anomalyDurationMS = sampledTimesMS
.map((time, i, times) => time - times[i - 1])
.slice(1)
.filter((time) => time > anomalyThresholdMS)
.reduce((a, b) => a + b, 0);

// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
const totalTimeMSBetweenValues = sampledTimesMS.at(-1)! - sampledTimesMS[0];
const timeMSBetweenValues = totalTimeMSBetweenValues - anomalyDurationMS;

const averageTimeMSBetweenValues = Math.min(
200,
timeMSBetweenValues / (sampledTimesMS.length - 1)
);
const timeSinceLastEmitMS = performance.now() - timeOfLastEmitMS;

// Emit after waiting duration or cancel if "next" event is emitted
const gotNext = await Promise.race([
sleep(Math.max(5, averageTimeMSBetweenValues - timeSinceLastEmitMS)),
waitForEvent(eventTarget, "next"),
]);

// Go to next iteration so we can re-calculate when to emit
if (gotNext) continue;

// Nothing in buffer to emit
if (valuesBuffer.length === 0) continue;

// Emit
timeOfLastEmitMS = performance.now();
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
yield valuesBuffer.shift()!;
}
}

const sleep = (ms: number) => new Promise((resolve) => setTimeout(resolve, ms));
const waitForEvent = (eventTarget: EventTarget, eventName: string) =>
new Promise<boolean>((resolve) =>
eventTarget.addEventListener(eventName, () => resolve(true), { once: true })
);
153 changes: 50 additions & 103 deletions src/routes/conversation/[id]/+page.svelte
Expand Up @@ -16,6 +16,7 @@
import file2base64 from "$lib/utils/file2base64";
import { addChildren } from "$lib/utils/tree/addChildren";
import { addSibling } from "$lib/utils/tree/addSibling";
import { fetchMessageUpdates } from "$lib/utils/messageUpdates";
import { createConvTreeStore } from "$lib/stores/convTree";
import type { v4 } from "uuid";

Expand Down Expand Up @@ -181,125 +182,71 @@

messages = [...messages];
const messageToWriteTo = messages.find((message) => message.id === messageToWriteToId);

if (!messageToWriteTo) {
throw new Error("Message to write to not found");
}

// disable websearch if assistant is present
const hasAssistant = !!$page.data.assistant;

const response = await fetch(`${base}/conversation/${$page.params.id}`, {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({
const messageUpdatesAbortController = new AbortController();
const messageUpdatesIterator = await fetchMessageUpdates(
$page.params.id,
{
base,
inputs: prompt,
id: messageId,
is_retry: isRetry,
is_continue: isContinue,
web_search: !hasAssistant && $webSearchParameters.useSearch,
messageId,
isRetry,
isContinue,
webSearch: !hasAssistant && $webSearchParameters.useSearch,
files: isRetry ? undefined : resizedImages,
}),
},
messageUpdatesAbortController.signal
).catch((err) => {
error.set(err.message);
});
if (messageUpdatesIterator === undefined) return;

files = [];
if (!response.body) {
throw new Error("Body not defined");
}

if (!response.ok) {
error.set((await response.json())?.message);
return;
}

// eslint-disable-next-line no-undef
const encoder = new TextDecoderStream();
const reader = response?.body?.pipeThrough(encoder).getReader();
let finalAnswer = "";
const messageUpdates: MessageUpdate[] = [];

// set str queue
// ex) if the last response is => {"type": "stream", "token":
// It should be => {"type": "stream", "token": "Hello"} = prev_input_chunk + "Hello"}
let prev_input_chunk = [""];

// this is a bit ugly
// we read the stream until we get the final answer

let readerClosed = false;

reader.closed.then(() => {
readerClosed = true;
});

while (finalAnswer === "") {
// check for abort
if ($isAborted || $error || readerClosed) {
reader?.cancel();
for await (const update of messageUpdatesIterator) {
if ($isAborted) {
messageUpdatesAbortController.abort();
return;
}
if (update.type === "finalAnswer") {
loading = false;
pending = false;
break;
}

// if there is something to read
await reader?.read().then(async ({ done, value }) => {
// we read, if it's done we cancel
if (done) {
reader.cancel();
}

if (!value) {
return;
}

value = prev_input_chunk.pop() + value;

// if it's not done we parse the value, which contains all messages
const inputs = value.split("\n");
inputs.forEach(async (el: string) => {
try {
const update = JSON.parse(el) as MessageUpdate;

if (update.type !== "stream") {
messageUpdates.push(update);
}

if (update.type === "finalAnswer") {
finalAnswer = update.text;
loading = false;
pending = false;
} else if (update.type === "stream") {
pending = false;
messageToWriteTo.content += update.token;
messages = [...messages];
} else if (update.type === "webSearch") {
messageToWriteTo.updates = [...(messageToWriteTo.updates ?? []), update];
messages = [...messages];
} else if (update.type === "status") {
if (update.status === "title" && update.message) {
const convInData = data.conversations.find(({ id }) => id === $page.params.id);
if (convInData) {
convInData.title = update.message;

$titleUpdate = {
title: update.message,
convId: $page.params.id,
};
}
} else if (update.status === "error") {
$error = update.message ?? "An error has occurred";
}
} else if (update.type === "error") {
error.set(update.message);
reader.cancel();
}
} catch (parseError) {
// in case of parsing error we wait for the next message

if (el === inputs[inputs.length - 1]) {
prev_input_chunk.push(el);
}
return;
messageUpdates.push(update);

if (update.type === "stream") {
pending = false;
messageToWriteTo.content += update.token;
messages = [...messages];
} else if (update.type === "webSearch") {
messageToWriteTo.updates = [...(messageToWriteTo.updates ?? []), update];
messages = [...messages];
} else if (update.type === "status") {
if (update.status === "title" && update.message) {
const convInData = data.conversations.find(({ id }) => id === $page.params.id);
if (convInData) {
convInData.title = update.message;

$titleUpdate = {
title: update.message,
convId: $page.params.id,
};
}
});
});
} else if (update.status === "error") {
$error = update.message ?? "An error has occurred";
}
} else if (update.type === "error") {
error.set(update.message);
messageUpdatesAbortController.abort();
}
}

messageToWriteTo.updates = messageUpdates;
Expand Down