Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Anthropic Endpoint Support #923

Merged
merged 1 commit into from Mar 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions .env
Expand Up @@ -9,6 +9,7 @@ COOKIE_NAME=hf-chat
HF_TOKEN=#hf_<token> from from https://huggingface.co/settings/token
HF_API_ROOT=https://api-inference.huggingface.co/models
OPENAI_API_KEY=#your openai api key here
ANTHROPIC_API_KEY=#your anthropic api key here

HF_ACCESS_TOKEN=#LEGACY! Use HF_TOKEN instead

Expand Down
45 changes: 45 additions & 0 deletions README.md
Expand Up @@ -459,6 +459,51 @@ MODELS=`[
]`
```

#### Anthropic

We also support Anthropic models through the official SDK. You may provide your API key via the `ANTHROPIC_API_KEY` env variable, or alternatively, through the `endpoints.apiKey` as per the following example.

```
MODELS=`[
{
"name": "claude-3-sonnet-20240229",
"displayName": "Claude 3 Sonnet",
"description": "Ideal balance of intelligence and speed",
"parameters": {
"max_new_tokens": 4096,
},
"endpoints": [
{
"type": "anthropic",
// optionals
"apiKey": "sk-ant-...",
"baseURL": "https://api.anthropic.com",
defaultHeaders: {},
defaultQuery: {}
}
]
},
{
"name": "claude-3-opus-20240229",
"displayName": "Claude 3 Opus",
"description": "Most powerful model for highly complex tasks",
"parameters": {
"max_new_tokens": 4096
},
"endpoints": [
{
"type": "anthropic",
// optionals
"apiKey": "sk-ant-...",
"baseURL": "https://api.anthropic.com",
defaultHeaders: {},
defaultQuery: {}
}
]
}
]`
```

#### Amazon

You can also specify your Amazon SageMaker instance as an endpoint for chat-ui. The config goes like this:
Expand Down
54 changes: 30 additions & 24 deletions package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions package.json
Expand Up @@ -82,6 +82,7 @@
"zod": "^3.22.3"
},
"optionalDependencies": {
"@anthropic-ai/sdk": "^0.17.1",
"aws4fetch": "^1.0.17",
"openai": "^4.14.2"
}
Expand Down
95 changes: 95 additions & 0 deletions src/lib/server/endpoints/anthropic/endpointAnthropic.ts
@@ -0,0 +1,95 @@
import { z } from "zod";
import { ANTHROPIC_API_KEY } from "$env/static/private";
import type { Endpoint } from "../endpoints";
import type { TextGenerationStreamOutput } from "@huggingface/inference";

export const endpointAnthropicParametersSchema = z.object({
weight: z.number().int().positive().default(1),
model: z.any(),
type: z.literal("anthropic"),
baseURL: z.string().url().default("https://api.anthropic.com"),
apiKey: z.string().default(ANTHROPIC_API_KEY ?? "sk-"),
defaultHeaders: z.record(z.string()).optional(),
defaultQuery: z.record(z.string()).optional(),
});

export async function endpointAnthropic(
input: z.input<typeof endpointAnthropicParametersSchema>
): Promise<Endpoint> {
const { baseURL, apiKey, model, defaultHeaders, defaultQuery } =
endpointAnthropicParametersSchema.parse(input);
let Anthropic;
try {
Anthropic = (await import("@anthropic-ai/sdk")).default;
} catch (e) {
throw new Error("Failed to import @anthropic-ai/sdk", { cause: e });
}

const anthropic = new Anthropic({
apiKey,
baseURL,
defaultHeaders,
defaultQuery,
});

return async ({ messages, preprompt }) => {
let system = preprompt;
if (messages?.[0]?.from === "system") {
system = messages[0].content;
}

const messagesFormatted = messages
.filter((message) => message.from !== "system")
.map((message) => ({
role: message.from,
content: message.content,
})) as unknown as {
role: "user" | "assistant";
content: string;
}[];

let tokenId = 0;
return (async function* () {
const stream = anthropic.messages.stream({
model: model.id ?? model.name,
messages: messagesFormatted,
max_tokens: model.parameters?.max_new_tokens,
temperature: model.parameters?.temperature,
top_p: model.parameters?.top_p,
top_k: model.parameters?.top_k,
stop_sequences: model.parameters?.stop,
system,
});
while (true) {
const result = await Promise.race([stream.emitted("text"), stream.emitted("end")]);

// Stream end
if (result === undefined) {
yield {
token: {
id: tokenId++,
text: "",
logprob: 0,
special: true,
},
generated_text: await stream.finalText(),
details: null,
} satisfies TextGenerationStreamOutput;
return;
}

// Text delta
yield {
token: {
id: tokenId++,
text: result as unknown as string,
special: false,
logprob: 0,
},
generated_text: null,
details: null,
} satisfies TextGenerationStreamOutput;
}
})();
};
}
6 changes: 6 additions & 0 deletions src/lib/server/endpoints/endpoints.ts
Expand Up @@ -6,6 +6,10 @@ import endpointAws, { endpointAwsParametersSchema } from "./aws/endpointAws";
import { endpointOAIParametersSchema, endpointOai } from "./openai/endpointOai";
import endpointLlamacpp, { endpointLlamacppParametersSchema } from "./llamacpp/endpointLlamacpp";
import endpointOllama, { endpointOllamaParametersSchema } from "./ollama/endpointOllama";
import {
endpointAnthropic,
endpointAnthropicParametersSchema,
} from "./anthropic/endpointAnthropic";

// parameters passed when generating text
export interface EndpointParameters {
Expand All @@ -28,13 +32,15 @@ export type EndpointGenerator<T extends CommonEndpoint> = (parameters: T) => End
// list of all endpoint generators
export const endpoints = {
tgi: endpointTgi,
anthropic: endpointAnthropic,
aws: endpointAws,
openai: endpointOai,
llamacpp: endpointLlamacpp,
ollama: endpointOllama,
};

export const endpointSchema = z.discriminatedUnion("type", [
endpointAnthropicParametersSchema,
endpointAwsParametersSchema,
endpointOAIParametersSchema,
endpointTgiParametersSchema,
Expand Down
2 changes: 2 additions & 0 deletions src/lib/server/models.ts
Expand Up @@ -109,6 +109,8 @@ const addEndpoint = (m: Awaited<ReturnType<typeof processModel>>) => ({
switch (args.type) {
case "tgi":
return endpoints.tgi(args);
case "anthropic":
return endpoints.anthropic(args);
case "aws":
return await endpoints.aws(args);
case "openai":
Expand Down