Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Is there some errors in transformers mha.py #216

Open
lqzzy opened this issue Sep 28, 2023 · 1 comment
Open

Is there some errors in transformers mha.py #216

lqzzy opened this issue Sep 28, 2023 · 1 comment

Comments

@lqzzy
Copy link

lqzzy commented Sep 28, 2023

in the file labml_nn/transformer/mha.py

    def forward(self, x: torch.Tensor):
        # Input has shape `[seq_len, batch_size, d_model]` or `[batch_size, d_model]`.
        # We apply the linear transformation to the last dimension and split that into
        # the heads.
        head_shape = x.shape[:-1]

        # Linear transform
        x = self.linear(x)

        # Split last dimension into heads
        x = x.view(*head_shape, self.heads, self.d_k)

        # Output has shape `[seq_len, batch_size, heads, d_k]` or `[batch_size, heads, d_model]`
        return x

I have question about:

The first:

        # Input has shape `[seq_len, batch_size, d_model]` or `[batch_size, d_model]`.
        # Output has shape `[seq_len, batch_size, heads, d_k]` or `[batch_size, heads, d_model]`

I think Output should be [seq_len, batch_size, heads, d_k] or [batch_size, heads, d_k]

The second:

class MultiHeadAttention(nn.Module):
    r"""
    <a id="MHA"></a>

    ## Multi-Head Attention Module

    This computes scaled multi-headed attention for given `query`, `key` and `value` vectors.

    $$\mathop{Attention}(Q, K, V) = \underset{seq}{\mathop{softmax}}\Bigg(\frac{Q K^\top}{\sqrt{d_k}}\Bigg)V$$

    In simple terms, it finds keys that matches the query, and gets the values of
     those keys.

    It uses dot-product of query and key as the indicator of how matching they are.
    Before taking the $softmax$ the dot-products are scaled by $\frac{1}{\sqrt{d_k}}$.
    This is done to avoid large dot-product values causing softmax to
    give very small gradients when $d_k$ is large.

    Softmax is calculated along the axis of of the sequence (or time).
    """

    def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1, bias: bool = True):
        """
        * `heads` is the number of heads.
        * `d_model` is the number of features in the `query`, `key` and `value` vectors.
        """

        super().__init__()

        # Number of features per head
        self.d_k = d_model // heads
        # Number of heads
        self.heads = heads

        # These transform the `query`, `key` and `value` vectors for multi-headed attention.
        self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=bias)
        self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=bias)
        self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=True)

        # Softmax for attention along the time dimension of `key`
        self.softmax = nn.Softmax(dim=1)

        # Output layer
        self.output = nn.Linear(d_model, d_model)
        # Dropout
        self.dropout = nn.Dropout(dropout_prob)
        # Scaling factor before the softmax
        self.scale = 1 / math.sqrt(self.d_k)

        # We store attentions so that it can be used for logging, or other computations if needed
        self.attn = None

I think the code

        # Softmax for attention along the time dimension of `key`
        self.softmax = nn.Softmax(dim=1)

dim should be -1

Output should be [seq_len, batch_size, heads, d_k] and after dot-product with keys, the result should be [seq_len, batch_size, heads, seq_len] , and the key dimension should be the last one, so I think dim should be -1.
I also have troubles of the case of output is [batch_size, heads, d_k], I have no idea what the result is after dot-product,
and what dim should be

@lqzzy
Copy link
Author

lqzzy commented Sep 28, 2023

After reading the entire code, I think I understand the second problem.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant