Skip to content

Optimized chunk processing logic in Chapter 7 for improved handling o… #763

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 28 additions & 8 deletions chapters/en/chapter7/7.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,10 @@ stride = 128


def preprocess_training_examples(examples):
# cleanup before tokenizing: some of the questions in the SQuAD dataset have extra spaces at the beginning and the end that don’t add anything

questions = [q.strip() for q in examples["question"]]

inputs = tokenizer(
questions,
examples["context"],
Expand All @@ -391,28 +394,38 @@ def preprocess_training_examples(examples):
start_positions = []
end_positions = []

# Track the last processed sample index
last_sample_idx = -1

for i, offset in enumerate(offset_mapping):
# Check which sample the current chunk /belongs to
sample_idx = sample_map[i]

# If we're processing a new sample, reset the answer found flag
if sample_idx != last_sample_idx:
last_sample_idx = sample_idx

# Extract the answer details for the current sample
answer = answers[sample_idx]
start_char = answer["answer_start"][0]
end_char = answer["answer_start"][0] + len(answer["text"][0])

# Identify the context boundaries using sequence IDs
sequence_ids = inputs.sequence_ids(i)

# Find the start and end of the context
# Find where the context starts and ends
idx = 0
while sequence_ids[idx] != 1:
idx += 1
context_start = idx
while sequence_ids[idx] == 1:

while idx < len(sequence_ids) and sequence_ids[idx] == 1:
idx += 1
context_end = idx - 1

# If the answer is not fully inside the context, label is (0, 0)
if offset[context_start][0] > start_char or offset[context_end][1] < end_char:
start_positions.append(0)
end_positions.append(0)
else:
# Otherwise it's the start and end token positions
# Check if the answer is fully inside the chunk
if offset[context_start][0] <= start_char and offset[context_end][1] >= end_char:
# The answer is fully inside the chunk; find its start and end positions
idx = context_start
while idx <= context_end and offset[idx][0] <= start_char:
idx += 1
Expand All @@ -422,9 +435,16 @@ def preprocess_training_examples(examples):
while idx >= context_start and offset[idx][1] >= end_char:
idx -= 1
end_positions.append(idx + 1)
else:
# If the answer is not in this chunk, append (0, 0)
start_positions.append(0)
end_positions.append(0)

start_positions, end_positions

inputs["start_positions"] = start_positions
inputs["end_positions"] = end_positions

return inputs
```

Expand Down