Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

♻️ refactor: Refactor Azure OpenAI Implementation #4619

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,6 @@
"@aws-sdk/client-bedrock-runtime": "^3.675.0",
"@aws-sdk/client-s3": "^3.675.0",
"@aws-sdk/s3-request-presigner": "^3.675.0",
"@azure/core-rest-pipeline": "1.16.0",
"@azure/openai": "1.0.0-beta.12",
"@baiducloud/qianfan": "^0.1.9",
"@cfworker/json-schema": "^2.0.1",
"@clerk/localizations": "^3.3.0",
Expand Down
7 changes: 4 additions & 3 deletions src/libs/agent-runtime/AgentRuntime.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ describe('AgentRuntime', () => {
describe('Azure OpenAI provider', () => {
it('should initialize correctly', async () => {
const jwtPayload = {
apikey: 'user-azure-key',
apiKey: 'user-azure-key',
endpoint: 'user-azure-endpoint',
apiVersion: '2024-06-01',
};
Expand All @@ -90,7 +90,7 @@ describe('AgentRuntime', () => {
});
it('should initialize with azureOpenAIParams correctly', async () => {
const jwtPayload = {
apikey: 'user-openai-key',
apiKey: 'user-openai-key',
endpoint: 'user-endpoint',
apiVersion: 'custom-version',
};
Expand All @@ -106,7 +106,8 @@ describe('AgentRuntime', () => {

it('should initialize with AzureAI correctly', async () => {
const jwtPayload = {
apikey: 'user-azure-key',
apiKey: 'user-azure-key',
apiVersion: '2024-06-01',
endpoint: 'user-azure-endpoint',
};
const runtime = await AgentRuntime.initializeWithProviderOptions(ModelProvider.Azure, {
Expand Down
4 changes: 2 additions & 2 deletions src/libs/agent-runtime/AgentRuntime.ts
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ class AgentRuntime {
ai21: Partial<ClientOptions>;
ai360: Partial<ClientOptions>;
anthropic: Partial<ClientOptions>;
azure: { apiVersion?: string; apikey?: string; endpoint?: string };
azure: { apiKey?: string; apiVersion?: string; endpoint?: string };
baichuan: Partial<ClientOptions>;
bedrock: Partial<LobeBedrockAIParams>;
deepseek: Partial<ClientOptions>;
Expand Down Expand Up @@ -171,7 +171,7 @@ class AgentRuntime {
case ModelProvider.Azure: {
runtimeModel = new LobeAzureOpenAI(
params.azure?.endpoint,
params.azure?.apikey,
params.azure?.apiKey,
params.azure?.apiVersion,
);
break;
Expand Down
187 changes: 114 additions & 73 deletions src/libs/agent-runtime/azureOpenai/index.test.ts
Original file line number Diff line number Diff line change
@@ -1,35 +1,120 @@
// @vitest-environment node
import { AzureKeyCredential, OpenAIClient } from '@azure/openai';
import OpenAI from 'openai';
import { AzureOpenAI } from 'openai';
import type { ChatCompletionToolChoiceOption } from 'openai/resources/chat/completions';
import type {
ResponseFormatJSONObject,
ResponseFormatJSONSchema,
ResponseFormatText,
} from 'openai/resources/shared';
import { Mock, afterEach, beforeEach, describe, expect, it, vi } from 'vitest';

import * as debugStreamModule from '../utils/debugStream';
import { LobeAzureOpenAI } from './index';
import * as AzureOpenAIStreamUtils from '../utils/streams/azureOpenai';
import { LobeAzureOpenAI, convertResponseMode, convertToolChoice } from './index';

declare module './index' {
export function convertResponseMode(
responseMode?: 'streamText' | 'json',
): ResponseFormatText | ResponseFormatJSONObject | ResponseFormatJSONSchema | undefined;

export function convertToolChoice(
tool_choice?: string,
): ChatCompletionToolChoiceOption | undefined;
}

const bizErrorType = 'ProviderBizError';
const invalidErrorType = 'InvalidProviderAPIKey';
const endpoint = 'https://test.openai.azure.com/';
const apiKey = 'test_key';
const apiVersion = '2024-06-01';

// Mock the console.error to avoid polluting test output
vi.spyOn(console, 'error').mockImplementation(() => {});

describe('helper functions', () => {
describe('convertResponseMode', () => {
it('should return undefined when responseMode is not provided', () => {
const result = convertResponseMode();

expect(result).toBeUndefined();
});

it('should return text when responseMode is streamText', () => {
const result = convertResponseMode('streamText');

expect(result).toEqual({
type: 'text',
});
});

it('should return json_object when responseMode is json', () => {
const result = convertResponseMode('json');

expect(result).toEqual({
type: 'json_object',
});
});
});

describe('convertToolChoice', () => {
it('should return undefined when tool_choice is not provided', () => {
const result = convertToolChoice();

expect(result).toBeUndefined();
});

it('should return undefined when tool_choice is empty', () => {
const result = convertToolChoice('');

expect(result).toBeUndefined();
});

it('should return none when tool_choice is none', () => {
const result = convertToolChoice('none');

expect(result).toEqual('none');
});

it('should return auto when tool_choice is auto', () => {
const result = convertToolChoice('auto');

expect(result).toEqual('auto');
});

it('should return required when tool_choice is required', () => {
const result = convertToolChoice('required');

expect(result).toEqual('required');
});

it('should return function object when tool_choice is provided', () => {
const result = convertToolChoice('test_function');

expect(result).toEqual({
function: {
name: 'test_function',
},
type: 'function',
});
});
});
});

describe('LobeAzureOpenAI', () => {
let instance: LobeAzureOpenAI;

beforeEach(() => {
instance = new LobeAzureOpenAI(
'https://test.openai.azure.com/',
'test_key',
'2023-03-15-preview',
);
instance = new LobeAzureOpenAI(endpoint, apiKey, apiVersion);

// 使用 vi.spyOn 来模拟 streamChatCompletions 方法
vi.spyOn(instance['client'], 'streamChatCompletions').mockResolvedValue(
vi.spyOn(instance.client.chat.completions, 'create').mockResolvedValue(
new ReadableStream() as any,
);

vi.spyOn(AzureOpenAIStreamUtils, 'convertToStream').mockImplementation((x) => x as any);
});

afterEach(() => {
vi.clearAllMocks();
vi.restoreAllMocks();
});

describe('constructor', () => {
Expand All @@ -41,14 +126,10 @@ describe('LobeAzureOpenAI', () => {
}
});

it('should create an instance of OpenAIClient with correct parameters', () => {
const endpoint = 'https://test.openai.azure.com/';
const apikey = 'test_key';
const apiVersion = '2023-03-15-preview';
it('should create an instance of AzureOpenAI with correct parameters', () => {
const instance = new LobeAzureOpenAI(endpoint, apiKey, apiVersion);

const instance = new LobeAzureOpenAI(endpoint, apikey, apiVersion);

expect(instance.client).toBeInstanceOf(OpenAIClient);
expect(instance.client).toBeInstanceOf(AzureOpenAI);
expect(instance.baseURL).toBe(endpoint);
});
});
Expand All @@ -59,12 +140,12 @@ describe('LobeAzureOpenAI', () => {
const mockStream = new ReadableStream();
const mockResponse = Promise.resolve(mockStream);

(instance['client'].streamChatCompletions as Mock).mockResolvedValue(mockResponse);
(instance.client.chat.completions.create as Mock).mockResolvedValue(mockResponse);

// Act
const result = await instance.chat({
messages: [{ content: 'Hello', role: 'user' }],
model: 'text-davinci-003',
model: 'gpt-4o-mini',
temperature: 0,
});

Expand Down Expand Up @@ -164,7 +245,7 @@ describe('LobeAzureOpenAI', () => {
controller.close();
},
});
vi.spyOn(instance['client'], 'streamChatCompletions').mockResolvedValue(mockStream as any);
(instance.client.chat.completions.create as Mock).mockResolvedValue(mockStream as any);

const result = await instance.chat({
stream: true,
Expand Down Expand Up @@ -214,13 +295,13 @@ describe('LobeAzureOpenAI', () => {
message: 'Deployment not found',
};

(instance['client'].streamChatCompletions as Mock).mockRejectedValue(error);
(instance.client.chat.completions.create as Mock).mockRejectedValue(error);

// Act
try {
await instance.chat({
messages: [{ content: 'Hello', role: 'user' }],
model: 'text-davinci-003',
model: 'gpt-4o-mini',
temperature: 0,
});
} catch (e) {
Expand All @@ -230,7 +311,7 @@ describe('LobeAzureOpenAI', () => {
error: {
code: 'DeploymentNotFound',
message: 'Deployment not found',
deployId: 'text-davinci-003',
deployId: 'gpt-4o-mini',
},
errorType: bizErrorType,
provider: 'azure',
Expand All @@ -242,13 +323,13 @@ describe('LobeAzureOpenAI', () => {
// Arrange
const genericError = new Error('Generic Error');

(instance['client'].streamChatCompletions as Mock).mockRejectedValue(genericError);
(instance.client.chat.completions.create as Mock).mockRejectedValue(genericError);

// Act
try {
await instance.chat({
messages: [{ content: 'Hello', role: 'user' }],
model: 'text-davinci-003',
model: 'gpt-4o-mini',
temperature: 0,
});
} catch (e) {
Expand Down Expand Up @@ -279,7 +360,7 @@ describe('LobeAzureOpenAI', () => {
}) as any;
mockDebugStream.toReadableStream = () => mockDebugStream;

(instance['client'].streamChatCompletions as Mock).mockResolvedValue({
(instance.client.chat.completions.create as Mock).mockResolvedValue({
tee: () => [mockProdStream, { toReadableStream: () => mockDebugStream }],
});

Expand All @@ -289,7 +370,7 @@ describe('LobeAzureOpenAI', () => {
// Act
await instance.chat({
messages: [{ content: 'Hello', role: 'user' }],
model: 'text-davinci-003',
model: 'gpt-4o-mini',
temperature: 0,
});

Expand All @@ -303,53 +384,13 @@ describe('LobeAzureOpenAI', () => {
});

describe('private method', () => {
describe('tocamelCase', () => {
it('should convert string to camel case', () => {
const key = 'image_url';

const camelCaseKey = instance['tocamelCase'](key);

expect(camelCaseKey).toEqual('imageUrl');
});
});
describe('maskSensitiveUrl', () => {
it('should mask endpoint', () => {
const url = 'https://test.openai.azure.com/';

describe('camelCaseKeys', () => {
it('should convert object keys to camel case', () => {
const obj = {
frequency_penalty: 0,
messages: [
{
role: 'user',
content: [
{
type: 'image_url',
image_url: {
url: '<image URL>',
},
},
],
},
],
};
const maskedUrl = instance['maskSensitiveUrl'](url);

const newObj = instance['camelCaseKeys'](obj);

expect(newObj).toEqual({
frequencyPenalty: 0,
messages: [
{
role: 'user',
content: [
{
type: 'image_url',
imageUrl: {
url: '<image URL>',
},
},
],
},
],
});
expect(maskedUrl).toEqual('https://***.openai.azure.com/');
});
});
});
Expand Down
Loading