From 3784b2680b20f72bf702f459e8220d9cb0e45eee Mon Sep 17 00:00:00 2001 From: Rui <179625410+rpsilva-aws@users.noreply.github.com> Date: Fri, 10 Jan 2025 17:14:45 -0800 Subject: [PATCH] Introduce HLO graph bindings (#8551) --- torch_xla/csrc/init_python_bindings.cpp | 4 ++++ torch_xla/csrc/ir_dump_util.cpp | 3 +++ torch_xla/csrc/ir_dump_util.h | 1 + 3 files changed, 8 insertions(+) diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index c4cbd8030927..04dcbf526ed0 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -1318,6 +1318,10 @@ void InitXlaModuleBindings(py::module m) { [](const std::vector& tensors) -> std::string { return GetTensorsHloGraph(tensors, EmitMode::kHloReadable); }); + m.def("_get_xla_tensors_hlo_proto", + [](const std::vector& tensors) -> py::bytes { + return py::bytes(GetTensorsHloGraph(tensors, EmitMode::kHloProto)); + }); m.def("_get_xla_tensor_debug_info", [](const at::Tensor& tensor) -> std::string { return GetXLATensorDebugInfo(tensor); diff --git a/torch_xla/csrc/ir_dump_util.cpp b/torch_xla/csrc/ir_dump_util.cpp index 448cbf63d276..63a3f17e6cb7 100644 --- a/torch_xla/csrc/ir_dump_util.cpp +++ b/torch_xla/csrc/ir_dump_util.cpp @@ -287,6 +287,9 @@ std::string DumpUtil::ToHlo(c10::ArrayRef values, switch (mode) { case EmitMode::kHloReadable: return ConsumeValue(runtime::util::GetComputationHloText(computation)); + case EmitMode::kHloProto: + return ConsumeValue(runtime::util::GetDeterministicSerializedModuleProto( + computation.proto())); case EmitMode::kStableHloReadable: return hloToStablehlo(&computation.proto(), /* emit_bytecode = */ false); diff --git a/torch_xla/csrc/ir_dump_util.h b/torch_xla/csrc/ir_dump_util.h index 8c9124a3c44a..f47c93f82a1a 100644 --- a/torch_xla/csrc/ir_dump_util.h +++ b/torch_xla/csrc/ir_dump_util.h @@ -12,6 +12,7 @@ namespace torch_xla { enum class EmitMode { kHloReadable, + kHloProto, kStableHloReadable, kStableHloBytecode, };