Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue with attention mask unsqueeze in attention weights #4

Open
meadewaking opened this issue Jan 24, 2024 · 1 comment
Open

Issue with attention mask unsqueeze in attention weights #4

meadewaking opened this issue Jan 24, 2024 · 1 comment

Comments

@meadewaking
Copy link

After completing a batch inference, I discovered a bug in the attention weight computation. The attention mask was being added to the attention weights with an unsqueeze operation that was using the wrong dimension. It should have been unsqueezed along the second dimension instead of the first. Here's the revised and corrected code:

attn_weights = attn_weights + attention_mask[position_ids, :attn_weights.shape[3]].unsqueeze(1)

batch inference code:

```
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = None

model_filename = "data/model.safetensors"
tokenizer_filename = "data/tokenizer.model"

state_dict = load_safetensors(model_filename, device, dtype)
tokenizer = ChatTokenizer(tokenizer_filename)
prompts = ['import numpy as np', '#include<stdio.h>']

token_ids = [tokenizer.encode(i) for i in prompts]

cache = {}
max_text_len = 100
pad_id = -1

token_ids_torch = torch.full((len(token_ids), max_text_len), pad_id, dtype=torch.long, device=device)
for k, t in enumerate(token_ids):
    token_ids_torch[k, : len(t)] = torch.tensor(t, dtype=torch.long, device=device)
prompt_tokens_mask = token_ids_torch != pad_id

# Generate tokens one by one
for cur_pos in range(1, max_text_len):

    inputs = token_ids_torch[:, cur_pos-1:cur_pos]
    position_ids = torch.full(inputs.size(), cur_pos-1)

    # Predict logits
    logits = llama(inputs, position_ids, cache, state_dict)

    # Choose the most likely token
    new_token_ids = logits[:, -1].argmax(-1)

    # Stop if we reach the special end token
    if new_token_ids.any() == tokenizer.end_token_id:
        break

    # Append the new token to the list of tokens
    new_token_ids = torch.where(prompt_tokens_mask[:, cur_pos], token_ids_torch[:, cur_pos], new_token_ids)
    token_ids_torch[:, cur_pos] = new_token_ids

    token_ids = token_ids_torch.cpu().tolist()

response = [tokenizer.decode(i) for i in token_ids]
print(response)
@99991
Copy link
Owner

99991 commented Jan 24, 2024

Good catch! And thank you for the correction. I must have forgotten to test with batch size > 1 at some point.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants