Skip to content

Commit 3c31946

Browse files
committed
fixed iris lmks, + iris optimization
1 parent 667c608 commit 3c31946

File tree

2 files changed

+181
-60
lines changed

2 files changed

+181
-60
lines changed

preproc_video.py

Lines changed: 121 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ def optimization_loop(self, image: torch.Tensor, first_frame: bool=False):
220220
expression_param = torch.nn.Parameter(self.prev_expression.detach(), requires_grad=True)
221221
jaw_param = torch.nn.Parameter(self.prev_jaw_pose.detach(), requires_grad=True)
222222
neck_pose_param = torch.nn.Parameter(self.prev_neck_pose.detach(), requires_grad=True)
223-
eye_pose_param = self.prev_eye_pose.detach()
223+
eye_pose_param = self.prev_eye_pose.detach().requires_grad_(False)
224224

225225
camera_trans = torch.nn.Parameter(self.prev_camera_trans.detach(), requires_grad=True)
226226
camera_quat = torch.nn.Parameter(self.prev_camera_quat, requires_grad=True)
@@ -229,17 +229,18 @@ def optimization_loop(self, image: torch.Tensor, first_frame: bool=False):
229229
betas = self.optim_kwargs['betas']
230230
if not first_frame:
231231
lr = lr * 0.1
232+
233+
# flame optimizer
232234
optim = torch.optim.Adam(
233235
[expression_param, jaw_param, neck_pose_param],
234236
lr=lr, betas=betas
235237
)
236238
sched = torch.optim.lr_scheduler.MultiStepLR(optim, **self.sched_kwargs)
237239

240+
# camera optimizer
238241
cam_optim = torch.optim.Adam([camera_trans, camera_quat], lr=lr, betas=betas)
239242
cam_sched = torch.optim.lr_scheduler.MultiStepLR(cam_optim, **self.sched_kwargs)
240243

241-
242-
243244
# estimate mediapipe landmarks
244245
mp_lmks_ref, fan_lmks_ref = self.face_parsing.parse_lmks((image * 255).to(torch.uint8))
245246
iris_lmks_ref = self.face_parsing.parse_iris_lmlks(mp_lmks_ref)
@@ -250,7 +251,7 @@ def optimization_loop(self, image: torch.Tensor, first_frame: bool=False):
250251

251252
iris_lmks_ref = iris_lmks_ref[..., 0:2]
252253
iris_lmks_ref = self.lmks2d_to_screen(iris_lmks_ref, image.shape[1], image.shape[2]).clone().detach().to(self.device)
253-
iris_lmks_center_ref = iris_lmks_ref[:, [0, 5], :]
254+
iris_lmks_center_ref = iris_lmks_ref[:, [5, 0], :]
254255

255256
# get segmentation mask
256257
segmentation_mask, lebeled_mask = self.face_parsing.parse_mask((image[0].cpu().numpy() * 255).astype(np.uint8))
@@ -266,15 +267,13 @@ def optimization_loop(self, image: torch.Tensor, first_frame: bool=False):
266267
# get shape and landmarks
267268
pose_param = torch.cat([self.prev_global_rot, jaw_param], dim=-1)
268269
verts, lmks, mp_lmks = self.flame_model(self.shapecode, expression_param, pose_param, neck_pose_param, eye_pose_param)
269-
iris_lmks = verts[:, nir.k_iris_vert_idxs, :]
270270

271271
# with the current camera extrinsics
272272
# transform landmarks to screen
273273
rot = quaternion_to_matrix(camera_quat)
274274
cameras = FoVPerspectiveCameras(0.01, 1000, 1, R=rot, T=camera_trans).to(self.device)
275275
lmks2d = cameras.transform_points_screen(lmks, 1e-8, image_size=(image.shape[1], image.shape[2]))[..., 0:2]
276276
mp_lmks2d = cameras.transform_points_screen(mp_lmks, 1e-8, image_size=(image.shape[1], image.shape[2]))[..., 0:2]
277-
iris_lmks2d = cameras.transform_points_screen(iris_lmks, 1e-8, image_size=(image.shape[1], image.shape[2]))[..., 0:2]
278277

