Skip to content
This repository was archived by the owner on Aug 31, 2021. It is now read-only.

Commit deca0a1

Browse files
author
Xuanyi Dong
committed
UPDATE
1 parent b92a9d6 commit deca0a1

File tree

6 files changed

+13
-8
lines changed

6 files changed

+13
-8
lines changed

README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ python ./exps/basic_main.py [<required arguments>]
5151
The argument list is loaded by `./lib/config_utils/basic_args.py`.
5252
An examples script can is `./scripts/300W-DET.sh`, and you can simple run to train the base detector on the `300-W` dataset.
5353
```
54-
sh scripts/300W-DET.sh
54+
bash scripts/300W-DET.sh
5555
```
5656

5757
### Improving the Detector by SBR
@@ -64,7 +64,7 @@ The argument list is loaded by `./lib/config_utils/lk_args.py`.
6464
#### An example to train SBR on the unlabeled sequences
6565
The `init_model` parameter is the path to the detector trained in the `Basic Training` section.
6666
```
67-
sh scripts/demo_sbr.sh
67+
bash scripts/demo_sbr.sh
6868
```
6969
To see visualization results use the commands in `Visualization`.
7070

@@ -104,7 +104,7 @@ ffmpeg -start_number 3 -i cache_data/cache/demo-sbr-vis/image%04d.png -b:v 30000
104104
supervision-by-registration is released under the [CC-BY-NC license](https://github.com/facebookresearch/supervision-by-registration/blob/master/LICENSE).
105105

106106

107-
## Useful information
107+
## Useful Information
108108

109109
### 1. train on your own video data
110110
You should look at the `./lib/datasets/VideoDataset.py` and `./lib/datasets/parse_utils.py`, and add how to find the neighbour frames when giving one image path.

exps/basic_main.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,7 @@ def main(args):
159159
'args' : deepcopy(args),
160160
'arch' : model_config.arch,
161161
'state_dict': net.state_dict(),
162+
'detector' : net.state_dict(),
162163
'scheduler' : scheduler.state_dict(),
163164
'optimizer' : optimizer.state_dict(),
164165
}, logger.path('model') / '{:}-{:}.pth'.format(model_config.arch, epoch_str), logger)
@@ -169,6 +170,7 @@ def main(args):
169170
}, logger.last_info(), logger)
170171

171172
eval_results = eval_all(args, eval_loaders, net, criterion, epoch_str, logger, opt_config)
173+
logger.log('NME Results : {:}'.format( eval_results ))
172174

173175
# measure elapsed time
174176
epoch_time.update(time.time() - start_time)

exps/eval.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from models import obtain_model, remove_module_dict
2121
from config_utils import load_configure
2222

23+
2324
def evaluate(args):
2425
assert torch.cuda.is_available(), 'CUDA is not available.'
2526
torch.backends.cudnn.enabled = True
@@ -39,14 +40,15 @@ def evaluate(args):
3940
std=[0.229, 0.224, 0.225])
4041

4142
param = snapshot['args']
42-
eval_transform = transforms.Compose([transforms.PreCrop(param.pre_crop_expand), transforms.TrainScale2WH((param.crop_width, param.crop_height)), transforms.ToTensor(), normalize])
43+
import pdb; pdb.set_trace()
44+
eval_transform = transforms.Compose([transforms.PreCrop(param.pre_crop_expand), transforms.TrainScale2WH((param.crop_width, param.crop_height)), transforms.ToTensor(), normalize])
4345
model_config = load_configure(param.model_config, None)
4446
dataset = Dataset(eval_transform, param.sigma, model_config.downsample, param.heatmap_type, param.data_indicator)
4547
dataset.reset(param.num_pts)
4648

4749
net = obtain_model(model_config, param.num_pts + 1)
4850
net = net.cuda()
49-
weights = remove_module_dict(snapshot['state_dict'])
51+
weights = remove_module_dict(snapshot['detector'])
5052
net.load_state_dict(weights)
5153
print ('Prepare input data')
5254
[image, _, _, _, _, _, cropped_size], meta = dataset.prepare_input(args.image, args.face)

exps/lk_main.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,7 @@ def main(args):
169169
'args' : deepcopy(args),
170170
'arch' : model_config.arch,
171171
'state_dict': net.state_dict(),
172+
'detector' : detector.state_dict(),
172173
'scheduler' : scheduler.state_dict(),
173174
'optimizer' : optimizer.state_dict(),
174175
}, logger.path('model') / '{:}-{:}.pth'.format(model_config.arch, epoch_str), logger)

lib/procedure/basic_eval.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ def basic_eval_all(args, loaders, net, criterion, epoch_str, logger, opt_config)
2323
nme, _, _ = eval_meta.compute_mse(logger)
2424
meta_path = logger.path('meta') / 'eval-{:}-{:02d}-{:02d}.pth'.format(epoch_str, i, len(loaders))
2525
eval_meta.save(meta_path)
26-
nmes.append(nme)
27-
return ', '.join(['{:.1f}'.format(x) for x in nmes])
26+
nmes.append(nme*100)
27+
return ', '.join(['{:.2f}'.format(x) for x in nmes])
2828

2929

3030
def basic_eval(args, loader, net, criterion, epoch_str, logger, opt_config):

lib/xvision/common_eval.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def evaluate_normalized_mean_error(predictions, groundtruth, log, extra_faces):
6262
accuracy_under_007 = np.sum(error_per_image<0.07) * 100. / error_per_image.size
6363
accuracy_under_008 = np.sum(error_per_image<0.08) * 100. / error_per_image.size
6464

65-
print_log('Compute NME and AUC for {:} images with {:} points :: [(nms): mean={:.3f}, std={:.3f}], [email protected]={:.3f}, [email protected]{:.3f}, [email protected]={:.3f}, [email protected]={:.3f}'.format(num_images, num_points, normalise_mean_error*100, error_per_image.std()*100, area_under_curve07*100, area_under_curve08*100, accuracy_under_007, accuracy_under_008), log)
65+
print_log('Compute NME and AUC for {:} images with {:} points :: [(NME): mean={:.3f}, std={:.3f}], [email protected]={:.3f}, [email protected]{:.3f}, [email protected]={:.3f}, [email protected]={:.3f}'.format(num_images, num_points, normalise_mean_error*100, error_per_image.std()*100, area_under_curve07*100, area_under_curve08*100, accuracy_under_007, accuracy_under_008), log)
6666

6767
for_pck_curve = []
6868
for x in range(0, 3501, 1):

0 commit comments

Comments
 (0)