diff --git a/src/model/chat.test.ts b/src/model/chat.test.ts index 0af890f..d7c4340 100644 --- a/src/model/chat.test.ts +++ b/src/model/chat.test.ts @@ -120,7 +120,7 @@ describe('ChatModel', () => { expect(apiResponseEvent).toHaveBeenCalledWith({ timestamp: new Date().toISOString(), modelType: 'chat', - modelProvider: 'openai', + modelProvider: 'spy', // because we mocked the client params: { model: 'gpt-fake', messages: [{ role: 'user', content: 'content' }], @@ -133,7 +133,7 @@ describe('ChatModel', () => { it('implements extend', async () => { type ChatContext = { userId: string; cloned?: boolean }; - const chatModel = new ChatModel({ + const chatModel = new ChatModel>({ client: Client, context: { userId: '123' }, params: { model: 'gpt-fake' }, @@ -220,7 +220,7 @@ describe('ChatModel', () => { // Extend the model and make another request const secondChatModel = chatModel.extend({ - params: { model: 'gpt-fake-extended' }, + params: { model: 'gpt-fake-extended'}, context: { level: 2 }, events: { onComplete: [newOnComplete] }, }); diff --git a/src/model/chat.ts b/src/model/chat.ts index 27b5b11..22f4fb5 100644 --- a/src/model/chat.ts +++ b/src/model/chat.ts @@ -55,7 +55,7 @@ export class ChatModel< CustomCtx > { modelType = 'chat' as const; - modelProvider = 'anthropic' as const; + modelProvider = 'openai'; constructor(args: ChatModelArgs = {}) { const { @@ -83,6 +83,8 @@ export class ChatModel< ), ...rest, }); + + this.modelProvider = this.client.name; } protected async runModel>( @@ -274,17 +276,16 @@ export class ChatModel< } /** Clone the model and merge/override the given properties. */ - extend( - args?: PartialChatModelArgs - ): this { + extend(args?: PartialChatModelArgs>): this { + const { client, params, ...rest } = args ?? {}; return new ChatModel({ cacheKey: this.cacheKey, cache: this.cache, - client: this.client, + client: client ?? this.client, debug: this.debug, telemetry: this.telemetry, - ...args, - params: deepMerge(this.params, args?.params), + ...rest, + params: deepMerge(this.params, params), context: args?.context && Object.keys(args.context).length === 0 ? undefined diff --git a/src/model/model.ts b/src/model/model.ts index 1d39bdf..ec8b684 100644 --- a/src/model/model.ts +++ b/src/model/model.ts @@ -83,23 +83,33 @@ export abstract class AbstractModel< ): Promise; /** Clones the model, optionally modifying its config */ - abstract extend< - Args extends PartialModelArgs, - >(args?: Args): this; + abstract extend( + args?: PartialModelArgs< + MClient, + Model.Base.Config, + MRun, + // Note: this response type maybe change over time as the user + // extends the model + // it should be inferred from some types rather than set to MResponse + MResponse, + CustomCtx + > + ): this; public abstract readonly modelType: Model.Type; - public abstract readonly modelProvider: Model.Provider; + public abstract modelProvider: Model.Provider; - protected readonly cacheKey: CacheKey; + // the cache key can be updated in a call to .extend so it doesn't necessarily conform to MRun & MConfig + protected readonly cacheKey: CacheKey, string>; protected readonly cache?: CacheStorage; public readonly client: MClient; public readonly context: CustomCtx; public readonly debug: boolean; public readonly params: MConfig & Partial; public readonly events: Model.Events< - MClient, - MRun & MConfig, - MResponse, + Model.Base.Client, + Model.Base.Run & Model.Base.Config, + Model.Base.Response, CustomCtx, AResponse >; @@ -107,13 +117,22 @@ export abstract class AbstractModel< public readonly telemetry: Telemetry.Provider; constructor(args: ModelArgs) { - this.cacheKey = args.cacheKey ?? defaultCacheKey; + this.cacheKey = (args.cacheKey ?? defaultCacheKey) as CacheKey< + Model.Base.Run & Model.Base.Config, + string + >; this.cache = args.cache; this.client = args.client; this.context = args.context ?? ({} as CustomCtx); this.debug = args.debug ?? false; this.params = args.params; - this.events = args.events || {}; + this.events = (args.events || {}) as Model.Events< + Model.Base.Client, + Model.Base.Run & Model.Base.Config, + Model.Base.Response, + CustomCtx, + AResponse + >; this.tokenizer = createTokenizer(args.params.model); this.telemetry = args.telemetry ?? DefaultTelemetry; }