Skip to content

Commit

Permalink
linting
Browse files Browse the repository at this point in the history
  • Loading branch information
rishsriv committed Jan 29, 2025
1 parent 3925bc1 commit c5f549d
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 6 deletions.
6 changes: 4 additions & 2 deletions runners/anthropic_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def generate_prompt(

with open(prompt_file, "r") as f:
prompt = f.read()

if table_metadata_string == "":
md = dbs[db_name]["table_metadata"]
pruned_metadata_ddl = to_prompt_schema(md, shuffle)
Expand Down Expand Up @@ -104,7 +104,9 @@ def process_row(row, model_name, args):
messages = [{"role": "user", "content": prompt}]
try:
response = chat_anthropic(messages=messages, model=model_name, temperature=0.0)
generated_query = response.content.split("```sql", 1)[-1].split("```", 1)[0].strip()
generated_query = (
response.content.split("```sql", 1)[-1].split("```", 1)[0].strip()
)
try:
generated_query = sqlparse.format(
generated_query, reindent=True, keyword_case="upper"
Expand Down
2 changes: 1 addition & 1 deletion runners/gemini_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def generate_prompt(

with open(prompt_file, "r") as f:
prompt = f.read()

if table_metadata_string == "":
md = dbs[db_name]["table_metadata"]
pruned_metadata_ddl = to_prompt_schema(md, shuffle)
Expand Down
8 changes: 5 additions & 3 deletions runners/openai_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def generate_prompt(

with open(prompt_file, "r") as f:
prompt = json.load(f)

if table_metadata_string == "":
md = dbs[db_name]["table_metadata"]
pruned_metadata_ddl = to_prompt_schema(md, shuffle)
Expand Down Expand Up @@ -69,7 +69,7 @@ def generate_prompt(
pruned_metadata_str = pruned_metadata_ddl + join_str
else:
pruned_metadata_str = table_metadata_string

if prompt[0]["role"] == "system":
prompt[0]["content"] = prompt[0]["content"].format(
db_type=db_type,
Expand Down Expand Up @@ -109,7 +109,9 @@ def process_row(row, model_name, args):
)
try:
response = chat_openai(messages=messages, model=model_name, temperature=0.0)
generated_query = response.content.split("```sql", 1)[-1].split("```", 1)[0].strip()
generated_query = (
response.content.split("```sql", 1)[-1].split("```", 1)[0].strip()
)
try:
generated_query = sqlparse.format(
generated_query, reindent=True, keyword_case="upper"
Expand Down

0 comments on commit c5f549d

Please sign in to comment.