Skip to content

Support extra_state attributes in from_pretrained #38154

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
2 of 4 tasks
pstjohn opened this issue May 15, 2025 · 1 comment · May be fixed by #38155
Open
2 of 4 tasks

Support extra_state attributes in from_pretrained #38154

pstjohn opened this issue May 15, 2025 · 1 comment · May be fixed by #38155
Labels

Comments

@pstjohn
Copy link
Contributor

pstjohn commented May 15, 2025

System Info

transformers main branch, python 3.12.

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Using TransformerEngine layers as an example, which add fp8 metadata to the _extra_state key:

from transformers import PretrainedConfig, PreTrainedModel
from transformer_engine.pytorch import TransformerLayer


class SimpleTEConfig(PretrainedConfig):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.hidden_size = 320
        self.intermediate_size = 1024
        self.num_attention_heads = 16


class SimpleTEModel(PreTrainedModel):
    config_class = SimpleTEConfig

    def __init__(self, config: SimpleTEConfig):
        super().__init__(config)
        self.te_layer = TransformerLayer(
            hidden_size=config.hidden_size,
            ffn_hidden_size=config.intermediate_size,
            num_attention_heads=config.num_attention_heads,
        )

    def forward(self, hidden_states, attention_mask):
        return self.te_layer(hidden_states, attention_mask)


def test_simple_te_model(tmp_path):
    config = SimpleTEConfig()
    model = SimpleTEModel(config)

    model.save_pretrained(tmp_path / "simple_te_model")
    del model
    model = SimpleTEModel.from_pretrained(tmp_path / "simple_te_model")
    assert isinstance(model.te_layer, TransformerLayer)

Expected behavior

from_pretrained should pass the deserialized extra_state value to the nn.Module's from_state_dict method; which will then call into set_extra_state. https://docs.pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.get_extra_state.

Instead, the loading fails on get_parameter_or_buffer:

>       raise AttributeError(f"`{target}` is neither a parameter nor a buffer.")
E       AttributeError: `te_layer.layernorm_mlp._extra_state` is neither a parameter nor a buffer.
@Rocketknight1
Copy link
Member

Flagging @Cyrilvallez who worked on that file recently! See also the PR at #38155

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants