Skip to content

Commit

Permalink
Update openai-fetch to 3.3.1 and handle refusals
Browse files Browse the repository at this point in the history
This commit updates the `openai-fetch` dependency to version 3.3.1.
It also introduces handling for message refusals by throwing a
`RefusalError` when a refusal is returned. This includes changes in
`chat.ts`, test enhancements, and updates to message types and checks.
  • Loading branch information
rileytomasek committed Oct 13, 2024
1 parent 82fe6d2 commit 53ede7e
Show file tree
Hide file tree
Showing 10 changed files with 188 additions and 39 deletions.
2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@
"hash-object": "^5.0.1",
"jsonrepair": "^3.8.1",
"ky": "^1.7.2",
"openai-fetch": "2.0.4",
"openai-fetch": "3.3.1",
"p-map": "^7.0.2",
"p-throttle": "^6.2.0",
"parse-json": "^8.1.0",
Expand Down
52 changes: 45 additions & 7 deletions pnpm-lock.yaml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

52 changes: 51 additions & 1 deletion src/model/chat.test.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import { describe, expect, it, vi } from 'vitest';

import { RefusalError } from '../prompt/index.js';
import { ChatModel } from './chat.js';
import { type Model } from './types.js';

Expand Down Expand Up @@ -27,6 +28,38 @@ const FAKE_RESPONSE: Model.Chat.Response = {
message: {
content: 'Hi from fake AI',
role: 'assistant',
refusal: null,
},
logprobs: null,
},
],
};

const FAKE_REFUSAL_RESPONSE: Model.Chat.Response = {
message: {
content: null,
role: 'assistant',
},
cached: false,
latency: 0,
cost: 0,
created: 0,
id: 'fake-id',
model: 'gpt-fake',
object: 'chat.completion',
usage: {
completion_tokens: 1,
prompt_tokens: 1,
total_tokens: 2,
},
choices: [
{
finish_reason: 'stop',
index: 0,
message: {
content: null,
role: 'assistant',
refusal: 'I refuse to answer',
},
logprobs: null,
},
Expand Down Expand Up @@ -54,7 +87,10 @@ describe('ChatModel', () => {
const response = await chatModel.run({
messages: [{ role: 'user', content: 'content' }],
});
expect(response).toEqual(FAKE_RESPONSE);
expect(response).toEqual({
...FAKE_RESPONSE,
message: FAKE_RESPONSE.message,
});
});

it('triggers events', async () => {
Expand Down Expand Up @@ -246,4 +282,18 @@ describe('ChatModel', () => {
expect(chatModel.context.userId).toBe('123');
expect(chatModel.params.model).toBe('gpt-fake');
});

it('throws a refusal error when a refusal is returned', async () => {
vi.setSystemTime(new Date());
Client = vi.fn() as unknown as Model.Chat.Client;
Client.createChatCompletion = vi
.fn()
.mockImplementation(() => Promise.resolve(FAKE_REFUSAL_RESPONSE));
const chatModel = new ChatModel({ client: Client });
await expect(() =>
chatModel.run({
messages: [{ role: 'user', content: 'content' }],
})
).rejects.toThrowError(RefusalError);
});
});
17 changes: 11 additions & 6 deletions src/model/chat.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import { type ChatResponse } from 'openai-fetch';
import { type PartialDeep, type SetOptional } from 'type-fest';

import { Msg } from '../prompt/index.js';
import { deepMerge, mergeEvents, type Prettify } from '../utils/helpers.js';
import { createOpenAIClient } from './clients/openai.js';
import { AbstractModel, type ModelArgs } from './model.js';
Expand Down Expand Up @@ -97,9 +99,11 @@ export class ChatModel<
) ?? []
);

const message = Msg.fromChatMessage(response.choices[0].message);

const modelResponse: Model.Chat.Response = {
...response,
message: response.choices[0].message,
message,
cached: false,
latency: Date.now() - start,
cost: calculateCost({ model: params.model, tokens: response.usage }),
Expand Down Expand Up @@ -200,7 +204,7 @@ export class ChatModel<
finish_reason:
choice.finish_reason as Model.Chat.Response['choices'][0]['finish_reason'],
index: choice.index,
message: choice.delta as Model.Message & { role: 'assistant' },
message: choice.delta as ChatResponse['choices'][0]['message'],
logprobs: choice.logprobs || null,
},
],
Expand All @@ -209,9 +213,8 @@ export class ChatModel<
// Calculate the token usage and add it to the response.
// OpenAI doesn't provide token usage for streaming requests.
const promptTokens = this.tokenizer.countTokens(params.messages);
const completionTokens = this.tokenizer.countTokens(
response.choices[0].message
);
const messageContent = response.choices[0].message.content ?? '';
const completionTokens = this.tokenizer.countTokens(messageContent);
response.usage = {
completion_tokens: completionTokens,
prompt_tokens: promptTokens,
Expand All @@ -234,9 +237,11 @@ export class ChatModel<
) ?? []
);

