Skip to content

at::test_transform_bias_rescale_qkv_nested gets wrong result #1426

Closed
@weishi-deng

Description

@weishi-deng

🐛 Describe the bug

AssertionError: The values for attribute 'shape' do not match: torch.Size([4, 2, 2, 12]) != torch.Size([4, 2, 8, 12]).

import torch
import pytest  # noqa
from torch.testing._internal.common_utils import TestCase


cpu_device = torch.device("cpu")
dpcpp_device = torch.device("xpu")

class TestTorchMethod(TestCase):
    def test_transform_bias_rescale_qkv_nested(self, device="xpu", dtype=torch.float32):
        tests = [
            (64, 4, 16, 8),
            (24, 2, 4, 2),
            (2, 2, 2, 2),
            (24, 4, 4, 2),
            (48, 4, 16, 8),
        ]
        for embed_dim, num_heads, bs, sl in tests:
            dense_x = x = torch.randn(
                bs, sl, 3 * embed_dim, device="cpu", dtype=torch.float32
            )
            xs = list(torch.unbind(x))
            x = torch.nested.nested_tensor(xs, device="cpu", dtype=torch.float32)
            x_xpu = x.to("xpu")

            qkv = torch.nn.Linear(
                embed_dim, 3 * embed_dim, device="cpu", dtype=torch.float32
            )
            bias = qkv.bias
            bias_xpu = bias.to("xpu")

            self.assertEqual(x.to(cpu_device), x_xpu.to(cpu_device))
            self.assertEqual(bias.to(cpu_device), bias_xpu.to(cpu_device))

            (q, k, v) = torch._transform_bias_rescale_qkv(x, bias, num_heads=num_heads)
            (q_xpu, k_xpu, v_xpu) = torch._transform_bias_rescale_qkv(
                x_xpu, bias_xpu, num_heads=num_heads
            )

            self.assertEqual(q.to(cpu_device), q_xpu.to(cpu_device))
            self.assertEqual(k.to(cpu_device), k_xpu.to(cpu_device))
            self.assertEqual(v.to(cpu_device), v_xpu.to(cpu_device))

Versions

pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/xpu’

reproduce:
pytest test.py

Metadata

Metadata

Assignees

Labels

Type

Projects

No projects

Relationships

None yet

Development

No branches or pull requests

Issue actions