Skip to content

Commit

Permalink
Support custom context type with generic Ctx (#34)
Browse files Browse the repository at this point in the history
* Support custom context type with generic `Ctx`

This allows a `Ctx` type to be specified when creating a `Model`
instance so that the context values are type safe.

I used `any` instead of typing the `Datastore` types because we aren't
using them an they will either be deprecated or revamped in a future
release.

* Update tests to more realistic generic ctx usage

This shows a more realistic and clear way to specify the context generic
for a model.
  • Loading branch information
rileytomasek authored Apr 24, 2024
1 parent 599d79b commit 87174eb
Show file tree
Hide file tree
Showing 13 changed files with 122 additions and 77 deletions.
1 change: 1 addition & 0 deletions examples/caching.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ async function main() {
// Pinecone datastore with cache
const store = new PineconeDatastore<{ content: string }>({
contentKey: 'content',
// @ts-ignore
embeddingModel,
events: { onQueryComplete: [console.log] },
cache: new Map(),
Expand Down
3 changes: 1 addition & 2 deletions src/datastore/datastore.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import type { Model } from '../model/index.js';
import type { Datastore } from './types.js';
import {
type CacheKey,
Expand Down Expand Up @@ -40,7 +39,7 @@ export abstract class AbstractDatastore<
>;

public readonly contentKey: keyof DocMeta;
public readonly embeddingModel: Model.Embedding.Model;
public readonly embeddingModel: any;
public readonly namespace?: string;
public readonly events: Datastore.Events<DocMeta, Filter>;
public readonly context: Datastore.Ctx;
Expand Down
3 changes: 1 addition & 2 deletions src/datastore/hybrid-datastore.ts
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import type { Model } from '../model/index.js';
import { AbstractDatastore } from './datastore.js';
import type { Datastore } from './types.js';

export abstract class AbstractHybridDatastore<
DocMeta extends Datastore.BaseMeta,
Filter extends Datastore.BaseFilter<DocMeta>,
> extends AbstractDatastore<DocMeta, Filter> {
protected spladeModel: Model.SparseVector.Model;
protected spladeModel: any;

constructor(args: Datastore.OptsHybrid<DocMeta, Filter>) {
const { spladeModel, ...rest } = args;
Expand Down
2 changes: 2 additions & 0 deletions src/datastore/pinecone/datastore.ts
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ export class PineconeDatastore<
}

// Query Pinecone
// @ts-ignore
const response = await this.pinecone.query({
topK: query.topK ?? 10,
...(typeof query.minScore === 'number'
Expand All @@ -67,6 +68,7 @@ export class PineconeDatastore<

const queryResult: Datastore.QueryResult<DocMeta> = {
query: query.query,
// @ts-ignore
docs: response.matches,
};

Expand Down
4 changes: 2 additions & 2 deletions src/datastore/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ export namespace Datastore {
*/
contentKey: keyof DocMeta;
namespace?: string;
embeddingModel: Model.Embedding.Model;
embeddingModel: any;
/**
* A function that returns a cache key for the given params.
*
Expand Down Expand Up @@ -107,7 +107,7 @@ export namespace Datastore {
Filter extends BaseFilter<DocMeta>,
> extends Opts<DocMeta, Filter> {
/** Splade instance for creating sparse vectors */
spladeModel: Model.SparseVector.Model;
spladeModel: any;
}

/** The provider of the vector database. */
Expand Down
3 changes: 2 additions & 1 deletion src/model/chat.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,8 @@ describe('ChatModel', () => {
});

it('implements extend', async () => {
const chatModel = new ChatModel({
type ChatContext = { userId: string; cloned?: boolean };
const chatModel = new ChatModel<ChatContext>({
client: Client,
context: { userId: '123' },
params: { model: 'gpt-fake' },
Expand Down
24 changes: 14 additions & 10 deletions src/model/chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,32 +7,36 @@ import { createOpenAIClient } from './clients/openai.js';
import { AbstractModel } from './model.js';
import { deepMerge, mergeEvents, type Prettify } from '../utils/helpers.js';

export type ChatModelArgs = SetOptional<
export type ChatModelArgs<CustomCtx extends Model.Ctx> = SetOptional<
ModelArgs<
Model.Chat.Client,
Model.Chat.Config,
Model.Chat.Run,
Model.Chat.Response
Model.Chat.Response,
CustomCtx
>,
'client' | 'params'
>;

export type PartialChatModelArgs = Prettify<
PartialDeep<Pick<ChatModelArgs, 'params'>> &
Partial<Omit<ChatModelArgs, 'params'>>
export type PartialChatModelArgs<CustomCtx extends Model.Ctx> = Prettify<
PartialDeep<Pick<ChatModelArgs<Partial<CustomCtx>>, 'params'>> &
Partial<Omit<ChatModelArgs<Partial<CustomCtx>>, 'params'>>
>;

export class ChatModel extends AbstractModel<
export class ChatModel<
CustomCtx extends Model.Ctx = Model.Ctx,
> extends AbstractModel<
Model.Chat.Client,
Model.Chat.Config,
Model.Chat.Run,
Model.Chat.Response,
Model.Chat.ApiResponse
Model.Chat.ApiResponse,
CustomCtx
> {
modelType = 'chat' as const;
modelProvider = 'openai' as const;

constructor(args: ChatModelArgs = {}) {
constructor(args: ChatModelArgs<CustomCtx> = {}) {
const {
// Add a default client if none is provided
client = createOpenAIClient(),
Expand Down Expand Up @@ -62,7 +66,7 @@ export class ChatModel extends AbstractModel<

protected async runModel(
{ handleUpdate, ...params }: Model.Chat.Run & Model.Chat.Config,
context: Model.Ctx
context: CustomCtx
): Promise<Model.Chat.Response> {
const start = Date.now();

Expand Down Expand Up @@ -204,7 +208,7 @@ export class ChatModel extends AbstractModel<
}

/** Clone the model and merge/override the given properties. */
extend(args?: PartialChatModelArgs): this {
extend(args?: PartialChatModelArgs<CustomCtx>): this {
return new ChatModel({
cacheKey: this.cacheKey,
cache: this.cache,
Expand Down
24 changes: 14 additions & 10 deletions src/model/completion.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,32 +7,36 @@ import { createOpenAIClient } from './clients/openai.js';
import { AbstractModel } from './model.js';
import { deepMerge, mergeEvents, type Prettify } from '../index.js';

export type CompletionModelArgs = SetOptional<
export type CompletionModelArgs<CustomCtx extends Model.Ctx> = SetOptional<
ModelArgs<
Model.Completion.Client,
Model.Completion.Config,
Model.Completion.Run,
Model.Completion.Response
Model.Completion.Response,
CustomCtx
>,
'client' | 'params'
>;

export type PartialCompletionModelArgs = Prettify<
PartialDeep<Pick<CompletionModelArgs, 'params'>> &
Partial<Omit<CompletionModelArgs, 'params'>>
export type PartialCompletionModelArgs<CustomCtx extends Model.Ctx> = Prettify<
PartialDeep<Pick<CompletionModelArgs<Partial<CustomCtx>>, 'params'>> &
Partial<Omit<CompletionModelArgs<Partial<CustomCtx>>, 'params'>>
>;

export class CompletionModel extends AbstractModel<
export class CompletionModel<
CustomCtx extends Model.Ctx = Model.Ctx,
> extends AbstractModel<
Model.Completion.Client,
Model.Completion.Config,
Model.Completion.Run,
Model.Completion.Response,
Model.Completion.ApiResponse
Model.Completion.ApiResponse,
CustomCtx
> {
modelType = 'completion' as const;
modelProvider = 'openai' as const;

constructor(args?: CompletionModelArgs) {
constructor(args?: CompletionModelArgs<CustomCtx>) {
let { client, params, ...rest } = args ?? {};
// Add a default client if none is provided
client = client ?? createOpenAIClient();
Expand All @@ -43,7 +47,7 @@ export class CompletionModel extends AbstractModel<

protected async runModel(
params: Model.Completion.Run & Model.Completion.Config,
context: Model.Ctx
context: CustomCtx
): Promise<Model.Completion.Response> {
const start = Date.now();

Expand Down Expand Up @@ -77,7 +81,7 @@ export class CompletionModel extends AbstractModel<
}

/** Clone the model and merge/override the given properties. */
extend(args?: PartialCompletionModelArgs): this {
extend(args?: PartialCompletionModelArgs<CustomCtx>): this {
return new CompletionModel({
cacheKey: this.cacheKey,
cache: this.cache,
Expand Down
3 changes: 2 additions & 1 deletion src/model/embedding.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@ describe('EmbeddingModel', () => {
});

it('implements extend', async () => {
const model = new EmbeddingModel({
type EmbeddingContext = { userId: string; cloned?: boolean };
const model = new EmbeddingModel<EmbeddingContext>({
client: Client,
context: { userId: '123' },
params: { model: 'gpt-fake' },
Expand Down
34 changes: 19 additions & 15 deletions src/model/embedding.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,24 +9,25 @@ import { createOpenAIClient } from './clients/openai.js';
import { AbstractModel } from './model.js';
import { deepMerge, mergeEvents, type Prettify } from '../utils/helpers.js';

export type EmbeddingModelArgs = SetOptional<
export type EmbeddingModelArgs<CustomCtx extends Model.Ctx> = SetOptional<
ModelArgs<
Model.Embedding.Client,
Model.Embedding.Config,
Model.Embedding.Run,
Model.Embedding.Response
Model.Embedding.Response,
CustomCtx
>,
'client' | 'params'
>;

export type PartialSparseVectorModelArgs = Prettify<
PartialDeep<Pick<EmbeddingModelArgs, 'params'>> &
Partial<Omit<EmbeddingModelArgs, 'params'>>
export type PartialEmbeddingModelArgs<CustomCtx extends Model.Ctx> = Prettify<
PartialDeep<Pick<EmbeddingModelArgs<Partial<CustomCtx>>, 'params'>> &
Partial<Omit<EmbeddingModelArgs<Partial<CustomCtx>>, 'params'>>
>;

type BulkEmbedder = (
type BulkEmbedder<CustomCtx extends Model.Ctx> = (
params: Model.Embedding.Run & Model.Embedding.Config,
context: Model.Ctx
context: CustomCtx
) => Promise<Model.Embedding.Response>;

const DEFAULTS = {
Expand All @@ -39,18 +40,21 @@ const DEFAULTS = {
model: 'text-embedding-ada-002',
} as const;

export class EmbeddingModel extends AbstractModel<
export class EmbeddingModel<
CustomCtx extends Model.Ctx = Model.Ctx,
> extends AbstractModel<
Model.Embedding.Client,
Model.Embedding.Config,
Model.Embedding.Run,
Model.Embedding.Response,
Model.Embedding.ApiResponse
Model.Embedding.ApiResponse,
CustomCtx
> {
modelType = 'embedding' as const;
modelProvider = 'openai' as const;
throttledModel: BulkEmbedder;
throttledModel: BulkEmbedder<CustomCtx>;

constructor(args: EmbeddingModelArgs = {}) {
constructor(args: EmbeddingModelArgs<CustomCtx> = {}) {
const {
client = createOpenAIClient(),
params = { model: DEFAULTS.model },
Expand All @@ -66,7 +70,7 @@ export class EmbeddingModel extends AbstractModel<
this.throttledModel = pThrottle({ limit, interval })(
async (
params: Model.Embedding.Run & Model.Embedding.Config,
context: Model.Ctx
context: CustomCtx
) => {
const start = Date.now();

Expand Down Expand Up @@ -112,7 +116,7 @@ export class EmbeddingModel extends AbstractModel<

protected async runModel(
params: Model.Embedding.Run & Model.Embedding.Config,
context: Model.Ctx
context: CustomCtx
): Promise<Model.Embedding.Response> {
const start = Date.now();
// Batch the inputs for the requests
Expand Down Expand Up @@ -168,8 +172,8 @@ export class EmbeddingModel extends AbstractModel<
}

/** Clone the model and merge/override the given properties. */
extend(args?: PartialSparseVectorModelArgs): this {
return new EmbeddingModel({
extend(args?: PartialEmbeddingModelArgs<CustomCtx>): this {
return new EmbeddingModel<CustomCtx>({
cacheKey: this.cacheKey,
cache: this.cache,
client: this.client,
Expand Down
Loading

0 comments on commit 87174eb

Please sign in to comment.