11import typer
22import os
3+ from write_the .models import models
34from write_the .__about__ import __version__
45from write_the .commands import write_the_tests , write_the_mkdocs , write_the_converters
56from write_the .utils import list_python_files
1314from functools import wraps
1415
1516from .tasks import async_cli_task
17+ from .model import get_default_model , set_default_model
1618
1719
1820class AsyncTyper (typer .Typer ):
@@ -30,6 +32,12 @@ def sync_func(*_args, **_kwargs):
3032
3133app = AsyncTyper ()
3234
35+ def _get_model_callback (value : str ):
36+ if value is None :
37+ return get_default_model ()
38+ if value not in models :
39+ raise typer .BadParameter (f"Model '{ value } ' not found!" )
40+ return value
3341
3442def _print_version (ctx : typer .Context , value : bool ):
3543 if value :
@@ -94,10 +102,11 @@ async def docs(
94102 False , "--batch/--no-batch" , "-b" , help = "Send each node as a separate request."
95103 ),
96104 model : str = typer .Option (
97- "gpt-3.5-turbo-instruct" ,
105+ None ,
98106 "--model" ,
99107 "-m" ,
100108 help = "The model to use for generating the docstrings." ,
109+ callback = _get_model_callback ,
101110 ),
102111):
103112 """
@@ -191,10 +200,11 @@ async def tests(
191200 help = "Save empty files if a test creation fails. This will prevent write-the from regenerating failed test creations." ,
192201 ),
193202 model : str = typer .Option (
194- "gpt-3.5-turbo-instruct" ,
203+ None ,
195204 "--model" ,
196205 "-m" ,
197206 help = "The model to use for generating the tests." ,
207+ callback = _get_model_callback ,
198208 ),
199209):
200210 """
@@ -282,10 +292,11 @@ async def convert(
282292 False , "--pretty/--plain" , "-p" , help = "Syntax highlight the output."
283293 ),
284294 model : str = typer .Option (
285- "gpt-3.5-turbo-instruct" ,
295+ None ,
286296 "--model" ,
287297 "-m" ,
288298 help = "The model to use for generating the tests." ,
299+ callback = _get_model_callback ,
289300 ),
290301):
291302 """
@@ -337,8 +348,45 @@ async def convert(
337348
338349
339350@app .command ()
340- def models ():
341- raise NotImplementedError ()
351+ def model (
352+ desired_model : Optional [str ] = typer .Argument (
353+ None ,
354+ help = "Set the default model." ,
355+ ),
356+ list : bool = typer .Option (
357+ False ,
358+ "--list" ,
359+ "-l" ,
360+ help = "List all available models." ,
361+ ),
362+ ):
363+ """
364+ View or set the default model.
365+ """
366+ default_model = get_default_model ()
367+ if list :
368+ from rich import table , print
369+ table_ = table .Table ()
370+ table_ .add_column ("Name" , justify = "left" , style = "cyan" )
371+ table_ .add_column ("Context" , justify = "left" , style = "magenta" )
372+ table_ .add_column ("Default" , justify = "left" , style = "green" )
373+ for name , model in models .items ():
374+ if name == default_model :
375+ table_ .add_row (name , f"{ model ['context_window' ]} " , "✅" )
376+ else :
377+ table_ .add_row (name , f"{ model ['context_window' ]} " , "" )
378+ print (table_ )
379+ return typer .Exit (0 )
380+ if desired_model :
381+ if desired_model not in models :
382+ typer .secho (f"Model '{ desired_model } ' not found!" , fg = "red" )
383+ return typer .Exit (1 )
384+ set_default_model (desired_model )
385+ typer .echo (f"Default model: { desired_model } " )
386+ return typer .Exit (0 )
387+ typer .echo (default_model )
388+
389+
342390
343391
344392@app .command ()
0 commit comments