Skip to content

Commit 19e0905

Browse files
committed
Updated chamfer distance function, now returning nn indices as well.
1 parent 5f7c645 commit 19e0905

File tree

3 files changed

+4
-4
lines changed

3 files changed

+4
-4
lines changed

evaluate_fitting.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ def evaluate_chamfer(fitting_results_path, scan_path, device, **kwargs):
222222

223223
# compute chamfer
224224
# dist 1 is 1 x 6890 - chamfer_distance does not return euclidean distance but squared distance
225-
dist1, dist2 = chamfer_distance(fit_verts.to(device), scan_vertices.to(device))
225+
dist1, dist2, _, _ = chamfer_distance(fit_verts.to(device), scan_vertices.to(device))
226226
chamfer_standard += (torch.mean(dist1) + torch.mean(dist2)).detach().cpu().item()
227227
chamfer_bidirectional_average += torch.mean(torch.cat([dist1[0],dist2[0]])).detach().cpu().item()
228228
chamfer_bm2scan += torch.mean(torch.sqrt(dist1)).detach().cpu().item()
@@ -259,7 +259,7 @@ def evaluate_chamfer(fitting_results_path, scan_path, device, **kwargs):
259259
fit_verts = torch.from_numpy(fit_verts).unsqueeze(0).float()
260260

261261
# compute chamfer
262-
dist1, dist2 = chamfer_distance(fit_verts.to(device), scan_verts.to(device))
262+
dist1, dist2, _ , _ = chamfer_distance(fit_verts.to(device), scan_verts.to(device))
263263
chamfer_standard = (torch.mean(dist1) + torch.mean(dist2)).detach().cpu().item()
264264
chamfer_bidirectional_average = torch.mean(torch.cat([dist1,dist2])).detach().cpu().item()
265265
chamfer_bm2scan = torch.mean(torch.sqrt(dist1)).detach().cpu().item()

fit_body_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def fit_body_model(input_dict: dict, cfg: dict):
124124
scale)
125125

126126
# compute losses
127-
dist1, dist2 = chamfer_distance(body_model_verts.unsqueeze(0), input_vertices)
127+
dist1, dist2, _ , _ = chamfer_distance(body_model_verts.unsqueeze(0), input_vertices)
128128
data_loss = (torch.mean(dist1)) + (torch.mean(dist2))
129129
data_loss_weighted = data_loss_weight * data_loss
130130
landmark_loss = summed_L2(body_model_verts[body_model_landmark_inds,:], input_landmarks)

losses.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def forward(self,scan_vertices,template_vertices,**kwargs):
5858
template to the closest point of the scan closer than
5959
partial_data_threshold
6060
'''
61-
_, template2scan_dist = self.chamfer_dist(scan_vertices,template_vertices)
61+
_, template2scan_dist, _ , _ = self.chamfer_dist(scan_vertices,template_vertices)
6262
return template2scan_dist[template2scan_dist < self.partial_data_threshold].sum()
6363

6464

0 commit comments

Comments
 (0)