Skip to content

Commit

Permalink
fix: refine code style
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiminzhang0830 committed Dec 29, 2024
1 parent d5bea62 commit 4ecad77
Show file tree
Hide file tree
Showing 24 changed files with 2,199 additions and 3,654 deletions.
11 changes: 7 additions & 4 deletions jointContribution/CHGNet/chgnet/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from __future__ import annotations

import os
from importlib.metadata import PackageNotFoundError, version
from importlib.metadata import PackageNotFoundError
from importlib.metadata import version
from typing import Literal

try:
__version__ = version(__name__)
except PackageNotFoundError:
__version__ = 'unknown'
TrainTask = Literal['ef', 'efs', 'efsm']
PredTask = Literal['e', 'ef', 'em', 'efs', 'efsm']
__version__ = "unknown"
TrainTask = Literal["ef", "efs", "efsm"]
PredTask = Literal["e", "ef", "em", "efs", "efsm"]
ROOT = os.path.dirname(os.path.dirname(__file__))
566 changes: 352 additions & 214 deletions jointContribution/CHGNet/chgnet/data/dataset.py

Large diffs are not rendered by default.

5 changes: 3 additions & 2 deletions jointContribution/CHGNet/chgnet/graph/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from __future__ import annotations
from chgnet.graph.converter import CrystalGraphConverter
from chgnet.graph.crystalgraph import CrystalGraph

from chgnet.graph.converter import CrystalGraphConverter # noqa
from chgnet.graph.crystalgraph import CrystalGraph # noqa
193 changes: 113 additions & 80 deletions jointContribution/CHGNet/chgnet/graph/converter.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,46 @@
from __future__ import annotations
import paddle

import gc
import sys
import warnings
from typing import TYPE_CHECKING

import numpy as np
import paddle
from chgnet.graph.crystalgraph import CrystalGraph
from chgnet.graph.graph import Graph, Node
from chgnet.graph.graph import Graph
from chgnet.graph.graph import Node

if TYPE_CHECKING:
from typing import Literal

from pymatgen.core import Structure
from typing_extensions import Self
try:
from chgnet.graph.cygraph import make_graph
except (ImportError, AttributeError):
make_graph = None
DTYPE = 'float32'
# try:
# from chgnet.graph.cygraph import make_graph
# except (ImportError, AttributeError):
# make_graph = None
make_graph = None
DTYPE = "float32"


class CrystalGraphConverter(paddle.nn.Layer):
"""Convert a pymatgen.core.Structure to a CrystalGraph
The CrystalGraph dataclass stores essential field to make sure that
gradients like force and stress can be calculated through back-propagation later.
"""

make_graph = None

def __init__(self, *, atom_graph_cutoff: float=6, bond_graph_cutoff:
float=3, algorithm: Literal['legacy', 'fast']='fast',
on_isolated_atoms: Literal['ignore', 'warn', 'error']='error',
verbose: bool=False) ->None:
def __init__(
self,
*,
atom_graph_cutoff: float = 6,
bond_graph_cutoff: float = 3,
algorithm: Literal["legacy", "fast"] = "fast",
on_isolated_atoms: Literal["ignore", "warn", "error"] = "error",
verbose: bool = False,
) -> None:
"""Initialize the Crystal Graph Converter.
Args:
Expand All @@ -49,37 +61,40 @@ def __init__(self, *, atom_graph_cutoff: float=6, bond_graph_cutoff:
"""
super().__init__()
self.atom_graph_cutoff = atom_graph_cutoff
self.bond_graph_cutoff = (atom_graph_cutoff if bond_graph_cutoff is
None else bond_graph_cutoff)
self.bond_graph_cutoff = (
atom_graph_cutoff if bond_graph_cutoff is None else bond_graph_cutoff
)
self.on_isolated_atoms = on_isolated_atoms
self.create_graph = self._create_graph_legacy
self.algorithm = 'legacy'
if algorithm == 'fast':
self.algorithm = "legacy"
if algorithm == "fast":
if make_graph is not None:
self.create_graph = self._create_graph_fast
self.algorithm = 'fast'
self.algorithm = "fast"
else:
warnings.warn(
'`fast` algorithm is not available, using `legacy`',
UserWarning, stacklevel=1)
elif algorithm != 'legacy':
warnings.warn(f'Unknown algorithm={algorithm!r}, using `legacy`',
UserWarning, stacklevel=1)
"`fast` algorithm is not available, using `legacy`",
UserWarning,
stacklevel=1,
)
elif algorithm != "legacy":
warnings.warn(
f"Unknown algorithm={algorithm!r}, using `legacy`",
UserWarning,
stacklevel=1,
)
if verbose:
print(self)

