diff --git a/runners/anthropic_runner.py b/runners/anthropic_runner.py index 264f2e1..77b509b 100644 --- a/runners/anthropic_runner.py +++ b/runners/anthropic_runner.py @@ -1,11 +1,12 @@ -import os from time import time from concurrent.futures import ThreadPoolExecutor, as_completed +import os import pandas as pd import sqlparse from tqdm import tqdm +from runners.base_runner import run_eval_in_threadpool from eval.eval import compare_query_results from utils.creds import db_creds_all from utils.dialects import convert_postgres_ddl_to_dialect @@ -29,61 +30,79 @@ def generate_prompt( public_data=True, shuffle=True, ): - if public_data: - from defog_data.metadata import dbs - import defog_data.supplementary as sup - else: - from defog_data_private.metadata import dbs - import defog_data_private.supplementary as sup - - with open(prompt_file, "r") as f: - prompt = f.read() - - if table_metadata_string == "": - md = dbs[db_name]["table_metadata"] - pruned_metadata_ddl = to_prompt_schema(md, shuffle) - pruned_metadata_ddl = convert_postgres_ddl_to_dialect( - postgres_ddl=pruned_metadata_ddl, - to_dialect=db_type, - db_name=db_name, - ) - column_join = sup.columns_join.get(db_name, {}) - join_list = [] - for values in column_join.values(): - if isinstance(values[0], tuple): - for col_pair in values: - col_1, col_2 = col_pair + try: + if public_data: + from defog_data.metadata import dbs + import defog_data.supplementary as sup + else: + from defog_data_private.metadata import dbs + import defog_data_private.supplementary as sup + + with open(prompt_file, "r") as f: + prompt = f.read() + + if table_metadata_string == "": + md = dbs[db_name]["table_metadata"] + pruned_metadata_ddl = to_prompt_schema(md, shuffle) + pruned_metadata_ddl = convert_postgres_ddl_to_dialect( + postgres_ddl=pruned_metadata_ddl, + to_dialect=db_type, + db_name=db_name, + ) + column_join = sup.columns_join.get(db_name, {}) + join_list = [] + for values in column_join.values(): + if isinstance(values[0], tuple): + for col_pair in values: + col_1, col_2 = col_pair + join_str = f"{col_1} can be joined with {col_2}" + if join_str not in join_list: + join_list.append(join_str) + else: + col_1, col_2 = values[0] join_str = f"{col_1} can be joined with {col_2}" if join_str not in join_list: join_list.append(join_str) + if len(join_list) > 0: + join_str = "\nHere is a list of joinable columns:\n" + "\n".join(join_list) else: - col_1, col_2 = values[0] - join_str = f"{col_1} can be joined with {col_2}" - if join_str not in join_list: - join_list.append(join_str) - if len(join_list) > 0: - join_str = "\nHere is a list of joinable columns:\n" + "\n".join(join_list) + join_str = "" + pruned_metadata_str = pruned_metadata_ddl + join_str else: - join_str = "" - pruned_metadata_str = pruned_metadata_ddl + join_str - else: - pruned_metadata_str = table_metadata_string - - prompt = prompt.format( - user_question=question, - db_type=db_type, - instructions=instructions, - table_metadata_string=pruned_metadata_str, - k_shot_prompt=k_shot_prompt, - glossary=glossary, - prev_invalid_sql=prev_invalid_sql, - prev_error_msg=prev_error_msg, - ) - return prompt + pruned_metadata_str = table_metadata_string + + prompt = prompt.format( + user_question=question, + db_type=db_type, + instructions=instructions, + table_metadata_string=pruned_metadata_str, + k_shot_prompt=k_shot_prompt, + glossary=glossary, + prev_invalid_sql=prev_invalid_sql, + prev_error_msg=prev_error_msg, + ) + return prompt + except ImportError: + # When defog_data is not available, just format with the existing table_metadata_string + with open(prompt_file, "r") as f: + prompt = f.read() + + prompt = prompt.format( + user_question=question, + db_type=db_type, + instructions=instructions, + table_metadata_string=table_metadata_string, + k_shot_prompt=k_shot_prompt, + glossary=glossary, + prev_invalid_sql=prev_invalid_sql, + prev_error_msg=prev_error_msg, + ) + return prompt def process_row(row, model_name, args): start_time = time() + result_row = row.copy() # Create a copy of the original row to maintain all data prompt = generate_prompt( prompt_file=args.prompt_file[0], question=row["question"], @@ -110,21 +129,43 @@ def process_row(row, model_name, args): ) except: pass - return { - "query": generated_query, + result_row.update({ + "generated_query": generated_query, "reason": "", - "err": "", + "error_msg": "", "latency_seconds": time() - start_time, "tokens_used": response.input_tokens + response.output_tokens, - } + }) + + # Verify the generated query + try: + exact_match, correct = compare_query_results( + query_gold=row["query"], + query_gen=generated_query, + db_name=row["db_name"], + db_type=args.db_type, + db_creds=db_creds_all[args.db_type], + question=row["question"], + query_category=row["query_category"], + decimal_points=args.decimal_points if hasattr(args, 'decimal_points') else 2, + ) + result_row["exact_match"] = int(exact_match) + result_row["correct"] = int(correct) + result_row["is_correct"] = int(correct) + except Exception as e: + result_row["error_db_exec"] = 1 + result_row["error_msg"] = f"EXECUTION ERROR: {str(e)}" + result_row["is_correct"] = 0 except Exception as e: - return { - "query": "", + result_row.update({ + "generated_query": "", "reason": "", - "err": f"GENERATION ERROR: {str(e)}", + "error_msg": f"GENERATION ERROR: {str(e)}", "latency_seconds": time() - start_time, "tokens_used": 0, - } + "is_correct": 0, + }) + return result_row def run_anthropic_eval(args): diff --git a/utils/questions.py b/utils/questions.py index 414a89d..bc89354 100644 --- a/utils/questions.py +++ b/utils/questions.py @@ -1,19 +1,23 @@ -from typing import Optional import pandas as pd +from typing import Optional def get_table_aliases(db_name: str) -> str: - from defog_data.metadata import dbs - from utils.aliases import generate_aliases - - metadata = dbs[db_name]["table_metadata"] - table_names = list(metadata.keys()) - aliases = generate_aliases(table_names) - aliases_instruction = ( - "Use the following table aliases when referencing tables in the query:\n" - + aliases - ) - return aliases_instruction + try: + from defog_data.metadata import dbs + from utils.aliases import generate_aliases + + metadata = dbs[db_name]["table_metadata"] + table_names = list(metadata.keys()) + aliases = generate_aliases(table_names) + aliases_instruction = ( + "Use the following table aliases when referencing tables in the query:\n" + + aliases + ) + return aliases_instruction + except ImportError: + # Return empty string when defog_data is not available + return "" def prepare_questions_df( @@ -142,4 +146,4 @@ def prepare_questions_df( elif cot_table_alias == "pregen": question_query_df["cot_pregen"] = True - return question_query_df + return question_query_df \ No newline at end of file