Skip to content

Commit b8f710f

Browse files
authored
Small fixes (#62)
* fix dynamic polarizability * update dft * update Davidson solver * update tdscf * update doc
1 parent 0342d5f commit b8f710f

File tree

18 files changed

+439
-315
lines changed

18 files changed

+439
-315
lines changed

.github/workflows/install_pyscf.sh

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,6 @@ pip install pytest-cov
77
pip install numpy
88
pip install scipy
99
pip install h5py
10-
pip install jaxlib
11-
pip install jax
12-
pip install 'pyscf>=2.3,<2.7'
10+
pip install 'jaxlib<=0.4.35'
11+
pip install 'jax<=0.4.35'
12+
pip install 'pyscf>=2.3'

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ The dependent C library `pyscfadlib` can be compiled from source following the i
3131
`numpy`, `scipy`,
3232
`pyscf>=2.3.0`,
3333
`pyscfadlib>=0.1.4`,
34-
`jax>=0.4.14`, and `jaxlib>=0.4.14`.
34+
`jax>=0.4.14,<=0.4.35`, and `jaxlib>=0.4.14,<=0.4.35`.
3535

3636
Citing PySCFAD
3737
--------------

doc/source/getting_started/install.rst

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,9 +82,15 @@ Package supported versions
8282
`numpy <https://numpy.org>`_ >=1.17
8383
`scipy <https://scipy.org>`_ >=1.7
8484
`h5py <https://www.h5py.org/>`_ >=2.7
85-
`jax <https://jax.readthedocs.io/en/latest/>`_ >=0.4.14
86-
`jaxlib <https://pypi.org/project/jaxlib/>`_ >=0.4.14
85+
`jax <https://jax.readthedocs.io/en/latest/>`_ >=0.4.14,<=0.4.35
86+
`jaxlib <https://pypi.org/project/jaxlib/>`_ >=0.4.14,<=0.4.35
8787
`pyscf <https://pyscf.org/>`_ >=2.3.0
8888
`pyscfadlib <https://pypi.org/project/pyscfadlib/>`_ >=0.1.4
8989
===================================================== ==================
9090

91+
.. note::
92+
93+
Since jax version 0.4.36, the tracing machinery has been modified
94+
to eliminate data-dependent tracing, which conflicts with pyscfad's flexibility.
95+
A comprehensive update to ensure compatibility with later jax versions may be
96+
introduced in the future pyscfad 0.2 release.

examples/scf/10-polarizability.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import numpy
21
import jax
32
from pyscfad import numpy as np
43
from pyscfad import gto, scf
@@ -19,7 +18,7 @@ def apply_E(E):
1918
mf.kernel()
2019
return mf.dip_moment(mol, mf.make_rdm1(), unit='AU', verbose=0)
2120

22-
E0 = numpy.zeros((3))
21+
E0 = np.zeros((3))
2322
polar = jax.jacfwd(apply_E)(E0)
2423
print(polar)
2524

examples/scf/11-born_charge.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import numpy
55
import jax
66
from pyscfad import gto, scf
7-
from pyscfad.lib import numpy as jnp
7+
from pyscfad import numpy as np
88

99
mol = gto.Mole()
1010
mol.atom = '''H , 0. 0. 0.
@@ -17,7 +17,7 @@ def dip_moment(mol):
1717
ao_dip = mol.intor_symmetric('int1e_r', comp=3)
1818
h1 = mf.get_hcore()
1919
E = numpy.zeros((3))
20-
mf.get_hcore = lambda *args, **kwargs: h1 + jnp.einsum('x,xij->ij', E, ao_dip)
20+
mf.get_hcore = lambda *args, **kwargs: h1 + np.einsum('x,xij->ij', E, ao_dip)
2121
mf.kernel()
2222
return mf.dip_moment(mol, mf.make_rdm1(), unit='AU', verbose=0)
2323

examples/scf/12-raman.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
'''
1111
import jax
1212
from pyscfad import gto, scf
13-
from pyscfad.lib import numpy as np
13+
from pyscfad import numpy as np
1414
from pyscfad.prop.polarizability.rhf import Polarizability
1515
from pyscfad.prop.thermo import vib
1616

examples/scf/21-ir.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
'''
99
import jax
1010
from pyscfad import gto, scf, cc
11-
from pyscfad.lib import numpy as np
11+
from pyscfad import numpy as np
1212
from pyscfad.prop.thermo import vib
1313

1414
mol = gto.Mole()

pyscfad/dft/libxc.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -59,10 +59,13 @@ def _eval_xc_comp_jvp(xc_code, spin, relativity, deriv, omega, verbose,
5959
vrho1, vsigma1, vlapl1, vtau1 = _vxc_partial_deriv(rho, val, val1, 'MGGA')
6060
vrho_jvp = np.einsum('np,np->p', vrho1, rho_t)
6161
vsigma_jvp = np.einsum('np,np->p', vsigma1, rho_t)
62-
vlapl_jvp = np.einsum('np,np->p', vlapl1, rho_t)
62+
if vlapl1 is None:
63+
vlapl_jvp = None
64+
else:
65+
vlapl_jvp = np.einsum('np,np->p', vlapl1, rho_t)
6366
vtau_jvp = np.einsum('np,np->p', vtau1, rho_t)
6467
vrho1 = vsigma1 = vlapl1 = vtau1 = None
65-
jvp = np.vstack((vrho_jvp, vsigma_jvp, vlapl_jvp, vtau_jvp))
68+
jvp = (vrho_jvp, vsigma_jvp, vlapl_jvp, vtau_jvp)
6669
else:
6770
raise NotImplementedError
6871
else:
@@ -92,7 +95,10 @@ def _exc_partial_deriv(rho, exc, vxc, xctype='LDA'):
9295
dsigma = vxc[1] / rho[0] * 2. * rho[1:4]
9396
exc1 = np.vstack((drho, dsigma))
9497
if xctype == 'MGGA':
95-
dlap = vxc[2] / rho[0]
98+
if vxc[2] is None:
99+
dlap = np.zeros_like(rho[0])
100+
else:
101+
dlap = vxc[2] / rho[0]
96102
dtau = vxc[3] / rho[0]
97103
exc1 = np.vstack((exc1, dlap, dtau))
98104
else:
@@ -108,10 +114,21 @@ def _vxc_partial_deriv(rho, vxc, fxc, xctype='LDA'):
108114
vrho1 = np.vstack((fxc[0], fxc[1] * 2. * rho[1:4]))
109115
vsigma1 = np.vstack((fxc[1], fxc[2] * 2. * rho[1:4]))
110116
if xctype == 'MGGA':
111-
vrho1 = np.vstack((vrho1, fxc[5], fxc[6]))
112-
vsigma1 = np.vstack((vsigma1, fxc[8], fxc[9]))
113-
vlapl1 = np.vstack((fxc[5], fxc[8] * 2. * rho[1:4], fxc[3], fxc[7]))
114-
vtau1 = np.vstack((fxc[6], fxc[9] * 2. * rho[1:4], fxc[7], fxc[4]))
117+
ZERO = np.zeros_like(rho[0])
118+
if vxc[2] is None:
119+
fxc3 = fxc5 = fxc7 = fxc8 = ZERO
120+
else:
121+
fxc3 = fxc[3]
122+
fxc5 = fxc[5]
123+
fxc7 = fxc[7]
124+
fxc8 = fxc[8]
125+
vrho1 = np.vstack((vrho1, fxc5, fxc[6]))
126+
vsigma1 = np.vstack((vsigma1, fxc8, fxc[9]))
127+
if vxc[2] is None:
128+
vlapl1 = None
129+
else:
130+
vlapl1 = np.vstack((fxc5, fxc8 * 2. * rho[1:4], fxc3, fxc7))
131+
vtau1 = np.vstack((fxc[6], fxc[9] * 2. * rho[1:4], fxc7, fxc[4]))
115132
else:
116133
raise KeyError
117134
return vrho1, vsigma1, vlapl1, vtau1

0 commit comments

Comments
 (0)