Skip to content

Commit 6827e15

Browse files
mzientstiepan
authored andcommitted
Add tests for operator cast. Revert to plain batched cast kernel until the optimized one is fixed. (#3927)
* Add tests for operator cast. Revert to plain batched cast kernel (the BinSearch is broken). Signed-off-by: Michal Zientkiewicz <[email protected]>
1 parent ed559c1 commit 6827e15

File tree

2 files changed

+118
-40
lines changed

2 files changed

+118
-40
lines changed

dali/operators/generic/cast.cu

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,24 @@ void CastGPU::RunImpl(DeviceWorkspace &ws) {
8989
}
9090

9191
auto blocks = block_setup_.Blocks();
92+
93+
kernels::BlockDesc<1> *blocks_dev;
94+
kernels::CastSampleDesc *samples_dev;
95+
std::tie(blocks_dev, samples_dev) = scratchpad.ToContiguousGPU(ws.stream(),
96+
blocks, samples_);
97+
98+
DALIDataType itype = input.type();
99+
dim3 grid_dim = block_setup_.GridDim();
100+
dim3 block_dim = block_setup_.BlockDim();
101+
TYPE_SWITCH(output_type_, type2id, OType, CAST_ALLOWED_TYPES, (
102+
TYPE_SWITCH(itype, type2id, IType, CAST_ALLOWED_TYPES, (
103+
kernels::BatchedCastKernel<OType, IType>
104+
<<<grid_dim, block_dim, 0, ws.stream()>>>(samples_dev, blocks_dev);
105+
), DALI_FAIL(make_string("Invalid input type: ", itype));); // NOLINT(whitespace/parens)
106+
), DALI_FAIL(make_string("Invalid output type: ", output_type_));); // NOLINT(whitespace/parens)
107+
108+
/*
109+
TODO(michalz): Fix the kernel!
92110
// Calculate id of the earliest block that should process given sample
93111
for (int block_id = 0, sample_id = -1; block_id < blocks.size(); block_id++) {
94112
if (blocks[block_id].sample_idx != sample_id) {
@@ -112,6 +130,7 @@ void CastGPU::RunImpl(DeviceWorkspace &ws) {
112130
num_samples, block_volume_scale);
113131
), DALI_FAIL(make_string("Invalid input type: ", itype));); // NOLINT(whitespace/parens)
114132
), DALI_FAIL(make_string("Invalid output type: ", output_type_));); // NOLINT(whitespace/parens)
133+
*/
115134
}
116135

117136
DALI_REGISTER_OPERATOR(Cast, CastGPU, GPU);
Lines changed: 99 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2019-2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+
# Copyright (c) 2019-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -12,46 +12,105 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from nvidia.dali.pipeline import Pipeline
16-
import nvidia.dali.ops as ops
15+
import nose_utils
16+
from nvidia.dali import pipeline_def
17+
import nvidia.dali as dali
18+
import nvidia.dali.fn as fn
1719
import nvidia.dali.types as types
1820
import numpy as np
21+
from nose.tools import nottest
1922

