diff --git a/examples/variant-prediction/predict.py b/examples/variant-prediction/predict.py index e12fd614..8bf6f82c 100644 --- a/examples/variant-prediction/predict.py +++ b/examples/variant-prediction/predict.py @@ -135,12 +135,12 @@ def compute_pppl(row, sequence, model, alphabet, offset_idx): # compute probabilities at each position log_probs = [] - for i in range(1, len(sequence) - 1): + for i in range(1, batch_tokens.shape[1] - 1): batch_tokens_masked = batch_tokens.clone() batch_tokens_masked[0, i] = alphabet.mask_idx with torch.no_grad(): token_probs = torch.log_softmax(model(batch_tokens_masked.cuda())["logits"], dim=-1) - log_probs.append(token_probs[0, i, alphabet.get_idx(sequence[i])].item()) # vocab size + log_probs.append(token_probs[0, i, alphabet.get_idx(sequence[i - 1])].item()) # vocab size return sum(log_probs)