12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
14
15
+ from latentsync .utils .util import read_video , write_video
15
16
from torchvision import transforms
16
17
import cv2
17
18
from einops import rearrange
18
- import mediapipe as mp
19
19
import torch
20
20
import numpy as np
21
21
from typing import Union
@@ -32,90 +32,31 @@ def load_fixed_mask(resolution: int, mask_image_path="latentsync/utils/mask.png"
32
32
33
33
34
34
class ImageProcessor :
35
- def __init__ (self , resolution : int = 512 , mask : str = "fix_mask" , device : str = "cpu" , mask_image = None ):
35
+ def __init__ (self , resolution : int = 512 , device : str = "cpu" , mask_image = None ):
36
36
self .resolution = resolution
37
37
self .resize = transforms .Resize (
38
- (resolution , resolution ), interpolation = transforms .InterpolationMode .BILINEAR , antialias = True
38
+ (resolution , resolution ), interpolation = transforms .InterpolationMode .BICUBIC , antialias = True
39
39
)
40
40
self .normalize = transforms .Normalize ([0.5 ], [0.5 ], inplace = True )
41
- self .mask = mask
42
41
43
- if mask in ["mouth" , "face" , "eye" ]:
44
- self .face_mesh = mp .solutions .face_mesh .FaceMesh (static_image_mode = True ) # Process single image
45
- if mask == "fix_mask" :
46
- self .face_mesh = None
47
- self .restorer = AlignRestore (device = device )
42
+ self .restorer = AlignRestore (resolution = resolution , device = device )
48
43
49
- if mask_image is None :
50
- self .mask_image = load_fixed_mask (resolution )
51
- else :
52
- self .mask_image = mask_image
53
-
54
- if device != "cpu" :
55
- self .face_detector = FaceDetector (device = device )
56
- self .face_mesh = None
57
- else :
58
- # self.face_mesh = mp.solutions.face_mesh.FaceMesh(static_image_mode=True) # Process single image
59
- self .face_mesh = None
60
- self .face_detector = None
61
-
62
- def detect_facial_landmarks (self , image : np .ndarray ):
63
- height , width , _ = image .shape
64
- results = self .face_mesh .process (image )
65
- if not results .multi_face_landmarks : # Face not detected
66
- raise RuntimeError ("Face not detected" )
67
- face_landmarks = results .multi_face_landmarks [0 ] # Only use the first face in the image
68
- landmark_coordinates = [
69
- (int (landmark .x * width ), int (landmark .y * height )) for landmark in face_landmarks .landmark
70
- ] # x means width, y means height
71
- return landmark_coordinates
72
-
73
- def preprocess_one_masked_image (self , image : torch .Tensor ) -> np .ndarray :
74
- image = self .resize (image )
75
-
76
- if self .mask == "mouth" or self .mask == "face" :
77
- landmark_coordinates = self .detect_facial_landmarks (image )
78
- if self .mask == "mouth" :
79
- surround_landmarks = mouth_surround_landmarks
80
- else :
81
- surround_landmarks = face_surround_landmarks
82
-
83
- points = [landmark_coordinates [landmark ] for landmark in surround_landmarks ]
84
- points = np .array (points )
85
- mask = np .ones ((self .resolution , self .resolution ))
86
- mask = cv2 .fillPoly (mask , pts = [points ], color = (0 , 0 , 0 ))
87
- mask = torch .from_numpy (mask )
88
- mask = mask .unsqueeze (0 )
89
- elif self .mask == "half" :
90
- mask = torch .ones ((self .resolution , self .resolution ))
91
- height = mask .shape [0 ]
92
- mask [height // 2 :, :] = 0
93
- mask = mask .unsqueeze (0 )
94
- elif self .mask == "eye" :
95
- mask = torch .ones ((self .resolution , self .resolution ))
96
- landmark_coordinates = self .detect_facial_landmarks (image )
97
- y = landmark_coordinates [195 ][1 ]
98
- mask [y :, :] = 0
99
- mask = mask .unsqueeze (0 )
44
+ if mask_image is None :
45
+ self .mask_image = load_fixed_mask (resolution )
100
46
else :
101
- raise ValueError ( "Invalid mask type" )
47
+ self . mask_image = mask_image
102
48
103
- image = image .to (dtype = torch .float32 )
104
- pixel_values = self .normalize (image / 255.0 )
105
- masked_pixel_values = pixel_values * mask
106
- mask = 1 - mask
107
-
108
- return pixel_values , masked_pixel_values , mask
49
+ if device == "cpu" :
50
+ self .face_detector = None
51
+ else :
52
+ self .face_detector = FaceDetector (device = device )
109
53
110
54
def affine_transform (self , image : torch .Tensor ) -> np .ndarray :
111
- # image = rearrange(image, "c h w-> h w c").numpy()
112
55
if self .face_detector is None :
113
- landmark_coordinates = np .array (self .detect_facial_landmarks (image ))
114
- lm68 = mediapipe_lm478_to_face_alignment_lm68 (landmark_coordinates )
115
- else :
116
- bbox , landmark_2d_106 = self .face_detector (image )
117
- if bbox is None :
118
- raise RuntimeError ("Face not detected" )
56
+ raise NotImplementedError ("Using the CPU for face detection is not supported" )
57
+ bbox , landmark_2d_106 = self .face_detector (image )
58
+ if bbox is None :
59
+ raise RuntimeError ("Face not detected" )
119
60
120
61
pt_left_eye = np .mean (landmark_2d_106 [[43 , 48 , 49 , 51 , 50 ]], axis = 0 ) # left eyebrow center
121
62
pt_right_eye = np .mean (landmark_2d_106 [101 :106 ], axis = 0 ) # right eyebrow center
@@ -153,10 +94,8 @@ def prepare_masks_and_masked_images(self, images: Union[torch.Tensor, np.ndarray
153
94
images = torch .from_numpy (images )
154
95
if images .shape [3 ] == 3 :
155
96
images = rearrange (images , "f h w c -> f c h w" )
156
- if self .mask == "fix_mask" :
157
- results = [self .preprocess_fixed_mask_image (image , affine_transform = affine_transform ) for image in images ]
158
- else :
159
- results = [self .preprocess_one_masked_image (image ) for image in images ]
97
+
98
+ results = [self .preprocess_fixed_mask_image (image , affine_transform = affine_transform ) for image in images ]
160
99
161
100
pixel_values_list , masked_pixel_values_list , masks_list = list (zip (* results ))
162
101
return torch .stack (pixel_values_list ), torch .stack (masked_pixel_values_list ), torch .stack (masks_list )
@@ -170,172 +109,24 @@ def process_images(self, images: Union[torch.Tensor, np.ndarray]):
170
109
pixel_values = self .normalize (images / 255.0 )
171
110
return pixel_values
172
111
173
- def close (self ):
174
- if self .face_mesh is not None :
175
- self .face_mesh .close ()
176
-
177
-
178
- def mediapipe_lm478_to_face_alignment_lm68 (lm478 , return_2d = True ):
179
- """
180
- lm478: [B, 478, 3] or [478,3]
181
- """
182
- # lm478[..., 0] *= W
183
- # lm478[..., 1] *= H
184
- landmarks_extracted = []
185
- for index in landmark_points_68 :
186
- x = lm478 [index ][0 ]
187
- y = lm478 [index ][1 ]
188
- landmarks_extracted .append ((x , y ))
189
- return np .array (landmarks_extracted )
190
-
191
112
192
- landmark_points_68 = [
193
- 162 ,
194
- 234 ,
195
- 93 ,
196
- 58 ,
197
- 172 ,
198
- 136 ,
199
- 149 ,
200
- 148 ,
201
- 152 ,
202
- 377 ,
203
- 378 ,
204
- 365 ,
205
- 397 ,
206
- 288 ,
207
- 323 ,
208
- 454 ,
209
- 389 ,
210
- 71 ,
211
- 63 ,
212
- 105 ,
213
- 66 ,
214
- 107 ,
215
- 336 ,
216
- 296 ,
217
- 334 ,
218
- 293 ,
219
- 301 ,
220
- 168 ,
221
- 197 ,
222
- 5 ,
223
- 4 ,
224
- 75 ,
225
- 97 ,
226
- 2 ,
227
- 326 ,
228
- 305 ,
229
- 33 ,
230
- 160 ,
231
- 158 ,
232
- 133 ,
233
- 153 ,
234
- 144 ,
235
- 362 ,
236
- 385 ,
237
- 387 ,
238
- 263 ,
239
- 373 ,
240
- 380 ,
241
- 61 ,
242
- 39 ,
243
- 37 ,
244
- 0 ,
245
- 267 ,
246
- 269 ,
247
- 291 ,
248
- 405 ,
249
- 314 ,
250
- 17 ,
251
- 84 ,
252
- 181 ,
253
- 78 ,
254
- 82 ,
255
- 13 ,
256
- 312 ,
257
- 308 ,
258
- 317 ,
259
- 14 ,
260
- 87 ,
261
- ]
113
+ class VideoProcessor :
114
+ def __init__ (self , resolution : int = 512 , device : str = "cpu" ):
115
+ self .image_processor = ImageProcessor (resolution , device )
262
116
117
+ def affine_transform_video (self , video_path ):
118
+ video_frames = read_video (video_path , change_fps = False )
119
+ results = []
120
+ for frame in video_frames :
121
+ frame , _ , _ = self .image_processor .affine_transform (frame )
122
+ results .append (frame )
123
+ results = torch .stack (results )
263
124
264
- # Refer to https://storage.googleapis.com/mediapipe-assets/documentation/mediapipe_face_landmark_fullsize.png
265
- mouth_surround_landmarks = [
266
- 164 ,
267
- 165 ,
268
- 167 ,
269
- 92 ,
270
- 186 ,
271
- 57 ,
272
- 43 ,
273
- 106 ,
274
- 182 ,
275
- 83 ,
276
- 18 ,
277
- 313 ,
278
- 406 ,
279
- 335 ,
280
- 273 ,
281
- 287 ,
282
- 410 ,
283
- 322 ,
284
- 391 ,
285
- 393 ,
286
- ]
125
+ results = rearrange (results , "f c h w -> f h w c" ).numpy ()
126
+ return results
287
127
288
- face_surround_landmarks = [
289
- 152 ,
290
- 377 ,
291
- 400 ,
292
- 378 ,
293
- 379 ,
294
- 365 ,
295
- 397 ,
296
- 288 ,
297
- 435 ,
298
- 433 ,
299
- 411 ,
300
- 425 ,
301
- 423 ,
302
- 327 ,
303
- 326 ,
304
- 94 ,
305
- 97 ,
306
- 98 ,
307
- 203 ,
308
- 205 ,
309
- 187 ,
310
- 213 ,
311
- 215 ,
312
- 58 ,
313
- 172 ,
314
- 136 ,
315
- 150 ,
316
- 149 ,
317
- 176 ,
318
- 148 ,
319
- ]
320
128
321
129
if __name__ == "__main__" :
322
- image_processor = ImageProcessor (512 , mask = "fix_mask" )
323
- video = cv2 .VideoCapture ("assets/demo1_video.mp4" )
324
- while True :
325
- ret , frame = video .read ()
326
- # if not ret:
327
- # break
328
-
329
- # cv2.imwrite("image.jpg", frame)
330
-
331
- frame = rearrange (torch .Tensor (frame ).type (torch .uint8 ), "h w c -> c h w" )
332
- # face, masked_face, _ = image_processor.preprocess_fixed_mask_image(frame, affine_transform=True)
333
- face , _ , _ = image_processor .affine_transform (frame )
334
-
335
- break
336
-
337
- face = (rearrange (face , "c h w -> h w c" ).detach ().cpu ().numpy ()).astype (np .uint8 )
338
- cv2 .imwrite ("face.jpg" , face )
339
-
340
- # masked_face = (rearrange(masked_face, "c h w -> h w c").detach().cpu().numpy()).astype(np.uint8)
341
- # cv2.imwrite("masked_face.jpg", masked_face)
130
+ video_processor = VideoProcessor (256 , "cuda" )
131
+ video_frames = video_processor .affine_transform_video ("validation/flux.mp4" )
132
+ write_video ("output.mp4" , video_frames , fps = 25 )
0 commit comments