Closed
Description
🐛 Describe the bug
AssertionError: The values for attribute 'shape' do not match: torch.Size([4, 2, 2, 12]) != torch.Size([4, 2, 8, 12]).
import torch
import pytest # noqa
from torch.testing._internal.common_utils import TestCase
cpu_device = torch.device("cpu")
dpcpp_device = torch.device("xpu")
class TestTorchMethod(TestCase):
def test_transform_bias_rescale_qkv_nested(self, device="xpu", dtype=torch.float32):
tests = [
(64, 4, 16, 8),
(24, 2, 4, 2),
(2, 2, 2, 2),
(24, 4, 4, 2),
(48, 4, 16, 8),
]
for embed_dim, num_heads, bs, sl in tests:
dense_x = x = torch.randn(
bs, sl, 3 * embed_dim, device="cpu", dtype=torch.float32
)
xs = list(torch.unbind(x))
x = torch.nested.nested_tensor(xs, device="cpu", dtype=torch.float32)
x_xpu = x.to("xpu")
qkv = torch.nn.Linear(
embed_dim, 3 * embed_dim, device="cpu", dtype=torch.float32
)
bias = qkv.bias
bias_xpu = bias.to("xpu")
self.assertEqual(x.to(cpu_device), x_xpu.to(cpu_device))
self.assertEqual(bias.to(cpu_device), bias_xpu.to(cpu_device))
(q, k, v) = torch._transform_bias_rescale_qkv(x, bias, num_heads=num_heads)
(q_xpu, k_xpu, v_xpu) = torch._transform_bias_rescale_qkv(
x_xpu, bias_xpu, num_heads=num_heads
)
self.assertEqual(q.to(cpu_device), q_xpu.to(cpu_device))
self.assertEqual(k.to(cpu_device), k_xpu.to(cpu_device))
self.assertEqual(v.to(cpu_device), v_xpu.to(cpu_device))
Versions
pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/xpu’
reproduce:
pytest test.py