Skip to content

Commit b84401e

Browse files
committed
make data transforms nn.Modules for device assignment
1 parent f65488d commit b84401e

File tree

3 files changed

+27
-19
lines changed

3 files changed

+27
-19
lines changed

nequip/data/transforms/neighborlist.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from typing import Optional, Dict, Union, List
66

77

8-
class NeighborListTransform:
8+
class NeighborListTransform(torch.nn.Module):
99
"""Constructs a neighborlist and adds it to the ``AtomicDataDict``.
1010
1111
Args:
@@ -23,6 +23,8 @@ def __init__(
2323
type_names: Optional[List[str]] = None,
2424
**kwargs,
2525
):
26+
super().__init__()
27+
2628
self.r_max = r_max
2729
self.type_names = type_names
2830
self.per_edge_type_cutoff = per_edge_type_cutoff
@@ -40,7 +42,7 @@ def __init__(
4042
type_names=type_names,
4143
)
4244

43-
def __call__(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type:
45+
def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type:
4446
data = compute_neighborlist_(data, self.r_max, **self.kwargs)
4547

4648
# prune based on per-edge-type cutoffs if specified
@@ -50,7 +52,7 @@ def __call__(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type:
5052
return data
5153

5254

53-
class NeighborListPruneTransform:
55+
class NeighborListPruneTransform(torch.nn.Module):
5456
"""Prunes a neighborlist based on per-edge-type cutoffs.
5557
5658
Args:
@@ -65,6 +67,8 @@ def __init__(
6567
per_edge_type_cutoff: Dict[str, Union[float, Dict[str, float]]],
6668
type_names: List[str],
6769
):
70+
super().__init__()
71+
6872
self.r_max = r_max
6973
self.per_edge_type_cutoff = per_edge_type_cutoff
7074
self.type_names = type_names
@@ -78,7 +82,7 @@ def __init__(
7882
per_edge_type_cutoff=per_edge_type_cutoff,
7983
)
8084

81-
def __call__(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type:
85+
def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type:
8286
"""Prune neighbor list based on per-edge-type cutoffs."""
8387

8488
if AtomicDataDict.ATOM_TYPE_KEY not in data:
@@ -114,9 +118,9 @@ def __call__(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type:
114118
class SortedNeighborListTransform(NeighborListTransform):
115119
"""Behaves like :class:`NeighborListTransform` but additionally sorts the neighborlist and provides transpose permutation indices."""
116120

117-
def __call__(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type:
121+
def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type:
118122
# first compute the basic neighborlist
119-
data = super().__call__(data)
123+
data = super().forward(data)
120124

121125
# sort the edge index and corresponding edge attributes
122126
edge_idxs = data[AtomicDataDict.EDGE_INDEX_KEY]

nequip/data/transforms/stress_utils.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from nequip.data import AtomicDataDict
44

55

6-
class VirialToStressTransform:
6+
class VirialToStressTransform(torch.nn.Module):
77
r"""Converts virials to stress and adds the stress to the ``AtomicDataDict``.
88
99
Specifically implements
@@ -15,9 +15,9 @@ class VirialToStressTransform:
1515
"""
1616

1717
def __init__(self):
18-
pass
18+
super().__init__()
1919

20-
def __call__(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type:
20+
def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type:
2121
# see discussion in https://github.com/libAtoms/QUIP/issues/227 about sign convention
2222
# they say the standard convention is virial = -stress x volume
2323
# we assume that the AtomicDataDict contains virials
@@ -29,23 +29,23 @@ def __call__(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type:
2929
return data
3030

3131

32-
class StressSignFlipTransform:
32+
class StressSignFlipTransform(torch.nn.Module):
3333
r"""Flips the sign of stress in the ``AtomicDataDict``.
3434
3535
In the NequIP convention, positive diagonal components of the stress tensor implies that the system is under tensile strain and wants to compress, while a negative value implies that the system is under compressive strain and wants to expand.
3636
This transform can be applied to datasets that follow the opposite sign convention, so that the necessary sign flip happens on-the-fly during training and users can avoid having to generate a copy of the dataset with NequIP stress sign conventions.
3737
"""
3838

3939
def __init__(self):
40-
pass
40+
super().__init__()
4141

42-
def __call__(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type:
42+
def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type:
4343
# see discussion in https://github.com/libAtoms/QUIP/issues/227 about sign convention
4444
data[AtomicDataDict.STRESS_KEY] = data[AtomicDataDict.STRESS_KEY].neg()
4545
return data
4646

4747

48-
class AddNaNStressTransform:
48+
class AddNaNStressTransform(torch.nn.Module):
4949
"""Add NaN stress tensors for structures without stress data.
5050
5151
Useful for datasets where stresses are not available for all structures.
@@ -54,9 +54,9 @@ class AddNaNStressTransform:
5454
"""
5555

5656
def __init__(self):
57-
pass
57+
super().__init__()
5858

59-
def __call__(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type:
59+
def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type:
6060
# only add if stress is not already present
6161
if AtomicDataDict.STRESS_KEY not in data:
6262
num_frames = AtomicDataDict.num_frames(data)

nequip/data/transforms/type_mapper.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from typing import List, Dict, Optional
88

99

10-
class ChemicalSpeciesToAtomTypeMapper:
10+
class ChemicalSpeciesToAtomTypeMapper(torch.nn.Module):
1111
"""Maps atomic numbers to atom types and adds the atom types to the ``AtomicDataDict``.
1212
1313
This transform accounts for how the atom types seen by the model can be different from the atomic species that one obtains from a conventional dataset. There could be cases where the same chemical species corresponds to multiple atom types, e.g. different charge states.
@@ -23,6 +23,8 @@ def __init__(
2323
chemical_species_to_atom_type_map: Optional[Dict[str, str]] = None,
2424
chemical_symbols: Optional[List[str]] = None,
2525
):
26+
super().__init__()
27+
2628
# TODO: eventually remove all this logic
2729
# error out with deprecated API usage
2830
if chemical_symbols is not None:
@@ -68,7 +70,7 @@ def __init__(
6870
)
6971

7072
# make a lookup table mapping atomic numbers to 0-based model type indexes
71-
self.lookup_table = torch.full(
73+
lookup_table = torch.full(
7274
(max(ase.data.atomic_numbers.values()),), -1, dtype=torch.long
7375
)
7476

@@ -78,9 +80,11 @@ def __init__(
7880
raise ValueError(f"Invalid chemical symbol '{chem_symbol}'")
7981
atomic_num = ase.data.atomic_numbers[chem_symbol]
8082
type_idx = type_name_to_index[atom_type_name]
81-
self.lookup_table[atomic_num] = type_idx
83+
lookup_table[atomic_num] = type_idx
84+
85+
self.register_buffer("lookup_table", lookup_table)
8286

83-
def __call__(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type:
87+
def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type:
8488
if AtomicDataDict.ATOM_TYPE_KEY in data:
8589
raise RuntimeError(f"Data already contains {AtomicDataDict.ATOM_TYPE_KEY}")
8690

0 commit comments

Comments
 (0)