Skip to content

Commit 29a433c

Browse files
author
Rui Zhang
committed
add cosql editsql
1 parent 05ded8f commit 29a433c

File tree

4 files changed

+99
-4
lines changed

4 files changed

+99
-4
lines changed

README.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ Download Pretrained BERT model from [here](https://drive.google.com/file/d/1f_LE
5959
### Run Spider experiment
6060
First, download [Spider](https://yale-lily.github.io/spider). Then please follow
6161

62-
- `run_spider_editsql.sh`. We saved our experimental logs at `logs/logs_spider_editsql`
62+
- `run_spider_editsql.sh`. We saved our experimental logs at `logs/logs_spider_editsql`. The dev results can be reproduced by `test_spider_editsql.sh` with the trained model `logs/logs_spider_editsql/save_12`.
6363

6464
This reproduces the Spider result in "Editing-Based SQL Query Generation for Cross-Domain Context-Dependent Questions".
6565

@@ -140,7 +140,8 @@ This reproduces the SParC result in "Editing-Based SQL Query Generation for Cros
140140

141141
First, download CoSQL from [here](https://yale-lily.github.io/cosql). Then please follow
142142

143-
- `run_cosql_cdseq2seq.sh`. We saved our experimental logs at `logs/logs_cosql_cdseq2seq`
143+
- `run_cosql_cdseq2seq.sh`. We saved our experimental logs at `logs/logs_cosql_cdseq2seq`.
144+
- `run_cosql_editsql.sh`. We saved our experimental logs at `logs/logs_cosql_editsql`. The dev results can be reproduced by `test_cosql_editsql.sh` with the trained model downloaded from [here](https://drive.google.com/file/d/1ggf05rLVUpqamkEFbhu2CX35-PTGpFx4/view?usp=sharing) and put under `logs/logs_cosql_editsql/save_12_cosql_editsql`.
144145

145146
This reproduces the SQL-grounded dialog state tracking result in "CoSQL: A Conversational Text-to-SQL Challenge Towards Cross-Domain Natural Language Interfaces to Databases".
146147

preprocess.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,7 @@ def check_oov(format_sql_final, output_vocab, schema_tokens):
244244
for sql_tok in format_sql_final.split():
245245
if not (sql_tok in schema_tokens or sql_tok in output_vocab):
246246
print('OOV!', sql_tok)
247+
raise Exception('OOV')
247248

248249

249250
def normalize_space(format_sql):
@@ -401,7 +402,11 @@ def read_data_json(split_json, interaction_list, database_schemas, column_names,
401402
continue
402403

403404
if remove_from:
404-
turn_sql_parse = parse_sql(turn_sql, db_id, column_names[db_id], output_vocab, schema_tokens[db_id], database_schemas[db_id])
405+
try:
406+
turn_sql_parse = parse_sql(turn_sql, db_id, column_names[db_id], output_vocab, schema_tokens[db_id], database_schemas[db_id])
407+
except:
408+
print('continue')
409+
continue
405410
else:
406411
turn_sql_parse = turn_sql
407412

@@ -411,7 +416,8 @@ def read_data_json(split_json, interaction_list, database_schemas, column_names,
411416
turn_utterance = turn['utterance']
412417

413418
interaction['interaction'].append({'utterance': turn_utterance, 'sql': turn_sql_parse})
414-
interaction_list[db_id].append(interaction)
419+
if len(interaction['interaction']) > 0:
420+
interaction_list[db_id].append(interaction)
415421

416422
return interaction_list
417423

run_cosql_editsql.sh

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
#! /bin/bash
2+
3+
# 1. preprocess dataset by the following. It will produce data/cosql_data_removefrom/
4+
5+
python3 preprocess.py --dataset=cosql --remove_from
6+
7+
# 2. train and evaluate.
8+
# the result (models, logs, prediction outputs) are saved in $LOGDIR
9+
10+
GLOVE_PATH="/home/lily/rz268/dialog2sql/word_emb/glove.840B.300d.txt" # you need to change this
11+
LOGDIR="logs_cosql_editsql"
12+
13+
CUDA_VISIBLE_DEVICES=0 python3 run.py --raw_train_filename="data/cosql_data_removefrom/train.pkl" \
14+
--raw_validation_filename="data/cosql_data_removefrom/dev.pkl" \
15+
--database_schema_filename="data/cosql_data_removefrom/tables.json" \
16+
--embedding_filename=$GLOVE_PATH \
17+
--data_directory="processed_data_cosql_removefrom" \
18+
--input_key="utterance" \
19+
--state_positional_embeddings=1 \
20+
--discourse_level_lstm=1 \
21+
--use_utterance_attention=1 \
22+
--use_previous_query=1 \
23+
--use_query_attention=1 \
24+
--use_copy_switch=1 \
25+
--use_schema_encoder=1 \
26+
--use_schema_attention=1 \
27+
--use_encoder_attention=1 \
28+
--use_bert=1 \
29+
--bert_type_abb=uS \
30+
--fine_tune_bert=1 \
31+
--use_schema_self_attention=1 \
32+
--use_schema_encoder_2=1 \
33+
--interaction_level=1 \
34+
--reweight_batch=1 \
35+
--freeze=1 \
36+
--train=1 \
37+
--logdir=$LOGDIR \
38+
--evaluate=1 \
39+
--evaluate_split="valid" \
40+
--use_predicted_queries=1
41+
42+
# 3. get evaluation result
43+
44+
python3 postprocess_eval.py --dataset=cosql --split=dev --pred_file $LOGDIR/valid_use_predicted_queries_predictions.json --remove_from

test_cosql_editsql.sh

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
#! /bin/bash
2+
3+
# 1. preprocess dataset by the following. It will produce data/cosql_data_removefrom/
4+
5+
# python3 preprocess.py --dataset=cosql --remove_from
6+
7+
# 2. train and evaluate.
8+
# the result (models, logs, prediction outputs) are saved in $LOGDIR
9+
10+
GLOVE_PATH="/home/lily/rz268/dialog2sql/word_emb/glove.840B.300d.txt" # you need to change this
11+
LOGDIR="logs/logs_cosql_editsql"
12+
13+
CUDA_VISIBLE_DEVICES=0 python3 run.py --raw_train_filename="data/cosql_data_removefrom/train.pkl" \
14+
--raw_validation_filename="data/cosql_data_removefrom/dev.pkl" \
15+
--database_schema_filename="data/cosql_data_removefrom/tables.json" \
16+
--embedding_filename=$GLOVE_PATH \
17+
--data_directory="processed_data_cosql_removefrom" \
18+
--input_key="utterance" \
19+
--state_positional_embeddings=1 \
20+
--discourse_level_lstm=1 \
21+
--use_utterance_attention=1 \
22+
--use_previous_query=1 \
23+
--use_query_attention=1 \
24+
--use_copy_switch=1 \
25+
--use_schema_encoder=1 \
26+
--use_schema_attention=1 \
27+
--use_encoder_attention=1 \
28+
--use_bert=1 \
29+
--bert_type_abb=uS \
30+
--fine_tune_bert=1 \
31+
--use_schema_self_attention=1 \
32+
--use_schema_encoder_2=1 \
33+
--interaction_level=1 \
34+
--reweight_batch=1 \
35+
--freeze=1 \
36+
--logdir=$LOGDIR \
37+
--evaluate=1 \
38+
--evaluate_split="valid" \
39+
--use_predicted_queries=1 \
40+
--save_file="$LOGDIR/save_12_cosql_editsql"
41+
42+
# 3. get evaluation result
43+
44+
python3 postprocess_eval.py --dataset=cosql --split=dev --pred_file $LOGDIR/valid_use_predicted_queries_predictions.json --remove_from

0 commit comments

Comments
 (0)