Skip to content

Commit 222b186

Browse files
cheahjsnsarrazin
andauthored
feat: add support for anthropic on vertex (huggingface#958)
Co-authored-by: Nathan Sarrazin <sarrazin.nathan@gmail.com>
1 parent 16524b2 commit 222b186

File tree

6 files changed

+188
-4
lines changed

6 files changed

+188
-4
lines changed

README.md

Lines changed: 47 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -480,8 +480,8 @@ MODELS=`[
480480
// optionals
481481
"apiKey": "sk-ant-...",
482482
"baseURL": "https://api.anthropic.com",
483-
defaultHeaders: {},
484-
defaultQuery: {}
483+
"defaultHeaders": {},
484+
"defaultQuery": {}
485485
}
486486
]
487487
},
@@ -498,8 +498,51 @@ MODELS=`[
498498
// optionals
499499
"apiKey": "sk-ant-...",
500500
"baseURL": "https://api.anthropic.com",
501-
defaultHeaders: {},
502-
defaultQuery: {}
501+
"defaultHeaders": {},
502+
"defaultQuery": {}
503+
}
504+
]
505+
}
506+
]`
507+
```
508+
509+
We also support using Anthropic models running on Vertex AI. Authentication is done using Google Application Default Credentials. Project ID can be provided through the `endpoints.projectId` as per the following example:
510+
511+
```
512+
MODELS=`[
513+
{
514+
"name": "claude-3-sonnet@20240229",
515+
"displayName": "Claude 3 Sonnet",
516+
"description": "Ideal balance of intelligence and speed",
517+
"parameters": {
518+
"max_new_tokens": 4096,
519+
},
520+
"endpoints": [
521+
{
522+
"type": "anthropic-vertex",
523+
"region": "us-central1",
524+
"projectId": "gcp-project-id",
525+
// optionals
526+
"defaultHeaders": {},
527+
"defaultQuery": {}
528+
}
529+
]
530+
},
531+
{
532+
"name": "claude-3-haiku@20240307",
533+
"displayName": "Claude 3 Haiku",
534+
"description": "Fastest, most compact model for near-instant responsiveness",
535+
"parameters": {
536+
"max_new_tokens": 4096
537+
},
538+
"endpoints": [
539+
{
540+
"type": "anthropic-vertex",
541+
"region": "us-central1",
542+
"projectId": "gcp-project-id",
543+
// optionals
544+
"defaultHeaders": {},
545+
"defaultQuery": {}
503546
}
504547
]
505548
}

package-lock.json

Lines changed: 37 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

package.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@
8989
},
9090
"optionalDependencies": {
9191
"@anthropic-ai/sdk": "^0.17.1",
92+
"@anthropic-ai/vertex-sdk": "^0.3.0",
9293
"@google-cloud/vertexai": "^1.1.0",
9394
"aws4fetch": "^1.0.17",
9495
"cohere-ai": "^7.9.0",
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
import { z } from "zod";
2+
import type { Endpoint } from "../endpoints";
3+
import type { TextGenerationStreamOutput } from "@huggingface/inference";
4+
5+
export const endpointAnthropicVertexParametersSchema = z.object({
6+
weight: z.number().int().positive().default(1),
7+
model: z.any(),
8+
type: z.literal("anthropic-vertex"),
9+
region: z.string().default("us-central1"),
10+
projectId: z.string(),
11+
defaultHeaders: z.record(z.string()).optional(),
12+
defaultQuery: z.record(z.string()).optional(),
13+
});
14+
15+
export async function endpointAnthropicVertex(
16+
input: z.input<typeof endpointAnthropicVertexParametersSchema>
17+
): Promise<Endpoint> {
18+
const { region, projectId, model, defaultHeaders, defaultQuery } =
19+
endpointAnthropicVertexParametersSchema.parse(input);
20+
let AnthropicVertex;
21+
try {
22+
AnthropicVertex = (await import("@anthropic-ai/vertex-sdk")).AnthropicVertex;
23+
} catch (e) {
24+
throw new Error("Failed to import @anthropic-ai/vertex-sdk", { cause: e });
25+
}
26+
27+
const anthropic = new AnthropicVertex({
28+
baseURL: `https://${region}-aiplatform.googleapis.com/v1`,
29+
region,
30+
projectId,
31+
defaultHeaders,
32+
defaultQuery,
33+
});
34+
35+
return async ({ messages, preprompt }) => {
36+
let system = preprompt;
37+
if (messages?.[0]?.from === "system") {
38+
system = messages[0].content;
39+
}
40+
41+
const messagesFormatted = messages
42+
.filter((message) => message.from !== "system")
43+
.map((message) => ({
44+
role: message.from,
45+
content: message.content,
46+
})) as unknown as {
47+
role: "user" | "assistant";
48+
content: string;
49+
}[];
50+
51+
let tokenId = 0;
52+
return (async function* () {
53+
const stream = anthropic.messages.stream({
54+
model: model.id ?? model.name,
55+
messages: messagesFormatted,
56+
max_tokens: model.parameters?.max_new_tokens,
57+
temperature: model.parameters?.temperature,
58+
top_p: model.parameters?.top_p,
59+
top_k: model.parameters?.top_k,
60+
stop_sequences: model.parameters?.stop,
61+
system,
62+
});
63+
while (true) {
64+
const result = await Promise.race([stream.emitted("text"), stream.emitted("end")]);
65+
66+
// Stream end
67+
if (result === undefined) {
68+
yield {
69+
token: {
70+
id: tokenId++,
71+
text: "",
72+
logprob: 0,
73+
special: true,
74+
},
75+
generated_text: await stream.finalText(),
76+
details: null,
77+
} satisfies TextGenerationStreamOutput;
78+
return;
79+
}
80+
81+
// Text delta
82+
yield {
83+
token: {
84+
id: tokenId++,
85+
text: result as unknown as string,
86+
special: false,
87+
logprob: 0,
88+
},
89+
generated_text: null,
90+
details: null,
91+
} satisfies TextGenerationStreamOutput;
92+
}
93+
})();
94+
};
95+
}

src/lib/server/endpoints/endpoints.ts

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@ import {
1212
endpointAnthropic,
1313
endpointAnthropicParametersSchema,
1414
} from "./anthropic/endpointAnthropic";
15+
import {
16+
endpointAnthropicVertex,
17+
endpointAnthropicVertexParametersSchema,
18+
} from "./anthropic/endpointAnthropicVertex";
1519
import type { Model } from "$lib/types/Model";
1620
import endpointCloudflare, {
1721
endpointCloudflareParametersSchema,
@@ -44,6 +48,7 @@ export type EndpointGenerator<T extends CommonEndpoint> = (parameters: T) => End
4448
export const endpoints = {
4549
tgi: endpointTgi,
4650
anthropic: endpointAnthropic,
51+
anthropicvertex: endpointAnthropicVertex,
4752
aws: endpointAws,
4853
openai: endpointOai,
4954
llamacpp: endpointLlamacpp,
@@ -56,6 +61,7 @@ export const endpoints = {
5661

5762
export const endpointSchema = z.discriminatedUnion("type", [
5863
endpointAnthropicParametersSchema,
64+
endpointAnthropicVertexParametersSchema,
5965
endpointAwsParametersSchema,
6066
endpointOAIParametersSchema,
6167
endpointTgiParametersSchema,

src/lib/server/models.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,8 @@ const addEndpoint = (m: Awaited<ReturnType<typeof processModel>>) => ({
159159
return endpoints.tgi(args);
160160
case "anthropic":
161161
return endpoints.anthropic(args);
162+
case "anthropic-vertex":
163+
return endpoints.anthropicvertex(args);
162164
case "aws":
163165
return await endpoints.aws(args);
164166
case "openai":

0 commit comments

Comments
 (0)