Skip to content

Commit 7268fe1

Browse files
JanuszLstiepan
authored andcommitted
Fix OpticalFlow test premature exit on sm < 8 (#4933)
- fixes premature exit of OpticalFlow test on sm < 8 Signed-off-by: Janusz Lisiecki <[email protected]>
1 parent 7fc2671 commit 7268fe1

File tree

1 file changed

+26
-24
lines changed

1 file changed

+26
-24
lines changed

dali/test/python/test_optical_flow.py

Lines changed: 26 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from nvidia.dali import fn, types
2222
from test_utils import get_dali_extra_path, get_arch
2323
from nose_utils import raises, assert_raises
24+
from nose import SkipTest
2425

2526
test_data_root = get_dali_extra_path()
2627
images_dir = os.path.join(test_data_root, 'db', 'imgproc')
@@ -234,38 +235,39 @@ def check_optflow(output_grid=1, hint_grid=1, use_temporal_hints=False):
234235
hint_grid=hint_grid, use_temporal_hints=use_temporal_hints)
235236
pipe.build()
236237
if get_arch() < 8:
237-
if output_grid != 4 and (hint_grid in [4, 8, None]):
238+
if output_grid != 4:
238239
assert_raises(RuntimeError, pipe.run,
239240
glob="grid size: * is not supported, supported are:")
240-
if output_grid == 4 and hint_grid not in [4, 8, None]:
241+
raise SkipTest('Skipped as grid size is not supported for this arch')
242+
elif hint_grid not in [4, 8, None]:
241243
assert_raises(RuntimeError, pipe.run,
242244
glob="hint grid size: * is not supported, supported are:")
243-
else:
244-
for _ in range(2):
245-
out = pipe.run()
246-
for i in range(batch_size):
247-
seq = out[0].at(i)
248-
out_field = out[1].as_cpu().at(i)[0]
249-
_, ref_field = get_mapping(seq.shape[1:3])
250-
dsize = (out_field.shape[1], out_field.shape[0])
251-
ref_field = cv2.resize(ref_field, dsize=dsize, interpolation=cv2.INTER_AREA)
252-
if interactive:
253-
cv2.imshow("out", flow_to_color(out_field, None, True))
254-
cv2.imshow("ref", flow_to_color(ref_field, None, True))
255-
print(np.max(out_field))
256-
print(np.max(ref_field))
257-
cv2.imshow("dif", flow_to_color(ref_field - out_field, None, True))
258-
cv2.waitKey(0)
259-
err = np.linalg.norm(ref_field - out_field, ord=2, axis=2)
260-
assert np.mean(err) < 1 # average error of less than one pixel
261-
assert np.max(err) < 100 # no point more than 100px off
262-
assert np.sum(err > 1) / np.prod(err.shape) < 0.1 # 90% are within 1px
263-
assert np.sum(err > 2) / np.prod(err.shape) < 0.05 # 95% are within 2px
245+
raise SkipTest('Skipped as hint grid size is not supported for this arch')
246+
247+
for _ in range(2):
248+
out = pipe.run()
249+
for i in range(batch_size):
250+
seq = out[0].at(i)
251+
out_field = out[1].as_cpu().at(i)[0]
252+
_, ref_field = get_mapping(seq.shape[1:3])
253+
dsize = (out_field.shape[1], out_field.shape[0])
254+
ref_field = cv2.resize(ref_field, dsize=dsize, interpolation=cv2.INTER_AREA)
255+
if interactive:
256+
cv2.imshow("out", flow_to_color(out_field, None, True))
257+
cv2.imshow("ref", flow_to_color(ref_field, None, True))
258+
print(np.max(out_field))
259+
print(np.max(ref_field))
260+
cv2.imshow("dif", flow_to_color(ref_field - out_field, None, True))
261+
cv2.waitKey(0)
262+
err = np.linalg.norm(ref_field - out_field, ord=2, axis=2)
263+
assert np.mean(err) < 1 # average error of less than one pixel
264+
assert np.max(err) < 100 # no point more than 100px off
265+
assert np.sum(err > 1) / np.prod(err.shape) < 0.1 # 90% are within 1px
266+
assert np.sum(err > 2) / np.prod(err.shape) < 0.05 # 95% are within 2px
264267

265268

266269
def test_optflow():
267270
if not is_of_supported():
268-
from nose import SkipTest
269271
raise SkipTest('Optical Flow is not supported on this platform')
270272
for output_grid in [1, 2, 4]:
271273
hint_grid = random.choice([None, 1, 2, 4, 8])

0 commit comments

Comments
 (0)