Skip to content

Commit

Permalink
Revert "fix batch_norm amp autocast" (pytorch#8547)
Browse files Browse the repository at this point in the history
  • Loading branch information
tengyifei authored Jan 9, 2025
1 parent de75686 commit 05ca224
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 43 deletions.
14 changes: 0 additions & 14 deletions test/test_autocast.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,20 +484,6 @@ def test_autocast_tpu_check_dtype(self):
assert not torch.is_autocast_xla_enabled()


class TestOtherOps(unittest.TestCase):

@unittest.skipIf(not (xm.get_xla_supported_devices("TPU") or
xm.get_xla_supported_devices("GPU")),
f"bfloat16 is only enabled for TPU and GPU")
def test_batch_norm(self):
device = xm.xla_device()
data = torch.randn(4, 16, 32, 32, device=device, dtype=torch.bfloat16)
with autocast(device, dtype=torch.bfloat16):
output = torch.nn.BatchNorm2d(16)(data)
xm.mark_step()
self.assertEqual(output.dtype, torch.bfloat16)


if __name__ == "__main__":
test = unittest.main(verbosity=FLAGS.verbosity, exit=False)
sys.exit(0 if test.result.wasSuccessful() else 1)
48 changes: 19 additions & 29 deletions torch_xla/csrc/batch_norm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,10 @@
namespace torch_xla {
namespace {

bool IsF32BatchNormWithLowerFPInputs(const xla::XlaOp& input,
const xla::XlaOp& weight) {
static constexpr std::array<xla::PrimitiveType, 9> lowerPrecistionTypes = {
xla::PrimitiveType::F8E5M2, xla::PrimitiveType::F8E4M3,
xla::PrimitiveType::F8E4M3FN, xla::PrimitiveType::F8E4M3B11FNUZ,
xla::PrimitiveType::F8E3M4, xla::PrimitiveType::F8E5M2FNUZ,
xla::PrimitiveType::F8E4M3FNUZ, xla::PrimitiveType::F16,
xla::PrimitiveType::BF16};
if (std::find(lowerPrecistionTypes.begin(), lowerPrecistionTypes.end(),
ShapeHelper::ShapeOfXlaOp(input).element_type()) !=
lowerPrecistionTypes.end() &&
bool IsF32BatchNormWithFP16Inputs(const xla::XlaOp& input,
const xla::XlaOp& weight) {
if (ShapeHelper::ShapeOfXlaOp(input).element_type() ==
xla::PrimitiveType::F16 &&
ShapeHelper::ShapeOfXlaOp(weight).element_type() ==
xla::PrimitiveType::F32) {
return true;
Expand Down Expand Up @@ -46,39 +39,37 @@ xla::XlaOp BatchNormVarianceInvert(xla::XlaOp variance, float eps_value) {

BatchNormOutput BuildBatchNormTraining(xla::XlaOp input, xla::XlaOp weight,
xla::XlaOp bias, float eps_value) {
bool is_batchnorm_with_lower_fp_inputs =
IsF32BatchNormWithLowerFPInputs(input, weight);
bool is_batchnorm_with_fp16_inputs =
IsF32BatchNormWithFP16Inputs(input, weight);
// Handle the mixed precision use case.
if (is_batchnorm_with_lower_fp_inputs) {
if (is_batchnorm_with_fp16_inputs) {
input = xla::ConvertElementType(input, xla::PrimitiveType::F32);
}
xla::XlaOp outputs = xla::BatchNormTraining(input, weight, bias, eps_value,
/*feature_index=*/1);
xla::XlaOp output = xla::GetTupleElement(outputs, 0);
xla::XlaOp batch_mean = xla::GetTupleElement(outputs, 1);
xla::XlaOp batch_variance = xla::GetTupleElement(outputs, 2);
if (is_batchnorm_with_lower_fp_inputs) {
output = xla::ConvertElementType(
output, ShapeHelper::ShapeOfXlaOp(input).element_type());
if (is_batchnorm_with_fp16_inputs) {
output = xla::ConvertElementType(output, xla::PrimitiveType::F16);
}
return {output, batch_mean, batch_variance};
}

xla::XlaOp BuildBatchNormInference(xla::XlaOp input, xla::XlaOp weight,
xla::XlaOp bias, xla::XlaOp mean,
xla::XlaOp variance, float eps_value) {
bool is_batchnorm_with_lower_fp_inputs =
IsF32BatchNormWithLowerFPInputs(input, weight);
bool is_batchnorm_with_fp16_inputs =
IsF32BatchNormWithFP16Inputs(input, weight);
// Handle the mixed precision use case.
if (is_batchnorm_with_lower_fp_inputs) {
if (is_batchnorm_with_fp16_inputs) {
input = xla::ConvertElementType(input, xla::PrimitiveType::F32);
}
xla::XlaOp output =
xla::BatchNormInference(input, weight, bias, mean, variance, eps_value,
/*feature_index=*/1);
if (is_batchnorm_with_lower_fp_inputs) {
output = xla::ConvertElementType(
output, ShapeHelper::ShapeOfXlaOp(input).element_type());
if (is_batchnorm_with_fp16_inputs) {
output = xla::ConvertElementType(output, xla::PrimitiveType::F16);
}
return output;
}
Expand All @@ -87,10 +78,10 @@ BatchNormGrads BuildBatchNormBackward(xla::XlaOp grad, xla::XlaOp input,
xla::XlaOp weight, xla::XlaOp save_mean,
xla::XlaOp save_invstd, bool training,
float eps_value) {
bool is_batchnorm_with_lower_fp_inputs =
IsF32BatchNormWithLowerFPInputs(input, weight);
bool is_batchnorm_with_fp16_inputs =
IsF32BatchNormWithFP16Inputs(input, weight);
// Handle the mixed precision use case.
if (is_batchnorm_with_lower_fp_inputs) {
if (is_batchnorm_with_fp16_inputs) {
input = xla::ConvertElementType(input, xla::PrimitiveType::F32);
grad = xla::ConvertElementType(grad, xla::PrimitiveType::F32);
}
Expand All @@ -100,9 +91,8 @@ BatchNormGrads BuildBatchNormBackward(xla::XlaOp grad, xla::XlaOp input,
xla::XlaOp grad_input = xla::GetTupleElement(grads, 0);
xla::XlaOp grad_weight = xla::GetTupleElement(grads, 1);
xla::XlaOp grad_bias = xla::GetTupleElement(grads, 2);
if (is_batchnorm_with_lower_fp_inputs) {
grad_input = xla::ConvertElementType(
grad_input, ShapeHelper::ShapeOfXlaOp(input).element_type());
if (is_batchnorm_with_fp16_inputs) {
grad_input = xla::ConvertElementType(grad_input, xla::PrimitiveType::F16);
}
return {grad_input, grad_weight, grad_bias};
}
Expand Down

0 comments on commit 05ca224

Please sign in to comment.