@@ -67,10 +67,10 @@ def test_index_mul_float(self):
67
67
loss = (out_ .float ()** 2 ).sum () / out_ .numel () + (force_ .float ()** 2 ).sum ()
68
68
loss .backward ()
69
69
70
- self . assertTrue ( torch .allclose (self .input1_float , self .input1_float_ , atol = 1e-3 , rtol = 1e-3 , equal_nan = True ) )
71
- self . assertTrue ( torch .allclose (self .input2_float , self .input2_float_ , atol = 1e-3 , rtol = 1e-3 , equal_nan = True ) )
72
- self . assertTrue ( torch .allclose (self .input1_float .grad , self .input1_float_ .grad , atol = 1e-3 , rtol = 1e-3 , equal_nan = True ) )
73
- self . assertTrue ( torch .allclose (self .input2_float .grad , self .input2_float_ .grad , atol = 1e-3 , rtol = 1e-3 , equal_nan = True ) )
70
+ torch .testing . assert_close (self .input1_float , self .input1_float_ , atol = 1e-3 , rtol = 1e-3 , equal_nan = True )
71
+ torch .testing . assert_close (self .input2_float , self .input2_float_ , atol = 1e-3 , rtol = 1e-3 , equal_nan = True )
72
+ torch .testing . assert_close (self .input1_float .grad , self .input1_float_ .grad , atol = 1e-3 , rtol = 1e-3 , equal_nan = True )
73
+ torch .testing . assert_close (self .input2_float .grad , self .input2_float_ .grad , atol = 1e-3 , rtol = 1e-3 , equal_nan = True )
74
74
75
75
def test_index_mul_half (self ):
76
76
out = index_mul_2d (self .input1_half , self .input2_half , self .index1 )
@@ -95,10 +95,10 @@ def test_index_mul_half(self):
95
95
loss = (out_ .float ()** 2 ).sum () / out_ .numel () + (force_ .float ()** 2 ).sum ()
96
96
loss .backward ()
97
97
98
- self . assertTrue ( torch .allclose (self .input1_half , self .input1_half_ , atol = 1e-3 , rtol = 1e-3 , equal_nan = True ) )
99
- self . assertTrue ( torch .allclose (self .input2_half , self .input2_half_ , atol = 1e-3 , rtol = 1e-3 , equal_nan = True ) )
100
- self . assertTrue ( torch .allclose (self .input1_half .grad , self .input1_half_ .grad , atol = 1e-3 , rtol = 1e-3 , equal_nan = True ) )
101
- self . assertTrue ( torch .allclose (self .input2_half .grad , self .input2_half_ .grad , atol = 1e-3 , rtol = 1e-3 , equal_nan = True ) )
98
+ torch .testing . assert_close (self .input1_half , self .input1_half_ , atol = 1e-3 , rtol = 1e-3 , equal_nan = True )
99
+ torch .testing . assert_close (self .input2_half , self .input2_half_ , atol = 1e-3 , rtol = 1e-3 , equal_nan = True )
100
+ torch .testing . assert_close (self .input1_half .grad , self .input1_half_ .grad , atol = 1e-3 , rtol = 1e-3 , equal_nan = True )
101
+ torch .testing . assert_close (self .input2_half .grad , self .input2_half_ .grad , atol = 1e-3 , rtol = 1e-3 , equal_nan = True )
102
102
103
103
if __name__ == '__main__' :
104
104
unittest .main ()
0 commit comments