Skip to content

Commit d946c49

Browse files
committed
fix: ai.prompt API now also allows the model parameter to be a string with simply the model's name
as originally intended
1 parent 08f2693 commit d946c49

File tree

1 file changed

+27
-9
lines changed

1 file changed

+27
-9
lines changed

src/quickAddApi.ts

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ export class QuickAddApi {
107107
ai: {
108108
prompt: async (
109109
prompt: string,
110-
model: Model,
110+
model: Model | string,
111111
settings?: Partial<{
112112
variableName: string;
113113
shouldAssignVariables: boolean;
@@ -131,17 +131,29 @@ export class QuickAddApi {
131131
choiceExecutor
132132
).format;
133133

134-
const modelProvider = getModelProvider(model.name);
134+
let _model: Model;
135+
if (typeof model === "string") {
136+
const foundModel = getModelByName(model);
137+
if (!foundModel) {
138+
throw new Error(`Model '${model}' not found.`);
139+
}
140+
141+
_model = foundModel;
142+
} else {
143+
_model = model;
144+
}
145+
146+
const modelProvider = getModelProvider(_model.name);
135147

136148
if (!modelProvider) {
137149
throw new Error(
138-
`Model '${model.name}' not found in any provider`
150+
`Model '${_model.name}' not found in any provider`
139151
);
140152
}
141153

142154
const assistantRes = await Prompt(
143155
{
144-
model,
156+
model: _model,
145157
prompt,
146158
apiKey: modelProvider.apiKey,
147159
modelOptions: settings?.modelOptions ?? {},
@@ -173,7 +185,7 @@ export class QuickAddApi {
173185
chunkedPrompt: async (
174186
text: string,
175187
promptTemplate: string,
176-
model: string,
188+
model: Model | string,
177189
settings?: Partial<{
178190
variableName: string;
179191
shouldAssignVariables: boolean;
@@ -201,13 +213,19 @@ export class QuickAddApi {
201213
choiceExecutor
202214
).format;
203215

204-
const _model = getModelByName(model);
216+
let _model: Model;
217+
if (typeof model === "string") {
218+
const foundModel = getModelByName(model);
219+
if (!foundModel) {
220+
throw new Error(`Model ${model} not found.`);
221+
}
205222

206-
if (!_model) {
207-
throw new Error(`Model ${model} not found.`);
223+
_model = foundModel;
224+
} else {
225+
_model = model;
208226
}
209227

210-
const modelProvider = getModelProvider(model);
228+
const modelProvider = getModelProvider(_model.name);
211229

212230
if (!modelProvider) {
213231
throw new Error(

0 commit comments

Comments
 (0)