Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

♻️ refactor: Fix GitHub model fetch #4645

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 19 additions & 10 deletions src/config/modelProviders/github.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ const Github: ModelProviderCard = {
vision: true,
},
{
description: '专注于高级推理和解决复杂问题,包括数学和科学任务。非常适合需要深度上下文理解和自主工作流程的应用。',
description:
'专注于高级推理和解决复杂问题,包括数学和科学任务。非常适合需要深度上下文理解和自主工作流程的应用。',
displayName: 'OpenAI o1-preview',
enabled: true,
functionCall: false,
Expand Down Expand Up @@ -45,23 +46,26 @@ const Github: ModelProviderCard = {
vision: true,
},
{
description: '一个52B参数(12B活跃)的多语言模型,提供256K长上下文窗口、函数调用、结构化输出和基于事实的生成。',
description:
'一个52B参数(12B活跃)的多语言模型,提供256K长上下文窗口、函数调用、结构化输出和基于事实的生成。',
displayName: 'AI21 Jamba 1.5 Mini',
functionCall: true,
id: 'ai21-jamba-1.5-mini',
maxOutput: 4096,
tokens: 262_144,
},
{
description: '一个398B参数(94B活跃)的多语言模型,提供256K长上下文窗口、函数调用、结构化输出和基于事实的生成。',
description:
'一个398B参数(94B活跃)的多语言模型,提供256K长上下文窗口、函数调用、结构化输出和基于事实的生成。',
displayName: 'AI21 Jamba 1.5 Large',
functionCall: true,
id: 'ai21-jamba-1.5-large',
maxOutput: 4096,
tokens: 262_144,
},
{
description: 'Command R是一个可扩展的生成模型,旨在针对RAG和工具使用,使企业能够实现生产级AI。',
description:
'Command R是一个可扩展的生成模型,旨在针对RAG和工具使用,使企业能够实现生产级AI。',
displayName: 'Cohere Command R',
id: 'cohere-command-r',
maxOutput: 4096,
Expand All @@ -75,7 +79,8 @@ const Github: ModelProviderCard = {
tokens: 131_072,
},
{
description: 'Mistral Nemo是一种尖端的语言模型(LLM),在其尺寸类别中拥有最先进的推理、世界知识和编码能力。',
description:
'Mistral Nemo是一种尖端的语言模型(LLM),在其尺寸类别中拥有最先进的推理、世界知识和编码能力。',
displayName: 'Mistral Nemo',
id: 'mistral-nemo',
maxOutput: 4096,
Expand All @@ -89,7 +94,8 @@ const Github: ModelProviderCard = {
tokens: 131_072,
},
{
description: 'Mistral的旗舰模型,适合需要大规模推理能力或高度专业化的复杂任务(合成文本生成、代码生成、RAG或代理)。',
description:
'Mistral的旗舰模型,适合需要大规模推理能力或高度专业化的复杂任务(合成文本生成、代码生成、RAG或代理)。',
displayName: 'Mistral Large',
id: 'mistral-large',
maxOutput: 4096,
Expand All @@ -112,21 +118,24 @@ const Github: ModelProviderCard = {
vision: true,
},
{
description: 'Llama 3.1指令调优的文本模型,针对多语言对话用例进行了优化,在许多可用的开源和封闭聊天模型中,在常见行业基准上表现优异。',
description:
'Llama 3.1指令调优的文本模型,针对多语言对话用例进行了优化,在许多可用的开源和封闭聊天模型中,在常见行业基准上表现优异。',
displayName: 'Meta Llama 3.1 8B',
id: 'meta-llama-3.1-8b-instruct',
maxOutput: 4096,
tokens: 131_072,
},
{
description: 'Llama 3.1指令调优的文本模型,针对多语言对话用例进行了优化,在许多可用的开源和封闭聊天模型中,在常见行业基准上表现优异。',
description:
'Llama 3.1指令调优的文本模型,针对多语言对话用例进行了优化,在许多可用的开源和封闭聊天模型中,在常见行业基准上表现优异。',
displayName: 'Meta Llama 3.1 70B',
id: 'meta-llama-3.1-70b-instruct',
maxOutput: 4096,
tokens: 131_072,
},
{
description: 'Llama 3.1指令调优的文本模型,针对多语言对话用例进行了优化,在许多可用的开源和封闭聊天模型中,在常见行业基准上表现优异。',
description:
'Llama 3.1指令调优的文本模型,针对多语言对话用例进行了优化,在许多可用的开源和封闭聊天模型中,在常见行业基准上表现优异。',
displayName: 'Meta Llama 3.1 405B',
id: 'meta-llama-3.1-405b-instruct',
maxOutput: 4096,
Expand Down Expand Up @@ -209,7 +218,7 @@ const Github: ModelProviderCard = {
description: '通过GitHub模型,开发人员可以成为AI工程师,并使用行业领先的AI模型进行构建。',
enabled: true,
id: 'github',
// modelList: { showModelFetcher: true },
modelList: { showModelFetcher: true }, // I'm not sure if it is good to show the model fetcher, as remote list is not complete.
name: 'GitHub',
url: 'https://github.com/marketplace/models',
};
Expand Down
109 changes: 68 additions & 41 deletions src/libs/agent-runtime/github/index.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,10 @@ let instance: LobeOpenAICompatibleRuntime;

beforeEach(() => {
instance = new LobeGithubAI({ apiKey: 'test' });

// Use vi.spyOn to mock the chat.completions.create method
vi.spyOn(instance['client'].chat.completions, 'create').mockResolvedValue(
new ReadableStream() as any,
);
});

afterEach(() => {
vi.clearAllMocks();
vi.restoreAllMocks();
});

describe('LobeGithubAI', () => {
Expand All @@ -42,6 +37,13 @@ describe('LobeGithubAI', () => {
});

describe('chat', () => {
beforeEach(() => {
// Use vi.spyOn to mock the chat.completions.create method
vi.spyOn(instance['client'].chat.completions, 'create').mockResolvedValue(
new ReadableStream() as any,
);
});

describe('Error', () => {
it('should return GithubBizError with an openai error response when OpenAI.APIError is thrown', async () => {
// Arrange
Expand Down Expand Up @@ -119,41 +121,6 @@ describe('LobeGithubAI', () => {
}
});

it('should return GithubBizError with an cause response with desensitize Url', async () => {
// Arrange
const errorInfo = {
stack: 'abc',
cause: { message: 'api is undefined' },
};
const apiError = new OpenAI.APIError(400, errorInfo, 'module error', {});

instance = new LobeGithubAI({
apiKey: 'test',
baseURL: 'https://api.abc.com/v1',
});

vi.spyOn(instance['client'].chat.completions, 'create').mockRejectedValue(apiError);

// Act
try {
await instance.chat({
messages: [{ content: 'Hello', role: 'user' }],
model: 'meta-llama-3-70b-instruct',
temperature: 0.7,
});
} catch (e) {
expect(e).toEqual({
endpoint: 'https://api.***.com/v1',
error: {
cause: { message: 'api is undefined' },
stack: 'abc',
},
errorType: bizErrorType,
provider,
});
}
});

it('should throw an InvalidGithubToken error type on 401 status code', async () => {
// Mock the API call to simulate a 401 error
const error = new Error('InvalidApiKey') as any;
Expand Down Expand Up @@ -243,4 +210,64 @@ describe('LobeGithubAI', () => {
});
});
});

describe('models', () => {
beforeEach(() => {});

it('should return a list of models', async () => {
// Arrange
const arr = [
{
id: 'azureml://registries/azureml-ai21/models/AI21-Jamba-Instruct/versions/2',
name: 'AI21-Jamba-Instruct',
friendly_name: 'AI21-Jamba-Instruct',
model_version: 2,
publisher: 'AI21 Labs',
model_family: 'AI21 Labs',
model_registry: 'azureml-ai21',
license: 'custom',
task: 'chat-completion',
description:
"Jamba-Instruct is the world's first production-grade Mamba-based LLM model and leverages its hybrid Mamba-Transformer architecture to achieve best-in-class performance, quality, and cost efficiency.\n\n**Model Developer Name**: _AI21 Labs_\n\n## Model Architecture\n\nJamba-Instruct leverages a hybrid Mamba-Transformer architecture to achieve best-in-class performance, quality, and cost efficiency.\nAI21's Jamba architecture features a blocks-and-layers approach that allows Jamba to successfully integrate the two architectures. Each Jamba block contains either an attention or a Mamba layer, followed by a multi-layer perceptron (MLP), producing an overall ratio of one Transformer layer out of every eight total layers.\n",
summary:
"Jamba-Instruct is the world's first production-grade Mamba-based LLM model and leverages its hybrid Mamba-Transformer architecture to achieve best-in-class performance, quality, and cost efficiency.",
tags: ['chat', 'rag'],
},
{
id: 'azureml://registries/azureml-cohere/models/Cohere-command-r/versions/3',
name: 'Cohere-command-r',
friendly_name: 'Cohere Command R',
model_version: 3,
publisher: 'cohere',
model_family: 'cohere',
model_registry: 'azureml-cohere',
license: 'custom',
task: 'chat-completion',
description:
"Command R is a highly performant generative large language model, optimized for a variety of use cases including reasoning, summarization, and question answering. \n\nThe model is optimized to perform well in the following languages: English, French, Spanish, Italian, German, Brazilian Portuguese, Japanese, Korean, Simplified Chinese, and Arabic.\n\nPre-training data additionally included the following 13 languages: Russian, Polish, Turkish, Vietnamese, Dutch, Czech, Indonesian, Ukrainian, Romanian, Greek, Hindi, Hebrew, Persian.\n\n## Resources\n\nFor full details of this model, [release blog post](https://aka.ms/cohere-blog).\n\n## Model Architecture\n\nThis is an auto-regressive language model that uses an optimized transformer architecture. After pretraining, this model uses supervised fine-tuning (SFT) and preference training to align model behavior to human preferences for helpfulness and safety.\n\n### Tool use capabilities\n\nCommand R has been specifically trained with conversational tool use capabilities. These have been trained into the model via a mixture of supervised fine-tuning and preference fine-tuning, using a specific prompt template. Deviating from this prompt template will likely reduce performance, but we encourage experimentation.\n\nCommand R's tool use functionality takes a conversation as input (with an optional user-system preamble), along with a list of available tools. The model will then generate a json-formatted list of actions to execute on a subset of those tools. Command R may use one of its supplied tools more than once.\n\nThe model has been trained to recognise a special directly_answer tool, which it uses to indicate that it doesn't want to use any of its other tools. The ability to abstain from calling a specific tool can be useful in a range of situations, such as greeting a user, or asking clarifying questions. We recommend including the directly_answer tool, but it can be removed or renamed if required.\n\n### Grounded Generation and RAG Capabilities\n\nCommand R has been specifically trained with grounded generation capabilities. This means that it can generate responses based on a list of supplied document snippets, and it will include grounding spans (citations) in its response indicating the source of the information. This can be used to enable behaviors such as grounded summarization and the final step of Retrieval Augmented Generation (RAG).This behavior has been trained into the model via a mixture of supervised fine-tuning and preference fine-tuning, using a specific prompt template. Deviating from this prompt template may reduce performance, but we encourage experimentation.\n\nCommand R's grounded generation behavior takes a conversation as input (with an optional user-supplied system preamble, indicating task, context and desired output style), along with a list of retrieved document snippets. The document snippets should be chunks, rather than long documents, typically around 100-400 words per chunk. Document snippets consist of key-value pairs. The keys should be short descriptive strings, the values can be text or semi-structured.\n\nBy default, Command R will generate grounded responses by first predicting which documents are relevant, then predicting which ones it will cite, then generating an answer. Finally, it will then insert grounding spans into the answer. See below for an example. This is referred to as accurate grounded generation.\n\nThe model is trained with a number of other answering modes, which can be selected by prompt changes . A fast citation mode is supported in the tokenizer, which will directly generate an answer with grounding spans in it, without first writing the answer out in full. This sacrifices some grounding accuracy in favor of generating fewer tokens.\n\n### Code Capabilities\n\nCommand R has been optimized to interact with your code, by requesting code snippets, code explanations, or code rewrites. It might not perform well out-of-the-box for pure code completion. For better performance, we also recommend using a low temperature (and even greedy decoding) for code-generation related instructions.\n",
summary:
'Command R is a scalable generative model targeting RAG and Tool Use to enable production-scale AI for enterprise.',
tags: ['rag', 'multilingual'],
},
];
vi.spyOn(instance['client'].models, 'list').mockResolvedValue({
body: arr,
} as any);

// Act & Assert
const models = await instance.models();

const modelsCount = models.length;
expect(modelsCount).toBe(arr.length);

for (let i = 0; i < arr.length; i++) {
const model = models[i];
expect(model).toEqual({
description: arr[i].description,
displayName: arr[i].friendly_name,
id: arr[i].name,
});
}
});
});
});
52 changes: 51 additions & 1 deletion src/libs/agent-runtime/github/index.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,35 @@
import { LOBE_DEFAULT_MODEL_LIST } from '@/config/modelProviders';
import type { ChatModelCard } from '@/types/llm';

import { AgentRuntimeErrorType } from '../error';
import { o1Models, pruneO1Payload } from '../openai';
import { ModelProvider } from '../types';
import { LobeOpenAICompatibleFactory } from '../utils/openaiCompatibleFactory';
import {
CHAT_MODELS_BLOCK_LIST,
LobeOpenAICompatibleFactory,
} from '../utils/openaiCompatibleFactory';

enum Task {
'chat-completion',
'embeddings',
}

/* eslint-disable typescript-sort-keys/interface */
type Model = {
id: string;
name: string;
friendly_name: string;
model_version: number;
publisher: string;
model_family: string;
model_registry: string;
license: string;
task: Task;
description: string;
summary: string;
tags: string[];
};
/* eslint-enable typescript-sort-keys/interface */

export const LobeGithubAI = LobeOpenAICompatibleFactory({
baseURL: 'https://models.inference.ai.azure.com',
Expand All @@ -23,5 +51,27 @@ export const LobeGithubAI = LobeOpenAICompatibleFactory({
bizError: AgentRuntimeErrorType.ProviderBizError,
invalidAPIKey: AgentRuntimeErrorType.InvalidGithubToken,
},
models: async ({ client }) => {
const modelsPage = (await client.models.list()) as any;
const modelList: Model[] = modelsPage.body;
return modelList
.filter((model) => {
return CHAT_MODELS_BLOCK_LIST.every(
(keyword) => !model.name.toLowerCase().includes(keyword),
);
})
.map((model) => {
const knownModel = LOBE_DEFAULT_MODEL_LIST.find((m) => m.id === model.name);

if (knownModel) return knownModel;

return {
description: model.description,
displayName: model.friendly_name,
id: model.name,
};
})
.filter(Boolean) as ChatModelCard[];
},
provider: ModelProvider.Github,
});
3 changes: 2 additions & 1 deletion src/libs/agent-runtime/togetherai/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ export const LobeTogetherAI = LobeOpenAICompatibleFactory({
debug: {
chatCompletion: () => process.env.DEBUG_TOGETHERAI_CHAT_COMPLETION === '1',
},
models: async ({ apiKey }) => {
models: async ({ client }) => {
const apiKey = client.apiKey;
const data = await fetch(`${baseURL}/api/models`, {
headers: {
Authorization: `Bearer ${apiKey}`,
Expand Down
19 changes: 10 additions & 9 deletions src/libs/agent-runtime/utils/openaiCompatibleFactory/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,18 @@ import OpenAI, { ClientOptions } from 'openai';
import { Stream } from 'openai/streaming';

import { LOBE_DEFAULT_MODEL_LIST } from '@/config/modelProviders';
import { ChatModelCard } from '@/types/llm';
import type { ChatModelCard } from '@/types/llm';

import { LobeRuntimeAI } from '../../BaseAI';
import { AgentRuntimeErrorType, ILobeAgentRuntimeErrorType } from '../../error';
import {
import type {
ChatCompetitionOptions,
ChatCompletionErrorPayload,
ChatStreamPayload,
Embeddings,
EmbeddingsOptions,
EmbeddingsPayload,
ModelProvider,
TextToImagePayload,
TextToSpeechOptions,
TextToSpeechPayload,
Expand All @@ -26,7 +27,7 @@ import { StreamingResponse } from '../response';
import { OpenAIStream, OpenAIStreamOptions } from '../streams';

// the model contains the following keywords is not a chat model, so we should filter them out
const CHAT_MODELS_BLOCK_LIST = [
export const CHAT_MODELS_BLOCK_LIST = [
'embedding',
'davinci',
'curie',
Expand Down Expand Up @@ -77,7 +78,7 @@ interface OpenAICompatibleFactoryOptions<T extends Record<string, any> = any> {
invalidAPIKey: ILobeAgentRuntimeErrorType;
};
models?:
| ((params: { apiKey: string }) => Promise<ChatModelCard[]>)
| ((params: { client: OpenAI }) => Promise<ChatModelCard[]>)
| {
transformModel?: (model: OpenAI.Model) => ChatModelCard;
};
Expand Down Expand Up @@ -157,7 +158,7 @@ export const LobeOpenAICompatibleFactory = <T extends Record<string, any> = any>
client!: OpenAI;

baseURL!: string;
private _options: ConstructorOptions<T>;
protected _options: ConstructorOptions<T>;

constructor(options: ClientOptions & Record<string, any> = {}) {
const _options = {
Expand Down Expand Up @@ -249,7 +250,7 @@ export const LobeOpenAICompatibleFactory = <T extends Record<string, any> = any>
}

async models() {
if (typeof models === 'function') return models({ apiKey: this.client.apiKey });
if (typeof models === 'function') return models({ client: this.client });

const list = await this.client.models.list();

Expand Down Expand Up @@ -312,7 +313,7 @@ export const LobeOpenAICompatibleFactory = <T extends Record<string, any> = any>
}
}

private handleError(error: any): ChatCompletionErrorPayload {
protected handleError(error: any): ChatCompletionErrorPayload {
let desensitizedEndpoint = this.baseURL;

// refs: https://github.com/lobehub/lobe-chat/issues/842
Expand All @@ -337,7 +338,7 @@ export const LobeOpenAICompatibleFactory = <T extends Record<string, any> = any>
endpoint: desensitizedEndpoint,
error: error as any,
errorType: ErrorType.invalidAPIKey,
provider: provider as any,
provider: provider as ModelProvider,
});
}

Expand All @@ -353,7 +354,7 @@ export const LobeOpenAICompatibleFactory = <T extends Record<string, any> = any>
endpoint: desensitizedEndpoint,
error: errorResult,
errorType: RuntimeError || ErrorType.bizError,
provider: provider as any,
provider: provider as ModelProvider,
});
}
};
Expand Down