5
5
from glob import glob
6
6
from pathlib import Path
7
7
8
- import cv2
8
+ import sys
9
9
import numpy as np
10
10
import torch
11
11
import torch .backends .cudnn as cudnn
@@ -83,6 +83,8 @@ def render(self, verts, faces, cameras, flame_mask_tex=None):
83
83
84
84
return debug_view , None
85
85
86
+ def l2_loss (inputs : torch .Tensor , target : torch .Tensor ) -> torch .Tensor :
87
+ return torch .sqrt (((inputs - target ) ** 2 ).sum (dim = - 1 )).mean (dim = 1 ).mean ()
86
88
87
89
class OptimizationLoss (torch .nn .Module ):
88
90
def __init__ (
@@ -124,17 +126,20 @@ def forward(
124
126
fan_lmks_tgt : torch .Tensor ,
125
127
seg_mask : torch .Tensor ,
126
128
seg_mask_tgt : torch .Tensor ,
127
- expresion_vector : torch .Tensor ,
129
+ expression_vector : torch .Tensor ,
128
130
iris_lmks : torch .Tensor = None ,
129
131
iris_lmks_tgt : torch .Tensor = None ,
130
132
) -> torch .Tensor :
131
133
""""""
132
- mp_loss = self . wing_loss (mp_lmks , mp_lmks_tgt )
133
- fan_loss = self . wing_loss (fan_lmks , fan_lmks_tgt )
134
- iris_loss = self . wing_loss (iris_lmks , iris_lmks_tgt ) if iris_lmks is not None else torch .zeros ([1 ] ,device = mp_lmks .device , dtype = torch .float32 )
134
+ mp_loss = l2_loss (mp_lmks , mp_lmks_tgt )
135
+ fan_loss = l2_loss (fan_lmks , fan_lmks_tgt )
136
+ iris_loss = l2_loss (iris_lmks , iris_lmks_tgt ) if iris_lmks is not None else torch .zeros ([1 ] ,device = mp_lmks .device , dtype = torch .float32 )
135
137
136
138
seg_mask_loss = torch .abs (seg_mask - seg_mask_tgt ).mean ()
137
- output = mp_loss * self .w_mp + fan_loss + iris_loss + seg_mask_loss * self .w_seg + expresion_vector .abs ().mean () * self .w_reg
139
+
140
+ expression_reg = torch .mean (torch .square (expression_vector )) * self .w_reg
141
+ expression_reg += torch .mean (torch .square (expression_vector [1 :] - expression_vector [:- 1 ])) * 1e-1
142
+ output = mp_loss * self .w_mp + fan_loss + iris_loss + seg_mask_loss * self .w_seg + expression_reg
138
143
return output
139
144
140
145
@@ -463,6 +468,23 @@ def process(self, image: np.ndarray, image_size: int=224):
463
468
lmk = self .mica .flame .compute_landmarks (meshes )
464
469
return meshes [0 ].detach ().cpu (), code , lmk
465
470
471
+
472
+ class Matting :
473
+ def __init__ (self , script_path : str , chkp_path : str ):
474
+ self .script_path = script_path
475
+ self .chkp_path = chkp_path
476
+
477
+ def convert (self , video_path : str , output_mask_path : str ):
478
+ args = "--variant mobilenetv3 "
479
+ args += f"--checkpoint { self .chkp_path } "
480
+ args += f"--input-source { video_path } "
481
+ args += "--output-type png_sequence "
482
+ args += f"--output-alpha { output_mask_path } "
483
+ args += "--device cuda"
484
+ cmd = f"python { self .script_path } { args } "
485
+ os .system (cmd )
486
+
487
+
466
488
class DataSaver :
467
489
def __init__ (self , output_base : str , save_id_mesh : bool = True ):
468
490
self .output_base = output_base
@@ -541,6 +563,7 @@ def save_state(self,
541
563
'cuda:0'
542
564
)
543
565
566
+ matting = Matting (** conf .matting_kwargs )
544
567
# create dataset
545
568
# dataset = nir.get_dataset("SingleVideoDataset", **conf.video_dataset_kwargs)
546
569
@@ -551,13 +574,20 @@ def save_state(self,
551
574
for filename in filenames :
552
575
if not filename .endswith ('mp4' ):
553
576
continue
577
+
578
+
554
579
555
580
filepath = os .path .join (conf .base_dir , filename )
556
581
print (f"Processing file: { filename } " )
557
582
dataset = nir .get_dataset ("SingleVideoDataset" , filepath = filepath , preload = True )
558
583
dataloader = torch .utils .data .DataLoader (dataset , batch_size = 1 , pin_memory = True , num_workers = 1 , collate_fn = nir .collate_fn )
559
584
560
585
data_saver .set_output_state (filename .split ('.' )[0 ])
586
+
587
+ # preprocess whole video with matting
588
+ matting_alpha_path = data_saver .current_output_dir
589
+ print ("estimating alpha masks" )
590
+ matting .convert (filepath , matting_alpha_path )
561
591
562
592
for frame_idx , data in tqdm (enumerate (dataloader ), total = len (dataloader ), desc = "video progress" ):
563
593
data_saver .set_frame_index (frame_idx )
0 commit comments