Skip to content

Commit

Permalink
Make Model and Datastore immutable (#17)
Browse files Browse the repository at this point in the history
* feat: Make Model and Datastore immutable and switch clone to extend

* fix: use deepMerge for datastore.extend context

* feat: make event params Readonly

* Rebase and apply updated formatting

* Make `Ctx` ReadonlyDeep for extra safety

* Support clearing context/evens & add tests

---------

Co-authored-by: Riley Tomasek <[email protected]>
  • Loading branch information
transitive-bullshit and rileytomasek authored Apr 21, 2024
1 parent ab42016 commit 75ad931
Show file tree
Hide file tree
Showing 15 changed files with 427 additions and 273 deletions.
81 changes: 35 additions & 46 deletions src/datastore/datastore.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import type { Model } from '../model/index.js';
import type { Datastore } from './types.js';
import { deepMerge } from '../utils/helpers.js';
import {
type CacheKey,
type CacheStorage,
defaultCacheKey,
} from '../utils/cache.js';
import { mergeEvents } from '../utils/helpers.js';

export abstract class AbstractDatastore<
DocMeta extends Datastore.BaseMeta,
Expand All @@ -22,16 +22,29 @@ export abstract class AbstractDatastore<
abstract delete(docIds: string[]): Promise<void>;
abstract deleteAll(): Promise<void>;

abstract datastoreType: Datastore.Type;
abstract datastoreProvider: Datastore.Provider;

protected contentKey: keyof DocMeta;
protected embeddingModel: Model.Embedding.Model;
protected namespace?: string;
protected cacheKey: CacheKey<Datastore.Query<DocMeta, Filter>, string>;
protected cache?: CacheStorage<string, Datastore.QueryResult<DocMeta>>;
protected events: Datastore.Events<DocMeta, Filter>;
protected context: Datastore.Ctx;
/** Clones the datastore, optionally modifying it's config */
abstract extend<Args extends Datastore.Opts<DocMeta, Filter>>(
args?: Partial<Args>
): this;

public abstract readonly datastoreType: Datastore.Type;
public abstract readonly datastoreProvider: Datastore.Provider;

protected readonly cacheKey: CacheKey<
Datastore.Query<DocMeta, Filter>,
string
>;
protected readonly cache?: CacheStorage<
string,
Datastore.QueryResult<DocMeta>
>;

public readonly contentKey: keyof DocMeta;
public readonly embeddingModel: Model.Embedding.Model;
public readonly namespace?: string;
public readonly events: Datastore.Events<DocMeta, Filter>;
public readonly context: Datastore.Ctx;
public readonly debug: boolean;

constructor(args: Datastore.Opts<DocMeta, Filter>) {
this.namespace = args.namespace;
Expand All @@ -40,14 +53,17 @@ export abstract class AbstractDatastore<
this.cacheKey = args.cacheKey ?? defaultCacheKey;
this.cache = args.cache;
this.context = args.context ?? {};
this.events = args.events ?? {};
if (args.debug) {
this.addEvents({
onQueryStart: [console.debug],
onQueryComplete: [console.debug],
onError: [console.error],
});
}
this.debug = args.debug ?? false;
this.events = mergeEvents(
args.events,
args.debug
? {
onQueryStart: [console.debug],
onQueryComplete: [console.debug],
onError: [console.error],
}
: {}
);
}

async query(
Expand Down Expand Up @@ -143,31 +159,4 @@ export abstract class AbstractDatastore<
throw error;
}
}

/** Get the current event handlers */
getEvents() {
return this.events;
}

/** Add event handlers to the datastore. */
addEvents(events: typeof this.events): this {
this.events = this.mergeEvents(this.events, events);
return this;
}

/**
* Set the event handlers to a new set of events. Removes all existing event handlers.
* Set to empty object `{}` to remove all events.
*/
setEvents(events: typeof this.events): this {
this.events = events;
return this;
}

protected mergeEvents(
existingEvents: typeof this.events,
newEvents: typeof this.events
): typeof this.events {
return deepMerge(existingEvents, newEvents);
}
}
32 changes: 25 additions & 7 deletions src/datastore/pinecone/datastore.ts
Original file line number Diff line number Diff line change
@@ -1,23 +1,25 @@
import { deepMerge, mergeEvents } from '../../utils/helpers.js';
import { AbstractDatastore } from '../datastore.js';
import type { Datastore, Prettify } from '../types.js';
import type { PineconeClient } from './client.js';
import { createPineconeClient } from './client.js';
import type { Pinecone } from './types.js';

