Skip to content

[BUG] State_dict error when packaging/compiling fine-tuned models #572

@kavanase

Description

@kavanase

When trying to nequip-compile or nequip-package a checkpoint from a fine-tuning run (which used ModelFromPackage, and where that package file came from an earlier nequip version), I am getting this error message:

(pytorch_2.7) Perlmutter: > nequip-package build ckpts/nequip_MPA/*-0.ckpt test.nequip.zip
[2025-11-16 10:19:41,024][nequip.utils.versions.package_versions][INFO] - [rank: 0] Version Information:
[2025-11-16 10:19:41,025][nequip.utils.versions.package_versions][INFO] - [rank: 0] torch 2.7.0+cu128
[2025-11-16 10:19:41,025][nequip.utils.versions.package_versions][INFO] - [rank: 0] e3nn 0.5.6
[2025-11-16 10:19:41,025][nequip.utils.versions.package_versions][INFO] - [rank: 0] nequip 0.15.0
[2025-11-16 10:19:41,025][nequip.utils.versions.package_versions][INFO] - [rank: 0] talaria 0.1.0
[2025-11-16 10:19:41,025][nequip.utils.versions.package_versions][INFO] - [rank: 0] allegro 0.7.1
[2025-11-16 10:19:41,047][nequip.scripts.package][INFO] - [rank: 0] Building `eager` model for packaging ...
[2025-11-16 10:19:41,047][nequip.model.saved_models.load_utils][INFO] - [rank: 0] Loading model from ckpts/nequip_MPA/weighted_metric_0.0533-0.ckpt ...
[2025-11-16 10:19:41,047][nequip.model.saved_models.checkpoint][INFO] - [rank: 0] Loading model from checkpoint file: ckpts/nequip_MPA/weighted_metric_0.0533-0.ckpt ...
[2025-11-16 10:19:42,976][nequip.model.saved_models.package][INFO] - [rank: 0] Loading model from package file: XL_rtol_0.1_netF_l4_32_WM_0.0690-11.nequip.zip ...
Traceback (most recent call last):
  File "/global/homes/k/kavanase/miniconda3/envs/pytorch_2.7/bin/nequip-package", line 8, in <module>
    sys.exit(main())
             ^^^^^^
  File "/global/u2/k/kavanase/Packages/nequip/nequip/scripts/package.py", line 208, in main
    eager_model, data = load_saved_model(
                        ^^^^^^^^^^^^^^^^^
  File "/global/u2/k/kavanase/Packages/nequip/nequip/model/saved_models/load_utils.py", line 126, in load_saved_model
    model = ModelFromCheckpoint(
            ^^^^^^^^^^^^^^^^^^^^
  File "/global/u2/k/kavanase/Packages/nequip/nequip/model/saved_models/checkpoint.py", line 78, in ModelFromCheckpoint
    lightning_module = training_module.load_from_checkpoint(checkpoint_path)
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/global/homes/k/kavanase/miniconda3/envs/pytorch_2.7/lib/python3.12/site-packages/lightning/pytorch/utilities/model_helpers.py", line 125, in wrapper
    return self.method(cls, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/global/homes/k/kavanase/miniconda3/envs/pytorch_2.7/lib/python3.12/site-packages/lightning/pytorch/core/module.py", line 1581, in load_from_checkpoint
    loaded = _load_from_checkpoint(
             ^^^^^^^^^^^^^^^^^^^^^^
  File "/global/homes/k/kavanase/miniconda3/envs/pytorch_2.7/lib/python3.12/site-packages/lightning/pytorch/core/saving.py", line 91, in _load_from_checkpoint
    model = _load_state(cls, checkpoint, strict=strict, **kwargs)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/global/homes/k/kavanase/miniconda3/envs/pytorch_2.7/lib/python3.12/site-packages/lightning/pytorch/core/saving.py", line 187, in _load_state
    keys = obj.load_state_dict(checkpoint["state_dict"], strict=strict)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/global/homes/k/kavanase/miniconda3/envs/pytorch_2.7/lib/python3.12/site-packages/torch/nn/modules/module.py", line 2593, in load_state_dict
    raise RuntimeError(
RuntimeError: Error(s) in loading state_dict for EMALightningModule:
	Missing key(s) in state_dict: "model.sole_model.model.func.layer1_convnet.conv.tp_scatter.tp._compiled_main_left_right._tensor_constant0", "model.sole_model.model.func.layer1_convnet.conv.tp_scatter.tp._compiled_main_left_right._tensor_constant1", "model.sole_model.model.func.layer1_convnet.conv.tp_scatter.tp._compiled_main_left_right._tensor_constant2", "model.sole_model.model.func.layer1_convnet.conv.tp_scatter.tp._compiled_main_left_right._tensor_constant3", "model.sole_model.model.func.layer1_convnet.conv.tp_scatter.tp._compiled_main_left_right._tensor_constant4", "model.sole_model.model.func.layer1_convnet.conv.tp_scatter.tp._compiled_main_left_right._tensor_constant5", "model.sole_model.model.func.layer1_convnet.conv.tp_scatter.tp._compiled_main_left_right._tensor_constant6", "model.sole_model.model.func.layer1_convnet.conv.tp_scatter.tp._compiled_main_left_right._tensor_constant7", "model.sole_model.model.func.layer1_convnet.conv.tp_scatter.tp._compiled_main_left_right._tensor_constant8", "model.sole_model.model.func.layer1_convnet.conv.tp_scatter.tp._compiled_main_left_right._tensor_constant9", "model.sole_model.model.func.layer1_convnet.conv.tp_scatter.tp._compiled_main_left_right._tensor_constant10", "model.sole_model.model.func.layer1_convnet.conv.tp_scatter.tp._compiled_main_left_right._tensor_constant11", "model.sole_model.model.func.layer1_convnet.conv.tp_scatter.tp._compiled_main_left_right._tensor_constant12", "model.sole_model.model.func.layer2_convnet.conv.tp_scatter.tp._compiled_main_left_right._tensor_constant0", "model.sole_model.model.func.layer2_convnet.conv.tp_scatter.tp._compiled_main_left_right._tensor_constant1", "model.sole_model.model.func.layer2_convnet.conv.tp_scatter.tp._compiled_main_left_right._tensor_constant2", "model.sole_model.model.func.layer2_convnet.conv.tp_scatter.tp._compiled_main_left_right._tensor_constant3", "model.sole_model.model.func.layer2_convnet.conv.tp_scatter.tp._compiled_main_left_right._tensor_constant4", "model.sole_model.model.func.layer2_convnet.conv.tp_scatter.tp._compiled_main_left_right._tensor_constant5", "model.sole_model.model.func.layer2_convnet.conv.tp_scatter.tp._compiled_main_left_right._tensor_constant6", "model.sole_model.model.func.layer2_convnet.conv.tp_scatter.tp._compiled_main_left_right._tensor_constant7", "model.sole_model.model.func.layer2_convnet.conv.tp_scatter.tp._compiled_main_left_right._tensor_constant8", "model.sole_model.model.func.layer2_convnet.conv.tp_scatter.tp._compiled_main_left_right._tensor_constant9", "model.sole_model.model.func.layer2_convnet.conv.tp_scatter.tp._compiled_main_left_right._tensor_constant10", "model.sole_model.model.func.layer2_convnet.conv.tp_scatter.tp._compiled_main_left_right._tensor_constant11", "model.sole_model.model.func.layer2_convnet.conv.tp_scatter.tp._compiled_main_left_right._tensor_constant12", "model.sole_model.model.func.layer3_convnet.conv.tp_scatter.tp._compiled_main_left_right._tensor_constant0", "model.sole_model.model.func.layer3_convnet.conv.tp_scatter.tp._compiled_main_left_right._tensor_constant1", "model.sole_model.model.func.layer3_convnet.conv.tp_scatter.tp._compiled_main_left_right._tensor_constant2", "model.sole_model.model.func.layer3_convnet.conv.tp_scatter.tp._compiled_main_left_right._tensor_constant3", "model.sole_model.model.func.layer3_convnet.conv.tp_scatter.tp._compiled_main_left_right._tensor_constant4", "model.sole_model.model.func.layer3_convnet.conv.tp_scatter.tp._compiled_main_left_right._tensor_constant5", "model.sole_model.model.func.layer3_convnet.conv.tp_scatter.tp._compiled_main_left_right._tensor_constant6", "model.sole_model.model.func.layer3_convnet.conv.tp_scatter.tp._compiled_main_left_right._tensor_constant7", "model.sole_model.model.func.layer3_convnet.conv.tp_scatter.tp._compiled_main_left_right._tensor_constant8", "model.sole_model.model.func.layer3_convnet.conv.tp_scatter.tp._compiled_main_left_right._tensor_constant9", "model.sole_model.model.func.layer3_convnet.conv.tp_scatter.tp._compiled_main_left_right._tensor_constant10", "model.sole_model.model.func.layer3_convnet.conv.tp_scatter.tp._compiled_main_left_right._tensor_constant11", "model.sole_model.model.func.layer3_convnet.conv.tp_scatter.tp._compiled_main_left_right._tensor_constant12", "model.sole_model.model.func.layer4_convnet.conv.tp_scatter.tp._compiled_main_left_right._tensor_constant0", "model.sole_model.model.func.layer4_convnet.conv.tp_scatter.tp._compiled_main_left_right._tensor_constant1", "model.sole_model.model.func.layer4_convnet.conv.tp_scatter.tp._compiled_main_left_right._tensor_constant2", "model.sole_model.model.func.layer4_convnet.conv.tp_scatter.tp._compiled_main_left_right._tensor_constant3", "model.sole_model.model.func.layer4_convnet.conv.tp_scatter.tp._compiled_main_left_right._tensor_constant4", "model.sole_model.model.func.layer4_convnet.conv.tp_scatter.tp._compiled_main_left_right._tensor_constant5", "model.sole_model.model.func.layer4_convnet.conv.tp_scatter.tp._compiled_main_left_right._tensor_constant6", "model.sole_model.model.func.layer4_convnet.conv.tp_scatter.tp._compiled_main_left_right._tensor_constant7", "model.sole_model.model.func.layer4_convnet.conv.tp_scatter.tp._compiled_main_left_right._tensor_constant8", "model.sole_model.model.func.layer4_convnet.conv.tp_scatter.tp._compiled_main_left_right._tensor_constant9", "model.sole_model.model.func.layer4_convnet.conv.tp_scatter.tp._compiled_main_left_right._tensor_constant10", "model.sole_model.model.func.layer4_convnet.conv.tp_scatter.tp._compiled_main_left_right._tensor_constant11", "model.sole_model.model.func.layer4_convnet.conv.tp_scatter.tp._compiled_main_left_right._tensor_constant12".

As mentioned, the original model package file is from an earlier nequip version (still 0.15.0, but before some recent commits which changed model structure). It trained fine with ModelFromPackage, but compilation/packaging of the output checkpoints from this fine-tuning run are failing here.

Environment:

  • OS: Linux
  • See output above for package versions

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions