Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update openai-fetch and handle refusal responses #51

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ jobs:
steps:
- uses: actions/checkout@v4
- uses: ./.github/actions/install-node-pnpm
- run: pnpm run typecheck
- run: pnpm run build && pnpm run typecheck

lint:
name: Lint
Expand Down
2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@
"hash-object": "^5.0.1",
"jsonrepair": "^3.8.1",
"ky": "^1.7.2",
"openai-fetch": "2.0.4",
"openai-fetch": "3.3.1",
"p-map": "^7.0.2",
"p-throttle": "^6.2.0",
"parse-json": "^8.1.0",
Expand Down
52 changes: 45 additions & 7 deletions pnpm-lock.yaml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

52 changes: 51 additions & 1 deletion src/model/chat.test.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import { describe, expect, it, vi } from 'vitest';

import { RefusalError } from '../prompt/index.js';
import { ChatModel } from './chat.js';
import { type Model } from './types.js';

Expand Down Expand Up @@ -27,6 +28,38 @@ const FAKE_RESPONSE: Model.Chat.Response = {
message: {
content: 'Hi from fake AI',
role: 'assistant',
refusal: null,
},
logprobs: null,
},
],
};

const FAKE_REFUSAL_RESPONSE: Model.Chat.Response = {
message: {
content: null,
role: 'assistant',
},
cached: false,
latency: 0,
cost: 0,
created: 0,
id: 'fake-id',
model: 'gpt-fake',
object: 'chat.completion',
usage: {
completion_tokens: 1,
prompt_tokens: 1,
total_tokens: 2,
},
choices: [
{
finish_reason: 'stop',
index: 0,
message: {
content: null,
role: 'assistant',
refusal: 'I refuse to answer',
},
logprobs: null,
},
Expand Down Expand Up @@ -54,7 +87,10 @@ describe('ChatModel', () => {
const response = await chatModel.run({
messages: [{ role: 'user', content: 'content' }],
});
expect(response).toEqual(FAKE_RESPONSE);
expect(response).toEqual({
...FAKE_RESPONSE,
message: FAKE_RESPONSE.message,
});
});

it('triggers events', async () => {
Expand Down Expand Up @@ -246,4 +282,18 @@ describe('ChatModel', () => {
expect(chatModel.context.userId).toBe('123');
expect(chatModel.params.model).toBe('gpt-fake');
});

it('throws a refusal error when a refusal is returned', async () => {
vi.setSystemTime(new Date());
Client = vi.fn() as unknown as Model.Chat.Client;
Client.createChatCompletion = vi
.fn()
.mockImplementation(() => Promise.resolve(FAKE_REFUSAL_RESPONSE));
const chatModel = new ChatModel({ client: Client });
await expect(() =>
chatModel.run({
messages: [{ role: 'user', content: 'content' }],
})
).rejects.toThrowError(RefusalError);
});
});
17 changes: 11 additions & 6 deletions src/model/chat.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import { type ChatResponse } from 'openai-fetch';
import { type PartialDeep, type SetOptional } from 'type-fest';

import { Msg } from '../prompt/index.js';
import { deepMerge, mergeEvents, type Prettify } from '../utils/helpers.js';
import { createOpenAIClient } from './clients/openai.js';
import { AbstractModel, type ModelArgs } from './model.js';
Expand Down Expand Up @@ -97,9 +99,11 @@ export class ChatModel<
) ?? []
);

const message = Msg.fromChatMessage(response.choices[0].message);

const modelResponse: Model.Chat.Response = {
...response,
message: response.choices[0].message,
message,
cached: false,
latency: Date.now() - start,
cost: calculateCost({ model: params.model, tokens: response.usage }),
Expand Down Expand Up @@ -200,7 +204,7 @@ export class ChatModel<
finish_reason:
choice.finish_reason as Model.Chat.Response['choices'][0]['finish_reason'],
index: choice.index,
message: choice.delta as Model.Message & { role: 'assistant' },
message: choice.delta as ChatResponse['choices'][0]['message'],
logprobs: choice.logprobs || null,
},
],
Expand All @@ -209,9 +213,8 @@ export class ChatModel<
// Calculate the token usage and add it to the response.
// OpenAI doesn't provide token usage for streaming requests.
const promptTokens = this.tokenizer.countTokens(params.messages);
const completionTokens = this.tokenizer.countTokens(
response.choices[0].message
);
const messageContent = response.choices[0].message.content ?? '';
const completionTokens = this.tokenizer.countTokens(messageContent);
response.usage = {
completion_tokens: completionTokens,
prompt_tokens: promptTokens,
Expand All @@ -234,9 +237,11 @@ export class ChatModel<
) ?? []
);