const message = Msg.fromChatMessage(response.choices[0].message);

const modelResponse: Model.Chat.Response = {
...response,
message: response.choices[0].message,
message,
cached: false,
latency: Date.now() - start,
cost: calculateCost({ model: params.model, tokens: response.usage }),
Expand Down
8 changes: 4 additions & 4 deletions src/model/types.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
/* eslint-disable no-use-before-define */
import { type Options as KYOptions } from 'ky';
import {
type ChatMessage,
type ChatParams,
type ChatResponse,
type ChatStreamResponse,
Expand All @@ -12,6 +11,7 @@ import {
type OpenAIClient,
} from 'openai-fetch';

import { type Prompt } from '../prompt/types.js';
import { type ChatModel } from './chat.js';
import { type CompletionModel } from './completion.js';
import { type EmbeddingModel } from './embedding.js';
Expand Down Expand Up @@ -79,7 +79,7 @@ export namespace Model {
top_p?: ChatParams['top_p'];
}
export interface Response extends Base.Response, ChatResponse {
message: ChatMessage;
message: Prompt.Msg;
}
/** Streaming response from the OpenAI API. */
type StreamResponse = ChatStreamResponse;
Expand Down Expand Up @@ -205,7 +205,7 @@ export namespace Model {
* A single ChatMessage is counted as a completion and an array as a prompt.
* Strings are counted as is.
*/
countTokens(input?: string | ChatMessage | ChatMessage[]): number;
countTokens(input?: string | Prompt.Msg | Prompt.Msg[]): number;
/** Truncate a string to a maximum number of tokens */
truncate(args: {
/** Text to truncate */
Expand All @@ -218,7 +218,7 @@ export namespace Model {
}

/** Primary message type for chat models */
export type Message = ChatMessage;
export type Message = Prompt.Msg;

/** The provider of the model (eg: OpenAI) */
export type Provider = (string & {}) | 'openai' | 'custom';
Expand Down
1 change: 1 addition & 0 deletions src/prompt/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,6 @@ export { stringifyForModel } from './functions/stringify-for-model.js';
export { zodToJsonSchema } from './functions/zod-to-json.js';
export type { Prompt } from './types.js';
export * from './utils/errors.js';
export { AbortError, RefusalError } from './utils/errors.js';
export { getErrorMsg } from './utils/get-error-message.js';
export { Msg } from './utils/message.js';
7 changes: 7 additions & 0 deletions src/prompt/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,13 @@ export namespace Prompt {
content: string;
};

/** Message with a refusal reason and no content. */
export type Refusal = {
role: 'assistant';
refusal: string;
content?: null;
};

/** Message with arguments to call a function. */
export type FuncCall = {
role: 'assistant';
Expand Down
20 changes: 20 additions & 0 deletions src/prompt/utils/errors.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,23 @@ export class AbortError extends Error {
this.message = message;
}
}

export class RefusalError extends Error {
readonly name: 'RefusalError';
readonly originalError: Error;

constructor(message: string | Error) {
super();

if (message instanceof Error) {
this.originalError = message;
({ message } = message);
} else {
this.originalError = new Error(message);
this.originalError.stack = this.stack;
}

this.name = 'RefusalError';
this.message = message;
}
}
6 changes: 5 additions & 1 deletion src/prompt/utils/message.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,12 @@ describe('Msg', () => {
expect(Msg.isToolResult(msg)).toBe(true);
});

// Same as OpenAI.ChatMessage, except we throw a RefusalError if the message is a refusal
// so `refusal` isn't on the object and content can't be optional.
it('prompt message types should interop with openai-fetch message types', () => {
expectTypeOf({} as OpenAI.ChatMessage).toMatchTypeOf<Prompt.Msg>();
expectTypeOf(
{} as Omit<OpenAI.ChatMessage, 'refusal'> & { content: string | null }
).toMatchTypeOf<Prompt.Msg>();
expectTypeOf({} as Prompt.Msg).toMatchTypeOf<OpenAI.ChatMessage>();
expectTypeOf({} as Prompt.Msg.System).toMatchTypeOf<OpenAI.ChatMessage>();
expectTypeOf({} as Prompt.Msg.User).toMatchTypeOf<OpenAI.ChatMessage>();
Expand Down
Loading

0 comments on commit 53ede7e

Please sign in to comment.