Skip to content

Commit b33c5ac

Browse files
committed
Fix scalar batch handling in arithmetic ops (#1449)
Adjust test for checking if the Tensor should be considered a scalar - uniform shape of 1-dim scalar elements. Allow Scalar batch to be broadcasted in the operation properly - no offset in tiles for scalar-like data, same code-path as constants. Return shape is always a batch - fixed for Constant/Scalar inputs. Add C++ unit tests for Scalar Batch. Test if Pipeline behaves properly when switching between scalar and non-scalar inputs. Python test was extended to cover Scalar inputs, size of some of the tests was reduced. Signed-off-by: Krzysztof Lecki <[email protected]>
1 parent a5c275a commit b33c5ac

File tree

6 files changed

+400
-128
lines changed

6 files changed

+400
-128
lines changed

dali/operators/expressions/arithmetic.h

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,8 @@ inline std::vector<ExprImplTask> CreateExecutionTasks(const ExprNode &expr, Expr
145145
return result;
146146
}
147147

148-
inline TensorListShape<> ShapePromotion(std::string op, span<const TensorListShape<> *> shapes) {
148+
inline TensorListShape<> ShapePromotion(std::string op, span<const TensorListShape<> *> shapes,
149+
int batch_size) {
149150
const TensorListShape<> *out_shape = nullptr;
150151
for (int i = 0; i < shapes.size(); i++) {
151152
if (IsScalarLike(*shapes[i]))
@@ -159,14 +160,15 @@ inline TensorListShape<> ShapePromotion(std::string op, span<const TensorListSha
159160
*out_shape, ", ", *shapes[i], ")."));
160161
}
161162
}
162-
return out_shape ? *out_shape : TensorListShape<>{{1}};
163+
return out_shape ? *out_shape : uniform_list_shape(batch_size, {1});
163164
}
164165

165166
template <typename Backend>
166167
DLL_PUBLIC inline const TensorListShape<> &PropagateShapes(ExprNode &expr,
167-
const workspace_t<Backend> &ws) {
168+
const workspace_t<Backend> &ws,
169+
int batch_size) {
168170
if (expr.GetNodeType() == NodeType::Constant) {
169-
expr.SetShape(TensorListShape<>{{1}});
171+
expr.SetShape(uniform_list_shape(batch_size, {1}));
170172
return expr.GetShape();
171173
}
172174
if (expr.GetNodeType() == NodeType::Tensor) {
@@ -182,9 +184,9 @@ DLL_PUBLIC inline const TensorListShape<> &PropagateShapes(ExprNode &expr,
182184
SmallVector<const TensorListShape<> *, kMaxArity> shapes;
183185
shapes.resize(subexpression_count);
184186
for (int i = 0; i < subexpression_count; i++) {
185-
shapes[i] = &PropagateShapes<Backend>(func[i], ws);
187+
shapes[i] = &PropagateShapes<Backend>(func[i], ws, batch_size);
186188
}
187-
func.SetShape(ShapePromotion(func.GetFuncName(), make_span(shapes)));
189+
func.SetShape(ShapePromotion(func.GetFuncName(), make_span(shapes), batch_size));
188190
return func.GetShape();
189191
}
190192

@@ -242,7 +244,7 @@ class ArithmeticGenericOp : public Operator<Backend> {
242244
types_layout_inferred_ = true;
243245
}
244246

245-
result_shape_ = PropagateShapes<Backend>(*expr_, ws);
247+
result_shape_ = PropagateShapes<Backend>(*expr_, ws, batch_size_);
246248
AllocateIntermediateNodes();
247249
exec_order_ = CreateExecutionTasks<Backend>(*expr_, cache_, ws.has_stream() ? ws.stream() : 0);
248250

dali/operators/expressions/arithmetic_meta.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -429,7 +429,7 @@ inline ArithmeticOp NameToOp(const std::string &op_name) {
429429
}
430430

431431
inline bool IsScalarLike(const TensorListShape<> &shape) {
432-
return shape.num_samples() == 1 && shape.num_elements() == 1;
432+
return is_uniform(shape) && shape.sample_dim() == 1 && shape.tensor_shape_span(0)[0] == 1;
433433
}
434434

435435
} // namespace dali

0 commit comments

Comments
 (0)