From 196cab3fe784125047fc13d53c9e27baafc5212a Mon Sep 17 00:00:00 2001 From: Rui <179625410+rpsilva-aws@users.noreply.github.com> Date: Thu, 9 Jan 2025 18:50:07 -0800 Subject: [PATCH] [Computation Hash] Introduce deterministic hash for user computations (#8539) --- torch_xla/csrc/init_python_bindings.cpp | 5 +- torch_xla/csrc/runtime/BUILD | 2 + torch_xla/csrc/runtime/computation_client.cc | 9 ++ torch_xla/csrc/runtime/computation_client.h | 13 +- torch_xla/csrc/runtime/xla_util.cc | 21 +++ torch_xla/csrc/runtime/xla_util.h | 6 + torch_xla/csrc/runtime/xla_util_test.cc | 151 ++++++++++++++++++- torch_xla/csrc/xla_graph_executor.cpp | 21 +-- 8 files changed, 211 insertions(+), 17 deletions(-) diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 3c34761da5b5..c4cbd8030927 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -1169,9 +1169,8 @@ class PyLoweringContext { // Create a serialized HloModule protobuf from a lowered graph py::bytes GetHlo() { const xla::HloModuleProto& proto = computation.proto(); - std::string result; - proto.SerializeToString(&result); - return result; + return ConsumeValue( + runtime::util::GetDeterministicSerializedModuleProto(proto)); } // Create human-readable HloModule protobuf text from a lowered graph diff --git a/torch_xla/csrc/runtime/BUILD b/torch_xla/csrc/runtime/BUILD index daf93bda3650..213761748c7d 100644 --- a/torch_xla/csrc/runtime/BUILD +++ b/torch_xla/csrc/runtime/BUILD @@ -50,6 +50,7 @@ cc_library( ":types", ":util", ":xla_coordinator", + ":xla_util", "//torch_xla/csrc:device", "//torch_xla/csrc:dtype", "@com_google_absl//absl/memory", @@ -460,6 +461,7 @@ ptxla_cc_test( size = "small", srcs = ["xla_util_test.cc"], deps = [ + ":debug_macros", ":xla_util", "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest_main", diff --git a/torch_xla/csrc/runtime/computation_client.cc b/torch_xla/csrc/runtime/computation_client.cc index af304fc3ec6b..b9b7df530d16 100644 --- a/torch_xla/csrc/runtime/computation_client.cc +++ b/torch_xla/csrc/runtime/computation_client.cc @@ -13,6 +13,7 @@ #include "torch_xla/csrc/runtime/debug_macros.h" #include "torch_xla/csrc/runtime/env_vars.h" #include "torch_xla/csrc/runtime/sys_util.h" +#include "torch_xla/csrc/runtime/xla_util.h" #include "tsl/platform/stacktrace_handler.h" #include "xla/status_macros.h" @@ -194,5 +195,13 @@ metrics::Metric* ComputationClient::OutboundDataMetric() { return metric; } +::absl::StatusOr +ComputationClient::Computation::ComputeHash(const xla::HloModuleProto& proto, + const std::string& name) { + TF_ASSIGN_OR_RETURN(auto serialized_status, + util::GetDeterministicSerializedModuleProto(proto)); + return torch::lazy::MHash(name, serialized_status); +} + } // namespace runtime } // namespace torch_xla diff --git a/torch_xla/csrc/runtime/computation_client.h b/torch_xla/csrc/runtime/computation_client.h index fcd1adcf51ee..5bd295031bd6 100644 --- a/torch_xla/csrc/runtime/computation_client.h +++ b/torch_xla/csrc/runtime/computation_client.h @@ -115,8 +115,8 @@ class ComputationClient { computation_(std::move(computation)), devices_(std::move(devices)) { program_shape_ = ConsumeValue(computation_.GetProgramShape()); - hash_ = - torch::lazy::MHash(name, computation_.proto().SerializeAsString()); + const xla::HloModuleProto& proto = computation_.proto(); + hash_ = ConsumeValue(ComputeHash(proto, name)); } Computation(std::string name, xla::XlaComputation computation, @@ -159,7 +159,7 @@ class ComputationClient { // here. xla::XlaComputation move_computation() { if (computation_moved_) { - XLA_ERROR() << "Compuation has been moved\n"; + XLA_ERROR() << "Computation has been moved\n"; } computation_moved_ = true; return std::move(const_cast(this)->computation_); @@ -206,6 +206,13 @@ class ComputationClient { torch::lazy::hash_t hash_; std::string name_; + + // Computes a hash for an HLO module using deterministic proto + // serialization. It ensures consistent ordering of Map fields and repeated + // elements during during serialization. The resulting hash combines the + // serialized module with its computation name. + static ::absl::StatusOr ComputeHash( + const xla::HloModuleProto& proto, const std::string& name); }; using ComputationPtr = std::shared_ptr; diff --git a/torch_xla/csrc/runtime/xla_util.cc b/torch_xla/csrc/runtime/xla_util.cc index 7a3658c6857d..a91860636dd8 100644 --- a/torch_xla/csrc/runtime/xla_util.cc +++ b/torch_xla/csrc/runtime/xla_util.cc @@ -14,6 +14,7 @@ #include "tsl/platform/errors.h" #include "tsl/platform/stacktrace.h" #include "xla/shape_util.h" +#include "xla/tsl/lib/strings/proto_serialization.h" #include "xla/util.h" namespace torch_xla { @@ -115,6 +116,26 @@ torch::lazy::hash_t ShapeHash(const xla::Shape& shape) { return hash; } +absl::StatusOr GetDeterministicSerializedModuleProto( + const xla::HloModuleProto& hlo_proto) { + const size_t size = hlo_proto.ByteSizeLong(); + if (size == 0) { + return std::string(); + } + std::string serialized; + // Pre-allocate the string buffer for the serialized result. + serialized.resize(size); + + // Perform deterministic serialization ensuring consistent ordering + // of map fields and repeated elements + if (!tsl::SerializeToBufferDeterministic(hlo_proto, serialized.data(), + size)) { + return absl::InvalidArgumentError("Could not serialize module proto"); + } + + return serialized; +} + } // namespace util } // namespace runtime } // namespace torch_xla diff --git a/torch_xla/csrc/runtime/xla_util.h b/torch_xla/csrc/runtime/xla_util.h index 836a3f730eb1..391da16a4651 100644 --- a/torch_xla/csrc/runtime/xla_util.h +++ b/torch_xla/csrc/runtime/xla_util.h @@ -40,6 +40,12 @@ void CheckComputationStatus( torch::lazy::hash_t ShapeHash(const xla::Shape& shape); +// Return the serialized module proto, using deterministic proto serialization. +// It ensures consistent ordering of Map fields and repeated elements during +// serialization. +absl::StatusOr GetDeterministicSerializedModuleProto( + const xla::HloModuleProto& hlo_proto); + } // namespace util } // namespace runtime } // namespace torch_xla diff --git a/torch_xla/csrc/runtime/xla_util_test.cc b/torch_xla/csrc/runtime/xla_util_test.cc index 7b45f7daab95..b167cfe4d360 100644 --- a/torch_xla/csrc/runtime/xla_util_test.cc +++ b/torch_xla/csrc/runtime/xla_util_test.cc @@ -3,12 +3,15 @@ #include #include +#include #include +#include #include #include #include "absl/status/status.h" #include "absl/types/span.h" +#include "torch_xla/csrc/runtime/debug_macros.h" #include "tsl/platform/errors.h" #include "tsl/platform/protobuf.h" #include "tsl/platform/status_matchers.h" @@ -46,7 +49,7 @@ absl::StatusOr ParseTextProto(const std::string& text_proto) { return parsed_proto; } -TEST(XlaUtilrest, CreateModule) { +TEST(XlaUtilTest, CreateModule) { TF_ASSERT_OK_AND_ASSIGN( xla::HloModuleProto hlo_module_proto, ParseTextProto( @@ -102,7 +105,7 @@ TEST(XlaUtilrest, CreateModule) { EXPECT_EQ((*got)->computation_count(), 1); } -TEST(XlaUtilrest, XlaToHlo) { +TEST(XlaUtilTest, XlaToHlo) { xla::Shape input_shape = xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {2, 2}); xla::XlaBuilder builder("AddComputation"); @@ -116,6 +119,150 @@ TEST(XlaUtilrest, XlaToHlo) { HasSubstr("ROOT %add.3")))); } +TEST(XlaUtilTest, TestDeterministicModuleProtoSerializationEmptyProto) { + xla::HloModuleProto empty_proto; + auto result = + ::ConsumeValue(GetDeterministicSerializedModuleProto(empty_proto)); + // Verify that the result is an empty string + EXPECT_TRUE(result.empty()); +} + +TEST(XlaUtilTest, TestDeterministicModuleProtoSerialization) { + // Create a test HLO module with a known structure + TF_ASSERT_OK_AND_ASSIGN( + xla::HloModuleProto hlo_module_proto, + ParseTextProto( + R"pb( + name: "myname" + id: 9 + entry_computation_name: "MyCustomName.9" + entry_computation_id: 9 + computations { + id: 9 + name: "MyCustomName.9" + instructions: { + name: "p0.1" + id: 1 + opcode: "parameter" + shape: { + element_type: S64 + layout { tail_padding_alignment_in_elements: 1 } + } + metadata { + op_type: "xla__device_data" + op_name: "xla__device_data" + source_file: "/ansible/pytorch/xla/small_test.py" + source_line: 14 + stack_frame_id: 1 + } + } + instructions: { + name: "p1.2" + id: 2 + opcode: "parameter" + parameter_number: 1 + shape: { + element_type: S64 + layout { tail_padding_alignment_in_elements: 1 } + } + metadata { + op_type: "xla__device_data" + op_name: "xla__device_data" + source_file: "/ansible/pytorch/xla/small_test.py" + source_line: 13 + stack_frame_id: 2 + } + } + instructions: { + name: "call.7" + id: 7 + opcode: "call" + shape: { + element_type: S64 + layout { tail_padding_alignment_in_elements: 1 } + } + metadata { + op_type: "xla___op_some_op" + op_name: "xla___op_some_op" + source_file: "/ansible/pytorch/xla/torch_xla/core/xla_op_registry.py" + source_line: 44 + stack_frame_id: 4 + } + called_computation_ids: 3 + operand_ids: 2 + operand_ids: 1 + } + instructions: { + name: "tuple.8" + id: 8 + opcode: "tuple" + shape: { + element_type: TUPLE + tuple_shapes { + element_type: S64 + layout { tail_padding_alignment_in_elements: 1 } + } + } + operand_ids: 7 + } + root_id: 8 + } + host_program_shape: { + parameters { + element_type: S64 + layout { tail_padding_alignment_in_elements: 1 } + } + parameters { + element_type: S64 + layout { tail_padding_alignment_in_elements: 1 } + } + result { + element_type: TUPLE + tuple_shapes { + element_type: S64 + layout { tail_padding_alignment_in_elements: 1 } + } + } + parameter_names: "p0" + parameter_names: "p1" + } + )pb")); + + // Define a set of dummy fixed key-value pairs for frontend attributes. + std::vector> attr_pairs = { + {"key1", "value1"}, + {"key2", "value2"}, + {"key3", "value3"}, + {"key4", "value4"}}; + + auto shuffle_and_hash = [&attr_pairs](xla::HloModuleProto hlo_module_proto) { + // Create a random number generator for shuffling. + std::random_device random_device; + std::mt19937 random_generator(random_device()); + + for (auto& computation : *hlo_module_proto.mutable_computations()) { + for (auto& instruction : *computation.mutable_instructions()) { + std::shuffle(attr_pairs.begin(), attr_pairs.end(), random_generator); + auto* frontend_attrs = instruction.mutable_frontend_attributes(); + // Add the dummy shuffled pairs to the frontend attributes. + for (const auto& pair : attr_pairs) { + (*frontend_attrs->mutable_map())[pair.first] = pair.second; + } + } + } + std::string serialized_proto = + ::ConsumeValue(GetDeterministicSerializedModuleProto(hlo_module_proto)); + return torch::lazy::Hash(serialized_proto); + }; + + // Compute hashes with different random orderings of attributes + torch::lazy::hash_t hash1 = shuffle_and_hash(hlo_module_proto); + torch::lazy::hash_t hash2 = shuffle_and_hash(hlo_module_proto); + // Verify that different orderings produce the same hash + ASSERT_EQ(hash1, hash2) + << "Hashes should match regardless of the frontend attribute ordering"; +} + } // namespace util } // namespace runtime } // namespace torch_xla diff --git a/torch_xla/csrc/xla_graph_executor.cpp b/torch_xla/csrc/xla_graph_executor.cpp index a0aa6e7150d2..5dfc0e4e2b0a 100644 --- a/torch_xla/csrc/xla_graph_executor.cpp +++ b/torch_xla/csrc/xla_graph_executor.cpp @@ -1200,12 +1200,13 @@ XLAGraphExecutor::LookupCachedCompile(const torch::lazy::hash_t& hash) { TORCH_LAZY_COUNTER("UncachedCompile", 1); return nullptr; } + std::string serialized_computation = + ConsumeValue(runtime::util::GetDeterministicSerializedModuleProto( + cached_computation->computation->computation().proto())); TF_VLOG(5) << "Graph hash " << torch::lazy::HashToString(hash) << " is computation hash " - << torch::lazy::HashToString(torch::lazy::Hash( - cached_computation->computation->computation() - .proto() - .SerializeAsString())); + << torch::lazy::HashToString( + torch::lazy::Hash(serialized_computation)); TORCH_LAZY_COUNTER("CachedCompile", 1); return cached_computation; } @@ -1443,11 +1444,13 @@ XLAGraphExecutor::CompilationResult XLAGraphExecutor::Compile( << coll.device << " done!"; TF_VLOG(5) << "Compiled program shape " << computations.front()->program_shape().ToString() << std::endl; - TF_VLOG(5) - << "Graph hash " << torch::lazy::HashToString(coll.hash) - << " is computation hash " - << torch::lazy::HashToString(torch::lazy::Hash( - computations.front()->computation().proto().SerializeAsString())); + std::string serialized_computation = + ConsumeValue(runtime::util::GetDeterministicSerializedModuleProto( + computations.front()->computation().proto())); + TF_VLOG(5) << "Graph hash " << torch::lazy::HashToString(coll.hash) + << " is computation hash " + << torch::lazy::HashToString( + torch::lazy::Hash(serialized_computation)); if (use_autosharding) { const xla::HloModuleProto& computation_proto =