Skip to content

Commit f03c6fb

Browse files
authored
[contrib] Use torch.testing.assert_close in test_index_mul_2d.py (#1693)
Signed-off-by: Masaki Kozuki <[email protected]>
1 parent 7b2e71b commit f03c6fb

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

apex/contrib/test/index_mul_2d/test_index_mul_2d.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -67,10 +67,10 @@ def test_index_mul_float(self):
6767
loss = (out_.float()**2).sum() / out_.numel() + (force_.float()**2).sum()
6868
loss.backward()
6969

70-
self.assertTrue(torch.allclose(self.input1_float, self.input1_float_, atol=1e-3, rtol=1e-3, equal_nan=True))
71-
self.assertTrue(torch.allclose(self.input2_float, self.input2_float_, atol=1e-3, rtol=1e-3, equal_nan=True))
72-
self.assertTrue(torch.allclose(self.input1_float.grad, self.input1_float_.grad, atol=1e-3, rtol=1e-3, equal_nan=True))
73-
self.assertTrue(torch.allclose(self.input2_float.grad, self.input2_float_.grad, atol=1e-3, rtol=1e-3, equal_nan=True))
70+
torch.testing.assert_close(self.input1_float, self.input1_float_, atol=1e-3, rtol=1e-3, equal_nan=True)
71+
torch.testing.assert_close(self.input2_float, self.input2_float_, atol=1e-3, rtol=1e-3, equal_nan=True)
72+
torch.testing.assert_close(self.input1_float.grad, self.input1_float_.grad, atol=1e-3, rtol=1e-3, equal_nan=True)
73+
torch.testing.assert_close(self.input2_float.grad, self.input2_float_.grad, atol=1e-3, rtol=1e-3, equal_nan=True)
7474

7575
def test_index_mul_half(self):
7676
out = index_mul_2d(self.input1_half, self.input2_half, self.index1)
@@ -95,10 +95,10 @@ def test_index_mul_half(self):
9595
loss = (out_.float()**2).sum() / out_.numel() + (force_.float()**2).sum()
9696
loss.backward()
9797

98-
self.assertTrue(torch.allclose(self.input1_half, self.input1_half_, atol=1e-3, rtol=1e-3, equal_nan=True))
99-
self.assertTrue(torch.allclose(self.input2_half, self.input2_half_, atol=1e-3, rtol=1e-3, equal_nan=True))
100-
self.assertTrue(torch.allclose(self.input1_half.grad, self.input1_half_.grad, atol=1e-3, rtol=1e-3, equal_nan=True))
101-
self.assertTrue(torch.allclose(self.input2_half.grad, self.input2_half_.grad, atol=1e-3, rtol=1e-3, equal_nan=True))
98+
torch.testing.assert_close(self.input1_half, self.input1_half_, atol=1e-3, rtol=1e-3, equal_nan=True)
99+
torch.testing.assert_close(self.input2_half, self.input2_half_, atol=1e-3, rtol=1e-3, equal_nan=True)
100+
torch.testing.assert_close(self.input1_half.grad, self.input1_half_.grad, atol=1e-3, rtol=1e-3, equal_nan=True)
101+
torch.testing.assert_close(self.input2_half.grad, self.input2_half_.grad, atol=1e-3, rtol=1e-3, equal_nan=True)
102102

103103
if __name__ == '__main__':
104104
unittest.main()

0 commit comments

Comments
 (0)