def __repr__(self) ->str:
def __repr__(self) -> str:
"""String representation of the CrystalGraphConverter."""
atom_graph_cutoff = self.atom_graph_cutoff
bond_graph_cutoff = self.bond_graph_cutoff
algorithm = self.algorithm
cls_name = type(self).__name__
return (
f'{cls_name}(algorithm={algorithm!r}, atom_graph_cutoff={atom_graph_cutoff!r}, bond_graph_cutoff={bond_graph_cutoff!r})'
)
return f"{cls_name}(algorithm={algorithm!r}, atom_graph_cutoff={atom_graph_cutoff!r}, bond_graph_cutoff={bond_graph_cutoff!r})"

def forward(self, structure: Structure, graph_id=None, mp_id=None
) ->CrystalGraph:
def forward(self, structure: Structure, graph_id=None, mp_id=None) -> CrystalGraph:
"""Convert a structure, return a CrystalGraph.
Args:
Expand All @@ -93,55 +108,66 @@ def forward(self, structure: Structure, graph_id=None, mp_id=None
CrystalGraph that is ready to use by CHGNet
"""
n_atoms = len(structure)
data=[site.specie.Z for site in structure]
atomic_number = paddle.to_tensor(data, dtype='int32', stop_gradient=not False)
atom_frac_coord = paddle.to_tensor(data=structure.frac_coords,
dtype=DTYPE, stop_gradient=not True)
lattice = paddle.to_tensor(data=structure.lattice.matrix, dtype=
DTYPE, stop_gradient=not True)
center_index, neighbor_index, image, distance = (structure.
get_neighbor_list(r=self.atom_graph_cutoff, sites=structure.
sites, numerical_tol=1e-08))
graph = self.create_graph(n_atoms, center_index, neighbor_index,
image, distance)
data = [site.specie.Z for site in structure]
atomic_number = paddle.to_tensor(data, dtype="int32", stop_gradient=not False)
atom_frac_coord = paddle.to_tensor(
data=structure.frac_coords, dtype=DTYPE, stop_gradient=not True
)
lattice = paddle.to_tensor(
data=structure.lattice.matrix, dtype=DTYPE, stop_gradient=not True
)
center_index, neighbor_index, image, distance = structure.get_neighbor_list(
r=self.atom_graph_cutoff, sites=structure.sites, numerical_tol=1e-08
)
graph = self.create_graph(
n_atoms, center_index, neighbor_index, image, distance
)
atom_graph, directed2undirected = graph.adjacency_list()
atom_graph = paddle.to_tensor(data=atom_graph, dtype='int32')
directed2undirected = paddle.to_tensor(data=directed2undirected,
dtype='int32')
atom_graph = paddle.to_tensor(data=atom_graph, dtype="int32")
directed2undirected = paddle.to_tensor(data=directed2undirected, dtype="int32")
try:
bond_graph, undirected2directed = graph.line_graph_adjacency_list(
cutoff=self.bond_graph_cutoff)
cutoff=self.bond_graph_cutoff
)
except Exception as exc:
structure.to(filename='bond_graph_error.cif')
structure.to(filename="bond_graph_error.cif")
raise RuntimeError(
f'Failed creating bond graph for {graph_id}, check bond_graph_error.cif'
) from exc
bond_graph = paddle.to_tensor(data=bond_graph, dtype='int32')
undirected2directed = paddle.to_tensor(data=undirected2directed,
dtype='int32')
f"Failed creating bond graph for {graph_id}, check bond_graph_error.cif"
) from exc
bond_graph = paddle.to_tensor(data=bond_graph, dtype="int32")
undirected2directed = paddle.to_tensor(data=undirected2directed, dtype="int32")
n_isolated_atoms = len({*range(n_atoms)} - {*center_index})
if n_isolated_atoms:
atom_graph_cutoff = self.atom_graph_cutoff
msg = (
f'Structure graph_id={graph_id!r} has {n_isolated_atoms} isolated atom(s) with atom_graph_cutoff={atom_graph_cutoff!r}. CHGNet calculation will likely go wrong'
)
if self.on_isolated_atoms == 'error':
msg = f"Structure graph_id={graph_id!r} has {n_isolated_atoms} isolated atom(s) with atom_graph_cutoff={atom_graph_cutoff!r}. CHGNet calculation will likely go wrong"
if self.on_isolated_atoms == "error":
raise ValueError(msg)
elif self.on_isolated_atoms == 'warn':
elif self.on_isolated_atoms == "warn":
print(msg, file=sys.stderr)
return CrystalGraph(atomic_number=atomic_number, atom_frac_coord=
atom_frac_coord, atom_graph=atom_graph, neighbor_image=paddle.
to_tensor(data=image, dtype=DTYPE), directed2undirected=
directed2undirected, undirected2directed=undirected2directed,
bond_graph=bond_graph, lattice=lattice, graph_id=graph_id,
mp_id=mp_id, composition=structure.composition.formula,
atom_graph_cutoff=self.atom_graph_cutoff, bond_graph_cutoff=
self.bond_graph_cutoff)
return CrystalGraph(
atomic_number=atomic_number,
atom_frac_coord=atom_frac_coord,
atom_graph=atom_graph,
neighbor_image=paddle.to_tensor(data=image, dtype=DTYPE),
directed2undirected=directed2undirected,
undirected2directed=undirected2directed,
bond_graph=bond_graph,
lattice=lattice,
graph_id=graph_id,
mp_id=mp_id,
composition=structure.composition.formula,
atom_graph_cutoff=self.atom_graph_cutoff,
bond_graph_cutoff=self.bond_graph_cutoff,
)