279278
# render segmentation mask and debug view
280279
rendered, rendered_mask = flame_renderer.render(verts, self.flame_model.faces_tensor, cameras, flame_mask_texture)
@@ -297,47 +296,8 @@ def optimization_loop(self, image: torch.Tensor, first_frame: bool=False):
297296
self.logger.log_msg(f"{iter} | loss: {loss.detach().cpu().item()}")
298297
self.logger.log_image_w_lmks(image.permute(0, 3, 1, 2), [mp_lmks_ref, mp_lmks2d], 'mediapipe lmks', radius=1)
299298
self.logger.log_image_w_lmks(image.permute(0, 3, 1, 2), [fan_lmks_ref, lmks2d], 'retina lmks', radius=1)
300-
self.logger.log_image_w_lmks(rendered[..., 0:3].permute(0, 3, 1, 2), [iris_lmks_center_ref, iris_lmks2d], 'retina lmks', radius=1)
301299
self.logger.log_image(rendered_mask[..., 0:3].permute(0, 3, 1, 2), 'rendered mask')
302300
self.logger.log_image(lebeled_mask.permute(0, 3, 1, 2), "face mask")
303-
304-
305-
eye_pose_param = torch.nn.Parameter(self.prev_eye_pose.clone().detach(), requires_grad=True)
306-
eye_optim = torch.optim.Adam([eye_pose_param], lr=lr * 0.1, betas=betas)
307-
eye_sched = torch.optim.lr_scheduler.MultiStepLR(eye_optim, **self.sched_kwargs)
308-
309-
expression_param = expression_param.clone().detach().requires_grad_(False)
310-
pose_param = pose_param.clone().detach().requires_grad_(False)
311-
neck_pose_param = neck_pose_param.clone().detach().requires_grad_(False)
312-
rot = quaternion_to_matrix(camera_quat.clone().detach())
313-
cameras = FoVPerspectiveCameras(0.01, 1000, 1, R=rot, T=camera_trans.clone().detach).to(self.device)
314-
315-
for iter in range(self.optim_iters):
316-
eye_optim.zero_grad()
317-
318-
# get shape and landmarks
319-
pose_param = torch.cat([self.prev_global_rot, jaw_param], dim=-1)
320-
verts, lmks, mp_lmks = self.flame_model(
321-
self.shapecode,
322-
expression_param,
323-
pose_param,
324-
neck_pose_param,
325-
eye_pose_param
326-
)
327-
328-
iris_lmks = verts[:, nir.k_iris_vert_idxs, :]
329-
iris_lmks2d = cameras.transform_points_screen(iris_lmks, 1e-8, image_size=(image.shape[1], image.shape[2]))[..., 0:2]
330-
331-
# compute los
332-
iris_loss = self.criterion.wing_loss(iris_lmks2d, iris_lmks_center_ref)
333-
334-
iris_loss.backward()
335-
eye_optim.step()
336-
eye_sched.step()
337-
338-
if (iter % self.logger.log_iters == 0) and not self.log_result:
339-
self.logger.log_msg(f"{iter} | loss: {loss.detach().cpu().item()}")
340-
self.logger.log_image_w_lmks(rendered[..., 0:3].permute(0, 3, 1, 2), [iris_lmks_center_ref, iris_lmks2d], 'retina lmks', radius=1)
341301

342302

