From 861974ad3499601bda7ce4513c0bc34607ab2e16 Mon Sep 17 00:00:00 2001 From: CaoE Date: Mon, 22 Aug 2022 13:57:49 +0800 Subject: [PATCH] do contiguous for q, k, and v to get better performance for normalize --- basicsr/models/archs/restormer_arch.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/basicsr/models/archs/restormer_arch.py b/basicsr/models/archs/restormer_arch.py index a41221e..1874ba8 100644 --- a/basicsr/models/archs/restormer_arch.py +++ b/basicsr/models/archs/restormer_arch.py @@ -113,10 +113,10 @@ def forward(self, x): qkv = self.qkv_dwconv(self.qkv(x)) q,k,v = qkv.chunk(3, dim=1) - - q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads) - k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads) - v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads) + + q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads).contiguous(memory_format=torch.contiguous_format) + k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads).contiguous(memory_format=torch.contiguous_format) + v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads).contiguous(memory_format=torch.contiguous_format) q = torch.nn.functional.normalize(q, dim=-1) k = torch.nn.functional.normalize(k, dim=-1)