Skip to content

Commit

Permalink
refactor:
Browse files Browse the repository at this point in the history
  • Loading branch information
chaofengc committed Jul 17, 2024
1 parent b7d19ea commit aac09d0
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 8 deletions.
30 changes: 22 additions & 8 deletions pyiqa/archs/liqe_arch.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import torch
import torch.nn as nn
import os

from .constants import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD

Expand Down Expand Up @@ -68,15 +69,29 @@ def __init__(self,

if pretrained == 'mix':
self.mtl = True
text_feat_cache_path = os.path.expanduser("~/.cache/pyiqa/liqe_text_feat_mix.pt")
else:
self.mtl = mtl
text_feat_cache_path = os.path.expanduser("~/.cache/pyiqa/liqe_text_feat.pt")

if self.mtl:
self.joint_texts = torch.cat(
[clip.tokenize(f"a photo of a {c} with {d} artifacts, which is of {q} quality") for q, c, d
in product(qualitys, scenes, dists_map)])
if os.path.exists(text_feat_cache_path):
self.text_features = torch.load(text_feat_cache_path, map_location='cpu')
else:
self.joint_texts = torch.cat([clip.tokenize(f"a photo with {c} quality") for c in qualitys])
print(f'Generating text features for LIQE model, will be cached at {text_feat_cache_path}.')
if self.mtl:
self.joint_texts = torch.cat(
[clip.tokenize(f"a photo of a {c} with {d} artifacts, which is of {q} quality") for q, c, d
in product(qualitys, scenes, dists_map)])
else:
self.joint_texts = torch.cat([clip.tokenize(f"a photo with {c} quality") for c in qualitys])

self.text_features = self.get_text_features(self.joint_texts)
torch.save(self.text_features.to('cpu'), text_feat_cache_path)

def get_text_features(self, x):
text_features = self.clip_model.encode_text(self.joint_texts.to(x.device))
text_features = text_features / text_features.norm(dim=1, keepdim=True)
return text_features

def forward(self, x):
bs = x.size(0)
Expand Down Expand Up @@ -108,9 +123,8 @@ def forward(self, x):
x = x[:, sel, ...]
x = x.reshape(bs, num_patch, x.shape[2], x.shape[3], x.shape[4])

text_features = self.clip_model.encode_text(self.joint_texts.to(x.device))
text_features = text_features / text_features.norm(dim=1, keepdim=True)

text_features = self.text_features.to(x)

x = x.view(bs*x.size(1), x.size(2), x.size(3), x.size(4))
image_features = self.clip_model.encode_image(x, pos_embedding=True)
# normalized features
Expand Down
2 changes: 2 additions & 0 deletions pyiqa/utils/download_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
Returns:
str: The path to the downloaded file.
"""
model_dir = model_dir or os.path.expanduser("~/.cache/pyiqa")

if model_dir is None: # use the pytorch hub_dir
hub_dir = get_dir()
model_dir = os.path.join(hub_dir, 'checkpoints')
Expand Down

0 comments on commit aac09d0

Please sign in to comment.