diff --git a/basicsr/models/archs/restormer_arch.py b/basicsr/models/archs/restormer_arch.py index a41221e..1874ba8 100644 --- a/basicsr/models/archs/restormer_arch.py +++ b/basicsr/models/archs/restormer_arch.py @@ -113,10 +113,10 @@ def forward(self, x): qkv = self.qkv_dwconv(self.qkv(x)) q,k,v = qkv.chunk(3, dim=1) - - q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads) - k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads) - v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads) + + q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads).contiguous(memory_format=torch.contiguous_format) + k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads).contiguous(memory_format=torch.contiguous_format) + v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads).contiguous(memory_format=torch.contiguous_format) q = torch.nn.functional.normalize(q, dim=-1) k = torch.nn.functional.normalize(k, dim=-1)