Skip to content

Commit b742618

Browse files
authored
Fix log density spline extrapolation (#133)
* Add extrapolation limits to DensitySpline * Handle log/exp cutoffs better * Comment and clean up extrapolation code * Set extrapolation cutoff limit to r = r_N
1 parent 511a702 commit b742618

File tree

2 files changed

+23
-27
lines changed

2 files changed

+23
-27
lines changed

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ cython_debug/
147147
.vscode/
148148

149149
# Raw data files
150-
atomdb/datasets/*/raw/
150+
atomdb/datasets/*/db/
151151

152152
# Generated documentation
153153
docs/source/api/

atomdb/species.py

Lines changed: 22 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -15,35 +15,23 @@
1515

1616
r"""AtomDB, a database of atomic and ionic properties."""
1717

18-
from dataclasses import dataclass, field, asdict
19-
20-
from glob import glob
21-
22-
from importlib import import_module
23-
2418
import json
25-
19+
import re
20+
from dataclasses import asdict, dataclass, field
21+
from importlib import import_module
2622
from numbers import Integral
27-
2823
from os import makedirs, path
2924

30-
from msgpack import packb, unpackb
31-
32-
from msgpack_numpy import encode, decode
33-
3425
import numpy as np
35-
36-
from numpy import ndarray
37-
3826
import pooch
39-
import re
4027
import requests
41-
28+
from msgpack import packb, unpackb
29+
from msgpack_numpy import decode, encode
30+
from numpy import ndarray
4231
from scipy.interpolate import CubicSpline
4332

44-
from atomdb.utils import DEFAULT_DATASET, DEFAULT_DATAPATH, DEFAULT_REMOTE
45-
from atomdb.periodic import element_symbol, Element
46-
33+
from atomdb.periodic import Element, element_symbol
34+
from atomdb.utils import DEFAULT_DATAPATH, DEFAULT_DATASET, DEFAULT_REMOTE
4735

4836
__all__ = [
4937
"Species",
@@ -166,7 +154,9 @@ def __init__(self, x, y, log=False):
166154
self._log = log
167155
self._obj = CubicSpline(
168156
x,
169-
np.log(y) if log else y,
157+
# Clip y values to >= ε^2 if using log because they have to be above 0;
158+
# having them be at least ε^2 seems to work based on my testing
159+
np.log(y.clip(min=np.finfo(float).eps ** 2)) if log else y,
170160
axis=0,
171161
bc_type="not-a-knot",
172162
extrapolate=True,
@@ -192,7 +182,9 @@ def __call__(self, x, deriv=0):
192182
if not (0 <= deriv <= 2):
193183
raise ValueError(f"Invalid derivative order {deriv}; must be 0 <= `deriv` <= 2")
194184
elif self._log:
195-
y = np.exp(self._obj(x))
185+
# Get y = exp(log y). We'll handle errors from small log y values later.
186+
with np.errstate(over="ignore"):
187+
y = np.exp(self._obj(x))
196188
if deriv == 1:
197189
# d(ρ(r)) = d(log(ρ(r))) * ρ(r)
198190
dlogy = self._obj(x, nu=1)
@@ -201,9 +193,13 @@ def __call__(self, x, deriv=0):
201193
# d^2(ρ(r)) = d^2(log(ρ(r))) * ρ(r) + [d(ρ(r))]^2/ρ(r)
202194
dlogy = self._obj(x, nu=1)
203195
d2logy = self._obj(x, nu=2)
204-
y = d2logy.flatten() * y + dlogy.flatten() ** 2 * y
196+
y = d2logy.flatten() * y + dlogy.flatten() ** 2 / y
205197
else:
206198
y = self._obj(x, nu=deriv)
199+
# Handle errors from the y = exp(log y) operation -- set NaN to zero
200+
np.nan_to_num(y, nan=0., copy=False)
201+
# Cutoff value: assume y(x) is zero where x > final given point x_n
202+
y[x > self._obj.x[-1]] = 0
207203
return y
208204

209205

@@ -218,7 +214,7 @@ def default(self, obj):
218214
return JSONEncoder.default(self, obj)
219215

220216

221-
class _AtomicOrbitals(object):
217+
class _AtomicOrbitals:
222218
"""Atomic orbitals class."""
223219

224220
def __init__(self, data) -> None:
@@ -883,13 +879,13 @@ def datafile(
883879
url=f"{remotepath}{dataset.lower()}/db/repodata.txt",
884880
known_hash=None,
885881
path=path.join(datapath, dataset.lower(), "db"),
886-
fname=f"repo_data.txt",
882+
fname="repo_data.txt",
887883
)
888884
# if the file is not found or remote was not valid, use the local repodata file
889885
except (requests.exceptions.HTTPError, ValueError):
890886
repodata = path.join(datapath, dataset.lower(), "db", "repo_data.txt")
891887

892-
with open(repodata, "r") as f:
888+
with open(repodata) as f:
893889
data = f.read()
894890
files = re.findall(rf"\b{elem}+_{charge}+_{mult}+_{nexc}\.msg\b", data)
895891
species_list = []

0 commit comments

Comments
 (0)