Skip to content

Commit

Permalink
Fix QA example (#30580)
Browse files Browse the repository at this point in the history
* Handle cases when CLS token is absent

* Use BOS token as a fallback
  • Loading branch information
Rocketknight1 committed May 1, 2024
1 parent 4b4da18 commit 1e05671
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 6 deletions.
7 changes: 6 additions & 1 deletion examples/pytorch/question-answering/run_qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,12 @@ def prepare_train_features(examples):
for i, offsets in enumerate(offset_mapping):
# We will label impossible answers with the index of the CLS token.
input_ids = tokenized_examples["input_ids"][i]
cls_index = input_ids.index(tokenizer.cls_token_id)
if tokenizer.cls_token_id in input_ids:
cls_index = input_ids.index(tokenizer.cls_token_id)
elif tokenizer.bos_token_id in input_ids:
cls_index = input_ids.index(tokenizer.bos_token_id)
else:
cls_index = 0

# Grab the sequence corresponding to that example (to know what is the context and what is the question).
sequence_ids = tokenized_examples.sequence_ids(i)
Expand Down
14 changes: 12 additions & 2 deletions examples/pytorch/question-answering/run_qa_beam_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,12 @@ def prepare_train_features(examples):
for i, offsets in enumerate(offset_mapping):
# We will label impossible answers with the index of the CLS token.
input_ids = tokenized_examples["input_ids"][i]
cls_index = input_ids.index(tokenizer.cls_token_id)
if tokenizer.cls_token_id in input_ids:
cls_index = input_ids.index(tokenizer.cls_token_id)
elif tokenizer.bos_token_id in input_ids:
cls_index = input_ids.index(tokenizer.bos_token_id)
else:
cls_index = 0
tokenized_examples["cls_index"].append(cls_index)

# Grab the sequence corresponding to that example (to know what is the context and what is the question).
Expand Down Expand Up @@ -534,7 +539,12 @@ def prepare_validation_features(examples):

for i, input_ids in enumerate(tokenized_examples["input_ids"]):
# Find the CLS token in the input ids.
cls_index = input_ids.index(tokenizer.cls_token_id)
if tokenizer.cls_token_id in input_ids:
cls_index = input_ids.index(tokenizer.cls_token_id)
elif tokenizer.bos_token_id in input_ids:
cls_index = input_ids.index(tokenizer.bos_token_id)
else:
cls_index = 0
tokenized_examples["cls_index"].append(cls_index)

# Grab the sequence corresponding to that example (to know what is the context and what is the question).
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,12 @@ def prepare_train_features(examples):
for i, offsets in enumerate(offset_mapping):
# We will label impossible answers with the index of the CLS token.
input_ids = tokenized_examples["input_ids"][i]
cls_index = input_ids.index(tokenizer.cls_token_id)
if tokenizer.cls_token_id in input_ids:
cls_index = input_ids.index(tokenizer.cls_token_id)
elif tokenizer.bos_token_id in input_ids:
cls_index = input_ids.index(tokenizer.bos_token_id)
else:
cls_index = 0
tokenized_examples["cls_index"].append(cls_index)

# Grab the sequence corresponding to that example (to know what is the context and what is the question).
Expand Down Expand Up @@ -563,7 +568,12 @@ def prepare_validation_features(examples):

for i, input_ids in enumerate(tokenized_examples["input_ids"]):
# Find the CLS token in the input ids.
cls_index = input_ids.index(tokenizer.cls_token_id)
if tokenizer.cls_token_id in input_ids:
cls_index = input_ids.index(tokenizer.cls_token_id)
elif tokenizer.bos_token_id in input_ids:
cls_index = input_ids.index(tokenizer.bos_token_id)
else:
cls_index = 0
tokenized_examples["cls_index"].append(cls_index)

# Grab the sequence corresponding to that example (to know what is the context and what is the question).
Expand Down
7 changes: 6 additions & 1 deletion examples/pytorch/question-answering/run_qa_no_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,7 +513,12 @@ def prepare_train_features(examples):
for i, offsets in enumerate(offset_mapping):
# We will label impossible answers with the index of the CLS token.
input_ids = tokenized_examples["input_ids"][i]
cls_index = input_ids.index(tokenizer.cls_token_id)
if tokenizer.cls_token_id in input_ids:
cls_index = input_ids.index(tokenizer.cls_token_id)
elif tokenizer.bos_token_id in input_ids:
cls_index = input_ids.index(tokenizer.bos_token_id)
else:
cls_index = 0

# Grab the sequence corresponding to that example (to know what is the context and what is the question).
sequence_ids = tokenized_examples.sequence_ids(i)
Expand Down

0 comments on commit 1e05671

Please sign in to comment.