Skip to content

Commit

Permalink
feat: sync local changes
Browse files Browse the repository at this point in the history
  • Loading branch information
codestory committed Jan 29, 2025
1 parent 47e62ab commit 96bf331
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 68 deletions.
151 changes: 96 additions & 55 deletions runners/anthropic_runner.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import os
from time import time
from concurrent.futures import ThreadPoolExecutor, as_completed
import os

import pandas as pd
import sqlparse
from tqdm import tqdm

from runners.base_runner import run_eval_in_threadpool
from eval.eval import compare_query_results
from utils.creds import db_creds_all
from utils.dialects import convert_postgres_ddl_to_dialect
Expand All @@ -29,61 +30,79 @@ def generate_prompt(
public_data=True,
shuffle=True,
):
if public_data:
from defog_data.metadata import dbs
import defog_data.supplementary as sup
else:
from defog_data_private.metadata import dbs
import defog_data_private.supplementary as sup

with open(prompt_file, "r") as f:
prompt = f.read()

if table_metadata_string == "":
md = dbs[db_name]["table_metadata"]
pruned_metadata_ddl = to_prompt_schema(md, shuffle)
pruned_metadata_ddl = convert_postgres_ddl_to_dialect(
postgres_ddl=pruned_metadata_ddl,
to_dialect=db_type,
db_name=db_name,
)
column_join = sup.columns_join.get(db_name, {})
join_list = []
for values in column_join.values():
if isinstance(values[0], tuple):
for col_pair in values:
col_1, col_2 = col_pair
try:
if public_data:
from defog_data.metadata import dbs
import defog_data.supplementary as sup
else:
from defog_data_private.metadata import dbs
import defog_data_private.supplementary as sup

with open(prompt_file, "r") as f:
prompt = f.read()

if table_metadata_string == "":
md = dbs[db_name]["table_metadata"]
pruned_metadata_ddl = to_prompt_schema(md, shuffle)
pruned_metadata_ddl = convert_postgres_ddl_to_dialect(
postgres_ddl=pruned_metadata_ddl,
to_dialect=db_type,
db_name=db_name,
)
column_join = sup.columns_join.get(db_name, {})
join_list = []
for values in column_join.values():
if isinstance(values[0], tuple):
for col_pair in values:
col_1, col_2 = col_pair
join_str = f"{col_1} can be joined with {col_2}"
if join_str not in join_list:
join_list.append(join_str)
else:
col_1, col_2 = values[0]
join_str = f"{col_1} can be joined with {col_2}"
if join_str not in join_list:
join_list.append(join_str)
if len(join_list) > 0:
join_str = "\nHere is a list of joinable columns:\n" + "\n".join(join_list)
else:
col_1, col_2 = values[0]
join_str = f"{col_1} can be joined with {col_2}"
if join_str not in join_list:
join_list.append(join_str)
if len(join_list) > 0:
join_str = "\nHere is a list of joinable columns:\n" + "\n".join(join_list)
join_str = ""
pruned_metadata_str = pruned_metadata_ddl + join_str
else:
join_str = ""
pruned_metadata_str = pruned_metadata_ddl + join_str
else:
pruned_metadata_str = table_metadata_string

prompt = prompt.format(
user_question=question,
db_type=db_type,
instructions=instructions,
table_metadata_string=pruned_metadata_str,
k_shot_prompt=k_shot_prompt,
glossary=glossary,
prev_invalid_sql=prev_invalid_sql,
prev_error_msg=prev_error_msg,
)
return prompt
pruned_metadata_str = table_metadata_string

prompt = prompt.format(
user_question=question,
db_type=db_type,
instructions=instructions,
table_metadata_string=pruned_metadata_str,
k_shot_prompt=k_shot_prompt,
glossary=glossary,
prev_invalid_sql=prev_invalid_sql,
prev_error_msg=prev_error_msg,
)
return prompt
except ImportError:
# When defog_data is not available, just format with the existing table_metadata_string
with open(prompt_file, "r") as f:
prompt = f.read()

prompt = prompt.format(
user_question=question,
db_type=db_type,
instructions=instructions,
table_metadata_string=table_metadata_string,
k_shot_prompt=k_shot_prompt,
glossary=glossary,
prev_invalid_sql=prev_invalid_sql,
prev_error_msg=prev_error_msg,
)
return prompt


def process_row(row, model_name, args):
start_time = time()
result_row = row.copy() # Create a copy of the original row to maintain all data
prompt = generate_prompt(
prompt_file=args.prompt_file[0],
question=row["question"],
Expand All @@ -110,21 +129,43 @@ def process_row(row, model_name, args):
)
except:
pass
return {
"query": generated_query,
result_row.update({
"generated_query": generated_query,
"reason": "",
"err": "",
"error_msg": "",
"latency_seconds": time() - start_time,
"tokens_used": response.input_tokens + response.output_tokens,
}
})

# Verify the generated query
try:
exact_match, correct = compare_query_results(
query_gold=row["query"],
query_gen=generated_query,
db_name=row["db_name"],
db_type=args.db_type,
db_creds=db_creds_all[args.db_type],
question=row["question"],
query_category=row["query_category"],
decimal_points=args.decimal_points if hasattr(args, 'decimal_points') else 2,
)
result_row["exact_match"] = int(exact_match)
result_row["correct"] = int(correct)
result_row["is_correct"] = int(correct)
except Exception as e:
result_row["error_db_exec"] = 1
result_row["error_msg"] = f"EXECUTION ERROR: {str(e)}"
result_row["is_correct"] = 0
except Exception as e:
return {
"query": "",
result_row.update({
"generated_query": "",
"reason": "",
"err": f"GENERATION ERROR: {str(e)}",
"error_msg": f"GENERATION ERROR: {str(e)}",
"latency_seconds": time() - start_time,
"tokens_used": 0,
}
"is_correct": 0,
})
return result_row


def run_anthropic_eval(args):
Expand Down
30 changes: 17 additions & 13 deletions utils/questions.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,23 @@
from typing import Optional
import pandas as pd
from typing import Optional


def get_table_aliases(db_name: str) -> str:
from defog_data.metadata import dbs
from utils.aliases import generate_aliases

metadata = dbs[db_name]["table_metadata"]
table_names = list(metadata.keys())
aliases = generate_aliases(table_names)
aliases_instruction = (
"Use the following table aliases when referencing tables in the query:\n"
+ aliases
)
return aliases_instruction
try:
from defog_data.metadata import dbs
from utils.aliases import generate_aliases

metadata = dbs[db_name]["table_metadata"]
table_names = list(metadata.keys())
aliases = generate_aliases(table_names)
aliases_instruction = (
"Use the following table aliases when referencing tables in the query:\n"
+ aliases
)
return aliases_instruction
except ImportError:
# Return empty string when defog_data is not available
return ""


def prepare_questions_df(
Expand Down Expand Up @@ -142,4 +146,4 @@ def prepare_questions_df(
elif cot_table_alias == "pregen":
question_query_df["cot_pregen"] = True

return question_query_df
return question_query_df

0 comments on commit 96bf331

Please sign in to comment.