Skip to content

Unable to run flex attention and torch.compile #1005

Closed
@lkhphuc

Description

@lkhphuc

Bug description

CONFIG_FILE=./torchtitan/models/llama/train_configs/debug_model.toml ./run_train.sh --model.use_fl ex_attn --training.compile

I have a long stack trace with TORCHDYNAMO_VERBOSE=1:

    File "/lustre1/tier2/users/phuc.lekhac/mambaforge/envs/titan/lib/python3.12/site-packages/torch/_dynamo/variables/builtin.py", line 2108, in call_id
      unimplemented(f"call_id with args {args}")
    File "/lustre1/tier2/users/phuc.lekhac/mambaforge/envs/titan/lib/python3.12/site-packages/torch/_dynamo/exc.py", line 439, in unimplemented
      raise Unsupported(msg, case_name=case_name)
  torch._dynamo.exc.Unsupported: call_id with args (NestedUserFunctionVariable(),)
  
  from user code:
     File "/lustre1/tier2/users/phuc.lekhac/mambaforge/envs/titan/lib/python3.12/site-packages/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py", line 171, in forwa
rd
      return self.checkpoint_fn(  # type: ignore[misc]
    File "/lustre1/tier2/users/phuc.lekhac/torchtitan/torchtitan/models/llama/model.py", line 364, in forward
      h = x + self.attention(self.attention_norm(x), freqs_cis)
    File "/lustre1/tier2/users/phuc.lekhac/torchtitan/torchtitan/models/llama/model.py", line 231, in forward
      self._init_flex_attn(seqlen=seqlen)
    File "/lustre1/tier2/users/phuc.lekhac/mambaforge/envs/titan/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
      return func(*args, **kwargs)
    File "/lustre1/tier2/users/phuc.lekhac/torchtitan/torchtitan/models/llama/model.py", line 250, in _init_flex_attn
      self.block_mask = compiled_create_block_mask(
    File "/lustre1/tier2/users/phuc.lekhac/mambaforge/envs/titan/lib/python3.12/site-packages/torch/nn/attention/flex_attention.py", line 864, in create_block_mask
      mod_type = _get_mod_type(mask_mod)
    File "/lustre1/tier2/users/phuc.lekhac/mambaforge/envs/titan/lib/python3.12/site-packages/torch/nn/attention/flex_attention.py", line 62, in _get_mod_type
      for param in inspect.signature(fn).parameters.values()
    File "/lustre1/tier2/users/phuc.lekhac/mambaforge/envs/titan/lib/python3.12/inspect.py", line 3345, in signature
      return Signature.from_callable(obj, follow_wrapped=follow_wrapped,
    File "/lustre1/tier2/users/phuc.lekhac/mambaforge/envs/titan/lib/python3.12/inspect.py", line 3085, in from_callable
      return _signature_from_callable(obj, sigcls=cls,
    File "/lustre1/tier2/users/phuc.lekhac/mambaforge/envs/titan/lib/python3.12/inspect.py", line 2538, in _signature_from_callable
      obj = unwrap(obj, stop=(lambda f: hasattr(f, "__signature__")
    File "/lustre1/tier2/users/phuc.lekhac/mambaforge/envs/titan/lib/python3.12/inspect.py", line 773, in unwrap
      memo = {id(f): f}

Versions

  • torchtitan: main
  • torch: nightly

Metadata

Metadata

Assignees

Type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions