Skip to content

Commit 2b1f23a

Browse files
committed
Updated visualize_fitting.
1 parent 5442825 commit 2b1f23a

File tree

2 files changed

+139
-41
lines changed

2 files changed

+139
-41
lines changed

Readme.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ You can use:
219219
Check [Notes](##-📝-Notes) section to find out the possible landmark definitions.
220220
3. Visualize fitting:
221221
```bash
222-
python visualization.py visualize_fitting --scan_path {path-to-scan} --fitted_npz_file {path-to-.npz-file}
222+
python visualization.py visualize_fitting --scan_path {path-to-scan} --fit_paths {path-to-.npz-file}
223223
```
224224
where the `.npz` is obtained with the fitting scripts.
225225

@@ -336,7 +336,7 @@ python visualization.py visualize_scan_landmarks --scan_path data/demo/tr_scan_0
336336

337337
Visualize the fitted vertices of the BM onto the FAUST scan:
338338
```bash
339-
python visualization.py visualize_fitting --scan_path data/demo/tr_scan_000.ply --fitted_npz_file data/demo/tr_scan_000.npz
339+
python visualization.py visualize_fitting --scan_path data/demo/tr_scan_000.ply --fit_paths data/demo/tr_scan_000.npz
340340
```
341341

342342
<br>

visualization.py

Lines changed: 137 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import smplx
99
import os
1010
from glob import glob
11+
import re
1112

1213
import landmarks
1314
from utils import load_config, load_landmarks, load_scan
@@ -255,6 +256,49 @@ def viz_final_fit(input_scan_verts: torch.tensor,
255256
# VIZ FUNCS #
256257
######################################################
257258

259+
def create_wireframe_plot(verts: np.ndarray,faces: np.ndarray):
260+
'''
261+
Given vertices and faces, creates a wireframe of plotly segments.
262+
Used for visualizing the wireframe.
263+
264+
:param verts: np.array (N,3) of vertices
265+
:param faces: np.array (F,3) of faces connecting the verts
266+
'''
267+
i=faces[:,0]
268+
j=faces[:,1]
269+
k=faces[:,2]
270+
271+
triangles = np.vstack((i,j,k)).T
272+
273+
x=verts[:,0]
274+
y=verts[:,1]
275+
z=verts[:,2]
276+
277+
vertices = np.vstack((x,y,z)).T
278+
tri_points = vertices[triangles]
279+
280+
#extract the lists of x, y, z coordinates of the triangle
281+
# vertices and connect them by a "line" by adding None
282+
# this is a plotly convention for plotting segments
283+
Xe = []
284+
Ye = []
285+
Ze = []
286+
for T in tri_points:
287+
Xe.extend([T[k%3][0] for k in range(4)]+[ None])
288+
Ye.extend([T[k%3][1] for k in range(4)]+[ None])
289+
Ze.extend([T[k%3][2] for k in range(4)]+[ None])
290+
291+
# return Xe, Ye, Ze
292+
wireframe = go.Scatter3d(
293+
x=Xe,
294+
y=Ye,
295+
z=Ze,
296+
mode='lines',
297+
name='wireframe',
298+
line=dict(color= 'rgb(70,70,70)', width=1)
299+
)
300+
return wireframe
301+
258302
def visualize_smpl_landmarks(**kwargs):
259303
cfg = load_config()
260304

@@ -384,31 +428,35 @@ def visualize_scan_landmarks(scan_path,landmark_path, **kwargs):
384428

385429
fig.show()
386430

387-
def visualize_fitting(scan_path, fitted_npz_file, scale_scan=None,
388-
return_fig=False, **kwargs):
431+
def visualize_fitting(scan_path, fit_paths, fit_nicknames=None, scale_scan=None,
432+
return_fig=False, visualize_mesh_texture=False, **kwargs):
389433

390-
experiment_name = os.path.basename(os.path.dirname(fitted_npz_file))
391-
verts, faces = load_scan(scan_path)
434+
if visualize_mesh_texture:
435+
verts, faces, verts_colors = load_scan(scan_path,
436+
return_vertex_colors=visualize_mesh_texture)
437+
else:
438+
verts, faces = load_scan(scan_path)
439+
verts_colors = None
392440

393441
if scale_scan:
394442
verts = verts / scale_scan
395443

396444
fig = go.Figure()
397445

398-
## plot body
446+
## plot scan
399447
if isinstance(faces,type(None)):
400448
plot_body = go.Scatter3d(x = verts[:,0],
401-
y =verts[:,1],
402-
z = verts[:,2],
403-
mode='markers',
404-
marker=dict(
405-
color="lightpink",
406-
size=8,
407-
line=dict(
408-
color='black',
409-
width=1)
410-
),
411-
name="Scan")
449+
y =verts[:,1],
450+
z = verts[:,2],
451+
mode='markers',
452+
marker=dict(
453+
color="lightpink",
454+
size=8,
455+
line=dict(
456+
color='black',
457+
width=1)
458+
),
459+
name="Scan")
412460
else:
413461
plot_body = go.Mesh3d(
414462
x=verts[:,0],
@@ -420,33 +468,75 @@ def visualize_fitting(scan_path, fitted_npz_file, scale_scan=None,
420468
k=faces[:,2],
421469
name='Scan',
422470
showscale=True,
423-
opacity=0.7
471+
opacity=0.7,
472+
vertexcolor=verts_colors,
473+
flatshading=True
424474
)
425-
fig.add_trace(plot_body)
475+
476+
## add wireframe of scan
477+
wireframe_plot = create_wireframe_plot(verts,faces)
478+
fig.add_trace(wireframe_plot)
426479

480+
fig.add_trace(plot_body)
427481

428-
fitted_data = np.load(fitted_npz_file)
429-
fitted_verts = fitted_data["vertices"]
430-
fitted_name = str(fitted_data["name"])
431-
432-
# plot fitted_verts
433-
plot_fitted = go.Scatter3d(x = fitted_verts[:,0],
434-
y = fitted_verts[:,1],
435-
z = fitted_verts[:,2],
436-
mode='markers',
437-
marker=dict(
438-
color="blue",
439-
size=8,
440-
line=dict(
441-
color='black',
442-
width=1)
443-
),
444-
name=fitted_name)
445-
fig.add_trace(plot_fitted)
446482

483+
if not isinstance(fit_nicknames,type(None)):
484+
if len(fit_nicknames) != len(fit_paths):
485+
print("Number of fit_nicknames does not match number of fit_paths.")
486+
print("Using names extracted from paths")
487+
488+
experiment_names = []
489+
fitted_names = []
490+
n_colors = len(fit_paths)
491+
# want colors from 0.1 to 1 because below 0.1 is black
492+
colors = px.colors.sample_colorscale("turbo",
493+
list(np.linspace(0.1, 1, n_colors))
494+
# [0.1 + n/n_colors for n in range(n_colors)]
495+
)
496+
497+
for i, fit_path in enumerate(fit_paths):
498+
499+
# find standard experiment name
500+
experiment_name = re.search(r'\d{4}_\d{2}_\d{2}_\d{2}_\d{2}_\d{2}', fit_path)
501+
if not isinstance(experiment_name,type(None)):
502+
experiment_name = experiment_name.group(0)
503+
else:
504+
experiment_name = os.path.basename(os.path.dirname(fit_path))
505+
experiment_names.append(experiment_name)
506+
507+
if not isinstance(fit_nicknames,type(None)):
508+
viz_name = fit_nicknames[i]
509+
else:
510+
viz_name = experiment_name
511+
512+
fitted_data = np.load(fit_path)
513+
fitted_verts = fitted_data["vertices"]
514+
fitted_name = str(fitted_data["name"])
515+
fitted_names.append(fitted_name)
516+
517+
# plot fitted_verts
518+
plot_fitted = go.Scatter3d(x = fitted_verts[:,0],
519+
y = fitted_verts[:,1],
520+
z = fitted_verts[:,2],
521+
mode='markers',
522+
marker=dict(
523+
color=colors[i],
524+
size=8,
525+
line=dict(
526+
color='black',
527+
width=1)
528+
),
529+
name=viz_name)
530+
fig.add_trace(plot_fitted)
531+
532+
if np.unique(fitted_names).shape[0] > 1:
533+
raise ValueError("Visualizing fits for wrong scan.")
534+
535+
viz_title = f"Viz scan {fitted_name} and fitted bm from " + \
536+
f"experiments {experiment_names}"
447537
fig.update_layout(scene_aspectmode='data',
448-
width=1000, height=700,
449-
title=f"Exp {experiment_name} - Scan {fitted_name} + fitted body model",
538+
width=1000, height=1000,
539+
title=viz_title,
450540
)
451541
if return_fig:
452542
return fig
@@ -515,7 +605,15 @@ def visualize_pve(verts, vert_errors, faces, name=""):
515605

516606
parser_viz_fitting = subparsers.add_parser('visualize_fitting')
517607
parser_viz_fitting.add_argument("-S", "--scan_path", type=str, required=True)
518-
parser_viz_fitting.add_argument("-F", "--fitted_npz_file", type=str, required=True)
608+
parser_viz_fitting.add_argument("-F", "--fit_paths", required=True,
609+
nargs='+', default=[],
610+
help="One or multiple paths to the fitted npz file \
611+
you want to visualize.")
612+
parser_viz_fitting.add_argument("--fit_nicknames", required=False,
613+
nargs='+', default=[],
614+
help="More meaningful names for each experiment in the \
615+
fit_paths you are visualizing.")
616+
parser_viz_fitting.add_argument("--visualize_mesh_texture", action="store_true")
519617
parser_viz_fitting.add_argument("--scale_scan", type=float, required=False, default=1.0)
520618
parser_viz_fitting.set_defaults(func=visualize_fitting)
521619

0 commit comments

Comments
 (0)