20-
from test_utils import compare_pipelines
21-
from test_utils import RandomlyShapedDataIterator
22-
23-
class CastPipeline(Pipeline):
24-
def __init__(self, device, batch_size, iterator, cast_dtypes, num_threads=1, device_id=0):
25-
super(CastPipeline, self).__init__(batch_size, num_threads, device_id)
26-
self.layout = "HWC"
27-
self.device = device
28-
self.iterator = iterator
29-
self.inputs = ops.ExternalSource()
30-
self.cast = [ops.Cast(device=device, dtype=dtype) for dtype in cast_dtypes]
31-
32-
def define_graph(self):
33-
self.data = self.inputs()
34-
out = self.data.gpu() if self.device == 'gpu' else self.data
35-
for k in range(len(self.cast)):
36-
out = self.cast[k](out)
37-
return out
38-
39-
def iter_setup(self):
40-
data = self.iterator.next()
41-
self.feed_input(self.data, data, layout=self.layout)
42-
43-
def check_cast_operator_float16(device, batch_size, in_type, out_type):
44-
input_shape=(300, 400, 3)
45-
eii1 = RandomlyShapedDataIterator(batch_size, max_shape=input_shape, dtype=in_type)
46-
eii2 = RandomlyShapedDataIterator(batch_size, max_shape=input_shape, dtype=in_type)
47-
compare_pipelines(
48-
CastPipeline(device, batch_size, iter(eii1), [types.FLOAT16, out_type]),
49-
CastPipeline(device, batch_size, iter(eii2), [out_type]),
50-
batch_size=batch_size, N_iterations=5)
51-
52-
def test_cast_operator_float16():
23+
from test_utils import check_batch, np_type_to_dali
24+
25+
26+
def ref_cast(x, dtype):
27+
if np.issubdtype(dtype, np.integer):
28+
lo = np.iinfo(dtype).min
29+
hi = np.iinfo(dtype).max
30+
if np.issubdtype(x.dtype, np.floating):
31+
x = np.round(x)
32+
return x.clip(lo, hi).astype(dtype)
33+
else:
34+
return x.astype(dtype)
35+
36+
def random_shape(rng, ndim: int, max_size: int):
37+
if ndim == 0:
38+
return []
39+
max_size = int(max_size ** (1/ndim))
40+
return list(rng.integers(0, max_size, [ndim]))
41+
42+
def generate(rng, ndim: int, batch_size: int, in_dtype: np.dtype, out_dtype: np.dtype):
43+
lo, hi = -1000, 1000
44+
if np.issubdtype(out_dtype, np.integer):
45+
lo = np.iinfo(out_dtype).min
46+
hi = np.iinfo(out_dtype).max
47+
if hi < np.iinfo(np.int64).max:
48+
r = hi - lo
49+
hi += r // 2
50+
lo -= r // 2
51+
if np.issubdtype(in_dtype, np.integer):
52+
lo = max(np.iinfo(in_dtype).min, lo)
53+
hi = min(np.iinfo(in_dtype).max, hi)
54+
else:
55+
lo = max(-np.finfo(in_dtype).max, lo)
56+
hi = min(np.finfo(in_dtype).max, hi)
57+
58+
max_size = 100000 // batch_size
59+
out = [rng.uniform(lo, hi, size=random_shape(rng, ndim, max_size)).astype(in_dtype) for _ in range(batch_size)]
60+
if np.issubdtype(in_dtype, np.floating) and np.issubdtype(out_dtype, np.integer):
61+
for x in out:
62+
# avoid exactly halfway numbers - rounding is different for CPU and GPU
63+
halfway = x[x - np.floor(x) == 0.5]
64+
x[x - np.floor(x) == 0.5] = np.nextafter(halfway, np.Infinity)
65+
return out
66+
67+
rng = np.random.default_rng(1234)
68+
69+
@nottest
70+
def _test_operator_cast(ndim, batch_size, in_dtype, out_dtype, device):
71+
src = lambda: generate(rng, ndim, batch_size, in_dtype, out_dtype)
72+
@pipeline_def(batch_size=batch_size, num_threads=4, device_id=types.CPU_ONLY_DEVICE_ID if device == 'cpu' else 0)
73+
def cast_pipe():
74+
inp = fn.external_source(src)
75+
inp_dev = inp.gpu() if device == 'gpu' else inp
76+
return inp, fn.cast(inp_dev, dtype=np_type_to_dali(out_dtype))
77+
78+
pipe = cast_pipe()
79+
pipe.build()
80+
for _ in range(10):
81+
inp, out = pipe.run()
82+
if device=='gpu':
83+
out = out.as_cpu()
84+
ref = [ref_cast(np.array(x), out_dtype) for x in inp]
85+
86+
# work around a bug in numpy: when the argument is a scalar fp32 or fp16, nextafter
87+
# promotes it to fp64, resulting in insufficient epsilon - we want an epsilon of the
88+
# type specified in out_dtype
89+
eps = 0 if np.issubdtype(out_dtype, np.integer) else (np.nextafter(out_dtype([1]), 2) - 1.0)[0]
90+
91+
for i in range(batch_size):
92+
if not np.allclose(out[i], ref[i], eps):
93+
print("At sample", i)
94+
I = np.array(inp[i])
95+
O = np.array(out[i])
96+
R = ref[i]
97+
print(I)
98+
print(R)
99+
print(O)
100+
mask = np.logical_not(np.isclose(O, R, eps))
101+
print("Differences at", mask)
102+
print(I[mask])
103+
print(R[mask])
104+
print(O[mask])
105+
print(np.count_nonzero(mask), "wrong values out of", mask.size)
106+
assert np.array_equal(out[i], ref[i])
107+
108+
109+
def test_operator_cast():
110+
types = [np.uint8, np.int8, np.uint16, np.int16, np.uint32, np.int32, np.uint64, np.int64, np.float16, np.float32]
53111
for device in ['cpu', 'gpu']:
54-
for batch_size in [3]:
55-
for in_type in [np.uint8, np.int64]:
56-
for out_type in [types.FLOAT, types.INT8]:
57-
yield check_cast_operator_float16, device, batch_size, in_type, out_type
112+
for in_type in types:
113+
for out_type in types:
114+
ndim = rng.integers(0, 4)
115+
batch_size = rng.integers(1, 11)
116+
yield _test_operator_cast, ndim, batch_size, in_type, out_type, device

0 commit comments

Comments
 (0)