Skip to content

Unable to run flex attention and torch.compile #1005

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
lkhphuc opened this issue Mar 22, 2025 · 3 comments
Closed

Unable to run flex attention and torch.compile #1005

lkhphuc opened this issue Mar 22, 2025 · 3 comments

Comments

@lkhphuc
Copy link
Contributor

lkhphuc commented Mar 22, 2025

Bug description

CONFIG_FILE=./torchtitan/models/llama/train_configs/debug_model.toml ./run_train.sh --model.use_fl ex_attn --training.compile

I have a long stack trace with TORCHDYNAMO_VERBOSE=1:

    File "/lustre1/tier2/users/phuc.lekhac/mambaforge/envs/titan/lib/python3.12/site-packages/torch/_dynamo/variables/builtin.py", line 2108, in call_id
      unimplemented(f"call_id with args {args}")
    File "/lustre1/tier2/users/phuc.lekhac/mambaforge/envs/titan/lib/python3.12/site-packages/torch/_dynamo/exc.py", line 439, in unimplemented
      raise Unsupported(msg, case_name=case_name)
  torch._dynamo.exc.Unsupported: call_id with args (NestedUserFunctionVariable(),)
  
  from user code:
     File "/lustre1/tier2/users/phuc.lekhac/mambaforge/envs/titan/lib/python3.12/site-packages/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py", line 171, in forwa
rd
      return self.checkpoint_fn(  # type: ignore[misc]
    File "/lustre1/tier2/users/phuc.lekhac/torchtitan/torchtitan/models/llama/model.py", line 364, in forward
      h = x + self.attention(self.attention_norm(x), freqs_cis)
    File "/lustre1/tier2/users/phuc.lekhac/torchtitan/torchtitan/models/llama/model.py", line 231, in forward
      self._init_flex_attn(seqlen=seqlen)
    File "/lustre1/tier2/users/phuc.lekhac/mambaforge/envs/titan/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
      return func(*args, **kwargs)
    File "/lustre1/tier2/users/phuc.lekhac/torchtitan/torchtitan/models/llama/model.py", line 250, in _init_flex_attn
      self.block_mask = compiled_create_block_mask(
    File "/lustre1/tier2/users/phuc.lekhac/mambaforge/envs/titan/lib/python3.12/site-packages/torch/nn/attention/flex_attention.py", line 864, in create_block_mask
      mod_type = _get_mod_type(mask_mod)
    File "/lustre1/tier2/users/phuc.lekhac/mambaforge/envs/titan/lib/python3.12/site-packages/torch/nn/attention/flex_attention.py", line 62, in _get_mod_type
      for param in inspect.signature(fn).parameters.values()
    File "/lustre1/tier2/users/phuc.lekhac/mambaforge/envs/titan/lib/python3.12/inspect.py", line 3345, in signature
      return Signature.from_callable(obj, follow_wrapped=follow_wrapped,
    File "/lustre1/tier2/users/phuc.lekhac/mambaforge/envs/titan/lib/python3.12/inspect.py", line 3085, in from_callable
      return _signature_from_callable(obj, sigcls=cls,
    File "/lustre1/tier2/users/phuc.lekhac/mambaforge/envs/titan/lib/python3.12/inspect.py", line 2538, in _signature_from_callable
      obj = unwrap(obj, stop=(lambda f: hasattr(f, "__signature__")
    File "/lustre1/tier2/users/phuc.lekhac/mambaforge/envs/titan/lib/python3.12/inspect.py", line 773, in unwrap
      memo = {id(f): f}

Versions

  • torchtitan: main
  • torch: nightly
@tianyu-l
Copy link
Contributor

cc @fegin @drisspg for help

@fegin
Copy link
Contributor

fegin commented Apr 14, 2025

I used the same command with the latest TorchTitan and PyTorch for llama3 and could not reproduce the issue. Same as llama4. @lkhphuc can you try again?

@lkhphuc
Copy link
Contributor Author

lkhphuc commented Apr 23, 2025

Confirmed it work on current main now, for both no AC and full AC. Thanks.

@lkhphuc lkhphuc closed this as completed Apr 23, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants