Skip to content

Commit 8075ca8

Browse files
committed
Keep divide by zero check
1 parent 9d59ff2 commit 8075ca8

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2306,7 +2306,8 @@ class ConvertFusedBatchNormBase : public OpRewritePattern<FusedBatchNormOpT> {
23062306
// Apply Bessel's correction on the variance.
23072307
int total_input_size = bn_train_input_type_tensor.getNumElements();
23082308
int total_scale_size = scale_type_tensor.getNumElements();
2309-
int sample_size = total_input_size / total_scale_size;
2309+
int sample_size =
2310+
total_scale_size > 0 ? total_input_size / total_scale_size : 0;
23102311
int sample_size_minus_one = std::max(1, sample_size - 1);
23112312
double factor = static_cast<double>(sample_size) /
23122313
static_cast<double>(sample_size_minus_one);

0 commit comments

Comments
 (0)