Skip to content

Commit 2b8cee3

Browse files
orionarchercw-tan
authored andcommitted
Remove StateDict from torchsim integration
torch_sim.typing.StateDict has been removed from torch-sim. Update the torchsim integration to accept only ts.SimState, dropping the isinstance check and the dict-to-SimState fallback.
1 parent bac8d1e commit 2b8cee3

1 file changed

Lines changed: 3 additions & 8 deletions

File tree

nequip/integrations/torchsim.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44

55
import torch_sim as ts
66
from torch_sim.models.interface import ModelInterface
7-
from torch_sim.typing import StateDict
87

98
from nequip.data import AtomicDataDict
109
from nequip.data._nl import NEIGHBORLIST_BACKEND_ALCHEMIOPS
@@ -134,21 +133,17 @@ def setup_from_system_idx(
134133
self.n_systems = system_idx.max().item() + 1
135134
self.total_atoms = atomic_numbers.shape[0]
136135

137-
def forward(self, state: ts.SimState | StateDict) -> dict[str, torch.Tensor]: # noqa: C901
136+
def forward(self, state: ts.SimState) -> dict[str, torch.Tensor]: # noqa: C901
138137
"""Compute energies, forces, and stresses.
139138
140139
Args:
141-
state (:class:`~torch_sim.SimState` | :class:`~torch_sim.typing.StateDict`): state object containing positions, cell,
140+
state (:class:`~torch_sim.SimState`): state object containing positions, cell,
142141
and system information.
143142
144143
Returns:
145144
dict[str, :class:`torch.Tensor`]: computed properties (``"energy"``, ``"forces"``, ``"stress"``).
146145
"""
147-
sim_state = (
148-
state
149-
if isinstance(state, ts.SimState)
150-
else ts.SimState(**state, masses=torch.ones_like(state["positions"]))
151-
)
146+
sim_state = state
152147

153148
# handle input validation for atomic numbers
154149
if sim_state.atomic_numbers is None and not self.atomic_numbers_in_init:

0 commit comments

Comments
 (0)