Skip to content

Commit 77ad3de

Browse files
committed
Remove redundant code in the ImageProcessor
1 parent 214f617 commit 77ad3de

19 files changed

+56
-293
lines changed

configs/unet/stage1.yaml

-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ data:
1212
num_workers: 12 # 12
1313
num_frames: 16
1414
resolution: 256
15-
mask: fix_mask
1615
mask_image_path: latentsync/utils/mask.png
1716
audio_sample_rate: 16000
1817
video_fps: 25

configs/unet/stage2.yaml

-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ data:
1212
num_workers: 12 # 12
1313
num_frames: 16
1414
resolution: 256
15-
mask: fix_mask
1615
mask_image_path: latentsync/utils/mask.png
1716
audio_sample_rate: 16000
1817
video_fps: 25

configs/unet/stage2_efficient.yaml

-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ data:
1212
num_workers: 12 # 12
1313
num_frames: 16
1414
resolution: 256
15-
mask: fix_mask
1615
mask_image_path: latentsync/utils/mask.png
1716
audio_sample_rate: 16000
1817
video_fps: 25

latentsync/data/syncnet_dataset.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def __init__(self, data_dir: str, fileslist: str, config):
4242

4343
self.audio_sample_rate = config.data.audio_sample_rate
4444
self.video_fps = config.data.video_fps
45-
self.image_processor = ImageProcessor(resolution=config.data.resolution, mask="half")
45+
self.image_processor = ImageProcessor(resolution=config.data.resolution)
4646
self.audio_mel_cache_dir = config.data.audio_mel_cache_dir
4747
os.makedirs(self.audio_mel_cache_dir, exist_ok=True)
4848

latentsync/data/unet_dataset.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def worker_init_fn(self, worker_id):
8989
setattr(
9090
self,
9191
f"image_processor_{worker_id}",
92-
ImageProcessor(self.resolution, self.mask, mask_image=self.mask_image),
92+
ImageProcessor(self.resolution, mask_image=self.mask_image),
9393
)
9494

9595
def __getitem__(self, idx):

