Skip to content

Commit

Permalink
✨ View or set the default model
Browse files Browse the repository at this point in the history
  • Loading branch information
Wytamma committed Mar 15, 2024
1 parent cf257eb commit bad31ed
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 5 deletions.
58 changes: 53 additions & 5 deletions write_the/cli/main.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import typer
import os
from write_the.models import models
from write_the.__about__ import __version__
from write_the.commands import write_the_tests, write_the_mkdocs, write_the_converters
from write_the.utils import list_python_files
Expand All @@ -13,6 +14,7 @@
from functools import wraps

from .tasks import async_cli_task
from .model import get_default_model, set_default_model


class AsyncTyper(typer.Typer):
Expand All @@ -30,6 +32,12 @@ def sync_func(*_args, **_kwargs):

app = AsyncTyper()

def _get_model_callback(value: str):
if value is None:
return get_default_model()
if value not in models:
raise typer.BadParameter(f"Model '{value}' not found!")
return value

def _print_version(ctx: typer.Context, value: bool):
if value:
Expand Down Expand Up @@ -94,10 +102,11 @@ async def docs(
False, "--batch/--no-batch", "-b", help="Send each node as a separate request."
),
model: str = typer.Option(
"gpt-3.5-turbo-instruct",
None,
"--model",
"-m",
help="The model to use for generating the docstrings.",
callback=_get_model_callback,
),
):
"""
Expand Down Expand Up @@ -191,10 +200,11 @@ async def tests(
help="Save empty files if a test creation fails. This will prevent write-the from regenerating failed test creations.",
),
model: str = typer.Option(
"gpt-3.5-turbo-instruct",
None,
"--model",
"-m",
help="The model to use for generating the tests.",
callback=_get_model_callback,
),
):
"""
Expand Down Expand Up @@ -282,10 +292,11 @@ async def convert(
False, "--pretty/--plain", "-p", help="Syntax highlight the output."
),
model: str = typer.Option(
"gpt-3.5-turbo-instruct",
None,
"--model",
"-m",
help="The model to use for generating the tests.",
callback=_get_model_callback,
),
):
"""
Expand Down Expand Up @@ -337,8 +348,45 @@ async def convert(


@app.command()
def models():
raise NotImplementedError()
def model(
desired_model: Optional[str] = typer.Argument(
None,
help="Set the default model.",
),
list: bool = typer.Option(
False,
"--list",
"-l",
help="List all available models.",
),
):
"""
View or set the default model.
"""
default_model = get_default_model()
if list:
from rich import table, print
table_ = table.Table()
table_.add_column("Name", justify="left", style="cyan")
table_.add_column("Context", justify="left", style="magenta")
table_.add_column("Default", justify="left", style="green")
for name, model in models.items():
if name == default_model:
table_.add_row(name, f"{model['context_window']}", "✅")
else:
table_.add_row(name, f"{model['context_window']}", "")
print(table_)
return typer.Exit(0)
if desired_model:
if desired_model not in models:
typer.secho(f"Model '{desired_model}' not found!", fg="red")
return typer.Exit(1)
set_default_model(desired_model)
typer.echo(f"Default model: {desired_model}")
return typer.Exit(0)
typer.echo(default_model)




@app.command()
Expand Down
35 changes: 35 additions & 0 deletions write_the/cli/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import typer
import json
from pathlib import Path


def get_config_path():
APP_NAME = "write-the"
app_dir = typer.get_app_dir(APP_NAME)
config_path: Path = Path(app_dir) / "config.json"
config_path.parent.mkdir(parents=True, exist_ok=True)
if not config_path.exists():
with open(config_path, "w") as f:
json.dump({}, f)
return config_path

def get_default_model():
config_path = get_config_path()
try:
with open(config_path, "r") as f:
config = json.load(f)
return config["default_model"]
except Exception:
return "gpt-3.5-turbo-instruct"

def set_default_model(model: str):
config_path = get_config_path()
config = {}
try:
with open(config_path, "r") as f:
config = json.load(f)
except Exception:
pass
config["default_model"] = model
with open(config_path, "w") as f:
json.dump(config, f)

0 comments on commit bad31ed

Please sign in to comment.