Skip to content

Commit cd4c828

Browse files
[SYSTEMDS-3835] Add additional visual representations
This patch adds new visual (image, video) representations, and a test utility for the image modality.
1 parent b3c6d28 commit cd4c828

File tree

14 files changed

+591
-88
lines changed

14 files changed

+591
-88
lines changed

.github/workflows/python.yml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,8 @@ jobs:
171171
h5py \
172172
gensim \
173173
opt-einsum \
174-
nltk
174+
nltk \
175+
fvcore
175176
kill $KA
176177
cd src/main/python
177-
python -m unittest discover -s tests/scuro -p 'test_*.py' -v
178+
python -m unittest discover -s tests/scuro -p 'test_*.py' -v

src/main/python/systemds/scuro/__init__.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@
7777
)
7878
from systemds.scuro.representations.word2vec import W2V
7979
from systemds.scuro.representations.x3d import X3D
80+
from systemds.scuro.representations.color_histogram import ColorHistogram
8081
from systemds.scuro.models.model import Model
8182
from systemds.scuro.models.discrete_model import DiscreteModel
8283
from systemds.scuro.modality.joined import JoinedModality
@@ -97,7 +98,8 @@
9798
)
9899
from systemds.scuro.drsearch.multimodal_optimizer import MultimodalOptimizer
99100
from systemds.scuro.drsearch.unimodal_optimizer import UnimodalOptimizer
100-
101+
from systemds.scuro.representations.vgg import VGG19
102+
from systemds.scuro.representations.clip import CLIPText, CLIPVisual
101103

102104
__all__ = [
103105
"BaseLoader",
@@ -120,6 +122,7 @@
120122
"MFCC",
121123
"Hadamard",
122124
"OpticalFlow",
125+
"ColorHistogram",
123126
"Representation",
124127
"NPY",
125128
"JSON",
@@ -169,4 +172,7 @@
169172
"Quantile",
170173
"BandpowerFFT",
171174
"ZeroCrossingRate",
175+
"VGG19",
176+
"CLIPVisual",
177+
"CLIPText",
172178
]

src/main/python/systemds/scuro/modality/type.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,16 @@ def create_video_metadata(self, frequency, length, width, height, num_channels):
254254
md["data_layout"]["representation"] = DataLayout.NESTED_LEVEL
255255
md["data_layout"]["type"] = float
256256
md["data_layout"]["shape"] = (width, height, num_channels)
257+
return md
257258

259+
def create_image_metadata(self, width, height, num_channels):
260+
md = deepcopy(self.get_schema())
261+
md["width"] = width
262+
md["height"] = height
263+
md["num_channels"] = num_channels
264+
md["data_layout"]["representation"] = DataLayout.SINGLE_LEVEL
265+
md["data_layout"]["type"] = float
266+
md["data_layout"]["shape"] = (width, height, num_channels)
258267
return md
259268

260269