latentsync/pipelines/lipsync_pipeline.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -328,7 +328,6 @@ def __call__(
328328
guidance_scale: float = 1.5,
329329
weight_dtype: Optional[torch.dtype] = torch.float16,
330330
eta: float = 0.0,
331-
mask: str = "fix_mask",
332331
mask_image_path: str = "latentsync/utils/mask.png",
333332
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
334333
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
@@ -344,7 +343,7 @@ def __call__(
344343
batch_size = 1
345344
device = self._execution_device
346345
mask_image = load_fixed_mask(height, mask_image_path)
347-
self.image_processor = ImageProcessor(height, mask=mask, device="cuda", mask_image=mask_image)
346+
self.image_processor = ImageProcessor(height, device="cuda", mask_image=mask_image)
348347
self.set_progress_bar_config(desc=f"Sample frames: {num_frames}")
349348

350349
# 1. Default height and width to unet

latentsync/utils/affine_transform.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,10 @@
88

99

1010
class AlignRestore(object):
11-
def __init__(self, align_points=3, device="cpu", dtype=torch.float32):
11+
def __init__(self, align_points=3, resolution=256, device="cpu", dtype=torch.float32):
1212
if align_points == 3:
1313
self.upscale_factor = 1
14-
ratio = 2.8
14+
ratio = resolution / 256 * 2.8
1515
self.crop_ratio = (ratio, ratio)
1616
self.face_template = np.array([[19 - 2, 30 - 10], [56 + 2, 30 - 10], [37.5, 45 - 5]])
1717
self.face_template = self.face_template * ratio

latentsync/utils/image_processor.py

+32-241
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,10 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from latentsync.utils.util import read_video, write_video
1516
from torchvision import transforms
1617
import cv2
1718
from einops import rearrange
18-
import mediapipe as mp
1919
import torch
2020
import numpy as np
2121
from typing import Union
@@ -32,90 +32,31 @@ def load_fixed_mask(resolution: int, mask_image_path="latentsync/utils/mask.png"
3232

3333

3434
class ImageProcessor:
35-
def __init__(self, resolution: int = 512, mask: str = "fix_mask", device: str = "cpu", mask_image=None):
35+
def __init__(self, resolution: int = 512, device: str = "cpu", mask_image=None):
3636
self.resolution = resolution
3737
self.resize = transforms.Resize(
38-
(resolution, resolution), interpolation=transforms.InterpolationMode.BILINEAR, antialias=True
38+
(resolution, resolution), interpolation=transforms.InterpolationMode.BICUBIC, antialias=True
3939
)
4040
self.normalize = transforms.Normalize([0.5], [0.5], inplace=True)
41-
self.mask = mask
4241

43-
if mask in ["mouth", "face", "eye"]:
44-
self.face_mesh = mp.solutions.face_mesh.FaceMesh(static_image_mode=True) # Process single image
45-
if mask == "fix_mask":
46-
self.face_mesh = None
47-
self.restorer = AlignRestore(device=device)
42+
self.restorer = AlignRestore(resolution=resolution, device=device)
4843

49-
if mask_image is None:
50-
self.mask_image = load_fixed_mask(resolution)
51-
else:
52-
self.mask_image = mask_image
53-
54-
if device != "cpu":
55-
self.face_detector = FaceDetector(device=device)
56-
self.face_mesh = None
57-
else:
58-
# self.face_mesh = mp.solutions.face_mesh.FaceMesh(static_image_mode=True) # Process single image
59-
self.face_mesh = None
60-
self.face_detector = None
61-
62-
def detect_facial_landmarks(self, image: np.ndarray):
63-
height, width, _ = image.shape
64-
results = self.face_mesh.process(image)
65-
if not results.multi_face_landmarks: # Face not detected
66-
raise RuntimeError("Face not detected")
67-
face_landmarks = results.multi_face_landmarks[0] # Only use the first face in the image
68-
landmark_coordinates = [
69-
(int(landmark.x * width), int(landmark.y * height)) for landmark in face_landmarks.landmark
70-
] # x means width, y means height
71-
return landmark_coordinates
72-
73-
def preprocess_one_masked_image(self, image: torch.Tensor) -> np.ndarray:
74-
image = self.resize(image)
75-
76-
if self.mask == "mouth" or self.mask == "face":
77-
landmark_coordinates = self.detect_facial_landmarks(image)
78-
if self.mask == "mouth":
79-
surround_landmarks = mouth_surround_landmarks
80-
else:
81-
surround_landmarks = face_surround_landmarks
82-
83-
points = [landmark_coordinates[landmark] for landmark in surround_landmarks]
84-
points = np.array(points)
85-
mask = np.ones((self.resolution, self.resolution))
86-
mask = cv2.fillPoly(mask, pts=[points], color=(0, 0, 0))
87-
mask = torch.from_numpy(mask)
88-
mask = mask.unsqueeze(0)
89-
elif self.mask == "half":
90-
mask = torch.ones((self.resolution, self.resolution))
91-
height = mask.shape[0]
92-
mask[height // 2 :, :] = 0
93-
mask = mask.unsqueeze(0)
94-
elif self.mask == "eye":
95-
mask = torch.ones((self.resolution, self.resolution))
96-
landmark_coordinates = self.detect_facial_landmarks(image)
97-
y = landmark_coordinates[195][1]
98-
mask[y:, :] = 0
99-
mask = mask.unsqueeze(0)
44+
if mask_image is None:
45+
self.mask_image = load_fixed_mask(resolution)
10046
else:
101-
raise ValueError("Invalid mask type")
47+
self.mask_image = mask_image
10248

103-
image = image.to(dtype=torch.float32)
104-
pixel_values = self.normalize(image / 255.0)
105-
masked_pixel_values = pixel_values * mask
106-
mask = 1 - mask
107-
108-
return pixel_values, masked_pixel_values, mask
49+
if device == "cpu":
50+
self.face_detector = None
51+
else:
52+
self.face_detector = FaceDetector(device=device)
10953

11054
def affine_transform(self, image: torch.Tensor) -> np.ndarray:
111-
# image = rearrange(image, "c h w-> h w c").numpy()
11255
if self.face_detector is None:
113-
landmark_coordinates = np.array(self.detect_facial_landmarks(image))
114-
lm68 = mediapipe_lm478_to_face_alignment_lm68(landmark_coordinates)
115-
else:
116-
bbox, landmark_2d_106 = self.face_detector(image)
117-
if bbox is None:
118-
raise RuntimeError("Face not detected")
56+
raise NotImplementedError("Using the CPU for face detection is not supported")
57+
bbox, landmark_2d_106 = self.face_detector(image)
58+
if bbox is None:
59+
raise RuntimeError("Face not detected")
11960

12061
pt_left_eye = np.mean(landmark_2d_106[[43, 48, 49, 51, 50]], axis=0) # left eyebrow center
12162
pt_right_eye = np.mean(landmark_2d_106[101:106], axis=0) # right eyebrow center
@@ -153,10 +94,8 @@ def prepare_masks_and_masked_images(self, images: Union[torch.Tensor, np.ndarray
15394
images = torch.from_numpy(images)
15495
if images.shape[3] == 3:
15596
images = rearrange(images, "f h w c -> f c h w")
156-
if self.mask == "fix_mask":
157-
results = [self.preprocess_fixed_mask_image(image, affine_transform=affine_transform) for image in images]
158-
else:
159-
results = [self.preprocess_one_masked_image(image) for image in images]
97+
98+
results = [self.preprocess_fixed_mask_image(image, affine_transform=affine_transform) for image in images]
16099

161100
pixel_values_list, masked_pixel_values_list, masks_list = list(zip(*results))
162101
return torch.stack(pixel_values_list), torch.stack(masked_pixel_values_list), torch.stack(masks_list)
@@ -170,172 +109,24 @@ def process_images(self, images: Union[torch.Tensor, np.ndarray]):
170109
pixel_values = self.normalize(images / 255.0)
171110
return pixel_values
172111

173-
def close(self):
174-
if self.face_mesh is not None:
175-
self.face_mesh.close()
176-
177-
178-
def mediapipe_lm478_to_face_alignment_lm68(lm478, return_2d=True):
179-
"""
180-
lm478: [B, 478, 3] or [478,3]
181-
"""
182-
# lm478[..., 0] *= W
183-
# lm478[..., 1] *= H
184-
landmarks_extracted = []
185-
for index in landmark_points_68:
186-
x = lm478[index][0]
187-
y = lm478[index][1]
188-
landmarks_extracted.append((x, y))
189-
return np.array(landmarks_extracted)
190-
191112

192-
landmark_points_68 = [
193-
162,
194-
234,
195-
93,
196-
58,
197-
172,
198-
136,
199-
149,
200-
148,
201-
152,
202-
377,
203-
378,
204-
365,
205-
397,
206-
288,
207-
323,
208-
454,
209-
389,
210-
71,
211-
63,
212-
105,
213-
66,
214-
107,
215-
336,
216-
296,
217-
334,
218-
293,
219-
301,
220-
168,
221-
197,
222-
5,
223-
4,
224-
75,
225-
97,
226-
2,
227-
326,
228-
305,
229-
33,
230-
160,
231-
158,
232-
133,
233-
153,
234-
144,
235-
362,
236-
385,
237-
387,
238-
263,
239-
373,
240-
380,
241-
61,
242-
39,
243-
37,
244-
0,
245-
267,
246-
269,
247-
291,
248-
405,
249-
314,
250-
17,
251-
84,
252-
181,
253-
78,
254-
82,
255-
13,
256-
312,
257-
308,
258-
317,
259-
14,
260-
87,
261-
]
113+
class VideoProcessor:
114+
def __init__(self, resolution: int = 512, device: str = "cpu"):
115+
self.image_processor = ImageProcessor(resolution, device)
262116

117+
def affine_transform_video(self, video_path):
118+
video_frames = read_video(video_path, change_fps=False)
119+
results = []
120+
for frame in video_frames:
121+
frame, _, _ = self.image_processor.affine_transform(frame)
122+
results.append(frame)
123+
results = torch.stack(results)
263124

264-
# Refer to https://storage.googleapis.com/mediapipe-assets/documentation/mediapipe_face_landmark_fullsize.png
265-
mouth_surround_landmarks = [
266-
164,
267-
165,
268-
167,
269-
92,
270-
186,
271-
57,
272-
43,
273-
106,
274-
182,
275-
83,
276-
18,
277-
313,
278-
406,
279-
335,
280-
273,
281-
287,
282-
410,
283-
322,
284-
391,
285-
393,
286-
]
125+
results = rearrange(results, "f c h w -> f h w c").numpy()
126+
return results
287127

288-
face_surround_landmarks = [
289-
152,
290-
377,
291-
400,
292-
378,
293-
379,
294-
365,
295-
397,
296-
288,
297-
435,
298-
433,
299-
411,
300-
425,
301-
423,
302-
327,
303-
326,
304-
94,
305-
97,
306-
98,
307-
203,
308-
205,
309-
187,
310-
213,
311-
215,
312-
58,
313-
172,
314-
136,
315-
150,
316-
149,
317-
176,
318-
148,
319-
]
320128

321129
if __name__ == "__main__":
322-
image_processor = ImageProcessor(512, mask="fix_mask")
323-
video = cv2.VideoCapture("assets/demo1_video.mp4")
324-
while True:
325-
ret, frame = video.read()
326-
# if not ret:
327-
# break
328-
329-
# cv2.imwrite("image.jpg", frame)
330-
331-
frame = rearrange(torch.Tensor(frame).type(torch.uint8), "h w c -> c h w")
332-
# face, masked_face, _ = image_processor.preprocess_fixed_mask_image(frame, affine_transform=True)
333-
face, _, _ = image_processor.affine_transform(frame)
334-
335-
break
336-
337-
face = (rearrange(face, "c h w -> h w c").detach().cpu().numpy()).astype(np.uint8)
338-
cv2.imwrite("face.jpg", face)
339-
340-
# masked_face = (rearrange(masked_face, "c h w -> h w c").detach().cpu().numpy()).astype(np.uint8)
341-
# cv2.imwrite("masked_face.jpg", masked_face)
130+
video_processor = VideoProcessor(256, "cuda")
131+
video_frames = video_processor.affine_transform_video("validation/flux.mp4")
132+
write_video("output.mp4", video_frames, fps=25)

0 commit comments

Comments
 (0)