export type PineconeDatastoreArgs<DocMeta extends Datastore.BaseMeta> =
Prettify<
Datastore.Opts<DocMeta, Pinecone.QueryFilter<DocMeta>> & {
pinecone?: PineconeClient<DocMeta>;
}
>;

export class PineconeDatastore<
DocMeta extends Datastore.BaseMeta,
> extends AbstractDatastore<DocMeta, Pinecone.QueryFilter<DocMeta>> {
datastoreType = 'embedding' as const;
datastoreProvider = 'pinecone' as const;
private readonly pinecone: PineconeClient<DocMeta>;

constructor(
args: Prettify<
Datastore.Opts<DocMeta, Pinecone.QueryFilter<DocMeta>> & {
pinecone?: PineconeClient<DocMeta>;
}
>
) {
constructor(args: PineconeDatastoreArgs<DocMeta>) {
const { pinecone, ...rest } = args;
super(rest);
this.pinecone =
Expand Down Expand Up @@ -157,4 +159,20 @@ export class PineconeDatastore<
async deleteAll(): Promise<void> {
return this.pinecone.delete({ deleteAll: true });
}

/** Clones the datastore, optionally modifying it's config */
extend(args?: Partial<PineconeDatastoreArgs<DocMeta>>): this {
return new PineconeDatastore({
contentKey: this.contentKey,
namespace: this.namespace,
embeddingModel: this.embeddingModel,
cacheKey: this.cacheKey,
cache: this.cache,
debug: this.debug,
pinecone: this.pinecone,
...args,
context: deepMerge(this.context, args?.context),
events: mergeEvents(this.events, args?.events),
}) as unknown as this;
}
}
33 changes: 26 additions & 7 deletions src/datastore/pinecone/hybrid-datastore.ts
Original file line number Diff line number Diff line change
@@ -1,24 +1,26 @@
import { deepMerge, mergeEvents } from '../../utils/helpers.js';
import type { Model } from '../../model/index.js';
import { AbstractHybridDatastore } from '../hybrid-datastore.js';
import type { Datastore, Prettify } from '../types.js';
import type { PineconeClient } from './client.js';
import { createPineconeClient } from './client.js';
import type { Pinecone } from './types.js';

export type PineconeHybridDatastoreArgs<DocMeta extends Datastore.BaseMeta> =
Prettify<
Datastore.OptsHybrid<DocMeta, Pinecone.QueryFilter<DocMeta>> & {
pinecone?: PineconeClient<DocMeta>;
}
>;

export class PineconeHybridDatastore<
DocMeta extends Datastore.BaseMeta,
> extends AbstractHybridDatastore<DocMeta, Pinecone.QueryFilter<DocMeta>> {
datastoreType = 'hybrid' as const;
datastoreProvider = 'pinecone' as const;
private readonly pinecone: PineconeClient<DocMeta>;

constructor(
args: Prettify<
Datastore.OptsHybrid<DocMeta, Pinecone.QueryFilter<DocMeta>> & {
pinecone?: PineconeClient<DocMeta>;
}
>
) {
constructor(args: PineconeHybridDatastoreArgs<DocMeta>) {
const { pinecone, ...rest } = args;
super(rest);
this.pinecone =
Expand Down Expand Up @@ -180,4 +182,21 @@ export class PineconeHybridDatastore<
async deleteAll(): Promise<void> {
return this.pinecone.delete({ deleteAll: true });
}

/** Clones the datastore, optionally modifying it's config */
extend(args?: Partial<PineconeHybridDatastoreArgs<DocMeta>>): this {
return new PineconeHybridDatastore({
contentKey: this.contentKey,
namespace: this.namespace,
embeddingModel: this.embeddingModel,
cacheKey: this.cacheKey,
cache: this.cache,
debug: this.debug,
pinecone: this.pinecone,
spladeModel: this.spladeModel,
...args,
context: deepMerge(this.context, args?.context),
events: mergeEvents(this.events, args?.events),
}) as unknown as this;
}
}
113 changes: 108 additions & 5 deletions src/model/chat.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -90,26 +90,26 @@ describe('ChatModel', () => {
});
});

it('implements clone', async () => {
it('implements extend', async () => {
const chatModel = new ChatModel({
client: Client,
context: { userId: '123' },
params: { model: 'gpt-fake' },
events: { onApiResponse: [() => {}] },
});
const clonedModel = chatModel.clone({
const clonedModel = chatModel.extend({
context: { cloned: true },
params: { model: 'gpt-fake-cloned' },
events: { onApiResponse: [() => {}] },
});
expect(clonedModel.getContext()).toEqual({
expect(clonedModel.context).toEqual({
userId: '123',
cloned: true,
});
expect(clonedModel.getParams()).toEqual({
expect(clonedModel.params).toEqual({
model: 'gpt-fake-cloned',
});
expect(clonedModel.getEvents()?.onApiResponse?.length).toBe(2);
expect(clonedModel.events.onApiResponse?.length).toBe(2);
});

it('can cache responses', async () => {
Expand Down Expand Up @@ -141,4 +141,107 @@ describe('ChatModel', () => {
expect(completeEvent).toHaveBeenCalledTimes(2);
expect(Client.createChatCompletion).toHaveBeenCalledOnce();
});

it('can be extended', async () => {
// Create a mocked cache (Map) to ensure that the cache is passed down
const cache = new Map();
const getSpy = vi.spyOn(cache, 'get');
const onComplete1 = vi.fn();
const onComplete2 = vi.fn();
const onError = vi.fn();
const onApiResponse = vi.fn();

// Create a ChatModel instance and make a request
const chatModel = new ChatModel({
cache,
client: Client,
params: { model: 'gpt-fake' },
context: { level: 1, userId: '123' },
events: {
onApiResponse: [onApiResponse],
onComplete: [onComplete1, onComplete2],
onError: [onError],
},
});
await chatModel.run({ messages: [{ role: 'user', content: 'content2' }] });

// Ensure the base model works as expected
expect(getSpy).toHaveBeenCalledOnce();
expect(onApiResponse).toHaveBeenCalledOnce();
expect(onComplete1).toHaveBeenCalledOnce();
expect(onComplete1).toHaveBeenCalledOnce();
expect(onError).not.toHaveBeenCalled();
expect(chatModel.params.model).toBe('gpt-fake');
expect(chatModel.context).toEqual({ level: 1, userId: '123' });

const newOnComplete = vi.fn();

// Extend the model and make another request
const secondChatModel = chatModel.extend({
params: { model: 'gpt-fake-extended' },
context: { level: 2 },
events: { onComplete: [newOnComplete] },
});
await secondChatModel.run({
messages: [{ role: 'user', content: 'content' }],
});

// Ensure the old model is unchanged
expect(chatModel.params.model).toBe('gpt-fake');
expect(chatModel.context).toEqual({ level: 1, userId: '123' });

// Ensure the new model works as expected
expect(onApiResponse).toHaveBeenCalledTimes(2);
expect(onComplete1).toHaveBeenCalledTimes(2); // these are kept when extending
expect(onComplete2).toHaveBeenCalledTimes(2); // these are kept when extending
expect(newOnComplete).toHaveBeenCalledOnce();
expect(onError).not.toHaveBeenCalled();
expect(getSpy).toHaveBeenCalledTimes(2);
expect(secondChatModel.params.model).toBe('gpt-fake-extended');
expect(secondChatModel.context).toEqual({ level: 2, userId: '123' });

const cache2 = new Map();
const getSpy2 = vi.spyOn(cache2, 'get');

// Extend again to clear properties
const thirdChatModel = secondChatModel.extend({
cache: cache2,
params: { model: 'gpt-fake-extended-2' },
context: {},
events: {},
});
await thirdChatModel.run({
messages: [{ role: 'user', content: 'content3' }],
});

expect(thirdChatModel.params).toEqual({ model: 'gpt-fake-extended-2' });
expect(thirdChatModel.context).toEqual({});
expect(thirdChatModel.events).toEqual({});

expect(getSpy2).toHaveBeenCalledOnce();
expect(getSpy).toHaveBeenCalledTimes(2);
expect(newOnComplete).toHaveBeenCalledOnce();
expect(onError).not.toHaveBeenCalled();
expect(getSpy).toHaveBeenCalledTimes(2);
expect(secondChatModel.params.model).toBe('gpt-fake-extended');
});

it(`mutating event data doesn't impact the models context and params`, async () => {
const onComplete = vi.fn().mockImplementation((e: any) => {
e.userId = 'mutated';
e.params.model = 'mutated';
});

const chatModel = new ChatModel({
client: Client,
params: { model: 'gpt-fake' },
events: { onComplete: [onComplete] },
context: { userId: '123' },
});
await chatModel.run({ messages: [{ role: 'user', content: 'content2' }] });

expect(onComplete).toHaveBeenCalledOnce();
expect(chatModel.context.userId).toBe('123');
expect(chatModel.params.model).toBe('gpt-fake');
});
});
Loading

0 comments on commit 75ad931

Please sign in to comment.