Skip to content

Commit

Permalink
add opt for passing abort signal to chatCompletion function (#37)
Browse files Browse the repository at this point in the history
* add opt for passing abort signal to chatCompletion function
  • Loading branch information
cfortuner authored May 31, 2024
1 parent 72bba04 commit fb20026
Show file tree
Hide file tree
Showing 10 changed files with 152 additions and 14 deletions.
53 changes: 53 additions & 0 deletions examples/abort-chat-completion.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import 'dotenv/config';
import { ChatModel, Msg, type Prompt } from '@dexaai/dexter';

/**
* npx tsx examples/abort-chat-completion.ts
*/
async function main() {
const chatModel = new ChatModel({
debug: true,
params: {
model: 'gpt-3.5-turbo',
temperature: 0.5,
max_tokens: 1000,
},
});

const messages: Prompt.Msg[] = [Msg.user(`Write a short story`)];

{
const abortController = new AbortController();
abortController.signal.addEventListener('abort', () => {
console.log('\n\nAborted');
});

try {
setTimeout(() => {
abortController.abort();
}, 2000);

// Invoke the chat model with the result
await chatModel.run({
messages,
requestOpts: {
signal: abortController.signal,
},
handleUpdate: (c) => {
// Note: The abort doesn't always cancel the request, so we need to handle when the request is aborted
// here as well.
if (abortController.signal.aborted) {
return;
}
process.stdout.write(c);
},
});

// console.log(message.content);
} catch (error) {
console.error('Error during chat model run:', error);
}
}
}

main();
16 changes: 13 additions & 3 deletions src/model/chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -65,15 +65,22 @@ export class ChatModel<
}

