Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

The "step unsupported" graph break will make dynamo can't completely trace code after break #125141

Open
YangQun1 opened this issue Apr 29, 2024 · 0 comments
Labels
module: dynamo module: graph breaks oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@YangQun1
Copy link
Contributor

YangQun1 commented Apr 29, 2024

馃悰 Describe the bug

When executing the following code:

import torch


class MyExp(torch.autograd.Function):
    @staticmethod
    def forward(ctx, i):
        result = i.exp()
        ctx.save_for_backward(result)
        return result

    @staticmethod
    def backward(ctx, grad_output):
        result, = ctx.saved_tensors
        return grad_output * result


class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = torch.nn.Conv2d(64, 64, 1,  bias=False)

    def forward(self, x):
        y = self.conv(x)
        z = torch.relu(y)
        # the assert here is just used to cause a graph break, does not make sense.
        assert MyExp != None, "failed to import MyExp"
        loss = z.pow(2.0).sum()
        return loss


model = MyModule()
compiled_model = torch.compile(model, backend="inductor")

input = torch.randn([1, 64, 32, 32])
loss = compiled_model(input)
print("done")

I expected the code before/after assertion will be traced into two separated fx graph, because the assert caused an error "torch._dynamo.exc.Unsupported: comparison AutogradFunctionVariable() ConstantVariable(NoneType)", so a graph break happens.

But actually I got only one fx graph, which contains the ops before assertion. As to the code (loss = z.pow(2.0).sum()) after assertion, it's not traced by dynamo.

===== __compiled_fn_0 =====
<eval_with_key>.0 class GraphModule(torch.nn.Module):
   def forward(self, L_x_ : torch.Tensor):
       l_x_ = L_x_
       
       # File: /home/quyang/my_test/experiments/test_graph_break/test_step_unsupported_no_func.py:23, code: y = self.conv(x)
       y = self.L__self___conv(l_x_);  l_x_ = None
       
       # File: /home/quyang/my_test/experiments/test_graph_break/test_step_unsupported_no_func.py:24, code: z = torch.relu(y)

But when I change the code to call a function, the function can be correctly traced.

def loss_func(x):
    return x.pow(2.0).sum()

class MyModule(torch.nn.Module):
    def forward(self, x):
        y = self.conv(x)
        z = torch.relu(y)
        assert MyExp != None, "failed to import MyExp"
        loss = loss_func(z)
        return loss
===== __compiled_fn_0 =====
<eval_with_key>.0 class GraphModule(torch.nn.Module):
   def forward(self, L_x_ : torch.Tensor):
       l_x_ = L_x_
       
       # File: /home/quyang/qnpu/pt/src/pytorch-integration/experiments/test_step_unsupported/test.py:27, code: y = self.conv(x)
       y = self.L__self___conv(l_x_);  l_x_ = None
       
       # File: /home/quyang/qnpu/pt/src/pytorch-integration/experiments/test_step_unsupported/test.py:28, code: z = torch.relu(y)
       z = torch.relu(y);  y = None
       return (z,)

===== __compiled_fn_1 =====
<eval_with_key>.40 class GraphModule(torch.nn.Module):
   def forward(self, L_x_ : torch.Tensor):
       l_x_ = L_x_
       
       # File: /home/quyang/qnpu/pt/src/pytorch-integration/experiments/test_step_unsupported/test.py:18, code: return x.pow(2.0).sum()
       pow_1 = l_x_.pow(2.0);  l_x_ = None
       sum_1 = pow_1.sum();  pow_1 = None
       return (sum_1,)

But if there are still other non-function-call codes after the function calls, that codes still can't be traced:

    def forward(self, x):
        y = self.conv(x)
        z = torch.relu(y)
        assert MyExp != None, "failed to import MyExp"
        loss = loss_func(z)
        loss = torch.sigmoid(loss)/2.0 # can't be traced to fx graph
        return loss

Error logs

No response

Minified repro

No response

Versions

Collecting environment information...
PyTorch version: 2.4.0.dev20240429+cu118
Is debug build: False
CUDA used to build PyTorch: 11.8
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: 14.0.0-1ubuntu1.1
CMake version: version 3.22.1
Libc version: glibc-2.35

Python version: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0] (64-bit runtime)
Python platform: Linux-5.15.0-102-generic-x86_64-with-glibc2.35
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 43 bits physical, 48 bits virtual
Byte Order: Little Endian
CPU(s): 12
On-line CPU(s) list: 0-11
Vendor ID: GenuineIntel
Model name: Intel(R) Xeon(R) Gold 6132 CPU @ 2.60GHz
CPU family: 6
Model: 85
Thread(s) per core: 1
Core(s) per socket: 6
Socket(s): 2
Stepping: 0
BogoMIPS: 5187.81
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon nopl xtopology tsc_reliable nonstop_tsc cpuid tsc_known_freq pni pclmulqdq vmx ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single pti ssbd ibrs ibpb stibp tpr_shadow vnmi ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 invpcid avx512f avx512dq rdseed adx smap clflushopt clwb avx512cd avx512bw avx512vl xsaveopt xsavec xsaves arat pku ospke md_clear flush_l1d arch_capabilities
Virtualization: VT-x
Hypervisor vendor: VMware
Virtualization type: full
L1d cache: 384 KiB (12 instances)
L1i cache: 384 KiB (12 instances)
L2 cache: 12 MiB (12 instances)
L3 cache: 38.5 MiB (2 instances)
NUMA node(s): 1
NUMA node0 CPU(s): 0-11
Vulnerability Gather data sampling: Unknown: Dependent on hypervisor status
Vulnerability Itlb multihit: KVM: Mitigation: VMX disabled
Vulnerability L1tf: Mitigation; PTE Inversion; VMX flush not necessary, SMT disabled
Vulnerability Mds: Mitigation; Clear CPU buffers; SMT Host state unknown
Vulnerability Meltdown: Mitigation; PTI
Vulnerability Mmio stale data: Mitigation; Clear CPU buffers; SMT Host state unknown
Vulnerability Retbleed: Mitigation; IBRS
Vulnerability Spec rstack overflow: Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; IBRS, IBPB conditional, STIBP disabled, RSB filling, PBRSB-eIBRS Not affected
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected

Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] pytorch-triton==3.0.0+45fff310c8
[pip3] torch==2.4.0.dev20240429+cu118
[pip3] torchaudio==2.2.0.dev20240429+cu118
[pip3] torchvision==0.19.0.dev20240429+cu118
[conda] Could not collect

cc @ezyang @msaroufim @bdhirsh @anijain2305 @chauhang @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng

@jbschlosser jbschlosser added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module module: dynamo module: graph breaks labels Apr 29, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: dynamo module: graph breaks oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

2 participants