Skip to content

Commit

Permalink
Bugfix for FID (#103)
Browse files Browse the repository at this point in the history
* bugfix(fid): Removed inplace operation
bugfix(fid): Transfer tensor to the device of input

* bugfix(fid): Transfer tensor to the device of input

* minor(fid): Remove unused variable

* minor(fid): Changed float precision
  • Loading branch information
denproc authored Jun 18, 2020
1 parent e380fa4 commit 73a9071
Showing 1 changed file with 9 additions and 10 deletions.
19 changes: 9 additions & 10 deletions photosynthesis_metrics/fid.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,14 @@ def _sqrtm_newton_schulz(A: torch.Tensor, num_iters: int = 100) -> Tuple[torch.T

if num_iters <= 0:
raise ValueError(f'Number of iteration equals {num_iters}, expected greater than 0')
dtype = A.type()
dim = A.size(0)
normA = A.norm(p='fro')
Y = A.div(normA)
I = torch.eye(dim, dim, requires_grad=False).type(dtype)
Z = torch.eye(dim, dim, requires_grad=False).type(dtype)
I = torch.eye(dim, dim, requires_grad=False).to(A)
Z = torch.eye(dim, dim, requires_grad=False).to(A)

sA = torch.empty_like(A)
error = torch.empty(1)
error = torch.empty(1).to(A)

for i in range(num_iters):
T = 0.5 * (3.0 * I - Z.mm(Y))
Expand All @@ -55,7 +54,7 @@ def _sqrtm_newton_schulz(A: torch.Tensor, num_iters: int = 100) -> Tuple[torch.T

sA = Y * torch.sqrt(normA)
error = _approximation_error(A, sA)
if torch.isclose(error, torch.tensor([0.], device=error.device), atol=1e-5):
if torch.isclose(error, torch.tensor([0.]).to(error), atol=1e-5):
break
return sA, error

Expand All @@ -82,7 +81,7 @@ def _compute_fid(mu1: torch.Tensor, sigma1: torch.Tensor, mu2: torch.Tensor, sig
# Product might be almost singular
if not torch.isfinite(covmean).all():
print(f'FID calculation produces singular product; adding {eps} to diagonal of cov estimates')
offset = torch.eye(sigma1.size(0)) * eps
offset = torch.eye(sigma1.size(0)).to(mu1) * eps
covmean, _ = _sqrtm_newton_schulz((sigma1 + offset).mm(sigma2 + offset))

tr_covmean = torch.trace(covmean)
Expand Down Expand Up @@ -117,7 +116,7 @@ def _cov(m: torch.Tensor, rowvar: bool = True) -> torch.Tensor:
if not rowvar and m.size(0) != 1:
m = m.t()
fact = 1.0 / (m.size(1) - 1)
m -= torch.mean(m, dim=1, keepdim=True)
m = m - torch.mean(m, dim=1, keepdim=True)
mt = m.t()
return fact * m.matmul(mt).squeeze()

Expand Down Expand Up @@ -173,9 +172,9 @@ def compute_metric(self, predicted_features: torch.Tensor, target_features: torc
-- : The Frechet Distance.
"""
# GPU -> CPU
m_pred, s_pred = _compute_statistics(predicted_features.detach())
m_targ, s_targ = _compute_statistics(target_features.detach())
m_pred, s_pred = _compute_statistics(predicted_features.detach().type(torch.float64))
m_targ, s_targ = _compute_statistics(target_features.detach().type(torch.float64))

score = _compute_fid(m_pred, s_pred, m_targ, s_targ)

return torch.tensor(score, device=predicted_features.device)
return score.type(torch.float32)

0 comments on commit 73a9071

Please sign in to comment.