We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 9d59ff2 commit 8075ca8Copy full SHA for 8075ca8
tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf.cc
@@ -2306,7 +2306,8 @@ class ConvertFusedBatchNormBase : public OpRewritePattern<FusedBatchNormOpT> {
2306
// Apply Bessel's correction on the variance.
2307
int total_input_size = bn_train_input_type_tensor.getNumElements();
2308
int total_scale_size = scale_type_tensor.getNumElements();
2309
- int sample_size = total_input_size / total_scale_size;
+ int sample_size =
2310
+ total_scale_size > 0 ? total_input_size / total_scale_size : 0;
2311
int sample_size_minus_one = std::max(1, sample_size - 1);
2312
double factor = static_cast<double>(sample_size) /
2313
static_cast<double>(sample_size_minus_one);
0 commit comments