343303
if self.log_result:
@@ -349,23 +309,101 @@ def optimization_loop(self, image: torch.Tensor, first_frame: bool=False):
349309
self.logger.log_image(rendered[..., 0:3].permute(0, 3, 1, 2), 'rendered')
350310
self.logger.log_image_w_lmks(rendered[..., 0:3].permute(0, 3, 1, 2), mp_lmks2d, 'lmks on flame', radius=1)
351311
self.logger.log_image(lebeled_mask.permute(0, 3, 1, 2), "face mask")
312+
352313
self.prev_expression = expression_param.detach()
353314
self.prev_global_rot = pose_param[:, 0:3].detach()
354315
self.prev_jaw_pose = pose_param[:, 3:].detach()
355316
self.prev_neck_pose = neck_pose_param.detach()
356-
self.prev_eye_pose = eye_pose_param.detach()
357317
self.prev_camera_trans = camera_trans.detach()
358318
self.prev_camera_quat = camera_quat.detach()
359319
# intrinsics = cameras.get_projection_transform()
360320
return {
361-
"camera_intrinsics": cameras.get_projection_transform()._matrix.detach(),
362-
"camera_translation": camera_trans.detach(),
363-
"camera_quaternion": camera_quat.detach(),
321+
"cam_intrinsics_p3d": cameras.get_projection_transform()._matrix.detach(),
322+
"cam_position": camera_trans.detach(),
323+
"cam_quaternion": camera_quat.detach(),
364324
"flame_expression": expression_param.detach(),
365325
"flame_pose": pose_param.detach(),
366326
"flame_neck_pose": neck_pose_param.detach(),
367-
"flame_eyes_pose": eye_pose_param.detach()
368-
}
327+
}, iris_lmks_center_ref
328+
329+
330+
class IrisOptimization:
331+
def __init__(self,
332+
flame_model,
333+
face_parsing_module,
334+
logger,
335+
optim_kwargs,
336+
sched_kwargs,
337+
loss_kwargs,
338+
log_result_only: bool=False,
339+
optim_iters: int=5000,
340+
device: str="cuda:0"
341+
):
342+
self.flame_model = flame_model
343+
self.logger = logger
344+
self.face_parsing = face_parsing_module
345+
346+
self.optim_kwargs = optim_kwargs
347+
self.sched_kwargs = sched_kwargs
348+
349+
# configure loss
350+
self.criterion = OptimizationLoss(**loss_kwargs)
351+
self.log_results_only = log_result_only
352+
self.optim_iters = optim_iters
353+
self.device = torch.device(device)
354+
355+
356+
self.prev_eye_pose = torch.zeros([1, 6], device=self.device, dtype=torch.float32)
357+
358+
def lmks2d_to_screen(self, lmks2d, width, height):
359+
lmks2d[..., 0] = torch.ceil(lmks2d[..., 0] * height)
360+
lmks2d[..., 1] = torch.ceil(lmks2d[..., 1] * width)
361+
return lmks2d.long()
362+
363+
def optimization_loop(
364+
self,
365+
image,
366+
iris_lmks_ref,
367+
flame_shape,
368+
flame_expression,
369+
flame_pose,
370+
flame_neck_pose,
371+
camera_quaternion,
372+
camera_trans
373+
):
374+
image = torch.from_numpy(image)[None].to(self.device, dtype=torch.float32) / 255.0
375+
376+
# create paramters
377+
eye_pose_param = torch.nn.Parameter(self.prev_eye_pose, requires_grad=True)
378+
379+
optim = torch.optim.Adam([eye_pose_param], lr=self.optim_kwargs['lr'] * 0.1, betas=self.optim_kwargs['betas'])
380+
sched = torch.optim.lr_scheduler.MultiStepLR(optim, **self.sched_kwargs)
381+
382+
for iter in tqdm(range(self.optim_iters), total=self.optim_iters, desc="iris optimization"):
383+
optim.zero_grad()
384+
385+
verts, lmks, mp_lmks = self.flame_model(
386+
flame_shape, flame_expression, flame_pose, flame_neck_pose, eye_pose_param
387+
)
388+
iris_lmks = verts[:, nir.k_iris_vert_idxs, :]
389+
rot = quaternion_to_matrix(camera_quaternion)
390+
cameras = FoVPerspectiveCameras(0.01, 1000, 1, R=rot, T=camera_trans).to(self.device)
391+
iris_lmks2d = cameras.transform_points_screen(iris_lmks, 1e-8, image_size=(image.shape[1], image.shape[2]))[..., 0:2]
392+
393+
loss = torch.nn.functional.l1_loss(iris_lmks2d, iris_lmks_ref)
394+
395+
loss.backward(retain_graph=True)
396+
optim.step()
397+
sched.step()
398+
399+
if (iter % self.logger.log_iters == 0) and not self.log_results_only:
400+
self.logger.log_msg(f"{iter} | loss {loss.detach().cpu().item()}")
401+
self.logger.log_image_w_lmks(image[..., 0:3].permute(0, 3, 1, 2), [iris_lmks_ref, iris_lmks2d], 'retina lmks', radius=1)
402+
if self.log_results_only:
403+
self.logger.log_image_w_lmks(image[..., 0:3].permute(0, 3, 1, 2), [iris_lmks_ref, iris_lmks2d], 'retina lmks', radius=1)
404+
405+
self.prev_eye_pose = eye_pose_param.detach()
406+
return eye_pose_param.detach()
369407

