Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose sampling controls in assistants (#955) #959

Merged
merged 10 commits into from Mar 27, 2024
Merged
149 changes: 133 additions & 16 deletions src/lib/components/AssistantSettings.svelte
Expand Up @@ -9,11 +9,14 @@
import { base } from "$app/paths";
import CarbonPen from "~icons/carbon/pen";
import CarbonUpload from "~icons/carbon/upload";
import CarbonHelpFilled from "~icons/carbon/help";
import CarbonSettingsAdjust from "~icons/carbon/settings-adjust";

import { useSettingsStore } from "$lib/stores/settings";
import { isHuggingChat } from "$lib/utils/isHuggingChat";
import IconInternet from "./icons/IconInternet.svelte";
import TokensCounter from "./TokensCounter.svelte";
import HoverTooltip from "./HoverTooltip.svelte";

type ActionData = {
error: boolean;
Expand All @@ -31,16 +34,26 @@

let files: FileList | null = null;
const settings = useSettingsStore();
let modelId =
assistant?.modelId ?? models.find((_model) => _model.id === $settings.activeModel)?.name;
let modelId = "";
let systemPrompt = assistant?.preprompt ?? "";
let dynamicPrompt = assistant?.dynamicPrompt ?? false;
let showModelSettings = Object.values(assistant?.generateSettings ?? {}).some((v) => !!v);

let compress: typeof readAndCompressImage | null = null;

onMount(async () => {
const module = await import("browser-image-resizer");
compress = module.readAndCompressImage;

if (assistant) {
modelId = assistant.modelId;
} else {
if (models.map((model) => model.id).includes($settings.activeModel)) {
modelId = $settings.activeModel;
} else {
modelId = models[0].id;
}
nsarrazin marked this conversation as resolved.
Show resolved Hide resolved
}
});

let inputMessage1 = assistant?.exampleInputs[0] ?? "";
Expand Down Expand Up @@ -89,6 +102,7 @@

const regex = /{{\s?url=(.+?)\s?}}/g;
$: templateVariables = [...systemPrompt.matchAll(regex)].map((match) => match[1]);
$: selectedModel = models.find((m) => m.id === modelId);
</script>

<form
Expand Down Expand Up @@ -246,21 +260,124 @@

<label>
<div class="mb-1 font-semibold">Model</div>
<select
name="modelId"
class="w-full rounded-lg border-2 border-gray-200 bg-gray-100 p-2"
bind:value={modelId}
<div class="flex gap-2">
<select
name="modelId"
class="w-full rounded-lg border-2 border-gray-200 bg-gray-100 p-2"
bind:value={modelId}
>
{#each models.filter((model) => !model.unlisted) as model}
<option value={model.id}>{model.displayName}</option>
{/each}
<p class="text-xs text-red-500">{getError("modelId", form)}</p>
</select>
<button
type="button"
class="flex aspect-square items-center gap-2 whitespace-nowrap rounded-lg border px-3 {showModelSettings
? 'border-blue-500/20 bg-blue-50 text-blue-600'
: ''}"
on:click={() => (showModelSettings = !showModelSettings)}
><CarbonSettingsAdjust class="text-xs" /></button
>
</div>
<div
class="mt-2 rounded-lg border border-blue-500/20 bg-blue-500/5 px-2 py-0.5"
class:hidden={!showModelSettings}
>
{#each models.filter((model) => !model.unlisted) as model}
<option
value={model.id}
selected={assistant
? assistant?.modelId === model.id
: $settings.activeModel === model.id}>{model.displayName}</option
>
{/each}
<p class="text-xs text-red-500">{getError("modelId", form)}</p>
</select>
<p class="text-xs text-red-500">{getError("inputMessage1", form)}</p>
<div class="my-2 grid grid-cols-1 gap-2.5 sm:grid-cols-2 sm:grid-rows-2">
<label for="temperature" class="flex justify-between">
<span class="m-1 ml-0 flex items-center gap-1.5 whitespace-nowrap text-sm">
Temperature

<HoverTooltip
label="Temperature: Controls creativity, higher values allow more variety."
>
<CarbonHelpFilled
class="inline text-xxs text-gray-500 group-hover/tooltip:text-blue-600"
/>
</HoverTooltip>
</span>
<input
type="number"
name="temperature"
min="0.1"
max="2"
step="0.1"
class="w-20 rounded-lg border-2 border-gray-200 bg-gray-100 px-2 py-1"
placeholder={selectedModel?.parameters?.temperature?.toString() ?? "1"}
value={assistant?.generateSettings?.temperature ?? ""}
/>
</label>
<label for="top_p" class="flex justify-between">
<span class="m-1 ml-0 flex items-center gap-1.5 whitespace-nowrap text-sm">
Top P
<HoverTooltip
align="right"
label="Top P: Sets word choice boundaries, lower values tighten focus."
>
<CarbonHelpFilled
class="inline text-xxs text-gray-500 group-hover/tooltip:text-blue-600"
/>
</HoverTooltip>
</span>

<input
type="number"
name="top_p"
class="w-20 rounded-lg border-2 border-gray-200 bg-gray-100 px-2 py-1"
min="0.05"
max="1"
step="0.05"
placeholder={selectedModel?.parameters?.top_p?.toString() ?? "1"}
value={assistant?.generateSettings?.top_p ?? ""}
/>
</label>
<label for="repetition_penalty" class="flex justify-between">
<span class="m-1 ml-0 flex items-center gap-1.5 whitespace-nowrap text-sm">
Repetition penalty
<HoverTooltip
label="Repetition penalty: Prevents reuse, higher values decrease repetition."
>
<CarbonHelpFilled
class="inline text-xxs text-gray-500 group-hover/tooltip:text-blue-600"
/>
</HoverTooltip>
</span>
<input
type="number"
name="repetition_penalty"
min="0.1"
max="2"
class="w-20 rounded-lg border-2 border-gray-200 bg-gray-100 px-2 py-1"
placeholder={selectedModel?.parameters?.repetition_penalty?.toString() ?? "1.0"}
value={assistant?.generateSettings?.repetition_penalty ?? ""}
/>
</label>
<label for="top_k" class="flex justify-between">
<span class="m-1 ml-0 flex items-center gap-1.5 whitespace-nowrap text-sm">
Top K <HoverTooltip
align="right"
label="Top K: Restricts word options, lower values for predictability."
>
<CarbonHelpFilled
class="inline text-xxs text-gray-500 group-hover/tooltip:text-blue-600"
/>
</HoverTooltip>
</span>
<input
type="number"
name="top_k"
min="5"
max="100"
step="5"
class="w-20 rounded-lg border-2 border-gray-200 bg-gray-100 px-2 py-1"
placeholder={selectedModel?.parameters?.top_k?.toString() ?? "50"}
value={assistant?.generateSettings?.top_k ?? ""}
/>
</label>
</div>
</div>
</label>

<label>
Expand Down
15 changes: 15 additions & 0 deletions src/lib/components/HoverTooltip.svelte
@@ -0,0 +1,15 @@
<script lang="ts">
export let label = "";
export let align: "left" | "right" = "left";
</script>

<div class="group/tooltip relative">
<slot />
<div
class="invisible absolute z-10 w-64 items-center whitespace-normal rounded-md bg-black p-2 text-center text-white group-hover/tooltip:visible group-active/tooltip:visible max-sm:top-5 md:top-5"
class:max-sm:left-0={align === "left"}
class:max-sm:right-0={align === "right"}
>
{label}
</div>
</div>
15 changes: 9 additions & 6 deletions src/lib/server/endpoints/anthropic/endpointAnthropic.ts
Expand Up @@ -32,7 +32,7 @@ export async function endpointAnthropic(
defaultQuery,
});

return async ({ messages, preprompt }) => {
return async ({ messages, preprompt, generateSettings }) => {
let system = preprompt;
if (messages?.[0]?.from === "system") {
system = messages[0].content;
Expand All @@ -49,15 +49,18 @@ export async function endpointAnthropic(
}[];

let tokenId = 0;

const parameters = { ...model.parameters, ...generateSettings };

return (async function* () {
const stream = anthropic.messages.stream({
model: model.id ?? model.name,
messages: messagesFormatted,
max_tokens: model.parameters?.max_new_tokens,
temperature: model.parameters?.temperature,
top_p: model.parameters?.top_p,
top_k: model.parameters?.top_k,
stop_sequences: model.parameters?.stop,
max_tokens: parameters?.max_new_tokens,
temperature: parameters?.temperature,
top_p: parameters?.top_p,
top_k: parameters?.top_k,
stop_sequences: parameters?.stop,
system,
});
while (true) {
Expand Down
4 changes: 2 additions & 2 deletions src/lib/server/endpoints/aws/endpointAws.ts
Expand Up @@ -36,7 +36,7 @@ export async function endpointAws(
region,
});

return async ({ messages, preprompt, continueMessage }) => {
return async ({ messages, preprompt, continueMessage, generateSettings }) => {
const prompt = await buildPrompt({
messages,
continueMessage,
Expand All @@ -46,7 +46,7 @@ export async function endpointAws(

return textGenerationStream(
{
parameters: { ...model.parameters, return_full_text: false },
parameters: { ...model.parameters, ...generateSettings, return_full_text: false },
model: url,
inputs: prompt,
},
Expand Down
2 changes: 2 additions & 0 deletions src/lib/server/endpoints/endpoints.ts
Expand Up @@ -10,12 +10,14 @@ import {
endpointAnthropic,
endpointAnthropicParametersSchema,
} from "./anthropic/endpointAnthropic";
import type { Model } from "$lib/types/Model";

// parameters passed when generating text
export interface EndpointParameters {
messages: Omit<Conversation["messages"][0], "id">[];
preprompt?: Conversation["preprompt"];
continueMessage?: boolean; // used to signal that the last message will be extended
generateSettings?: Partial<Model["parameters"]>;
}

interface CommonEndpoint {
Expand Down
16 changes: 9 additions & 7 deletions src/lib/server/endpoints/llamacpp/endpointLlamacpp.ts
Expand Up @@ -19,14 +19,16 @@ export function endpointLlamacpp(
input: z.input<typeof endpointLlamacppParametersSchema>
): Endpoint {
const { url, model } = endpointLlamacppParametersSchema.parse(input);
return async ({ messages, preprompt, continueMessage }) => {
return async ({ messages, preprompt, continueMessage, generateSettings }) => {
const prompt = await buildPrompt({
messages,
continueMessage,
preprompt,
model,
});

const parameters = { ...model.parameters, ...generateSettings };

const r = await fetch(`${url}/completion`, {
method: "POST",
headers: {
Expand All @@ -35,12 +37,12 @@ export function endpointLlamacpp(
body: JSON.stringify({
prompt,
stream: true,
temperature: model.parameters.temperature,
top_p: model.parameters.top_p,
top_k: model.parameters.top_k,
stop: model.parameters.stop,
repeat_penalty: model.parameters.repetition_penalty,
n_predict: model.parameters.max_new_tokens,
temperature: parameters.temperature,
top_p: parameters.top_p,
top_k: parameters.top_k,
stop: parameters.stop,
repeat_penalty: parameters.repetition_penalty,
n_predict: parameters.max_new_tokens,
cache_prompt: true,
}),
});
Expand Down
16 changes: 9 additions & 7 deletions src/lib/server/endpoints/ollama/endpointOllama.ts
Expand Up @@ -14,14 +14,16 @@ export const endpointOllamaParametersSchema = z.object({
export function endpointOllama(input: z.input<typeof endpointOllamaParametersSchema>): Endpoint {
const { url, model, ollamaName } = endpointOllamaParametersSchema.parse(input);

return async ({ messages, preprompt, continueMessage }) => {
return async ({ messages, preprompt, continueMessage, generateSettings }) => {
const prompt = await buildPrompt({
messages,
continueMessage,
preprompt,
model,
});

const parameters = { ...model.parameters, ...generateSettings };

const r = await fetch(`${url}/api/generate`, {
method: "POST",
headers: {
Expand All @@ -32,12 +34,12 @@ export function endpointOllama(input: z.input<typeof endpointOllamaParametersSch
model: ollamaName ?? model.name,
raw: true,
options: {
top_p: model.parameters.top_p,
top_k: model.parameters.top_k,
temperature: model.parameters.temperature,
repeat_penalty: model.parameters.repetition_penalty,
stop: model.parameters.stop,
num_predict: model.parameters.max_new_tokens,
top_p: parameters.top_p,
top_k: parameters.top_k,
temperature: parameters.temperature,
repeat_penalty: parameters.repetition_penalty,
stop: parameters.stop,
num_predict: parameters.max_new_tokens,
},
}),
});
Expand Down