src/main/python/systemds/scuro/modality/unimodal_modality.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,8 +165,9 @@ def apply_representation(self, representation):
165165
padded = np.pad(
166166
embeddings,
167167
pad_width=(
168-
(0, padding_needed),
169-
(0, 0),
168+
(0, padding_needed)
169+
if len(embeddings.shape) == 1
170+
else ((0, padding_needed), (0, 0))
170171
),
171172
mode="constant",
172173
constant_values=0,
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
# -------------------------------------------------------------
2+
#
3+
# Licensed to the Apache Software Foundation (ASF) under one
4+
# or more contributor license agreements. See the NOTICE file
5+
# distributed with this work for additional information
6+
# regarding copyright ownership. The ASF licenses this file
7+
# to you under the Apache License, Version 2.0 (the
8+
# "License"); you may not use this file except in compliance
9+
# with the License. You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing,
14+
# software distributed under the License is distributed on an
15+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16+
# KIND, either express or implied. See the License for the
17+
# specific language governing permissions and limitations
18+
# under the License.
19+
#
20+
# -------------------------------------------------------------
21+
import numpy as np
22+
from torchvision import transforms
23+
24+
from systemds.scuro.modality.transformed import TransformedModality
25+
from systemds.scuro.representations.unimodal import UnimodalRepresentation
26+
import torch
27+
from systemds.scuro.representations.utils import save_embeddings
28+
from systemds.scuro.modality.type import ModalityType
29+
from systemds.scuro.drsearch.operator_registry import register_representation
30+
from transformers import CLIPProcessor, CLIPModel
31+
32+
from systemds.scuro.utils.converter import numpy_dtype_to_torch_dtype
33+
from systemds.scuro.utils.static_variables import get_device
34+
from systemds.scuro.utils.torch_dataset import CustomDataset
35+
36+
37+
@register_representation(ModalityType.VIDEO)
38+
class CLIPVisual(UnimodalRepresentation):
39+
def __init__(self, output_file=None):
40+
parameters = {}
41+
super().__init__("CLIPVisual", ModalityType.EMBEDDING, parameters)
42+
self.model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(
43+
get_device()
44+
)
45+
self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
46+
self.output_file = output_file
47+
48+
def transform(self, modality):
49+
transformed_modality = TransformedModality(modality, self)
50+
self.data_type = numpy_dtype_to_torch_dtype(modality.data_type)
51+
if next(self.model.parameters()).dtype != self.data_type:
52+
self.model = self.model.to(self.data_type)
53+
54+
embeddings = self.create_visual_embeddings(modality)
55+
56+
if self.output_file is not None:
57+
save_embeddings(embeddings, self.output_file)
58+
59+
transformed_modality.data = list(embeddings.values())
60+
return transformed_modality
61+
62+
def create_visual_embeddings(self, modality):
63+
tf = transforms.Compose([transforms.ToPILImage(), transforms.ToTensor()])
64+
dataset = CustomDataset(
65+
modality.data,
66+
self.data_type,
67+
get_device(),
68+
(modality.metadata[0]["width"], modality.metadata[0]["height"]),
69+
tf=tf,
70+
)
71+
embeddings = {}
72+
for instance in torch.utils.data.DataLoader(dataset):
73+
id = int(instance["id"][0])
74+
frames = instance["data"][0]
75+
embeddings[id] = []
76+
batch_size = 64
77+
78+
for start_index in range(0, len(frames), batch_size):
79+
end_index = min(start_index + batch_size, len(frames))
80+
frame_ids_range = range(start_index, end_index)
81+
frame_batch = frames[frame_ids_range]
82+
83+
inputs = self.processor(images=frame_batch, return_tensors="pt")
84+
with torch.no_grad():
85+
output = self.model.get_image_features(**inputs)
86+
87+
if len(output.shape) > 2:
88+
output = torch.nn.functional.adaptive_avg_pool2d(output, (1, 1))
89+
90+
embeddings[id].extend(
91+
torch.flatten(output, 1)
92+
.detach()
93+
.cpu()
94+
.float()
95+
.numpy()
96+
.astype(modality.data_type)
97+
)
98+
99+
embeddings[id] = np.array(embeddings[id])
100+
return embeddings
101+
102+
103+
@register_representation(ModalityType.TEXT)
104+
class CLIPText(UnimodalRepresentation):
105+
def __init__(self, output_file=None):
106+
parameters = {}
107+
super().__init__("CLIPText", ModalityType.EMBEDDING, parameters)
108+
self.model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(
109+
get_device()
110+
)
111+
self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
112+
self.output_file = output_file
113+
114+
def transform(self, modality):
115+
transformed_modality = TransformedModality(modality, self)
116+
117+
embeddings = self.create_text_embeddings(modality.data, self.model)
118+
119+
if self.output_file is not None:
120+
save_embeddings(embeddings, self.output_file)
121+
122+
transformed_modality.data = embeddings
123+
return transformed_modality
124+
125+
def create_text_embeddings(self, data, model):
126+
embeddings = []
127+
for d in data:
128+
inputs = self.processor(text=d, return_tensors="pt", padding=True)
129+
with torch.no_grad():
130+
text_embedding = model.get_text_features(**inputs)
131+
embeddings.append(text_embedding.squeeze().numpy().reshape(1, -1))
132+
133+
return embeddings
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
# -------------------------------------------------------------
2+
#
3+
# Licensed to the Apache Software Foundation (ASF) under one
4+
# or more contributor license agreements. See the NOTICE file
5+
# distributed with this work for additional information
6+
# regarding copyright ownership. The ASF licenses this file
7+
# to you under the Apache License, Version 2.0 (the
8+
# "License"); you may not use this file except in compliance
9+
# with the License. You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing,
14+
# software distributed under the License is distributed on an
15+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16+
# KIND, either express or implied. See the License for the
17+
# specific language governing permissions and limitations
18+
# under the License.
19+
#
20+
# -------------------------------------------------------------
21+
22+
import numpy as np
23+
import cv2
24+
25+
from systemds.scuro.modality.type import ModalityType
26+
from systemds.scuro.representations.unimodal import UnimodalRepresentation
27+
from systemds.scuro.modality.transformed import TransformedModality
28+
29+
30+
class ColorHistogram(UnimodalRepresentation):
31+
def __init__(
32+
self,
33+
color_space="RGB",
34+
bins=32,
35+
normalize=True,
36+
aggregation="mean",
37+
output_file=None,
38+
):
39+
super().__init__(
40+
"ColorHistogram", ModalityType.EMBEDDING, self._get_parameters()
41+
)
42+
self.color_space = color_space
43+
self.bins = bins
44+
self.normalize = normalize
45+
self.aggregation = aggregation
46+
self.output_file = output_file
47+
48+
def _get_parameters(self):
49+
return {
50+
"color_space": ["RGB", "HSV", "GRAY"],
51+
"bins": [8, 16, 32, 64, 128, 256, (8, 8, 8), (16, 16, 16)],
52+
"normalize": [True, False],
53+
"aggregation": ["mean", "max", "concat"],
54+
}
55+
56+
def compute_histogram(self, image):
57+
if self.color_space == "HSV":
58+
img = cv2.cvtColor(image, cv2.COLOR_RGB2HSV)
59+
channels = [0, 1, 2]
60+
elif self.color_space == "GRAY":
61+
img = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
62+
channels = [0]
63+
else:
64+
img = image
65+
channels = [0, 1, 2]
66+
67+
hist = self._region_histogram(img, channels)
68+
return hist
69+
70+
def _region_histogram(self, img, channels):
71+
if isinstance(self.bins, tuple):
72+
bins = self.bins
73+
elif len(channels) > 1:
74+
bins = [self.bins] * len(channels)
75+
else:
76+
bins = [self.bins]
77+
hist = cv2.calcHist([img], channels, None, bins, [0, 256] * len(channels))
78+
hist = hist.flatten()
79+
if self.normalize:
80+
hist_sum = np.sum(hist)
81+
if hist_sum > 0:
82+
hist /= hist_sum
83+
return hist.astype(np.float32)
84+
85+
def transform(self, modality):
86+
if modality.modality_type == ModalityType.IMAGE:
87+
images = modality.data
88+
hist_list = [self.compute_histogram(img) for img in images]
89+
transformed_modality = TransformedModality(
90+
modality, self, ModalityType.EMBEDDING
91+
)
92+
transformed_modality.data = hist_list
93+
return transformed_modality
94+
elif modality.modality_type == ModalityType.VIDEO:
95+
embeddings = []
96+
for vid in modality.data:
97+
frame_hists = [self.compute_histogram(frame) for frame in vid]
98+
if self.aggregation == "mean":
99+
hist = np.mean(frame_hists, axis=0)
100+
elif self.aggregation == "max":
101+
hist = np.max(frame_hists, axis=0)
102+
elif self.aggregation == "concat":
103+
hist = np.concatenate(frame_hists)
104+
embeddings.append(hist)
105+
transformed_modality = TransformedModality(
106+
modality, self, ModalityType.EMBEDDING
107+
)
108+
transformed_modality.data = embeddings
109+
return transformed_modality
110+
else:
111+
raise ValueError("Unsupported data format for HistogramRepresentation")

src/main/python/systemds/scuro/representations/resnet.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -144,17 +144,21 @@ def hook(
144144
embeddings[video_id] = []
145145
batch_size = 64
146146

147+
if modality.modality_type == ModalityType.IMAGE:
148+
frames = frames.unsqueeze(0)
149+
147150
for start_index in range(0, len(frames), batch_size):
148151
end_index = min(start_index + batch_size, len(frames))
149152
frame_ids_range = range(start_index, end_index)
150153
frame_batch = frames[frame_ids_range]
151154

152155
_ = self.model(frame_batch)
153-
values = res5c_output
154-
pooled = torch.nn.functional.adaptive_avg_pool2d(values, (1, 1))
156+
output = res5c_output
157+
if len(output.shape) > 2:
158+
output = torch.nn.functional.adaptive_avg_pool2d(output, (1, 1))
155159

156160
embeddings[video_id].extend(
157-
torch.flatten(pooled, 1)
161+
torch.flatten(output, 1)
158162
.detach()
159163
.cpu()
160164
.float()

src/main/python/systemds/scuro/representations/swin_video_transformer.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
from systemds.scuro.utils.static_variables import get_device
3535

3636

37-
# @register_representation([ModalityType.VIDEO])
37+
@register_representation([ModalityType.VIDEO])
3838
class SwinVideoTransformer(UnimodalRepresentation):
3939
def __init__(self, layer_name="avgpool"):
4040
parameters = {
@@ -50,7 +50,7 @@ def __init__(self, layer_name="avgpool"):
5050
],
5151
}
5252
self.data_type = torch.float
53-
super().__init__("SwinVideoTransformer", ModalityType.TIMESERIES, parameters)
53+
super().__init__("SwinVideoTransformer", ModalityType.EMBEDDING, parameters)
5454
self.layer_name = layer_name
5555
self.model = swin3d_t(weights=models.video.Swin3D_T_Weights.KINETICS400_V1).to(
5656
get_device()
@@ -95,6 +95,7 @@ def hook(
9595
.detach()
9696
.cpu()
9797
.numpy()
98+
.flatten()
9899
.astype(modality.data_type)
99100
)
100101

0 commit comments

Comments
 (0)