Skip to content

Commit

Permalink
Keep divide by zero check
Browse files Browse the repository at this point in the history
  • Loading branch information
trevor-m committed May 30, 2024
1 parent 9d59ff2 commit 8075ca8
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2306,7 +2306,8 @@ class ConvertFusedBatchNormBase : public OpRewritePattern<FusedBatchNormOpT> {
// Apply Bessel's correction on the variance.
int total_input_size = bn_train_input_type_tensor.getNumElements();
int total_scale_size = scale_type_tensor.getNumElements();
int sample_size = total_input_size / total_scale_size;
int sample_size =
total_scale_size > 0 ? total_input_size / total_scale_size : 0;
int sample_size_minus_one = std::max(1, sample_size - 1);
double factor = static_cast<double>(sample_size) /
static_cast<double>(sample_size_minus_one);
Expand Down

0 comments on commit 8075ca8

Please sign in to comment.