1
1
import typer
2
2
import os
3
+ from write_the .models import models
3
4
from write_the .__about__ import __version__
4
5
from write_the .commands import write_the_tests , write_the_mkdocs , write_the_converters
5
6
from write_the .utils import list_python_files
13
14
from functools import wraps
14
15
15
16
from .tasks import async_cli_task
17
+ from .model import get_default_model , set_default_model
16
18
17
19
18
20
class AsyncTyper (typer .Typer ):
@@ -30,6 +32,12 @@ def sync_func(*_args, **_kwargs):
30
32
31
33
app = AsyncTyper ()
32
34
35
+ def _get_model_callback (value : str ):
36
+ if value is None :
37
+ return get_default_model ()
38
+ if value not in models :
39
+ raise typer .BadParameter (f"Model '{ value } ' not found!" )
40
+ return value
33
41
34
42
def _print_version (ctx : typer .Context , value : bool ):
35
43
if value :
@@ -94,10 +102,11 @@ async def docs(
94
102
False , "--batch/--no-batch" , "-b" , help = "Send each node as a separate request."
95
103
),
96
104
model : str = typer .Option (
97
- "gpt-3.5-turbo-instruct" ,
105
+ None ,
98
106
"--model" ,
99
107
"-m" ,
100
108
help = "The model to use for generating the docstrings." ,
109
+ callback = _get_model_callback ,
101
110
),
102
111
):
103
112
"""
@@ -191,10 +200,11 @@ async def tests(
191
200
help = "Save empty files if a test creation fails. This will prevent write-the from regenerating failed test creations." ,
192
201
),
193
202
model : str = typer .Option (
194
- "gpt-3.5-turbo-instruct" ,
203
+ None ,
195
204
"--model" ,
196
205
"-m" ,
197
206
help = "The model to use for generating the tests." ,
207
+ callback = _get_model_callback ,
198
208
),
199
209
):
200
210
"""
@@ -282,10 +292,11 @@ async def convert(
282
292
False , "--pretty/--plain" , "-p" , help = "Syntax highlight the output."
283
293
),
284
294
model : str = typer .Option (
285
- "gpt-3.5-turbo-instruct" ,
295
+ None ,
286
296
"--model" ,
287
297
"-m" ,
288
298
help = "The model to use for generating the tests." ,
299
+ callback = _get_model_callback ,
289
300
),
290
301
):
291
302
"""
@@ -337,8 +348,45 @@ async def convert(
337
348
338
349
339
350
@app .command ()
340
- def models ():
341
- raise NotImplementedError ()
351
+ def model (
352
+ desired_model : Optional [str ] = typer .Argument (
353
+ None ,
354
+ help = "Set the default model." ,
355
+ ),
356
+ list : bool = typer .Option (
357
+ False ,
358
+ "--list" ,
359
+ "-l" ,
360
+ help = "List all available models." ,
361
+ ),
362
+ ):
363
+ """
364
+ View or set the default model.
365
+ """
366
+ default_model = get_default_model ()
367
+ if list :
368
+ from rich import table , print
369
+ table_ = table .Table ()
370
+ table_ .add_column ("Name" , justify = "left" , style = "cyan" )
371
+ table_ .add_column ("Context" , justify = "left" , style = "magenta" )
372
+ table_ .add_column ("Default" , justify = "left" , style = "green" )
373
+ for name , model in models .items ():
374
+ if name == default_model :
375
+ table_ .add_row (name , f"{ model ['context_window' ]} " , "✅" )
376
+ else :
377
+ table_ .add_row (name , f"{ model ['context_window' ]} " , "" )
378
+ print (table_ )
379
+ return typer .Exit (0 )
380
+ if desired_model :
381
+ if desired_model not in models :
382
+ typer .secho (f"Model '{ desired_model } ' not found!" , fg = "red" )
383
+ return typer .Exit (1 )
384
+ set_default_model (desired_model )
385
+ typer .echo (f"Default model: { desired_model } " )
386
+ return typer .Exit (0 )
387
+ typer .echo (default_model )
388
+
389
+
342
390
343
391
344
392
@app .command ()
0 commit comments