const message = Msg.fromChatMessage(response.choices[0].message);

const modelResponse: Model.Chat.Response = {
...response,
message: response.choices[0].message,
message,
cached: false,
latency: Date.now() - start,
cost: calculateCost({ model: params.model, tokens: response.usage }),
Expand Down
8 changes: 4 additions & 4 deletions src/model/types.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
/* eslint-disable no-use-before-define */
import { type Options as KYOptions } from 'ky';
import {
type ChatMessage,
type ChatParams,
type ChatResponse,
type ChatStreamResponse,
Expand All @@ -12,6 +11,7 @@ import {
type OpenAIClient,
} from 'openai-fetch';

import { type Prompt } from '../prompt/types.js';
import { type ChatModel } from './chat.js';
import { type CompletionModel } from './completion.js';
import { type EmbeddingModel } from './embedding.js';
Expand Down Expand Up @@ -79,7 +79,7 @@ export namespace Model {
top_p?: ChatParams['top_p'];
}
export interface Response extends Base.Response, ChatResponse {
message: ChatMessage;
message: Prompt.Msg;
}
/** Streaming response from the OpenAI API. */
type StreamResponse = ChatStreamResponse;
Expand Down Expand Up @@ -205,7 +205,7 @@ export namespace Model {
* A single ChatMessage is counted as a completion and an array as a prompt.
* Strings are counted as is.
*/
countTokens(input?: string | ChatMessage | ChatMessage[]): number;
countTokens(input?: string | Prompt.Msg | Prompt.Msg[]): number;
/** Truncate a string to a maximum number of tokens */
truncate(args: {
/** Text to truncate */
Expand All @@ -218,7 +218,7 @@ export namespace Model {
}

/** Primary message type for chat models */
export type Message = ChatMessage;
export type Message = Prompt.Msg;

/** The provider of the model (eg: OpenAI) */
export type Provider = (string & {}) | 'openai' | 'custom';
Expand Down
2 changes: 1 addition & 1 deletion src/model/utils/tokenizer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ class Tokenizer implements Model.ITokenizer {
this.model = model;
try {
this.tiktoken = encoding_for_model(model as TiktokenModel);
// eslint-disable-next-line @typescript-eslint/no-unused-vars
} catch (e) {
console.error(`Failed to create tokenizer for model ${model}`, e);
this.tiktoken = encoding_for_model('gpt-3.5-turbo');
}
}
Expand Down
1 change: 1 addition & 0 deletions src/prompt/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,6 @@ export { stringifyForModel } from './functions/stringify-for-model.js';
export { zodToJsonSchema } from './functions/zod-to-json.js';
export type { Prompt } from './types.js';
export * from './utils/errors.js';
export { AbortError, RefusalError } from './utils/errors.js';
export { getErrorMsg } from './utils/get-error-message.js';
export { Msg } from './utils/message.js';
7 changes: 7 additions & 0 deletions src/prompt/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,13 @@ export namespace Prompt {
content: string;
};

/** Message with a refusal reason and no content. */
export type Refusal = {
role: 'assistant';
refusal: string;
content?: null;
};

/** Message with arguments to call a function. */
export type FuncCall = {
role: 'assistant';
Expand Down
20 changes: 20 additions & 0 deletions src/prompt/utils/errors.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,23 @@ export class AbortError extends Error {
this.message = message;
}
}

export class RefusalError extends Error {
readonly name: 'RefusalError';
readonly originalError: Error;

constructor(message: string | Error) {
super();

if (message instanceof Error) {
this.originalError = message;
({ message } = message);
} else {
this.originalError = new Error(message);
this.originalError.stack = this.stack;
}

this.name = 'RefusalError';
this.message = message;
}
}
6 changes: 5 additions & 1 deletion src/prompt/utils/message.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,12 @@ describe('Msg', () => {
expect(Msg.isToolResult(msg)).toBe(true);
});

// Same as OpenAI.ChatMessage, except we throw a RefusalError if the message is a refusal
// so `refusal` isn't on the object and content can't be optional.
it('prompt message types should interop with openai-fetch message types', () => {
expectTypeOf({} as OpenAI.ChatMessage).toMatchTypeOf<Prompt.Msg>();
expectTypeOf(
{} as Omit<OpenAI.ChatMessage, 'refusal'> & { content: string | null }
).toMatchTypeOf<Prompt.Msg>();
expectTypeOf({} as Prompt.Msg).toMatchTypeOf<OpenAI.ChatMessage>();
expectTypeOf({} as Prompt.Msg.System).toMatchTypeOf<OpenAI.ChatMessage>();
expectTypeOf({} as Prompt.Msg.User).toMatchTypeOf<OpenAI.ChatMessage>();
Expand Down
Loading