@staticmethod
def _create_graph_legacy(n_atoms: int, center_index: np.ndarray,
neighbor_index: np.ndarray, image: np.ndarray, distance: np.ndarray
) ->Graph:
def _create_graph_legacy(
n_atoms: int,
center_index: np.ndarray,
neighbor_index: np.ndarray,
image: np.ndarray,
distance: np.ndarray,
) -> Graph:
"""Given structure information, create a Graph structure to be used to
create Crystal_Graph using pure python implementation.
Expand All @@ -160,16 +186,20 @@ def _create_graph_legacy(n_atoms: int, center_index: np.ndarray,
Graph data structure used to create Crystal_Graph object
"""
graph = Graph([Node(index=idx) for idx in range(n_atoms)])
for ii, jj, img, dist in zip(center_index, neighbor_index, image,
distance, strict=True):
graph.add_edge(center_index=ii, neighbor_index=jj, image=img,
distance=dist)
for ii, jj, img, dist in zip(
center_index, neighbor_index, image, distance, strict=True
):
graph.add_edge(center_index=ii, neighbor_index=jj, image=img, distance=dist)
return graph

@staticmethod
def _create_graph_fast(n_atoms: int, center_index: np.ndarray,
neighbor_index: np.ndarray, image: np.ndarray, distance: np.ndarray
) ->Graph:
def _create_graph_fast(
n_atoms: int,
center_index: np.ndarray,
neighbor_index: np.ndarray,
image: np.ndarray,
distance: np.ndarray,
) -> Graph:
"""Given structure information, create a Graph structure to be used to
create Crystal_Graph using C implementation.
Expand Down Expand Up @@ -197,17 +227,18 @@ def _create_graph_fast(n_atoms: int, center_index: np.ndarray,
gc_saved = gc.get_threshold()
gc.set_threshold(0)
nodes, dir_edges_list, undir_edges_list, undirected_edges = make_graph(
center_index, len(center_index), neighbor_index, image,
distance, n_atoms)
center_index, len(center_index), neighbor_index, image, distance, n_atoms
)
graph = Graph(nodes=nodes)
graph.directed_edges_list = dir_edges_list
graph.undirected_edges_list = undir_edges_list
graph.undirected_edges = undirected_edges
gc.set_threshold(gc_saved[0])
return graph

def set_isolated_atom_response(self, on_isolated_atoms: Literal[
'ignore', 'warn', 'error']) ->None:
def set_isolated_atom_response(
self, on_isolated_atoms: Literal["ignore", "warn", "error"]
) -> None:
"""Set the graph converter's response to isolated atom graph
Args:
on_isolated_atoms ('ignore' | 'warn' | 'error'): how to handle Structures
Expand All @@ -219,13 +250,15 @@ def set_isolated_atom_response(self, on_isolated_atoms: Literal[
"""
self.on_isolated_atoms = on_isolated_atoms

def as_dict(self) ->dict[str, str | float]:
def as_dict(self) -> dict[str, str | float]:
"""Save the args of the graph converter."""
return {'atom_graph_cutoff': self.atom_graph_cutoff,
'bond_graph_cutoff': self.bond_graph_cutoff, 'algorithm': self.
algorithm}
return {
"atom_graph_cutoff": self.atom_graph_cutoff,
"bond_graph_cutoff": self.bond_graph_cutoff,
"algorithm": self.algorithm,
}

@classmethod
def from_dict(cls, dct: dict) ->Self:
def from_dict(cls, dct: dict) -> Self:
"""Create converter from dictionary."""
return cls(**dct)
Loading

0 comments on commit 4ecad77

Please sign in to comment.