@@ -220,7 +220,7 @@ def optimization_loop(self, image: torch.Tensor, first_frame: bool=False):
220
220
expression_param = torch .nn .Parameter (self .prev_expression .detach (), requires_grad = True )
221
221
jaw_param = torch .nn .Parameter (self .prev_jaw_pose .detach (), requires_grad = True )
222
222
neck_pose_param = torch .nn .Parameter (self .prev_neck_pose .detach (), requires_grad = True )
223
- eye_pose_param = self .prev_eye_pose .detach ()
223
+ eye_pose_param = self .prev_eye_pose .detach (). requires_grad_ ( False )
224
224
225
225
camera_trans = torch .nn .Parameter (self .prev_camera_trans .detach (), requires_grad = True )
226
226
camera_quat = torch .nn .Parameter (self .prev_camera_quat , requires_grad = True )
@@ -229,17 +229,18 @@ def optimization_loop(self, image: torch.Tensor, first_frame: bool=False):
229
229
betas = self .optim_kwargs ['betas' ]
230
230
if not first_frame :
231
231
lr = lr * 0.1
232
+
233
+ # flame optimizer
232
234
optim = torch .optim .Adam (
233
235
[expression_param , jaw_param , neck_pose_param ],
234
236
lr = lr , betas = betas
235
237
)
236
238
sched = torch .optim .lr_scheduler .MultiStepLR (optim , ** self .sched_kwargs )
237
239
240
+ # camera optimizer
238
241
cam_optim = torch .optim .Adam ([camera_trans , camera_quat ], lr = lr , betas = betas )
239
242
cam_sched = torch .optim .lr_scheduler .MultiStepLR (cam_optim , ** self .sched_kwargs )
240
243
241
-
242
-
243
244
# estimate mediapipe landmarks
244
245
mp_lmks_ref , fan_lmks_ref = self .face_parsing .parse_lmks ((image * 255 ).to (torch .uint8 ))
245
246
iris_lmks_ref = self .face_parsing .parse_iris_lmlks (mp_lmks_ref )
@@ -250,7 +251,7 @@ def optimization_loop(self, image: torch.Tensor, first_frame: bool=False):
250
251
251
252
iris_lmks_ref = iris_lmks_ref [..., 0 :2 ]
252
253
iris_lmks_ref = self .lmks2d_to_screen (iris_lmks_ref , image .shape [1 ], image .shape [2 ]).clone ().detach ().to (self .device )
253
- iris_lmks_center_ref = iris_lmks_ref [:, [0 , 5 ], :]
254
+ iris_lmks_center_ref = iris_lmks_ref [:, [5 , 0 ], :]
254
255
255
256
# get segmentation mask
256
257
segmentation_mask , lebeled_mask = self .face_parsing .parse_mask ((image [0 ].cpu ().numpy () * 255 ).astype (np .uint8 ))
@@ -266,15 +267,13 @@ def optimization_loop(self, image: torch.Tensor, first_frame: bool=False):
266
267
# get shape and landmarks
267
268
pose_param = torch .cat ([self .prev_global_rot , jaw_param ], dim = - 1 )
268
269
verts , lmks , mp_lmks = self .flame_model (self .shapecode , expression_param , pose_param , neck_pose_param , eye_pose_param )
269
- iris_lmks = verts [:, nir .k_iris_vert_idxs , :]
270
270
271
271
# with the current camera extrinsics
272
272
# transform landmarks to screen
273
273
rot = quaternion_to_matrix (camera_quat )
274
274
cameras = FoVPerspectiveCameras (0.01 , 1000 , 1 , R = rot , T = camera_trans ).to (self .device )
275
275
lmks2d = cameras .transform_points_screen (lmks , 1e-8 , image_size = (image .shape [1 ], image .shape [2 ]))[..., 0 :2 ]
276
276
mp_lmks2d = cameras .transform_points_screen (mp_lmks , 1e-8 , image_size = (image .shape [1 ], image .shape [2 ]))[..., 0 :2 ]
277
- iris_lmks2d = cameras .transform_points_screen (iris_lmks , 1e-8 , image_size = (image .shape [1 ], image .shape [2 ]))[..., 0 :2 ]
278
277
279
278
# render segmentation mask and debug view
280
279
rendered , rendered_mask = flame_renderer .render (verts , self .flame_model .faces_tensor , cameras , flame_mask_texture )
@@ -297,47 +296,8 @@ def optimization_loop(self, image: torch.Tensor, first_frame: bool=False):
297
296
self .logger .log_msg (f"{ iter } | loss: { loss .detach ().cpu ().item ()} " )
298
297
self .logger .log_image_w_lmks (image .permute (0 , 3 , 1 , 2 ), [mp_lmks_ref , mp_lmks2d ], 'mediapipe lmks' , radius = 1 )
299
298
self .logger .log_image_w_lmks (image .permute (0 , 3 , 1 , 2 ), [fan_lmks_ref , lmks2d ], 'retina lmks' , radius = 1 )
300
- self .logger .log_image_w_lmks (rendered [..., 0 :3 ].permute (0 , 3 , 1 , 2 ), [iris_lmks_center_ref , iris_lmks2d ], 'retina lmks' , radius = 1 )
301
299
self .logger .log_image (rendered_mask [..., 0 :3 ].permute (0 , 3 , 1 , 2 ), 'rendered mask' )
302
300
self .logger .log_image (lebeled_mask .permute (0 , 3 , 1 , 2 ), "face mask" )
303
-
304
-
305
- eye_pose_param = torch .nn .Parameter (self .prev_eye_pose .clone ().detach (), requires_grad = True )
306
- eye_optim = torch .optim .Adam ([eye_pose_param ], lr = lr * 0.1 , betas = betas )
307
- eye_sched = torch .optim .lr_scheduler .MultiStepLR (eye_optim , ** self .sched_kwargs )
308
-
309
- expression_param = expression_param .clone ().detach ().requires_grad_ (False )
310
- pose_param = pose_param .clone ().detach ().requires_grad_ (False )
311
- neck_pose_param = neck_pose_param .clone ().detach ().requires_grad_ (False )
312
- rot = quaternion_to_matrix (camera_quat .clone ().detach ())
313
- cameras = FoVPerspectiveCameras (0.01 , 1000 , 1 , R = rot , T = camera_trans .clone ().detach ).to (self .device )
314
-
315
- for iter in range (self .optim_iters ):
316
- eye_optim .zero_grad ()
317
-
318
- # get shape and landmarks
319
- pose_param = torch .cat ([self .prev_global_rot , jaw_param ], dim = - 1 )
320
- verts , lmks , mp_lmks = self .flame_model (
321
- self .shapecode ,
322
- expression_param ,
323
- pose_param ,
324
- neck_pose_param ,
325
- eye_pose_param
326
- )
327
-
328
- iris_lmks = verts [:, nir .k_iris_vert_idxs , :]
329
- iris_lmks2d = cameras .transform_points_screen (iris_lmks , 1e-8 , image_size = (image .shape [1 ], image .shape [2 ]))[..., 0 :2 ]
330
-
331
- # compute los
332
- iris_loss = self .criterion .wing_loss (iris_lmks2d , iris_lmks_center_ref )
333
-
334
- iris_loss .backward ()
335
- eye_optim .step ()
336
- eye_sched .step ()
337
-
338
- if (iter % self .logger .log_iters == 0 ) and not self .log_result :
339
- self .logger .log_msg (f"{ iter } | loss: { loss .detach ().cpu ().item ()} " )
340
- self .logger .log_image_w_lmks (rendered [..., 0 :3 ].permute (0 , 3 , 1 , 2 ), [iris_lmks_center_ref , iris_lmks2d ], 'retina lmks' , radius = 1 )
341
301
342
302
343
303
if self .log_result :
@@ -349,23 +309,101 @@ def optimization_loop(self, image: torch.Tensor, first_frame: bool=False):
349
309
self .logger .log_image (rendered [..., 0 :3 ].permute (0 , 3 , 1 , 2 ), 'rendered' )
350
310
self .logger .log_image_w_lmks (rendered [..., 0 :3 ].permute (0 , 3 , 1 , 2 ), mp_lmks2d , 'lmks on flame' , radius = 1 )
351
311
self .logger .log_image (lebeled_mask .permute (0 , 3 , 1 , 2 ), "face mask" )
312
+
352
313
self .prev_expression = expression_param .detach ()
353
314
self .prev_global_rot = pose_param [:, 0 :3 ].detach ()
354
315
self .prev_jaw_pose = pose_param [:, 3 :].detach ()
355
316
self .prev_neck_pose = neck_pose_param .detach ()
356
- self .prev_eye_pose = eye_pose_param .detach ()
357
317
self .prev_camera_trans = camera_trans .detach ()
358
318
self .prev_camera_quat = camera_quat .detach ()
359
319
# intrinsics = cameras.get_projection_transform()
360
320
return {
361
- "camera_intrinsics " : cameras .get_projection_transform ()._matrix .detach (),
362
- "camera_translation " : camera_trans .detach (),
363
- "camera_quaternion " : camera_quat .detach (),
321
+ "cam_intrinsics_p3d " : cameras .get_projection_transform ()._matrix .detach (),
322
+ "cam_position " : camera_trans .detach (),
323
+ "cam_quaternion " : camera_quat .detach (),
364
324
"flame_expression" : expression_param .detach (),
365
325
"flame_pose" : pose_param .detach (),
366
326
"flame_neck_pose" : neck_pose_param .detach (),
367
- "flame_eyes_pose" : eye_pose_param .detach ()
368
- }
327
+ }, iris_lmks_center_ref
328
+
329
+
330
+ class IrisOptimization :
331
+ def __init__ (self ,
332
+ flame_model ,
333
+ face_parsing_module ,
334
+ logger ,
335
+ optim_kwargs ,
336
+ sched_kwargs ,
337
+ loss_kwargs ,
338
+ log_result_only : bool = False ,
339
+ optim_iters : int = 5000 ,
340
+ device : str = "cuda:0"
341
+ ):
342
+ self .flame_model = flame_model
343
+ self .logger = logger
344
+ self .face_parsing = face_parsing_module
345
+
346
+ self .optim_kwargs = optim_kwargs
347
+ self .sched_kwargs = sched_kwargs
348
+
349
+ # configure loss
350
+ self .criterion = OptimizationLoss (** loss_kwargs )
351
+ self .log_results_only = log_result_only
352
+ self .optim_iters = optim_iters
353
+ self .device = torch .device (device )
354
+
355
+
356
+ self .prev_eye_pose = torch .zeros ([1 , 6 ], device = self .device , dtype = torch .float32 )
357
+
358
+ def lmks2d_to_screen (self , lmks2d , width , height ):
359
+ lmks2d [..., 0 ] = torch .ceil (lmks2d [..., 0 ] * height )
360
+ lmks2d [..., 1 ] = torch .ceil (lmks2d [..., 1 ] * width )
361
+ return lmks2d .long ()
362
+
363
+ def optimization_loop (
364
+ self ,
365
+ image ,
366
+ iris_lmks_ref ,
367
+ flame_shape ,
368
+ flame_expression ,
369
+ flame_pose ,
370
+ flame_neck_pose ,
371
+ camera_quaternion ,
372
+ camera_trans
373
+ ):
374
+ image = torch .from_numpy (image )[None ].to (self .device , dtype = torch .float32 ) / 255.0
375
+
376
+ # create paramters
377
+ eye_pose_param = torch .nn .Parameter (self .prev_eye_pose , requires_grad = True )
378
+
379
+ optim = torch .optim .Adam ([eye_pose_param ], lr = self .optim_kwargs ['lr' ] * 0.1 , betas = self .optim_kwargs ['betas' ])
380
+ sched = torch .optim .lr_scheduler .MultiStepLR (optim , ** self .sched_kwargs )
381
+
382
+ for iter in tqdm (range (self .optim_iters ), total = self .optim_iters , desc = "iris optimization" ):
383
+ optim .zero_grad ()
384
+
385
+ verts , lmks , mp_lmks = self .flame_model (
386
+ flame_shape , flame_expression , flame_pose , flame_neck_pose , eye_pose_param
387
+ )
388
+ iris_lmks = verts [:, nir .k_iris_vert_idxs , :]
389
+ rot = quaternion_to_matrix (camera_quaternion )
390
+ cameras = FoVPerspectiveCameras (0.01 , 1000 , 1 , R = rot , T = camera_trans ).to (self .device )
391
+ iris_lmks2d = cameras .transform_points_screen (iris_lmks , 1e-8 , image_size = (image .shape [1 ], image .shape [2 ]))[..., 0 :2 ]
392
+
393
+ loss = torch .nn .functional .l1_loss (iris_lmks2d , iris_lmks_ref )
394
+
395
+ loss .backward (retain_graph = True )
396
+ optim .step ()
397
+ sched .step ()
398
+
399
+ if (iter % self .logger .log_iters == 0 ) and not self .log_results_only :
400
+ self .logger .log_msg (f"{ iter } | loss { loss .detach ().cpu ().item ()} " )
401
+ self .logger .log_image_w_lmks (image [..., 0 :3 ].permute (0 , 3 , 1 , 2 ), [iris_lmks_ref , iris_lmks2d ], 'retina lmks' , radius = 1 )
402
+ if self .log_results_only :
403
+ self .logger .log_image_w_lmks (image [..., 0 :3 ].permute (0 , 3 , 1 , 2 ), [iris_lmks_ref , iris_lmks2d ], 'retina lmks' , radius = 1 )
404
+
405
+ self .prev_eye_pose = eye_pose_param .detach ()
406
+ return eye_pose_param .detach ()
369
407
370
408
371
409
@@ -444,9 +482,9 @@ def save_state(self,
444
482
flame_pose : torch .Tensor ,
445
483
flame_neck_pose : torch .Tensor ,
446
484
flame_eyes_pose : torch .Tensor ,
447
- camera_intrinsics : torch .Tensor ,
448
- camera_quaternion : torch .Tensor ,
449
- camera_translation : torch .Tensor ,
485
+ cam_intrinsics_p3d : torch .Tensor ,
486
+ cam_quaternion : torch .Tensor ,
487
+ cam_position : torch .Tensor ,
450
488
):
451
489
rgb_path = os .path .join (self .current_output_dir , self .video_id + f"_frm{ frame_idx } .png" )
452
490
nir .save_image (rgb_path , rgb )
@@ -457,10 +495,10 @@ def save_state(self,
457
495
"flame_expression" : flame_expression .cpu ().numpy (),
458
496
"flame_pose" : flame_pose .cpu ().numpy (),
459
497
"flame_neck_pose" : flame_neck_pose .cpu ().numpy (),
460
- " flame_eyes_pose" : flame_eyes_pose . cpu (). numpy () ,
461
- "cam_intrinsics " : camera_intrinsics .cpu ().numpy (),
462
- "cam_quaternion" : camera_quaternion .cpu ().numpy (),
463
- "cam_position" : camera_translation .cpu ().numpy ()
498
+ ' flame_eyes_pose' : flame_eyes_pose ,
499
+ "cam_intrinsics_p3d " : cam_intrinsics_p3d .cpu ().numpy (),
500
+ "cam_quaternion" : cam_quaternion .cpu ().numpy (),
501
+ "cam_position" : cam_position .cpu ().numpy ()
464
502
}
465
503
with open (npz_path , 'wb' ) as outfd :
466
504
np .savez (npz_path , ** npz_data )
@@ -486,14 +524,27 @@ def save_state(self,
486
524
# create the estimators
487
525
mica_estimator = MicaEstimator (** conf .mica_estimator_kwargs )
488
526
flame_optimizer = FLAMEPoseExpressionOptimization (** conf .flame_pose_expression_optimization_kwargs )
527
+ iris_optimizer = IrisOptimization (
528
+ flame_optimizer .flame_model ,
529
+ flame_optimizer .face_parsing ,
530
+ flame_optimizer .logger ,
531
+ conf .flame_pose_expression_optimization_kwargs ['optim_kwargs' ],
532
+ conf .flame_pose_expression_optimization_kwargs ['sched_kwargs' ],
533
+ conf .flame_pose_expression_optimization_kwargs ['loss_kwargs' ],
534
+ conf .flame_pose_expression_optimization_kwargs ['log_result_only' ],
535
+ conf .flame_pose_expression_optimization_kwargs ['optim_iters' ],
536
+ 'cuda:0'
537
+ )
489
538
490
539
# create dataset
491
- dataset = nir .get_dataset ("SingleVideoDataset" , ** conf .video_dataset_kwargs )
540
+ # dataset = nir.get_dataset("SingleVideoDataset", **conf.video_dataset_kwargs)
492
541
493
542
# Get all video filepaths
494
543
filenames = os .listdir (conf .base_dir )
495
544
print ("Starting preprocessing" )
496
545
for filename in filenames :
546
+ if not filename .endswith ('mp4' ):
547
+ continue
497
548
filepath = os .path .join (conf .base_dir , filename )
498
549
print (f"Processing file: { filename } " )
499
550
dataset = nir .get_dataset ("SingleVideoDataset" , filepath = filepath , preload = True )
@@ -507,7 +558,17 @@ def save_state(self,
507
558
508
559
# image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
509
560
flame_optimizer .reset (shapecode , lmks )
510
- optimized_data = flame_optimizer .optimization_loop (image , True if frame_idx == 0 else False )
561
+ optimized_data , iris_lmks = flame_optimizer .optimization_loop (image , True if frame_idx == 0 else False )
562
+ flame_eye_pose = iris_optimizer .optimization_loop (
563
+ image , iris_lmks , shapecode ,
564
+ optimized_data ['flame_expression' ],
565
+ optimized_data ['flame_pose' ],
566
+ optimized_data ['flame_neck_pose' ],
567
+ optimized_data ['cam_quaternion' ],
568
+ optimized_data ['cam_position' ]
569
+ )
570
+
571
+ optimized_data ['flame_eyes_pose' ] = flame_eye_pose .detach ().cpu ().numpy ()
511
572
optimized_data ['flame_shape' ] = shapecode .detach ().cpu ().numpy ()
512
573
optimized_data ['rgb' ] = data .rgb
513
574
optimized_data ['frame_idx' ] = frame_idx
0 commit comments