Skip to content

Commit 42ddb73

Browse files
orangecccyangyuantaoinnerlee
authored
[Fix] Fix astype error in function tensor2img (open-mmlab#429)
* fix astype error in function tensor2img * Test out dtype for tensor2img Signed-off-by: lizz <[email protected]> Co-authored-by: yangyuantao <[email protected]> Co-authored-by: lizz <[email protected]>
1 parent 2e7f0c8 commit 42ddb73

File tree

2 files changed

+10
-1
lines changed

2 files changed

+10
-1
lines changed

mmedit/core/misc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)):
6767
if out_type == np.uint8:
6868
# Unlike MATLAB, numpy.unit8() WILL NOT round by default.
6969
img_np = (img_np * 255.0).round()
70-
img_np.astype(out_type)
70+
img_np = img_np.astype(out_type)
7171
result.append(img_np)
7272
result = result[0] if len(result) == 1 else result
7373
return result

tests/test_utils/test_tensor2img.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,28 +28,33 @@ def test_tensor2img():
2828

2929
# 4d
3030
rlt = tensor2img(tensor_4d_1, out_type=np.uint8, min_max=(0, 1))
31+
assert rlt.dtype == np.uint8
3132
tensor_4d_1_np = make_grid(tensor_4d_1, nrow=1, normalize=False).numpy()
3233
tensor_4d_1_np = np.transpose(tensor_4d_1_np[[2, 1, 0], :, :], (1, 2, 0))
3334
np.testing.assert_almost_equal(rlt, (tensor_4d_1_np * 255).round())
3435

3536
rlt = tensor2img(tensor_4d_2, out_type=np.uint8, min_max=(0, 1))
37+
assert rlt.dtype == np.uint8
3638
tensor_4d_2_np = tensor_4d_2.squeeze().numpy()
3739
tensor_4d_2_np = np.transpose(tensor_4d_2_np[[2, 1, 0], :, :], (1, 2, 0))
3840
np.testing.assert_almost_equal(rlt, (tensor_4d_2_np * 255).round())
3941

4042
rlt = tensor2img(tensor_4d_3, out_type=np.uint8, min_max=(0, 1))
43+
assert rlt.dtype == np.uint8
4144
tensor_4d_3_np = make_grid(tensor_4d_3, nrow=1, normalize=False).numpy()
4245
tensor_4d_3_np = np.transpose(tensor_4d_3_np[[2, 1, 0], :, :], (1, 2, 0))
4346
np.testing.assert_almost_equal(rlt, (tensor_4d_3_np * 255).round())
4447

4548
rlt = tensor2img(tensor_4d_4, out_type=np.uint8, min_max=(0, 1))
49+
assert rlt.dtype == np.uint8
4650
tensor_4d_4_np = tensor_4d_4.squeeze().numpy()
4751
np.testing.assert_almost_equal(rlt, (tensor_4d_4_np * 255).round())
4852

4953
# 3d
5054
rlt = tensor2img([tensor_3d_1, tensor_3d_2],
5155
out_type=np.uint8,
5256
min_max=(0, 1))
57+
assert rlt[0].dtype == np.uint8
5358
tensor_3d_1_np = tensor_3d_1.numpy()
5459
tensor_3d_1_np = np.transpose(tensor_3d_1_np[[2, 1, 0], :, :], (1, 2, 0))
5560
tensor_3d_2_np = tensor_3d_2.numpy()
@@ -58,16 +63,20 @@ def test_tensor2img():
5863
np.testing.assert_almost_equal(rlt[1], (tensor_3d_2_np * 255).round())
5964

6065
rlt = tensor2img(tensor_3d_3, out_type=np.uint8, min_max=(0, 1))
66+
assert rlt.dtype == np.uint8
6167
tensor_3d_3_np = tensor_3d_3.squeeze().numpy()
6268
np.testing.assert_almost_equal(rlt, (tensor_3d_3_np * 255).round())
6369

6470
# 2d
6571
rlt = tensor2img(tensor_2d, out_type=np.uint8, min_max=(0, 1))
72+
assert rlt.dtype == np.uint8
6673
tensor_2d_np = tensor_2d.numpy()
6774
np.testing.assert_almost_equal(rlt, (tensor_2d_np * 255).round())
6875
rlt = tensor2img(tensor_2d, out_type=np.float32, min_max=(0, 1))
76+
assert rlt.dtype == np.float32
6977
np.testing.assert_almost_equal(rlt, tensor_2d_np)
7078

7179
rlt = tensor2img(tensor_2d, out_type=np.float32, min_max=(0.1, 0.5))
80+
assert rlt.dtype == np.float32
7281
tensor_2d_np = (np.clip(tensor_2d_np, 0.1, 0.5) - 0.1) / 0.4
7382
np.testing.assert_almost_equal(rlt, tensor_2d_np)

0 commit comments

Comments
 (0)