Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

♻️ refactor: Fix GitHub model fetch #4645

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 19 additions & 10 deletions src/config/modelProviders/github.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ const Github: ModelProviderCard = {
vision: true,
},
{
description: '专注于高级推理和解决复杂问题,包括数学和科学任务。非常适合需要深度上下文理解和自主工作流程的应用。',
description:
'专注于高级推理和解决复杂问题,包括数学和科学任务。非常适合需要深度上下文理解和自主工作流程的应用。',
displayName: 'OpenAI o1-preview',
enabled: true,
functionCall: false,
Expand Down Expand Up @@ -45,23 +46,26 @@ const Github: ModelProviderCard = {
vision: true,
},
{
description: '一个52B参数(12B活跃)的多语言模型,提供256K长上下文窗口、函数调用、结构化输出和基于事实的生成。',
description:
'一个52B参数(12B活跃)的多语言模型,提供256K长上下文窗口、函数调用、结构化输出和基于事实的生成。',
displayName: 'AI21 Jamba 1.5 Mini',
functionCall: true,
id: 'ai21-jamba-1.5-mini',
maxOutput: 4096,
tokens: 262_144,
},
{
description: '一个398B参数(94B活跃)的多语言模型,提供256K长上下文窗口、函数调用、结构化输出和基于事实的生成。',
description:
'一个398B参数(94B活跃)的多语言模型,提供256K长上下文窗口、函数调用、结构化输出和基于事实的生成。',
displayName: 'AI21 Jamba 1.5 Large',
functionCall: true,
id: 'ai21-jamba-1.5-large',
maxOutput: 4096,
tokens: 262_144,
},
{
description: 'Command R是一个可扩展的生成模型,旨在针对RAG和工具使用,使企业能够实现生产级AI。',
description:
'Command R是一个可扩展的生成模型,旨在针对RAG和工具使用,使企业能够实现生产级AI。',
displayName: 'Cohere Command R',
id: 'cohere-command-r',
maxOutput: 4096,
Expand All @@ -75,7 +79,8 @@ const Github: ModelProviderCard = {
tokens: 131_072,
},
{
description: 'Mistral Nemo是一种尖端的语言模型(LLM),在其尺寸类别中拥有最先进的推理、世界知识和编码能力。',
description:
'Mistral Nemo是一种尖端的语言模型(LLM),在其尺寸类别中拥有最先进的推理、世界知识和编码能力。',
displayName: 'Mistral Nemo',
id: 'mistral-nemo',
maxOutput: 4096,
Expand All @@ -89,7 +94,8 @@ const Github: ModelProviderCard = {
tokens: 131_072,
},
{
description: 'Mistral的旗舰模型,适合需要大规模推理能力或高度专业化的复杂任务(合成文本生成、代码生成、RAG或代理)。',
description:
'Mistral的旗舰模型,适合需要大规模推理能力或高度专业化的复杂任务(合成文本生成、代码生成、RAG或代理)。',
displayName: 'Mistral Large',
id: 'mistral-large',
maxOutput: 4096,
Expand All @@ -112,21 +118,24 @@ const Github: ModelProviderCard = {
vision: true,
},
{
description: 'Llama 3.1指令调优的文本模型,针对多语言对话用例进行了优化,在许多可用的开源和封闭聊天模型中,在常见行业基准上表现优异。',
description:
'Llama 3.1指令调优的文本模型,针对多语言对话用例进行了优化,在许多可用的开源和封闭聊天模型中,在常见行业基准上表现优异。',
displayName: 'Meta Llama 3.1 8B',
id: 'meta-llama-3.1-8b-instruct',
maxOutput: 4096,
tokens: 131_072,
},
{
description: 'Llama 3.1指令调优的文本模型,针对多语言对话用例进行了优化,在许多可用的开源和封闭聊天模型中,在常见行业基准上表现优异。',
description:
'Llama 3.1指令调优的文本模型,针对多语言对话用例进行了优化,在许多可用的开源和封闭聊天模型中,在常见行业基准上表现优异。',
displayName: 'Meta Llama 3.1 70B',
id: 'meta-llama-3.1-70b-instruct',
maxOutput: 4096,
tokens: 131_072,
},
{
description: 'Llama 3.1指令调优的文本模型,针对多语言对话用例进行了优化,在许多可用的开源和封闭聊天模型中,在常见行业基准上表现优异。',
description:
'Llama 3.1指令调优的文本模型,针对多语言对话用例进行了优化,在许多可用的开源和封闭聊天模型中,在常见行业基准上表现优异。',
displayName: 'Meta Llama 3.1 405B',
id: 'meta-llama-3.1-405b-instruct',
maxOutput: 4096,
Expand Down Expand Up @@ -209,7 +218,7 @@ const Github: ModelProviderCard = {
description: '通过GitHub模型,开发人员可以成为AI工程师,并使用行业领先的AI模型进行构建。',
enabled: true,
id: 'github',
// modelList: { showModelFetcher: true },
modelList: { showModelFetcher: true }, // I'm not sure if it is good to show the model fetcher, as remote list is not complete.
name: 'GitHub',
url: 'https://github.com/marketplace/models',
};
Expand Down
35 changes: 0 additions & 35 deletions src/libs/agent-runtime/github/index.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -119,41 +119,6 @@ describe('LobeGithubAI', () => {
}
});

it('should return GithubBizError with an cause response with desensitize Url', async () => {
// Arrange
const errorInfo = {
stack: 'abc',
cause: { message: 'api is undefined' },
};
const apiError = new OpenAI.APIError(400, errorInfo, 'module error', {});

instance = new LobeGithubAI({
apiKey: 'test',
baseURL: 'https://api.abc.com/v1',
});

vi.spyOn(instance['client'].chat.completions, 'create').mockRejectedValue(apiError);

// Act
try {
await instance.chat({
messages: [{ content: 'Hello', role: 'user' }],
model: 'meta-llama-3-70b-instruct',
temperature: 0.7,
});
} catch (e) {
expect(e).toEqual({
endpoint: 'https://api.***.com/v1',
error: {
cause: { message: 'api is undefined' },
stack: 'abc',
},
errorType: bizErrorType,
provider,
});
}
});

it('should throw an InvalidGithubToken error type on 401 status code', async () => {
// Mock the API call to simulate a 401 error
const error = new Error('InvalidApiKey') as any;
Expand Down
52 changes: 51 additions & 1 deletion src/libs/agent-runtime/github/index.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,35 @@
import { LOBE_DEFAULT_MODEL_LIST } from '@/config/modelProviders';
import type { ChatModelCard } from '@/types/llm';

import { AgentRuntimeErrorType } from '../error';
import { o1Models, pruneO1Payload } from '../openai';
import { ModelProvider } from '../types';
import { LobeOpenAICompatibleFactory } from '../utils/openaiCompatibleFactory';
import {
CHAT_MODELS_BLOCK_LIST,
LobeOpenAICompatibleFactory,
} from '../utils/openaiCompatibleFactory';

enum Task {
'chat-completion',
'embeddings',
}

/* eslint-disable typescript-sort-keys/interface */
type Model = {
id: string;
name: string;
friendly_name: string;
model_version: number;
publisher: string;
model_family: string;
model_registry: string;
license: string;
task: Task;
description: string;
summary: string;
tags: string[];
};
/* eslint-enable typescript-sort-keys/interface */

export const LobeGithubAI = LobeOpenAICompatibleFactory({
baseURL: 'https://models.inference.ai.azure.com',
Expand All @@ -23,5 +51,27 @@ export const LobeGithubAI = LobeOpenAICompatibleFactory({
bizError: AgentRuntimeErrorType.ProviderBizError,
invalidAPIKey: AgentRuntimeErrorType.InvalidGithubToken,
},
models: async ({ client }) => {
const modelsPage = (await client.models.list()) as any;
const modelList: Model[] = modelsPage.body;
return modelList
.filter((model) => {
return CHAT_MODELS_BLOCK_LIST.every(
(keyword) => !model.name.toLowerCase().includes(keyword),
);
})
.map((model) => {
const knownModel = LOBE_DEFAULT_MODEL_LIST.find((m) => m.id === model.name);

if (knownModel) return knownModel;

return {
description: model.description,
displayName: model.friendly_name,
id: model.name,
};
})
.filter(Boolean) as ChatModelCard[];
},
provider: ModelProvider.Github,
});
3 changes: 2 additions & 1 deletion src/libs/agent-runtime/togetherai/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ export const LobeTogetherAI = LobeOpenAICompatibleFactory({
debug: {
chatCompletion: () => process.env.DEBUG_TOGETHERAI_CHAT_COMPLETION === '1',
},
models: async ({ apiKey }) => {
models: async ({ client }) => {
const apiKey = client.apiKey;
const data = await fetch(`${baseURL}/api/models`, {
headers: {
Authorization: `Bearer ${apiKey}`,
Expand Down
21 changes: 11 additions & 10 deletions src/libs/agent-runtime/utils/openaiCompatibleFactory/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,18 @@ import OpenAI, { ClientOptions } from 'openai';
import { Stream } from 'openai/streaming';

import { LOBE_DEFAULT_MODEL_LIST } from '@/config/modelProviders';
import { ChatModelCard } from '@/types/llm';
import type { ChatModelCard } from '@/types/llm';

import { LobeRuntimeAI } from '../../BaseAI';
import { AgentRuntimeErrorType, ILobeAgentRuntimeErrorType } from '../../error';
import {
import type {
ChatCompetitionOptions,
ChatCompletionErrorPayload,
ChatStreamPayload,
Embeddings,
EmbeddingsOptions,
EmbeddingsPayload,
ModelProvider,
TextToImagePayload,
TextToSpeechOptions,
TextToSpeechPayload,
Expand All @@ -26,7 +27,7 @@ import { StreamingResponse } from '../response';
import { OpenAIStream, OpenAIStreamOptions } from '../streams';

// the model contains the following keywords is not a chat model, so we should filter them out
const CHAT_MODELS_BLOCK_LIST = [
export const CHAT_MODELS_BLOCK_LIST = [
'embedding',
'davinci',
'curie',
Expand Down Expand Up @@ -77,7 +78,7 @@ interface OpenAICompatibleFactoryOptions<T extends Record<string, any> = any> {
invalidAPIKey: ILobeAgentRuntimeErrorType;
};
models?:
| ((params: { apiKey: string }) => Promise<ChatModelCard[]>)
| ((params: { client: OpenAI }) => Promise<ChatModelCard[]>)
| {
transformModel?: (model: OpenAI.Model) => ChatModelCard;
};
Expand Down Expand Up @@ -157,7 +158,7 @@ export const LobeOpenAICompatibleFactory = <T extends Record<string, any> = any>
client!: OpenAI;

baseURL!: string;
private _options: ConstructorOptions<T>;
protected _options: ConstructorOptions<T>;

constructor(options: ClientOptions & Record<string, any> = {}) {
const _options = {
Expand Down Expand Up @@ -249,7 +250,7 @@ export const LobeOpenAICompatibleFactory = <T extends Record<string, any> = any>
}

async models() {
if (typeof models === 'function') return models({ apiKey: this.client.apiKey });
if (typeof models === 'function') return models({ client: this.client });

const list = await this.client.models.list();

Expand Down Expand Up @@ -312,7 +313,7 @@ export const LobeOpenAICompatibleFactory = <T extends Record<string, any> = any>
}
}

private handleError(error: any): ChatCompletionErrorPayload {
protected handleError(error: any): ChatCompletionErrorPayload {
let desensitizedEndpoint = this.baseURL;

// refs: https://github.com/lobehub/lobe-chat/issues/842
Expand All @@ -326,7 +327,7 @@ export const LobeOpenAICompatibleFactory = <T extends Record<string, any> = any>
if (errorResult)
return AgentRuntimeError.chat({
...errorResult,
provider,
BrandonStudio marked this conversation as resolved.
Show resolved Hide resolved
provider: provider,
} as ChatCompletionErrorPayload);
}

Expand All @@ -337,7 +338,7 @@ export const LobeOpenAICompatibleFactory = <T extends Record<string, any> = any>
endpoint: desensitizedEndpoint,
error: error as any,
errorType: ErrorType.invalidAPIKey,
provider: provider as any,
provider: provider as ModelProvider,
});
}

Expand All @@ -353,7 +354,7 @@ export const LobeOpenAICompatibleFactory = <T extends Record<string, any> = any>
endpoint: desensitizedEndpoint,
error: errorResult,
errorType: RuntimeError || ErrorType.bizError,
provider: provider as any,
provider: provider as ModelProvider,
});
}
};
Expand Down
Loading