Skip to content

Commit

Permalink
working on getting types to align
Browse files Browse the repository at this point in the history
  • Loading branch information
cfortuner committed Oct 28, 2024
1 parent 42ba106 commit 554c398
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 20 deletions.
6 changes: 3 additions & 3 deletions src/model/chat.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ describe('ChatModel', () => {
expect(apiResponseEvent).toHaveBeenCalledWith({
timestamp: new Date().toISOString(),
modelType: 'chat',
modelProvider: 'openai',
modelProvider: 'spy', // because we mocked the client
params: {
model: 'gpt-fake',
messages: [{ role: 'user', content: 'content' }],
Expand All @@ -133,7 +133,7 @@ describe('ChatModel', () => {

it('implements extend', async () => {
type ChatContext = { userId: string; cloned?: boolean };
const chatModel = new ChatModel<ChatContext>({
const chatModel = new ChatModel<ChatContext, Model.Chat.Client, Model.Chat.Config<Model.Chat.Client>>({
client: Client,
context: { userId: '123' },
params: { model: 'gpt-fake' },
Expand Down Expand Up @@ -220,7 +220,7 @@ describe('ChatModel', () => {

// Extend the model and make another request
const secondChatModel = chatModel.extend({
params: { model: 'gpt-fake-extended' },
params: { model: 'gpt-fake-extended'},
context: { level: 2 },
events: { onComplete: [newOnComplete] },
});
Expand Down
15 changes: 8 additions & 7 deletions src/model/chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ export class ChatModel<
CustomCtx
> {
modelType = 'chat' as const;
modelProvider = 'anthropic' as const;
modelProvider = 'openai';

constructor(args: ChatModelArgs<CustomCtx, CustomClient, CustomConfig> = {}) {
const {
Expand Down Expand Up @@ -83,6 +83,8 @@ export class ChatModel<
),
...rest,
});

this.modelProvider = this.client.name;
}

protected async runModel<Cfg extends Model.Chat.Config<CustomClient>>(
Expand Down Expand Up @@ -274,17 +276,16 @@ export class ChatModel<
}

/** Clone the model and merge/override the given properties. */
extend(
args?: PartialChatModelArgs<CustomCtx, CustomClient, CustomConfig>
): this {
extend(args?: PartialChatModelArgs<CustomCtx, CustomClient, Model.Chat.Config<CustomClient>>): this {
const { client, params, ...rest } = args ?? {};
return new ChatModel({
cacheKey: this.cacheKey,
cache: this.cache,
client: this.client,
client: client ?? this.client,
debug: this.debug,
telemetry: this.telemetry,
...args,
params: deepMerge(this.params, args?.params),
...rest,
params: deepMerge(this.params, params),
context:
args?.context && Object.keys(args.context).length === 0
? undefined
Expand Down
39 changes: 29 additions & 10 deletions src/model/model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -83,37 +83,56 @@ export abstract class AbstractModel<
): Promise<MResponse>;

/** Clones the model, optionally modifying its config */
abstract extend<
Args extends PartialModelArgs<MClient, MConfig, MRun, MResponse, CustomCtx>,
>(args?: Args): this;
abstract extend(
args?: PartialModelArgs<
MClient,
Model.Base.Config<MClient>,
MRun,
// Note: this response type maybe change over time as the user
// extends the model
// it should be inferred from some types rather than set to MResponse
MResponse,
CustomCtx
>
): this;

public abstract readonly modelType: Model.Type;
public abstract readonly modelProvider: Model.Provider;
public abstract modelProvider: Model.Provider;

protected readonly cacheKey: CacheKey<MRun & MConfig, string>;
// the cache key can be updated in a call to .extend so it doesn't necessarily conform to MRun & MConfig
protected readonly cacheKey: CacheKey<Model.Base.Run & Model.Base.Config<Model.Base.Client>, string>;
protected readonly cache?: CacheStorage<string, MResponse>;
public readonly client: MClient;
public readonly context: CustomCtx;
public readonly debug: boolean;
public readonly params: MConfig & Partial<MRun>;
public readonly events: Model.Events<
MClient,
MRun & MConfig,
MResponse,
Model.Base.Client,
Model.Base.Run & Model.Base.Config<Model.Base.Client>,
Model.Base.Response,
CustomCtx,
AResponse
>;
public readonly tokenizer: Model.ITokenizer;
public readonly telemetry: Telemetry.Provider;

constructor(args: ModelArgs<MClient, MConfig, MRun, MResponse, CustomCtx>) {
this.cacheKey = args.cacheKey ?? defaultCacheKey;
this.cacheKey = (args.cacheKey ?? defaultCacheKey) as CacheKey<
Model.Base.Run & Model.Base.Config<Model.Base.Client>,
string
>;
this.cache = args.cache;
this.client = args.client;
this.context = args.context ?? ({} as CustomCtx);
this.debug = args.debug ?? false;
this.params = args.params;
this.events = args.events || {};
this.events = (args.events || {}) as Model.Events<
Model.Base.Client,
Model.Base.Run & Model.Base.Config<Model.Base.Client>,
Model.Base.Response,
CustomCtx,
AResponse
>;
this.tokenizer = createTokenizer(args.params.model);
this.telemetry = args.telemetry ?? DefaultTelemetry;
}
Expand Down

0 comments on commit 554c398

Please sign in to comment.