Skip to content

Commit

Permalink
fix and test in-place atoms changes (#524)
Browse files Browse the repository at this point in the history
* fix and test in-place atoms changes

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* typo

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
PythonFZ and pre-commit-ci[bot] committed Jun 27, 2024
1 parent 8373446 commit 3ab06e0
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 2 deletions.
32 changes: 32 additions & 0 deletions tests/test_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,35 @@ def test_exotic_atoms():
)
npt.assert_array_equal(new_atoms.arrays["colors"], ["#ff0000"])
npt.assert_array_equal(new_atoms.arrays["radii"], [0.3])


def test_modified_atoms():
atoms = ase.Atoms("H2", positions=[[0, 0, 0], [0, 0, 1]])
new_atoms = znjson.loads(
znjson.dumps(atoms, cls=znjson.ZnEncoder.from_converters([ASEConverter])),
cls=znjson.ZnDecoder.from_converters([ASEConverter]),
)
npt.assert_array_equal(new_atoms.arrays["colors"], ["#ffffff", "#ffffff"])
npt.assert_almost_equal(new_atoms.arrays["radii"], [0.3458333, 0.3458333])

# subtract
atoms = new_atoms[:1]
new_atoms = znjson.loads(
znjson.dumps(atoms, cls=znjson.ZnEncoder.from_converters([ASEConverter])),
cls=znjson.ZnDecoder.from_converters([ASEConverter]),
)

npt.assert_array_equal(new_atoms.arrays["colors"], ["#ffffff"])
npt.assert_almost_equal(new_atoms.arrays["radii"], [0.3458333])

# add
atoms = new_atoms + ase.Atoms("H", positions=[[0, 0, 1]])

new_atoms = znjson.loads(
znjson.dumps(atoms, cls=znjson.ZnEncoder.from_converters([ASEConverter])),
cls=znjson.ZnDecoder.from_converters([ASEConverter]),
)

npt.assert_array_equal(new_atoms.get_atomic_numbers(), [1, 1])
npt.assert_array_equal(new_atoms.arrays["colors"], ["#ffffff", "#ffffff"])
npt.assert_almost_equal(new_atoms.arrays["radii"], [0.3458333, 0.3458333])
3 changes: 3 additions & 0 deletions zndraw/modify/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,9 @@ def run(self, vis: "ZnDraw", **kwargs) -> None:
for atom_id in vis.selection:
atoms[atom_id].symbol = self.symbol.name

del atoms.arrays["colors"]
del atoms.arrays["radii"]

vis.append(atoms)
vis.selection = []

Expand Down
4 changes: 2 additions & 2 deletions zndraw/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def encode(self, obj: ase.Atoms) -> ASEDict:
# All additional information should be stored in calc.results
# and not in calc.arrays, thus we will not convert it here!
arrays = {}
if "colors" not in obj.arrays:
if ("colors" not in obj.arrays) or ("" in obj.arrays["colors"]):
arrays["colors"] = [rgb2hex(jmol_colors[number]) for number in numbers]
else:
arrays["colors"] = (
Expand All @@ -110,7 +110,7 @@ def encode(self, obj: ase.Atoms) -> ASEDict:
else obj.arrays["colors"]
)

if "radii" not in obj.arrays:
if ("radii" not in obj.arrays) or (0 in obj.arrays["radii"]):
# arrays["radii"] = [covalent_radii[number] for number in numbers]
arrays["radii"] = [get_scaled_radii()[number] for number in numbers]
else:
Expand Down

0 comments on commit 3ab06e0

Please sign in to comment.