-
-
Notifications
You must be signed in to change notification settings - Fork 1
/
image_similarity.py
184 lines (141 loc) · 4.97 KB
/
image_similarity.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
import argparse
from PIL import Image
import torch.nn.functional as F
import open_clip
from sd_ext.files import get_files
from sd_ext.torch import torch_args, get_device
from sd_ext.format import to_csv, to_json, format_args
from sd_ext.clip import VISION_TRANSFORMER_MODELS, get_image_features
IMAGE_EXTS = [".png", ".jpg", ".jpeg", ".webp", ".bmp", ".tiff", ".avif"]
def main(args):
device = get_device(args.device)
print(device)
print(f"Loading {args.pretrained_clip} CLIP {args.clip_model}...")
clip_model, _, image_processor = open_clip.create_model_and_transforms(
args.clip_model, pretrained=args.pretrained_clip
)
clip_model.to(device)
print("Getting files list...")
files = get_files(
args.image_file_or_dir,
file_ext=args.exts,
recursive=args.recursive is True,
)
print("Getting comparision files list...")
other_files = files
if args.other_file_or_dir is not None:
other_files = get_files(
args.other_file_or_dir,
file_ext=args.exts,
recursive=args.recursive is True,
)
if args.filter:
files = [file for file in files if (args.filter in str(file)) is False]
other_files = [
file for file in files if (args.filter in str(file)) is False
]
print(f"Found files {len(files)}")
print(f"Other files {len(other_files)}")
print("Getting image embeddings from CLIP...")
embeddings = []
for file in files:
embedding = get_image_features(
image_processor, clip_model, Image.open(file), device
)
embeddings.append(embedding)
print(f"Image embeddings: {len(embeddings)}")
print("Getting comparision image embeddings from CLIP...")
other_embeddings = embeddings
if args.other_file_or_dir:
other_embeddings = []
for file in other_files:
other_embedding = get_image_features(
image_processor, clip_model, Image.open(file), device
)
other_embeddings.append(other_embedding)
print(f"Other image embeddings: {len(other_embeddings)}")
print("Comparing image embeddings using cosine similarity...")
# similarity tests
similarities = []
for file1, emb1 in zip(files, embeddings):
for file2, emb2 in zip(other_files, other_embeddings):
if file1 == file2:
print(file1, file2, "skip")
continue
similarity = F.cosine_similarity(emb1, emb2)
similarities.append(
{
"file1": file1,
"file2": file2,
"similarity": similarity.cpu().item(),
}
)
print(f"Similarities: {len(similarities)}")
if args.verbose:
for similarity in similarities:
print(
similarity["file1"].name,
similarity["file2"].name,
f"{similarity['similarity']:4f}",
)
# convert over file names to strings
for i, similarity in enumerate(similarities):
for key in similarity.keys():
if key in ["file1", "file2"]:
similarities[i][key] = str(similarity[key])
assert len(similarities) > 0, (
"Did not get any similarities. Check the paths to make sure we found "
+ "the images. Check --exts and --filter too."
)
if args.csv:
to_csv(similarities, args.csv)
print(f"Saved to: {args.csv}")
if args.json:
to_json(similarities, args.json)
print(f"Saved to: {args.json}")
if __name__ == "__main__":
argparser = argparse.ArgumentParser(
description="Find the image similarities between an image or directory of images"
)
argparser.add_argument(
"--image_file_or_dir", help="Images to use as inputs."
)
argparser.add_argument(
"--other_file_or_dir", help="Image dataset to compare with."
)
argparser.add_argument(
"--recursive",
action="store_true",
help="Recursively go through the directories.",
)
argparser.add_argument(
"--exts",
nargs="+",
default=IMAGE_EXTS,
help=f"Extensions for images. Default [{', '.join(IMAGE_EXTS)}]",
)
argparser.add_argument(
"--filter",
help="Filter to use to exclude images. Checks if the filter is in the filename. Ex: --filter mask",
)
argparser.add_argument(
"--verbose",
action="store_true",
help="Output the similarity between the images",
)
argparser.add_argument(
"--clip_model",
choices=VISION_TRANSFORMER_MODELS,
default="ViT-L-14",
help="CLIP model",
)
argparser.add_argument(
"--pretrained_clip",
choices=["openai", "openclip"],
default="openai",
help="Pretrained model producer",
)
argparser = torch_args(argparser)
argparser = format_args(argparser)
args = argparser.parse_args()
main(args)