Skip to content

Commit

Permalink
Expose sampling controls in assistants (#955) (#959)
Browse files Browse the repository at this point in the history
* Expose sampling controls in assistants (#955)

* Make sure all labels have the same font size

* styling

* Add better tooltips

* better padding & wrapping

* Revert "better padding & wrapping"

This reverts commit 1b44086.

* ui update

* tooltip on mobile

* lint

* Update src/lib/components/AssistantSettings.svelte

Co-authored-by: Mishig <[email protected]>

---------

Co-authored-by: Victor Mustar <[email protected]>
Co-authored-by: Mishig <[email protected]>
  • Loading branch information
3 people committed Mar 27, 2024
1 parent e9ad67e commit d4016bc
Show file tree
Hide file tree
Showing 13 changed files with 245 additions and 56 deletions.
145 changes: 128 additions & 17 deletions src/lib/components/AssistantSettings.svelte
Expand Up @@ -9,11 +9,14 @@
import { base } from "$app/paths";
import CarbonPen from "~icons/carbon/pen";
import CarbonUpload from "~icons/carbon/upload";
import CarbonHelpFilled from "~icons/carbon/help";
import CarbonSettingsAdjust from "~icons/carbon/settings-adjust";
import { useSettingsStore } from "$lib/stores/settings";
import { isHuggingChat } from "$lib/utils/isHuggingChat";
import IconInternet from "./icons/IconInternet.svelte";
import TokensCounter from "./TokensCounter.svelte";
import HoverTooltip from "./HoverTooltip.svelte";
type ActionData = {
error: boolean;
Expand All @@ -31,16 +34,22 @@
let files: FileList | null = null;
const settings = useSettingsStore();
let modelId =
assistant?.modelId ?? models.find((_model) => _model.id === $settings.activeModel)?.name;
let modelId = "";
let systemPrompt = assistant?.preprompt ?? "";
let dynamicPrompt = assistant?.dynamicPrompt ?? false;
let showModelSettings = Object.values(assistant?.generateSettings ?? {}).some((v) => !!v);
let compress: typeof readAndCompressImage | null = null;
onMount(async () => {
const module = await import("browser-image-resizer");
compress = module.readAndCompressImage;
if (assistant) {
modelId = assistant.modelId;
} else {
modelId = models.find((model) => model.id === $settings.activeModel)?.id ?? models[0].id;
}
});
let inputMessage1 = assistant?.exampleInputs[0] ?? "";
Expand Down Expand Up @@ -89,11 +98,12 @@
const regex = /{{\s?url=(.+?)\s?}}/g;
$: templateVariables = [...systemPrompt.matchAll(regex)].map((match) => match[1]);
$: selectedModel = models.find((m) => m.id === modelId);
</script>

<form
method="POST"
class="flex h-full flex-col overflow-y-auto p-4 md:p-8"
class="relative flex h-full flex-col overflow-y-auto p-4 md:p-8"
enctype="multipart/form-data"
use:enhance={async ({ formData }) => {
loading = true;
Expand Down Expand Up @@ -246,21 +256,122 @@

<label>
<div class="mb-1 font-semibold">Model</div>
<select
name="modelId"
class="w-full rounded-lg border-2 border-gray-200 bg-gray-100 p-2"
bind:value={modelId}
<div class="flex gap-2">
<select
name="modelId"
class="w-full rounded-lg border-2 border-gray-200 bg-gray-100 p-2"
bind:value={modelId}
>
{#each models.filter((model) => !model.unlisted) as model}
<option value={model.id}>{model.displayName}</option>
{/each}
<p class="text-xs text-red-500">{getError("modelId", form)}</p>
</select>
<button
type="button"
class="flex aspect-square items-center gap-2 whitespace-nowrap rounded-lg border px-3 {showModelSettings
? 'border-blue-500/20 bg-blue-50 text-blue-600'
: ''}"
on:click={() => (showModelSettings = !showModelSettings)}
><CarbonSettingsAdjust class="text-xs" /></button
>
</div>
<div
class="mt-2 rounded-lg border border-blue-500/20 bg-blue-500/5 px-2 py-0.5"
class:hidden={!showModelSettings}
>
{#each models.filter((model) => !model.unlisted) as model}
<option
value={model.id}
selected={assistant
? assistant?.modelId === model.id
: $settings.activeModel === model.id}>{model.displayName}</option
>
{/each}
<p class="text-xs text-red-500">{getError("modelId", form)}</p>
</select>
<p class="text-xs text-red-500">{getError("inputMessage1", form)}</p>
<div class="my-2 grid grid-cols-1 gap-2.5 sm:grid-cols-2 sm:grid-rows-2">
<label for="temperature" class="flex justify-between">
<span class="m-1 ml-0 flex items-center gap-1.5 whitespace-nowrap text-sm">
Temperature

<HoverTooltip
label="Temperature: Controls creativity, higher values allow more variety."
>
<CarbonHelpFilled
class="inline text-xxs text-gray-500 group-hover/tooltip:text-blue-600"
/>
</HoverTooltip>
</span>
<input
type="number"
name="temperature"
min="0.1"
max="2"
step="0.1"
class="w-20 rounded-lg border-2 border-gray-200 bg-gray-100 px-2 py-1"
placeholder={selectedModel?.parameters?.temperature?.toString() ?? "1"}
value={assistant?.generateSettings?.temperature ?? ""}
/>
</label>
<label for="top_p" class="flex justify-between">
<span class="m-1 ml-0 flex items-center gap-1.5 whitespace-nowrap text-sm">
Top P
<HoverTooltip
label="Top P: Sets word choice boundaries, lower values tighten focus."
>
<CarbonHelpFilled
class="inline text-xxs text-gray-500 group-hover/tooltip:text-blue-600"
/>
</HoverTooltip>
</span>

<input
type="number"
name="top_p"
class="w-20 rounded-lg border-2 border-gray-200 bg-gray-100 px-2 py-1"
min="0.05"
max="1"
step="0.05"
placeholder={selectedModel?.parameters?.top_p?.toString() ?? "1"}
value={assistant?.generateSettings?.top_p ?? ""}
/>
</label>
<label for="repetition_penalty" class="flex justify-between">
<span class="m-1 ml-0 flex items-center gap-1.5 whitespace-nowrap text-sm">
Repetition penalty
<HoverTooltip
label="Repetition penalty: Prevents reuse, higher values decrease repetition."
>
<CarbonHelpFilled
class="inline text-xxs text-gray-500 group-hover/tooltip:text-blue-600"
/>
</HoverTooltip>
</span>
<input
type="number"
name="repetition_penalty"
min="0.1"
max="2"
class="w-20 rounded-lg border-2 border-gray-200 bg-gray-100 px-2 py-1"
placeholder={selectedModel?.parameters?.repetition_penalty?.toString() ?? "1.0"}
value={assistant?.generateSettings?.repetition_penalty ?? ""}
/>
</label>
<label for="top_k" class="flex justify-between">
<span class="m-1 ml-0 flex items-center gap-1.5 whitespace-nowrap text-sm">
Top K <HoverTooltip
label="Top K: Restricts word options, lower values for predictability."
>
<CarbonHelpFilled
class="inline text-xxs text-gray-500 group-hover/tooltip:text-blue-600"
/>
</HoverTooltip>
</span>
<input
type="number"
name="top_k"
min="5"
max="100"
step="5"
class="w-20 rounded-lg border-2 border-gray-200 bg-gray-100 px-2 py-1"
placeholder={selectedModel?.parameters?.top_k?.toString() ?? "50"}
value={assistant?.generateSettings?.top_k ?? ""}
/>
</label>
</div>
</div>
</label>

<label>
Expand Down
12 changes: 12 additions & 0 deletions src/lib/components/HoverTooltip.svelte
@@ -0,0 +1,12 @@
<script lang="ts">
export let label = "";
</script>

<div class="group/tooltip md:relative">
<slot />
<div
class="invisible absolute z-10 w-64 whitespace-normal rounded-md bg-black p-2 text-center text-white group-hover/tooltip:visible group-active/tooltip:visible max-sm:left-1/2 max-sm:-translate-x-1/2"
>
{label}
</div>
</div>
15 changes: 9 additions & 6 deletions src/lib/server/endpoints/anthropic/endpointAnthropic.ts
Expand Up @@ -32,7 +32,7 @@ export async function endpointAnthropic(
defaultQuery,
});

return async ({ messages, preprompt }) => {
return async ({ messages, preprompt, generateSettings }) => {
let system = preprompt;
if (messages?.[0]?.from === "system") {
system = messages[0].content;
Expand All @@ -49,15 +49,18 @@ export async function endpointAnthropic(
}[];

let tokenId = 0;

const parameters = { ...model.parameters, ...generateSettings };

return (async function* () {
const stream = anthropic.messages.stream({
model: model.id ?? model.name,
messages: messagesFormatted,
max_tokens: model.parameters?.max_new_tokens,
temperature: model.parameters?.temperature,
top_p: model.parameters?.top_p,
top_k: model.parameters?.top_k,
stop_sequences: model.parameters?.stop,
max_tokens: parameters?.max_new_tokens,
temperature: parameters?.temperature,
top_p: parameters?.top_p,
top_k: parameters?.top_k,
stop_sequences: parameters?.stop,
system,
});
while (true) {
Expand Down
4 changes: 2 additions & 2 deletions src/lib/server/endpoints/aws/endpointAws.ts
Expand Up @@ -36,7 +36,7 @@ export async function endpointAws(
region,
});

return async ({ messages, preprompt, continueMessage }) => {
return async ({ messages, preprompt, continueMessage, generateSettings }) => {
const prompt = await buildPrompt({
messages,
continueMessage,
Expand All @@ -46,7 +46,7 @@ export async function endpointAws(

return textGenerationStream(
{
parameters: { ...model.parameters, return_full_text: false },
parameters: { ...model.parameters, ...generateSettings, return_full_text: false },
model: url,
inputs: prompt,
},
Expand Down
2 changes: 2 additions & 0 deletions src/lib/server/endpoints/endpoints.ts
Expand Up @@ -10,12 +10,14 @@ import {
endpointAnthropic,
endpointAnthropicParametersSchema,
} from "./anthropic/endpointAnthropic";
import type { Model } from "$lib/types/Model";

// parameters passed when generating text
export interface EndpointParameters {
messages: Omit<Conversation["messages"][0], "id">[];
preprompt?: Conversation["preprompt"];
continueMessage?: boolean; // used to signal that the last message will be extended
generateSettings?: Partial<Model["parameters"]>;
}

interface CommonEndpoint {
Expand Down
16 changes: 9 additions & 7 deletions src/lib/server/endpoints/llamacpp/endpointLlamacpp.ts
Expand Up @@ -19,14 +19,16 @@ export function endpointLlamacpp(
input: z.input<typeof endpointLlamacppParametersSchema>
): Endpoint {
const { url, model } = endpointLlamacppParametersSchema.parse(input);
return async ({ messages, preprompt, continueMessage }) => {
return async ({ messages, preprompt, continueMessage, generateSettings }) => {
const prompt = await buildPrompt({
messages,
continueMessage,
preprompt,
model,
});

const parameters = { ...model.parameters, ...generateSettings };

const r = await fetch(`${url}/completion`, {
method: "POST",
headers: {
Expand All @@ -35,12 +37,12 @@ export function endpointLlamacpp(
body: JSON.stringify({
prompt,
stream: true,
temperature: model.parameters.temperature,
top_p: model.parameters.top_p,
top_k: model.parameters.top_k,
stop: model.parameters.stop,
repeat_penalty: model.parameters.repetition_penalty,
n_predict: model.parameters.max_new_tokens,
temperature: parameters.temperature,
top_p: parameters.top_p,
top_k: parameters.top_k,
stop: parameters.stop,
repeat_penalty: parameters.repetition_penalty,
n_predict: parameters.max_new_tokens,
cache_prompt: true,
}),
});
Expand Down
16 changes: 9 additions & 7 deletions src/lib/server/endpoints/ollama/endpointOllama.ts
Expand Up @@ -14,14 +14,16 @@ export const endpointOllamaParametersSchema = z.object({
export function endpointOllama(input: z.input<typeof endpointOllamaParametersSchema>): Endpoint {
const { url, model, ollamaName } = endpointOllamaParametersSchema.parse(input);

return async ({ messages, preprompt, continueMessage }) => {
return async ({ messages, preprompt, continueMessage, generateSettings }) => {
const prompt = await buildPrompt({
messages,
continueMessage,
preprompt,
model,
});

const parameters = { ...model.parameters, ...generateSettings };

const r = await fetch(`${url}/api/generate`, {
method: "POST",
headers: {
Expand All @@ -32,12 +34,12 @@ export function endpointOllama(input: z.input<typeof endpointOllamaParametersSch
model: ollamaName ?? model.name,
raw: true,
options: {
top_p: model.parameters.top_p,
top_k: model.parameters.top_k,
temperature: model.parameters.temperature,
repeat_penalty: model.parameters.repetition_penalty,
stop: model.parameters.stop,
num_predict: model.parameters.max_new_tokens,
top_p: parameters.top_p,
top_k: parameters.top_k,
temperature: parameters.temperature,
repeat_penalty: parameters.repetition_penalty,
stop: parameters.stop,
num_predict: parameters.max_new_tokens,
},
}),
});
Expand Down

0 comments on commit d4016bc

Please sign in to comment.