Skip to content

Commit a1bc4f6

Browse files
authored
minor update (#47)
* minor update fci_slow * remove duplicate scipy functions
1 parent 08b57da commit a1bc4f6

File tree

3 files changed

+20
-151
lines changed

3 files changed

+20
-151
lines changed

.pylintrc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,7 @@ disable=abstract-method,
162162
arguments-out-of-order,
163163
consider-using-in,
164164
invalid-unary-operand-type,
165+
unnecessary-lambda-assignment,
165166

166167

167168
[REPORTS]

pyscfad/_src/scipy/linalg.py

Lines changed: 0 additions & 133 deletions
Original file line numberDiff line numberDiff line change
@@ -1,133 +0,0 @@
1-
from functools import partial
2-
import numpy
3-
import scipy
4-
import scipy.linalg
5-
from jax import numpy as np
6-
from jax import scipy as jax_scipy
7-
from pyscfad.ops import custom_jvp, jit
8-
9-
# default threshold for degenerate eigenvalues
10-
DEG_THRESH = 1e-9
11-
12-
# pylint: disable = redefined-builtin
13-
def eigh(a, b=None, lower=True, eigvals_only=False, overwrite_a=False,
14-
overwrite_b=False, turbo=True, eigvals=None, type=1,
15-
check_finite=True, subset_by_index=None, subset_by_value=None,
16-
driver=None, deg_thresh=DEG_THRESH):
17-
if overwrite_a is True or overwrite_b is True:
18-
raise NotImplementedError('Overwritting a or b is not implemeneted.')
19-
if type != 1:
20-
raise NotImplementedError('Only the type=1 case of eigh is implemented.')
21-
if not(eigvals is None and subset_by_index is None and subset_by_value is None):
22-
raise NotImplementedError('Subset of eigen values is not implemented.')
23-
24-
a = 0.5 * (a + a.T.conj())
25-
if b is not None:
26-
b = 0.5 * (b + b.T.conj())
27-
28-
w, v = _eigh(a, b, deg_thresh=deg_thresh)
29-
30-
if eigvals_only:
31-
return w
32-
else:
33-
return w, v
34-
35-
@partial(custom_jvp, nondiff_argnums=(2,))
36-
def _eigh(a, b, deg_thresh=DEG_THRESH):
37-
w, v = scipy.linalg.eigh(a, b=b)
38-
w = np.asarray(w, dtype=float)
39-
return w, v
40-
41-
@_eigh.defjvp
42-
def _eigh_jvp(deg_thresh, primals, tangents):
43-
a, b = primals
44-
at, bt = tangents
45-
w, v = _eigh(a, b, deg_thresh)
46-
47-
eji = w[None, :] - w[:, None]
48-
idx = numpy.asarray(abs(eji) <= deg_thresh, dtype=bool)
49-
eji = eji.at[idx].set(1e200)
50-
eji = eji.at[numpy.diag_indices_from(eji)].set(1)
51-
Fmat = 1 / eji - numpy.eye(a.shape[-1])
52-
if b is None:
53-
dw, dv = _eigh_jvp_jitted_nob(v, Fmat, at)
54-
else:
55-
bmask = numpy.zeros(a.shape)
56-
bmask[idx] = 1
57-
dw, dv = _eigh_jvp_jitted(w, v, Fmat, at, bt, bmask)
58-
return (w, v), (dw, dv)
59-
60-
@jit
61-
def _eigh_jvp_jitted(w, v, Fmat, at, bt, bmask):
62-
vt_at_v = np.dot(v.conj().T, np.dot(at, v))
63-
vt_bt_v = np.dot(v.conj().T, np.dot(bt, v))
64-
vt_bt_v_w = np.dot(vt_bt_v, np.diag(w))
65-
da_minus_ds = vt_at_v - vt_bt_v_w
66-
dw = np.diag(da_minus_ds).real
67-
68-
dv = np.dot(v, np.multiply(Fmat, da_minus_ds) - np.multiply(bmask, vt_bt_v) * .5)
69-
return dw, dv
70-
71-
@jit
72-
def _eigh_jvp_jitted_nob(v, Fmat, at):
73-
vt_at_v = np.dot(v.conj().T, np.dot(at, v))
74-
dw = np.diag(vt_at_v).real
75-
dv = np.dot(v, np.multiply(Fmat, vt_at_v))
76-
return dw, dv
77-
78-
79-
def svd(a, full_matrices=True, compute_uv=True,
80-
overwrite_a=False, check_finite=True,
81-
lapack_driver='gesdd'):
82-
if not full_matrices or not compute_uv:
83-
return jax_scipy.linalg.svd(a,
84-
full_matrices=full_matrices,
85-
compute_uv=compute_uv)
86-
else:
87-
return _svd(a)
88-
89-
@custom_jvp
90-
def _svd(a):
91-
return jax_scipy.linalg.svd(a)
92-
93-
@_svd.defjvp
94-
def _svd_jvp(primals, tangents):
95-
A, = primals
96-
dA, = tangents
97-
if np.iscomplexobj(A):
98-
raise NotImplementedError
99-
100-
m, n = A.shape
101-
if m > n:
102-
raise NotImplementedError('Use svd(A.conj().T) instead.')
103-
104-
U, s, Vt = _svd(A)
105-
Ut = U.conj().T
106-
V = Vt.conj().T
107-
s_dim = s[None, :]
108-
109-
dS = Ut @ dA @ V
110-
ds = np.diagonal(dS, 0, -2, -1).real
111-
112-
s_diffs = (s_dim + s_dim.T) * (s_dim - s_dim.T)
113-
s_diffs_zeros = (s_diffs == 0).astype(s_diffs.dtype)
114-
F = 1. / (s_diffs + s_diffs_zeros) - s_diffs_zeros
115-
116-
dP1 = dS[:,:m]
117-
dP2 = dS[:,m:]
118-
dSS = dP1 * s_dim
119-
SdS = s_dim.T * dP1
120-
121-
dU = U @ (F * (dSS + dSS.conj().T))
122-
dD1 = F * (SdS + SdS.conj().T)
123-
124-
s_zeros = (s == 0).astype(s.dtype)
125-
s_inv = 1. / (s + s_zeros) - s_zeros
126-
dD2 = s_inv[:,None] * dP2
127-
128-
dV = np.zeros_like(V)
129-
dV = dV.at[:m,:m].set(dD1)
130-
dV = dV.at[:m,m:].set(-dD2)
131-
dV = dV.at[m:,:m].set(dD2.conj().T)
132-
dV = V @ dV
133-
return (U, s, Vt), (dU, ds, dV.conj().T)

pyscfad/fci/fci_slow.py

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from pyscf.fci import cistring
33
from pyscfad import numpy as np
44
from pyscfad import ops
5-
from pyscfad.ops import vmap, stop_grad
5+
from pyscfad.ops import vmap, to_numpy
66
from pyscfad.lib.linalg_helper import davidson
77
from pyscfad.gto import mole
88
from pyscfad import ao2mo
@@ -82,16 +82,12 @@ def body(mo_ia, mo_ib, ida, idb):
8282
val = np.linalg.det(sij_a) * np.linalg.det(sij_b)
8383
return val
8484

85-
res = 0.
85+
res = 0
8686
for ia in range(na1):
8787
mo_ia = mo_a1[:,locs_a1[ia]]
8888
for ib in range(nb1):
8989
mo_ib = mo_b1[:,locs_b1[ib]]
9090
val = vmap(body, (None,None,0,0), signature='(i),(j)->()')(mo_ia, mo_ib, idxa, idxb)
91-
#val = []
92-
#for i in range(len(idxa)):
93-
# val.append(body(mo_ia, mo_ib, idxa[i], idxb[i]))
94-
#val = np.asarray(val)
9591
res += ci1[ia,ib] * (val * ci2.ravel()).sum()
9692
return res
9793

@@ -126,16 +122,19 @@ def contract_2e(eri, fcivec, norb, nelec, opt=None):
126122
fcinew = ops.index_add(fcinew, ops.index[:,str1], sign * t1[a,i,:,str0])
127123
return fcinew.reshape(fcivec.shape)
128124

129-
130125
def absorb_h1e(h1e, eri, norb, nelec, fac=1):
131126
if not isinstance(nelec, (int, np.integer)):
132-
nelec = sum(nelec)
127+
nelec = np.sum(nelec)
128+
assert nelec > 0
129+
133130
if eri.size != norb**4:
134-
h2e = ao2mo.restore(1, eri.copy(), norb)
131+
h2e = ao2mo.restore(1, eri, norb)
135132
else:
136-
h2e = eri.copy().reshape(norb,norb,norb,norb)
137-
f1e = h1e - np.einsum('jiik->jk', h2e) * .5
138-
f1e = f1e * (1./(nelec+1e-100))
133+
h2e = eri.reshape([norb,]*4)
134+
135+
f1e = h1e - np.einsum('jiik->jk', h2e) * .5
136+
f1e *= 1. / nelec
137+
139138
for k in range(norb):
140139
h2e = ops.index_add(h2e, ops.index[k,k,:,:], f1e)
141140
h2e = ops.index_add(h2e, ops.index[:,:,k,k], f1e)
@@ -153,7 +152,7 @@ def make_hdiag(h1e, eri, norb, nelec, opt=None):
153152
if eri.size != norb**4:
154153
eri = ao2mo.restore(1, eri, norb)
155154
else:
156-
eri = eri.reshape(norb,norb,norb,norb)
155+
eri = eri.reshape([norb,]*4)
157156
diagj = np.einsum('iijj->ij', eri)
158157
diagk = np.einsum('ijji->ij', eri)
159158
hdiag = []
@@ -173,10 +172,10 @@ def kernel(h1e, eri, norb, nelec, ecore=0, nroots=1):
173172
hdiag = make_hdiag(h1e, eri, norb, nelec)
174173
try:
175174
from pyscf.fci.direct_spin1 import pspace
176-
addrs, h0 = pspace(stop_grad(h1e), stop_grad(eri),
177-
norb, nelec, stop_grad(hdiag), nroots)
178-
# pylint: disable=bare-except
179-
except:
175+
addrs, _ = pspace(to_numpy(h1e), to_numpy(eri),
176+
norb, nelec, to_numpy(hdiag), nroots)
177+
# pylint: disable=broad-exception-caught
178+
except Exception:
180179
addrs = numpy.argsort(hdiag)[:nroots]
181180
ci0 = []
182181
for addr in addrs:
@@ -187,7 +186,9 @@ def kernel(h1e, eri, norb, nelec, ecore=0, nroots=1):
187186
def hop(c):
188187
hc = contract_2e(h2e, c, norb, nelec)
189188
return hc.ravel()
190-
# pylint: disable=unnecessary-lambda-assignment
189+
191190
precond = lambda x, e, *args: x/(hdiag-e+1e-4)
191+
192192
e, c = davidson(hop, ci0, precond, nroots=nroots)
193193
return e+ecore, c
194+

0 commit comments

Comments
 (0)