Skip to content

Commit bad31ed

Browse files
committed
✨ View or set the default model
1 parent cf257eb commit bad31ed

File tree

2 files changed

+88
-5
lines changed

2 files changed

+88
-5
lines changed

write_the/cli/main.py

Lines changed: 53 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import typer
22
import os
3+
from write_the.models import models
34
from write_the.__about__ import __version__
45
from write_the.commands import write_the_tests, write_the_mkdocs, write_the_converters
56
from write_the.utils import list_python_files
@@ -13,6 +14,7 @@
1314
from functools import wraps
1415

1516
from .tasks import async_cli_task
17+
from .model import get_default_model, set_default_model
1618

1719

1820
class AsyncTyper(typer.Typer):
@@ -30,6 +32,12 @@ def sync_func(*_args, **_kwargs):
3032

3133
app = AsyncTyper()
3234

35+
def _get_model_callback(value: str):
36+
if value is None:
37+
return get_default_model()
38+
if value not in models:
39+
raise typer.BadParameter(f"Model '{value}' not found!")
40+
return value
3341

3442
def _print_version(ctx: typer.Context, value: bool):
3543
if value:
@@ -94,10 +102,11 @@ async def docs(
94102
False, "--batch/--no-batch", "-b", help="Send each node as a separate request."
95103
),
96104
model: str = typer.Option(
97-
"gpt-3.5-turbo-instruct",
105+
None,
98106
"--model",
99107
"-m",
100108
help="The model to use for generating the docstrings.",
109+
callback=_get_model_callback,
101110
),
102111
):
103112
"""
@@ -191,10 +200,11 @@ async def tests(
191200
help="Save empty files if a test creation fails. This will prevent write-the from regenerating failed test creations.",
192201
),
193202
model: str = typer.Option(
194-
"gpt-3.5-turbo-instruct",
203+
None,
195204
"--model",
196205
"-m",
197206
help="The model to use for generating the tests.",
207+
callback=_get_model_callback,
198208
),
199209
):
200210
"""
@@ -282,10 +292,11 @@ async def convert(
282292
False, "--pretty/--plain", "-p", help="Syntax highlight the output."
283293
),
284294
model: str = typer.Option(
285-
"gpt-3.5-turbo-instruct",
295+
None,
286296
"--model",
287297
"-m",
288298
help="The model to use for generating the tests.",
299+
callback=_get_model_callback,
289300
),
290301
):
291302
"""
@@ -337,8 +348,45 @@ async def convert(
337348

338349

339350
@app.command()
340-
def models():
341-
raise NotImplementedError()
351+
def model(
352+
desired_model: Optional[str] = typer.Argument(
353+
None,
354+
help="Set the default model.",
355+
),
356+
list: bool = typer.Option(
357+
False,
358+
"--list",
359+
"-l",
360+
help="List all available models.",
361+
),
362+
):
363+
"""
364+
View or set the default model.
365+
"""
366+
default_model = get_default_model()
367+
if list:
368+
from rich import table, print
369+
table_ = table.Table()
370+
table_.add_column("Name", justify="left", style="cyan")
371+
table_.add_column("Context", justify="left", style="magenta")
372+
table_.add_column("Default", justify="left", style="green")
373+
for name, model in models.items():
374+
if name == default_model:
375+
table_.add_row(name, f"{model['context_window']}", "✅")
376+
else:
377+
table_.add_row(name, f"{model['context_window']}", "")
378+
print(table_)
379+
return typer.Exit(0)
380+
if desired_model:
381+
if desired_model not in models:
382+
typer.secho(f"Model '{desired_model}' not found!", fg="red")
383+
return typer.Exit(1)
384+
set_default_model(desired_model)
385+
typer.echo(f"Default model: {desired_model}")
386+
return typer.Exit(0)
387+
typer.echo(default_model)
388+
389+
342390

343391

344392
@app.command()

write_the/cli/model.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import typer
2+
import json
3+
from pathlib import Path
4+
5+
6+
def get_config_path():
7+
APP_NAME = "write-the"
8+
app_dir = typer.get_app_dir(APP_NAME)
9+
config_path: Path = Path(app_dir) / "config.json"
10+
config_path.parent.mkdir(parents=True, exist_ok=True)
11+
if not config_path.exists():
12+
with open(config_path, "w") as f:
13+
json.dump({}, f)
14+
return config_path
15+
16+
def get_default_model():
17+
config_path = get_config_path()
18+
try:
19+
with open(config_path, "r") as f:
20+
config = json.load(f)
21+
return config["default_model"]
22+
except Exception:
23+
return "gpt-3.5-turbo-instruct"
24+
25+
def set_default_model(model: str):
26+
config_path = get_config_path()
27+
config = {}
28+
try:
29+
with open(config_path, "r") as f:
30+
config = json.load(f)
31+
except Exception:
32+
pass
33+
config["default_model"] = model
34+
with open(config_path, "w") as f:
35+
json.dump(config, f)

0 commit comments

Comments
 (0)