370408

371409

@@ -444,9 +482,9 @@ def save_state(self,
444482
flame_pose: torch.Tensor,
445483
flame_neck_pose: torch.Tensor,
446484
flame_eyes_pose: torch.Tensor,
447-
camera_intrinsics: torch.Tensor,
448-
camera_quaternion: torch.Tensor,
449-
camera_translation: torch.Tensor,
485+
cam_intrinsics_p3d: torch.Tensor,
486+
cam_quaternion: torch.Tensor,
487+
cam_position: torch.Tensor,
450488
):
451489
rgb_path = os.path.join(self.current_output_dir, self.video_id + f"_frm{frame_idx}.png")
452490
nir.save_image(rgb_path, rgb)
@@ -457,10 +495,10 @@ def save_state(self,
457495
"flame_expression": flame_expression.cpu().numpy(),
458496
"flame_pose": flame_pose.cpu().numpy(),
459497
"flame_neck_pose": flame_neck_pose.cpu().numpy(),
460-
"flame_eyes_pose": flame_eyes_pose.cpu().numpy(),
461-
"cam_intrinsics": camera_intrinsics.cpu().numpy(),
462-
"cam_quaternion": camera_quaternion.cpu().numpy(),
463-
"cam_position": camera_translation.cpu().numpy()
498+
'flame_eyes_pose': flame_eyes_pose,
499+
"cam_intrinsics_p3d": cam_intrinsics_p3d.cpu().numpy(),
500+
"cam_quaternion": cam_quaternion.cpu().numpy(),
501+
"cam_position": cam_position.cpu().numpy()
464502
}
465503
with open(npz_path, 'wb') as outfd:
466504
np.savez(npz_path, **npz_data)
@@ -486,14 +524,27 @@ def save_state(self,
486524
# create the estimators
487525
mica_estimator = MicaEstimator(**conf.mica_estimator_kwargs)
488526
flame_optimizer = FLAMEPoseExpressionOptimization(**conf.flame_pose_expression_optimization_kwargs)
527+
iris_optimizer = IrisOptimization(
528+
flame_optimizer.flame_model,
529+
flame_optimizer.face_parsing,
530+
flame_optimizer.logger,
531+
conf.flame_pose_expression_optimization_kwargs['optim_kwargs'],
532+
conf.flame_pose_expression_optimization_kwargs['sched_kwargs'],
533+
conf.flame_pose_expression_optimization_kwargs['loss_kwargs'],
534+
conf.flame_pose_expression_optimization_kwargs['log_result_only'],
535+
conf.flame_pose_expression_optimization_kwargs['optim_iters'],
536+
'cuda:0'
537+
)
489538

490539
# create dataset
491-
dataset = nir.get_dataset("SingleVideoDataset", **conf.video_dataset_kwargs)
540+
# dataset = nir.get_dataset("SingleVideoDataset", **conf.video_dataset_kwargs)
492541

493542
# Get all video filepaths
494543
filenames = os.listdir(conf.base_dir)
495544
print("Starting preprocessing")
496545
for filename in filenames:
546+
if not filename.endswith('mp4'):
547+
continue
497548
filepath = os.path.join(conf.base_dir, filename)
498549
print(f"Processing file: {filename}")
499550
dataset = nir.get_dataset("SingleVideoDataset", filepath=filepath, preload=True)
@@ -507,7 +558,17 @@ def save_state(self,
507558

508559
# image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
509560
flame_optimizer.reset(shapecode, lmks)
510-
optimized_data = flame_optimizer.optimization_loop(image, True if frame_idx == 0 else False)
561+
optimized_data, iris_lmks = flame_optimizer.optimization_loop(image, True if frame_idx == 0 else False)
562+
flame_eye_pose = iris_optimizer.optimization_loop(
563+
image, iris_lmks, shapecode,
564+
optimized_data['flame_expression'],
565+
optimized_data['flame_pose'],
566+
optimized_data['flame_neck_pose'],
567+
optimized_data['cam_quaternion'],
568+
optimized_data['cam_position']
569+
)
570+
571+
optimized_data['flame_eyes_pose'] = flame_eye_pose.detach().cpu().numpy()
511572
optimized_data['flame_shape'] = shapecode.detach().cpu().numpy()
512573
optimized_data['rgb'] = data.rgb
513574
optimized_data['frame_idx'] = frame_idx

preproc_video_certh.yaml

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
base_dir: '/media/vcl3d/d1d61452-6b35-42dd-ad70-5036e8c2cfc8/ankarako/dev/datasets/2d/moi/'
2+
output_dir: '/media/vcl3d/d1d61452-6b35-42dd-ad70-5036e8c2cfc8/ankarako/dev/datasets/2d/moi/processed'
3+
4+
mica_estimator_kwargs:
5+
chkp_path: "/media/vcl3d/d1d61452-6b35-42dd-ad70-5036e8c2cfc8/ankarako/dev/libs/MICA/data/pretrained/mica.tar"
6+
7+
flame_pose_expression_optimization_kwargs:
8+
optim_iters: 2000
9+
log_result_only: true
10+
cam_init_z_trans: 0.5
11+
12+
face_parsing_kwargs:
13+
use_fan: false
14+
mp_face_mesh_detector_kwargs:
15+
max_num_faces: 1
16+
min_detection_conf: 0.5
17+
face_segmentor_kwargs:
18+
threshold: 0.8
19+
chkp: /media/vcl3d/d1d61452-6b35-42dd-ad70-5036e8c2cfc8/ankarako/dev/libs/face_parsing/ibug/face_parsing/rtnet/weights/rtnet50-fcn-11.torch
20+
nclasses: 11
21+
22+
flame_model_cfg:
23+
flame_model_path: /media/vcl3d/d1d61452-6b35-42dd-ad70-5036e8c2cfc8/ankarako/dev/3dmms/FLAME/FLAME2023/flame2023.pkl
24+
batch_size: 1
25+
use_face_contour: true
26+
shape_params: 300
27+
expression_params: 100
28+
use_3D_translation: false
29+
static_landmark_embedding_path: /media/vcl3d/d1d61452-6b35-42dd-ad70-5036e8c2cfc8/ankarako/dev/3dmms/FLAME/flame_static_embedding.pkl
30+
dynamic_landmark_embedding_path: /media/vcl3d/d1d61452-6b35-42dd-ad70-5036e8c2cfc8/ankarako/dev/3dmms/FLAME/flame_dynamic_embedding.npy
31+
mediapipe_landmark_embedding_path: /media/vcl3d/d1d61452-6b35-42dd-ad70-5036e8c2cfc8/ankarako/dev/3dmms/FLAME/mediapipe_landmark_embedding/mediapipe_landmark_embedding.npz
32+
flame_masks: /media/vcl3d/d1d61452-6b35-42dd-ad70-5036e8c2cfc8/ankarako/dev/3dmms/FLAME/FLAME_masks/FLAME_masks.pkl
33+
34+
loss_kwargs:
35+
w_mp: 0.3
36+
w_seg: 0.5
37+
w_reg: 1.0
38+
wing_loss_kwargs:
39+
omega: 10.0
40+
eps: 2.0
41+
adaptive_wing_loss_kwargs: null
42+
# omega: 24
43+
# theta: 0.5
44+
# eps: 1.0
45+
# alpha: 2.1
46+
47+
optim_kwargs:
48+
lr: 1.0e-2
49+
betas: [0.9, 0.999]
50+
51+
sched_kwargs:
52+
milestones: [200, 1500]
53+
gamma: 0.1
54+
55+
logger_kwargs:
56+
filepath: /media/vcl3d/d1d61452-6b35-42dd-ad70-5036e8c2cfc8/ankarako/dev/datasets/2d/moi/log.txt
57+
address: "127.0.0.1"
58+
port: 8097
59+
experiment_id: "Single video face parsing"
60+
log_iters: 10

0 commit comments

Comments
 (0)