Skip to content

Commit 961eab9

Browse files
committed
added expression regularization
1 parent a08e72a commit 961eab9

File tree

2 files changed

+43
-9
lines changed

2 files changed

+43
-9
lines changed

preproc_video.py

Lines changed: 36 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from glob import glob
66
from pathlib import Path
77

8-
import cv2
8+
import sys
99
import numpy as np
1010
import torch
1111
import torch.backends.cudnn as cudnn
@@ -83,6 +83,8 @@ def render(self, verts, faces, cameras, flame_mask_tex=None):
8383

8484
return debug_view, None
8585

86+
def l2_loss(inputs: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
87+
return torch.sqrt(((inputs - target) ** 2).sum(dim=-1)).mean(dim=1).mean()
8688

8789
class OptimizationLoss(torch.nn.Module):
8890
def __init__(
@@ -124,17 +126,20 @@ def forward(
124126
fan_lmks_tgt: torch.Tensor,
125127
seg_mask: torch.Tensor,
126128
seg_mask_tgt: torch.Tensor,
127-
expresion_vector: torch.Tensor,
129+
expression_vector: torch.Tensor,
128130
iris_lmks: torch.Tensor=None,
129131
iris_lmks_tgt: torch.Tensor=None,
130132
) -> torch.Tensor:
131133
""""""
132-
mp_loss = self.wing_loss(mp_lmks, mp_lmks_tgt)
133-
fan_loss = self.wing_loss(fan_lmks, fan_lmks_tgt)
134-
iris_loss = self.wing_loss(iris_lmks, iris_lmks_tgt) if iris_lmks is not None else torch.zeros([1] ,device=mp_lmks.device, dtype=torch.float32)
134+
mp_loss = l2_loss(mp_lmks, mp_lmks_tgt)
135+
fan_loss = l2_loss(fan_lmks, fan_lmks_tgt)
136+
iris_loss = l2_loss(iris_lmks, iris_lmks_tgt) if iris_lmks is not None else torch.zeros([1] ,device=mp_lmks.device, dtype=torch.float32)
135137

136138
seg_mask_loss = torch.abs(seg_mask - seg_mask_tgt).mean()
137-
output = mp_loss * self.w_mp + fan_loss + iris_loss + seg_mask_loss * self.w_seg + expresion_vector.abs().mean() * self.w_reg
139+
140+
expression_reg = torch.mean(torch.square(expression_vector)) * self.w_reg
141+
expression_reg += torch.mean(torch.square(expression_vector[1:] - expression_vector[:-1])) * 1e-1
142+
output = mp_loss * self.w_mp + fan_loss + iris_loss + seg_mask_loss * self.w_seg + expression_reg
138143
return output
139144

140145

@@ -463,6 +468,23 @@ def process(self, image: np.ndarray, image_size: int=224):
463468
lmk = self.mica.flame.compute_landmarks(meshes)
464469
return meshes[0].detach().cpu(), code, lmk
465470

471+
472+
class Matting:
473+
def __init__(self, script_path: str, chkp_path: str):
474+
self.script_path = script_path
475+
self.chkp_path = chkp_path
476+
477+
def convert(self, video_path: str, output_mask_path: str):
478+
args = "--variant mobilenetv3 "
479+
args += f"--checkpoint {self.chkp_path} "
480+
args += f"--input-source {video_path} "
481+
args += "--output-type png_sequence "
482+
args += f"--output-alpha {output_mask_path} "
483+
args += "--device cuda"
484+
cmd = f"python {self.script_path} {args}"
485+
os.system(cmd)
486+
487+
466488
class DataSaver:
467489
def __init__(self, output_base: str, save_id_mesh: bool=True):
468490
self.output_base = output_base
@@ -541,6 +563,7 @@ def save_state(self,
541563
'cuda:0'
542564
)
543565

566+
matting = Matting(**conf.matting_kwargs)
544567
# create dataset
545568
# dataset = nir.get_dataset("SingleVideoDataset", **conf.video_dataset_kwargs)
546569

@@ -551,13 +574,20 @@ def save_state(self,
551574
for filename in filenames:
552575
if not filename.endswith('mp4'):
553576
continue
577+
578+
554579

555580
filepath = os.path.join(conf.base_dir, filename)
556581
print(f"Processing file: {filename}")
557582
dataset = nir.get_dataset("SingleVideoDataset", filepath=filepath, preload=True)
558583
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, pin_memory=True, num_workers=1, collate_fn=nir.collate_fn)
559584

560585
data_saver.set_output_state(filename.split('.')[0])
586+
587+
# preprocess whole video with matting
588+
matting_alpha_path = data_saver.current_output_dir
589+
print("estimating alpha masks")
590+
matting.convert(filepath, matting_alpha_path)
561591

562592
for frame_idx, data in tqdm(enumerate(dataloader), total=len(dataloader), desc="video progress"):
563593
data_saver.set_frame_index(frame_idx)

preproc_video.yaml

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ video_dataset_kwargs:
1111

1212

1313
flame_pose_expression_optimization_kwargs:
14-
optim_iters: 2000
14+
optim_iters: 1000
1515
log_result_only: true
1616
cam_init_z_trans: 0.5
1717

@@ -40,7 +40,7 @@ flame_pose_expression_optimization_kwargs:
4040
loss_kwargs:
4141
w_mp: 0.3
4242
w_seg: 0.6
43-
w_reg: 0.7
43+
w_reg: 0.8
4444
wing_loss_kwargs:
4545
omega: 10.0
4646
eps: 2.0
@@ -55,7 +55,7 @@ flame_pose_expression_optimization_kwargs:
5555
betas: [0.9, 0.999]
5656

5757
sched_kwargs:
58-
milestones: [200, 1000, 1500]
58+
milestones: [200, 800, 1500]
5959
gamma: 0.1
6060

6161
logger_kwargs:
@@ -64,3 +64,7 @@ flame_pose_expression_optimization_kwargs:
6464
port: 8097
6565
experiment_id: "Single video face parsing"
6666
log_iters: 500
67+
68+
matting_kwargs:
69+
script_path: /media/perukas/Home/_dev/_phd/libraries/RobustVideoMatting/inference.py
70+
chkp_path: /media/perukas/Home/_dev/_phd/libraries/RobustVideoMatting/weights/rvm_mobilenetv3.pth

0 commit comments

Comments
 (0)