From 536b15592b2a7112ae012721a01cb8bd7e8872ca Mon Sep 17 00:00:00 2001 From: Henning Redestig Date: Fri, 4 Nov 2022 13:02:14 +0100 Subject: [PATCH] fix: iterate over tokenized sequence / index-error Iteration happen over all residue tokens of the tokenized sequence (first index 1). The corresponding residue in `sequence` exludes the start token so `i - 1`. --- examples/variant-prediction/predict.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/variant-prediction/predict.py b/examples/variant-prediction/predict.py index e12fd614..8bf6f82c 100644 --- a/examples/variant-prediction/predict.py +++ b/examples/variant-prediction/predict.py @@ -135,12 +135,12 @@ def compute_pppl(row, sequence, model, alphabet, offset_idx): # compute probabilities at each position log_probs = [] - for i in range(1, len(sequence) - 1): + for i in range(1, batch_tokens.shape[1] - 1): batch_tokens_masked = batch_tokens.clone() batch_tokens_masked[0, i] = alphabet.mask_idx with torch.no_grad(): token_probs = torch.log_softmax(model(batch_tokens_masked.cuda())["logits"], dim=-1) - log_probs.append(token_probs[0, i, alphabet.get_idx(sequence[i])].item()) # vocab size + log_probs.append(token_probs[0, i, alphabet.get_idx(sequence[i - 1])].item()) # vocab size return sum(log_probs)