protected async runModel(
{ handleUpdate, ...params }: Model.Chat.Run & Model.Chat.Config,
{
handleUpdate,
requestOpts,
...params
}: Model.Chat.Run & Model.Chat.Config,
context: CustomCtx
): Promise<Model.Chat.Response> {
const start = Date.now();

// Use non-streaming API if no handler is provided
if (!handleUpdate) {
// Make the OpenAI API request
const response = await this.client.createChatCompletion(params);
const response = await this.client.createChatCompletion(
params,
requestOpts
);

await Promise.allSettled(
this.events?.onApiResponse?.map((event) =>
Expand Down Expand Up @@ -102,7 +109,10 @@ export class ChatModel<
return modelResponse;
} else {
// Use the streaming API if a handler is provided
const stream = await this.client.streamChatCompletion(params);
const stream = await this.client.streamChatCompletion(
params,
requestOpts
);

// Keep track of the stream's output
let chunk = {} as Model.Chat.CompletionChunk;
Expand Down
6 changes: 5 additions & 1 deletion src/model/clients/splade.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import ky from 'ky';
import ky, { type Options as KYOptions } from 'ky';
import type { Model } from '../types.js';

export const createSpladeClient = () => ({
async createSparseVector(
params: {
input: string;
model: string;
requestOpts?: {
headers?: KYOptions['headers'];
};
},
serviceUrl: string
): Promise<Model.SparseVector.Vector> {
Expand All @@ -14,6 +17,7 @@ export const createSpladeClient = () => ({
.post(serviceUrl, {
timeout: 1000 * 60,
json: { text: params.input },
headers: params.requestOpts?.headers,
})
.json<Model.SparseVector.Vector>();
return sparseValues;
Expand Down
4 changes: 2 additions & 2 deletions src/model/completion.ts
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,13 @@ export class CompletionModel<
}

protected async runModel(
params: Model.Completion.Run & Model.Completion.Config,
{ requestOpts, ...params }: Model.Completion.Run & Model.Completion.Config,
context: CustomCtx
): Promise<Model.Completion.Response> {
const start = Date.now();

// Make the OpenAI API request
const response = await this.client.createCompletions(params);
const response = await this.client.createCompletions(params, requestOpts);

await Promise.allSettled(
this.events?.onApiResponse?.map((event) =>
Expand Down
3 changes: 2 additions & 1 deletion src/model/embedding.ts
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ export class EmbeddingModel<
}

protected async runModel(
params: Model.Embedding.Run & Model.Embedding.Config,
{ requestOpts, ...params }: Model.Embedding.Run & Model.Embedding.Config,
context: CustomCtx
): Promise<Model.Embedding.Response> {
const start = Date.now();
Expand All @@ -136,6 +136,7 @@ export class EmbeddingModel<
{
input: batch,
model: this.params.model,
requestOpts,
},
mergedContext
);
Expand Down
49 changes: 46 additions & 3 deletions src/model/model.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,23 @@ import type { Model } from './types.js';
class Test extends AbstractModel<
any,
{ model: string },
{ input: string },
{
input: string;
requestOpts?: {
signal?: AbortSignal;
};
},
{ output: string; cached: boolean }
> {
modelType = 'completion' as Model.Type;
modelProvider = 'custom' as Model.Provider;
protected async runModel(
params: { input: string },
async runModel(
params: {
input: string;
requestOpts?: {
signal?: AbortSignal;
};
},
context: Model.Ctx
): Promise<{ output: string; cached: boolean }> {
if (params.input === 'throw error') {
Expand Down Expand Up @@ -93,4 +103,37 @@ describe('AbstractModel', () => {
expect(completeEvent).toHaveBeenCalledTimes(2);
expect(completeEvent.mock.lastCall[0].cached).toBe(true);
});

it('can take in a signal', async () => {
const abortController = new AbortController();
const test = new Test({ params: { model: 'testmodel' }, client: false });
const result = await test.run(
{
input: 'fooin',
requestOpts: {
signal: abortController.signal,
},
},
{ userId: '123' }
);

const runModelSpy = vi.spyOn(test, 'runModel');
await test.run(
{ input: 'fooin', requestOpts: { signal: abortController.signal } },
{ userId: '123' }
);
expect(runModelSpy).toHaveBeenCalledWith(
expect.objectContaining({
requestOpts: expect.objectContaining({
signal: abortController.signal,
}),
}),
expect.objectContaining({ userId: '123' })
);

expect(result).toEqual({
output: 'fooin > AI response with context: {"userId":"123"}',
cached: false,
});
});
});
6 changes: 6 additions & 0 deletions src/model/model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,15 @@ export abstract class AbstractModel<
context?: CustomCtx
): Promise<MResponse> {
const start = Date.now();

const mergedContext = deepMerge(this.context, context);
const mergedParams = deepMerge(this.params, params);

// Handle signal separately since it's a instance of AbortSignal
if (params.requestOpts?.signal && mergedParams.requestOpts) {
mergedParams.requestOpts.signal = params.requestOpts.signal;
}

await Promise.allSettled(
this.events.onStart?.map((event) =>
Promise.resolve(
Expand Down
14 changes: 12 additions & 2 deletions src/model/sparse-vector.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import type { PartialDeep } from 'type-fest';
import pThrottle from 'p-throttle';
import pMap from 'p-map';
import { type Options as KYOptions } from 'ky';
import type { ModelArgs } from './model.js';
import { AbstractModel } from './model.js';
import type { Model } from './types.js';
Expand Down Expand Up @@ -54,7 +55,10 @@ export class SparseVectorModel<
}

protected async runModel(
params: Model.SparseVector.Run & Model.SparseVector.Config,
{
requestOpts,
...params
}: Model.SparseVector.Run & Model.SparseVector.Config,
context: CustomCtx
): Promise<Model.SparseVector.Response> {
const start = Date.now();
Expand Down Expand Up @@ -83,7 +87,13 @@ export class SparseVectorModel<
}

protected async runSingle(
params: { input: string; model: string },
params: {
input: string;
model: string;
requestOpts?: {
headers?: KYOptions['headers'];
};
},
context: CustomCtx
): Promise<{
vector: Model.SparseVector.Vector;
Expand Down
12 changes: 11 additions & 1 deletion src/model/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import type {
EmbeddingResponse,
OpenAIClient,
} from 'openai-fetch';
import { type Options as KYOptions } from 'ky';
import type { AbstractModel } from './model.js';
import type { ChatModel } from './chat.js';
import type { CompletionModel } from './completion.js';
Expand All @@ -30,7 +31,13 @@ export namespace Model {
export interface Config {
model: string;
}
export interface Run {}
export interface Run {
[key: string]: any;
requestOpts?: {
signal?: AbortSignal;
headers?: KYOptions['headers'];
};
}
export interface Params extends Config, Run {}
export interface Response {
cached: boolean;
Expand Down Expand Up @@ -228,6 +235,9 @@ export namespace Model {
params: {
input: string;
model: string;
requestOpts?: {
headers?: KYOptions['headers'];
};
},
serviceUrl: string
) => Promise<SparseVector.Vector>;
Expand Down
3 changes: 2 additions & 1 deletion src/utils/helpers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ import { deepmerge as deepmergeInit } from '@fastify/deepmerge';
export type Prettify<T> = { [K in keyof T]: T[K] } & {};

type DeepMerge = ReturnType<typeof deepmergeInit>;
const deepMergeImpl: DeepMerge = deepmergeInit();
export const deepMergeImpl: DeepMerge = deepmergeInit();

const deepMergeEventsImpl: DeepMerge = deepmergeInit({
// Note: this is not using a recursive deep merge since it isn't used for events.
mergeArray: () => (a: any[], b: any[]) => stableDedupe([...a, ...b]),
Expand Down

0 comments on commit fb20026

Please sign in to comment.