Skip to content

Commit 6e5659e

Browse files
authored
Metadata agnostic user computation hash (pytorch#8550)
1 parent c394d1b commit 6e5659e

File tree

7 files changed

+111
-29
lines changed

7 files changed

+111
-29
lines changed

test/neuron/run_tests.sh

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@ MAX_GRAPH_SIZE=500
66
GRAPH_CHECK_FREQUENCY=100
77
VERBOSITY=2
88

9+
# Utils file
10+
source "${CDIR}/utils/run_tests_utils.sh"
11+
912
# Note [Keep Going]
1013
#
1114
# Set the `CONTINUE_ON_ERROR` flag to `true` to make the CI tests continue on error.
@@ -93,16 +96,6 @@ function run_eager_debug {
9396
XLA_USE_EAGER_DEBUG_MODE=1 run_test "$@"
9497
}
9598

96-
function run_save_tensor_ir {
97-
echo "Running in save tensor file mode: $@"
98-
XLA_SAVE_TENSORS_FILE="/tmp/xla_test_save_ir.txt" XLA_SAVE_TENSORS_FMT="text" run_test "$@"
99-
}
100-
101-
function run_save_tensor_hlo {
102-
echo "Running in save tensor file mode: $@"
103-
XLA_SAVE_TENSORS_FILE="/tmp/xla_test_save_ir.txt" XLA_SAVE_TENSORS_FMT="hlo" run_test "$@"
104-
}
105-
10699
function run_pt_xla_debug {
107100
echo "Running in save tensor file mode: $@"
108101
PT_XLA_DEBUG=1 PT_XLA_DEBUG_FILE="/tmp/pt_xla_debug.txt" run_test "$@"
@@ -166,16 +159,16 @@ function run_xla_op_tests1 {
166159
run_test "$CDIR/dynamo/test_num_output.py"
167160
run_test "$CDIR/dynamo/test_graph_input_matcher.py"
168161
run_test "$CDIR/dynamo/test_dynamo_config.py"
169-
run_save_tensor_ir "$CDIR/dynamo/test_dynamo_graph_dump.py"
162+
run_save_tensor_ir run_test "$CDIR/dynamo/test_dynamo_graph_dump.py"
170163
#run_test "$CDIR/test_data_type.py"
171164
run_use_bf16 "$CDIR/test_data_type.py"
172165
run_downcast_bf16 "$CDIR/test_data_type.py"
173166
#run_test "$CDIR/test_fp8.py"
174167
run_xla_ir_debug "$CDIR/test_env_var_mapper.py"
175168
run_xla_hlo_debug "$CDIR/test_env_var_mapper.py"
176169
run_xla_hlo_debug "$CDIR/stablehlo/test_stablehlo_save_load.py"
177-
run_save_tensor_ir "$CDIR/spmd/test_spmd_graph_dump.py"
178-
run_save_tensor_hlo "$CDIR/spmd/test_spmd_graph_dump.py"
170+
run_save_tensor_ir run_test "$CDIR/spmd/test_spmd_graph_dump.py"
171+
run_save_tensor_hlo run_test "$CDIR/spmd/test_spmd_graph_dump.py"
179172
}
180173

181174
function run_xla_op_tests2 {
@@ -230,6 +223,7 @@ function run_xla_op_tests3 {
230223
run_torchrun "$CDIR/pjrt/test_torchrun.py"
231224
run_test "$CDIR/test_persistent_cache.py"
232225
run_test "$CDIR/test_devices.py"
226+
run_xla_ir_hlo_debug run_test "$CDIR/test_user_computation_debug_cache.py"
233227

234228
#python3 examples/data_parallel/train_resnet_xla_ddp.py # compiler error
235229
#python3 examples/fsdp/train_resnet_fsdp_auto_wrap.py

test/run_tests.sh

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -85,11 +85,6 @@ function run_test_without_functionalization {
8585
XLA_DISABLE_FUNCTIONALIZATION=1 run_test "$@"
8686
}
8787

88-
function run_xla_ir_debug {
89-
echo "Running with XLA_IR_DEBUG: $@"
90-
XLA_IR_DEBUG=1 run_test "$@"
91-
}
92-
9388
function run_use_bf16 {
9489
echo "Running with XLA_USE_BF16: $@"
9590
XLA_USE_BF16=1 run_test "$@"
@@ -100,11 +95,6 @@ function run_downcast_bf16 {
10095
XLA_DOWNCAST_BF16=1 run_test "$@"
10196
}
10297

103-
function run_xla_hlo_debug {
104-
echo "Running with XLA_IR_DEBUG: $@"
105-
XLA_HLO_DEBUG=1 run_test "$@"
106-
}
107-
10898
function run_dynamic {
10999
echo "Running in DynamicShape mode: $@"
110100
XLA_EXPERIMENTAL="nonzero:masked_select:masked_scatter:nms" run_test "$@"
@@ -191,9 +181,9 @@ function run_xla_op_tests1 {
191181
run_use_bf16 "$CDIR/test_data_type.py"
192182
run_downcast_bf16 "$CDIR/test_data_type.py"
193183
run_test "$CDIR/test_fp8.py"
194-
run_xla_ir_debug "$CDIR/test_env_var_mapper.py"
195-
run_xla_hlo_debug "$CDIR/test_env_var_mapper.py"
196-
run_xla_hlo_debug "$CDIR/stablehlo/test_stablehlo_save_load.py"
184+
run_xla_ir_debug run_test "$CDIR/test_env_var_mapper.py"
185+
run_xla_hlo_debug run_test "$CDIR/test_env_var_mapper.py"
186+
run_xla_hlo_debug run_test "$CDIR/stablehlo/test_stablehlo_save_load.py"
197187
run_save_tensor_ir run_test "$CDIR/spmd/test_spmd_graph_dump.py"
198188
run_save_tensor_hlo run_test "$CDIR/spmd/test_spmd_graph_dump.py"
199189
}
@@ -224,7 +214,7 @@ function run_xla_op_tests3 {
224214
run_test "$CDIR/stablehlo/test_composite.py"
225215
run_test "$CDIR/stablehlo/test_pt2e_qdq.py"
226216
run_test "$CDIR/stablehlo/test_stablehlo_custom_call.py"
227-
run_xla_hlo_debug "$CDIR/stablehlo/test_stablehlo_inference.py"
217+
run_xla_hlo_debug run_test "$CDIR/stablehlo/test_stablehlo_inference.py"
228218
run_test "$CDIR/stablehlo/test_stablehlo_compile.py"
229219
run_test "$CDIR/stablehlo/test_unbounded_dynamism.py"
230220
run_test "$CDIR/quantized_ops/test_quantized_matmul.py"
@@ -252,6 +242,7 @@ function run_xla_op_tests3 {
252242
# NOTE: this line below is testing export and don't care about GPU
253243
PJRT_DEVICE=CPU CPU_NUM_DEVICES=1 run_coverage "$CDIR/test_core_aten_ops.py"
254244
run_test "$CDIR/test_pallas.py"
245+
run_xla_ir_hlo_debug run_test "$CDIR/test_user_computation_debug_cache.py"
255246

256247
# CUDA tests
257248
if [ -x "$(command -v nvidia-smi)" ]; then
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
import os
2+
import sys
3+
import unittest
4+
5+
import torch
6+
import torch_xla
7+
import torch_xla.core.xla_builder as xb
8+
import torch_xla.core.xla_model as xm
9+
import torch_xla.debug.metrics as met
10+
11+
parent_folder = os.path.dirname(os.path.dirname(__file__))
12+
sys.path.append(parent_folder)
13+
14+
15+
class TestUserComputationDebugCache(unittest.TestCase):
16+
17+
def setUp(self):
18+
self.assertTrue(
19+
os.getenv("XLA_IR_DEBUG") == '1' and os.getenv("XLA_HLO_DEBUG") == '1',
20+
"XLA_IR_DEBUG and XLA_HLO_DEBUG must be set for this test.",
21+
)
22+
23+
def test_user_computation_debug_cache(self):
24+
"""
25+
Test that user computations with the same IR, but different OpMetadata
26+
are cached correctly. The metadata is generated when the environment
27+
variables that enable the Python stack trace for the IR nodes, and
28+
subsequently, the XLA HLO metadata; `XLA_IR_DEBUG` and `XLA_HLO_DEBUG`
29+
respectively.
30+
"""
31+
32+
met.clear_all()
33+
34+
def fn_op(a, b):
35+
return xb.Op.tuple([xb.Op.max(a, b) - xb.Op.min(a, b)])
36+
37+
def input_scope_0(tensor):
38+
return [torch.sin(tensor), torch.cos(tensor)]
39+
40+
def input_scope_1(tensor):
41+
return [torch.sin(tensor), torch.cos(tensor)]
42+
43+
device = xm.xla_device()
44+
init_tensor = torch.tensor(10).to(device)
45+
46+
def create_user_computation(fn):
47+
inputs = fn(init_tensor)
48+
comp = xb.create_computation("computation", fn_op,
49+
[xb.tensor_shape(p) for p in inputs])
50+
_ = torch_xla._XLAC._xla_user_computation("xla::computation", inputs,
51+
comp)
52+
torch_xla.sync()
53+
54+
# Create and launch the graph execution with the same IR graph, but with
55+
# different input tensor scope. When 'XLA_HLO_DEBUG' and 'XLA_IR_DEBUG' are
56+
# enabled, this will generate different OpMetadata for different input
57+
# scopes `input_scope_0` and `input_scope_1`, namely `source_line`.
58+
create_user_computation(input_scope_0)
59+
create_user_computation(input_scope_1)
60+
61+
# Ensure that we only compile once, and hit the cache the next time. This
62+
# is expected as the OpMetadata will not impact the hash of the user
63+
# computation, as the compiled executable is semantically the same.
64+
self.assertEqual(met.counter_value("UncachedCompile"), 1)
65+
self.assertEqual(met.counter_value("CachedCompile"), 1)
66+
67+
68+
if __name__ == "__main__":
69+
test = unittest.main()
70+
sys.exit(0 if test.result.wasSuccessful() else 1)

test/tpu/run_tests.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ python3 "$TEST_CDIR/torch_distributed/test_torch_distributed_all_reduce_xla_back
4444
python3 "$TEST_CDIR/torch_distributed/test_torch_distributed_multi_all_reduce_xla_backend.py"
4545
python3 "$TEST_CDIR/torch_distributed/test_torch_distributed_reduce_scatter_xla_backend.py"
4646
python3 "$TEST_CDIR/quantized_ops/test_dot_general.py"
47+
run_xla_ir_hlo_debug python3 "$TEST_CDIR/test_user_computation_debug_cache.py"
4748

4849
# run examples, each test should takes <2 minutes
4950
python3 "$TEST_CDIR/../examples/data_parallel/train_resnet_spmd_data_parallel.py"

test/utils/run_tests_utils.sh

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,3 +54,24 @@ function run_save_tensor_hlo {
5454
echo "Running in save tensor file mode: $@"
5555
run_save_tensor "$run_test_func" "hlo" "$@"
5656
}
57+
58+
function run_xla_ir_debug {
59+
local run_test_func="$1"
60+
shift
61+
echo "Running with XLA_IR_DEBUG: $@"
62+
XLA_IR_DEBUG=1 "$run_test_func" "$@"
63+
}
64+
65+
function run_xla_hlo_debug {
66+
local run_test_func="$1"
67+
shift
68+
echo "Running with XLA_HLO_DEBUG: $@"
69+
XLA_HLO_DEBUG=1 "$run_test_func" "$@"
70+
}
71+
72+
function run_xla_ir_hlo_debug {
73+
local run_test_func="$1"
74+
shift
75+
echo "Running with XLA_IR_DEBUG and XLA_HLO_DEBUG: $@"
76+
XLA_IR_DEBUG=1 XLA_HLO_DEBUG=1 "$run_test_func" "$@"
77+
}

torch_xla/csrc/runtime/computation_client.cc

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,8 +196,13 @@ metrics::Metric* ComputationClient::OutboundDataMetric() {
196196
}
197197

198198
::absl::StatusOr<torch::lazy::hash_t>
199-
ComputationClient::Computation::ComputeHash(const xla::HloModuleProto& proto,
199+
ComputationClient::Computation::ComputeHash(xla::HloModuleProto proto,
200200
const std::string& name) {
201+
for (auto& computation : *proto.mutable_computations()) {
202+
for (auto& instruction : *computation.mutable_instructions()) {
203+
instruction.mutable_metadata()->Clear();
204+
}
205+
}
201206
TF_ASSIGN_OR_RETURN(auto serialized_status,
202207
util::GetDeterministicSerializedModuleProto(proto));
203208
return torch::lazy::MHash(name, serialized_status);

torch_xla/csrc/runtime/computation_client.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,7 @@ class ComputationClient {
212212
// elements during during serialization. The resulting hash combines the
213213
// serialized module with its computation name.
214214
static ::absl::StatusOr<torch::lazy::hash_t> ComputeHash(
215-
const xla::HloModuleProto& proto, const std::string& name);
215+
xla::HloModuleProto proto, const std::string& name);
216216
};
217217

218218
using ComputationPtr = std::shared_ptr<Computation>;

0 commit comments

Comments
 (0)