Skip to content

Commit 7f2c492

Browse files
committed
Fix handling of scalar argument in slice operator (#3596)
- when a scalar argument is provided to slice it tries to obtain its shape which is empty what leads to a crash. Fixes this problem by providing a proper guard Signed-off-by: Janusz Lisiecki <[email protected]>
1 parent b815571 commit 7f2c492

File tree

2 files changed

+36
-5
lines changed

2 files changed

+36
-5
lines changed

dali/operators/generic/slice/slice_attr.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -227,9 +227,13 @@ class PositionalSliceAttr {
227227
auto shape_view = view<const ArgsType>(crop_shape);
228228
for (int data_idx = 0; data_idx < curr_batch_size; data_idx++) {
229229
span<const ArgsType> anchor(anchor_view.tensor_data(data_idx),
230-
anchor_view.tensor_shape_span(data_idx)[0]);
230+
anchor_view.shape.ndim != 0 ?
231+
anchor_view.tensor_shape_span(data_idx)[0] :
232+
1);
231233
span<const ArgsType> shape(shape_view.tensor_data(data_idx),
232-
shape_view.tensor_shape_span(data_idx)[0]);
234+
anchor_view.shape.ndim != 0 ?
235+
anchor_view.tensor_shape_span(data_idx)[0] :
236+
1);
233237
ProcessPositionalInputArgs(data_idx, anchor, shape);
234238
}
235239
), DALI_FAIL(make_string("Unsupported type of anchor and shape arguments: ", args_dtype))); // NOLINT

dali/test/python/test_operator_slice.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,7 @@
1313
# limitations under the License.
1414

1515
from nvidia.dali import Pipeline, pipeline_def, ops, fn, types
16-
import nvidia.dali as dali
17-
from nvidia.dali.backend_impl import TensorListGPU
1816
import numpy as np
19-
from numpy.testing import assert_array_equal, assert_allclose
2017
import os
2118
from functools import partial
2219
from math import floor
@@ -789,3 +786,33 @@ def test_wrong_axes():
789786
for wrong_axes_range in [(-10, -4), (3, 10)]:
790787
for named_args in [False, True]:
791788
yield check_wrong_axes, device, wrong_axes_range, named_args
789+
790+
def check_scalar(device):
791+
batch_size = 5
792+
def get_data():
793+
out = [np.random.ranf(size=[1000]).astype(dtype=np.single) for _ in range(batch_size)]
794+
return out
795+
796+
@pipeline_def(batch_size=batch_size, num_threads=1, device_id=0)
797+
def test_pipe():
798+
data = fn.external_source(source=get_data)
799+
shape = types.ScalarConstant(10)
800+
anchor = types.ScalarConstant(5)
801+
if device != 'cpu':
802+
data = data.gpu()
803+
sliced = fn.slice(data, start = anchor, shape = shape, axes=[0], device=device)
804+
return data, sliced, shape, anchor
805+
806+
pipe = test_pipe()
807+
pipe.build()
808+
ref, data, shape, anchor = pipe.run()
809+
for sample_idx in range(batch_size):
810+
d = as_array(data[sample_idx])
811+
r = as_array(ref[sample_idx])
812+
s = as_array(shape[sample_idx])
813+
a = as_array(anchor[sample_idx])
814+
np.testing.assert_allclose(d, r[a:a+s])
815+
816+
def test_scalar():
817+
for device in ['cpu', 'gpu']:
818+
yield check_scalar, device

0 commit comments

Comments
 (0)