|
1 |
| -from functools import partial |
2 |
| -import numpy |
3 |
| -import scipy |
4 |
| -import scipy.linalg |
5 |
| -from jax import numpy as np |
6 |
| -from jax import scipy as jax_scipy |
7 |
| -from pyscfad.ops import custom_jvp, jit |
8 |
| - |
9 |
| -# default threshold for degenerate eigenvalues |
10 |
| -DEG_THRESH = 1e-9 |
11 |
| - |
12 |
| -# pylint: disable = redefined-builtin |
13 |
| -def eigh(a, b=None, lower=True, eigvals_only=False, overwrite_a=False, |
14 |
| - overwrite_b=False, turbo=True, eigvals=None, type=1, |
15 |
| - check_finite=True, subset_by_index=None, subset_by_value=None, |
16 |
| - driver=None, deg_thresh=DEG_THRESH): |
17 |
| - if overwrite_a is True or overwrite_b is True: |
18 |
| - raise NotImplementedError('Overwritting a or b is not implemeneted.') |
19 |
| - if type != 1: |
20 |
| - raise NotImplementedError('Only the type=1 case of eigh is implemented.') |
21 |
| - if not(eigvals is None and subset_by_index is None and subset_by_value is None): |
22 |
| - raise NotImplementedError('Subset of eigen values is not implemented.') |
23 |
| - |
24 |
| - a = 0.5 * (a + a.T.conj()) |
25 |
| - if b is not None: |
26 |
| - b = 0.5 * (b + b.T.conj()) |
27 |
| - |
28 |
| - w, v = _eigh(a, b, deg_thresh=deg_thresh) |
29 |
| - |
30 |
| - if eigvals_only: |
31 |
| - return w |
32 |
| - else: |
33 |
| - return w, v |
34 |
| - |
35 |
| -@partial(custom_jvp, nondiff_argnums=(2,)) |
36 |
| -def _eigh(a, b, deg_thresh=DEG_THRESH): |
37 |
| - w, v = scipy.linalg.eigh(a, b=b) |
38 |
| - w = np.asarray(w, dtype=float) |
39 |
| - return w, v |
40 |
| - |
41 |
| -@_eigh.defjvp |
42 |
| -def _eigh_jvp(deg_thresh, primals, tangents): |
43 |
| - a, b = primals |
44 |
| - at, bt = tangents |
45 |
| - w, v = _eigh(a, b, deg_thresh) |
46 |
| - |
47 |
| - eji = w[None, :] - w[:, None] |
48 |
| - idx = numpy.asarray(abs(eji) <= deg_thresh, dtype=bool) |
49 |
| - eji = eji.at[idx].set(1e200) |
50 |
| - eji = eji.at[numpy.diag_indices_from(eji)].set(1) |
51 |
| - Fmat = 1 / eji - numpy.eye(a.shape[-1]) |
52 |
| - if b is None: |
53 |
| - dw, dv = _eigh_jvp_jitted_nob(v, Fmat, at) |
54 |
| - else: |
55 |
| - bmask = numpy.zeros(a.shape) |
56 |
| - bmask[idx] = 1 |
57 |
| - dw, dv = _eigh_jvp_jitted(w, v, Fmat, at, bt, bmask) |
58 |
| - return (w, v), (dw, dv) |
59 |
| - |
60 |
| -@jit |
61 |
| -def _eigh_jvp_jitted(w, v, Fmat, at, bt, bmask): |
62 |
| - vt_at_v = np.dot(v.conj().T, np.dot(at, v)) |
63 |
| - vt_bt_v = np.dot(v.conj().T, np.dot(bt, v)) |
64 |
| - vt_bt_v_w = np.dot(vt_bt_v, np.diag(w)) |
65 |
| - da_minus_ds = vt_at_v - vt_bt_v_w |
66 |
| - dw = np.diag(da_minus_ds).real |
67 |
| - |
68 |
| - dv = np.dot(v, np.multiply(Fmat, da_minus_ds) - np.multiply(bmask, vt_bt_v) * .5) |
69 |
| - return dw, dv |
70 |
| - |
71 |
| -@jit |
72 |
| -def _eigh_jvp_jitted_nob(v, Fmat, at): |
73 |
| - vt_at_v = np.dot(v.conj().T, np.dot(at, v)) |
74 |
| - dw = np.diag(vt_at_v).real |
75 |
| - dv = np.dot(v, np.multiply(Fmat, vt_at_v)) |
76 |
| - return dw, dv |
77 |
| - |
78 |
| - |
79 |
| -def svd(a, full_matrices=True, compute_uv=True, |
80 |
| - overwrite_a=False, check_finite=True, |
81 |
| - lapack_driver='gesdd'): |
82 |
| - if not full_matrices or not compute_uv: |
83 |
| - return jax_scipy.linalg.svd(a, |
84 |
| - full_matrices=full_matrices, |
85 |
| - compute_uv=compute_uv) |
86 |
| - else: |
87 |
| - return _svd(a) |
88 |
| - |
89 |
| -@custom_jvp |
90 |
| -def _svd(a): |
91 |
| - return jax_scipy.linalg.svd(a) |
92 |
| - |
93 |
| -@_svd.defjvp |
94 |
| -def _svd_jvp(primals, tangents): |
95 |
| - A, = primals |
96 |
| - dA, = tangents |
97 |
| - if np.iscomplexobj(A): |
98 |
| - raise NotImplementedError |
99 |
| - |
100 |
| - m, n = A.shape |
101 |
| - if m > n: |
102 |
| - raise NotImplementedError('Use svd(A.conj().T) instead.') |
103 |
| - |
104 |
| - U, s, Vt = _svd(A) |
105 |
| - Ut = U.conj().T |
106 |
| - V = Vt.conj().T |
107 |
| - s_dim = s[None, :] |
108 |
| - |
109 |
| - dS = Ut @ dA @ V |
110 |
| - ds = np.diagonal(dS, 0, -2, -1).real |
111 |
| - |
112 |
| - s_diffs = (s_dim + s_dim.T) * (s_dim - s_dim.T) |
113 |
| - s_diffs_zeros = (s_diffs == 0).astype(s_diffs.dtype) |
114 |
| - F = 1. / (s_diffs + s_diffs_zeros) - s_diffs_zeros |
115 |
| - |
116 |
| - dP1 = dS[:,:m] |
117 |
| - dP2 = dS[:,m:] |
118 |
| - dSS = dP1 * s_dim |
119 |
| - SdS = s_dim.T * dP1 |
120 |
| - |
121 |
| - dU = U @ (F * (dSS + dSS.conj().T)) |
122 |
| - dD1 = F * (SdS + SdS.conj().T) |
123 |
| - |
124 |
| - s_zeros = (s == 0).astype(s.dtype) |
125 |
| - s_inv = 1. / (s + s_zeros) - s_zeros |
126 |
| - dD2 = s_inv[:,None] * dP2 |
127 |
| - |
128 |
| - dV = np.zeros_like(V) |
129 |
| - dV = dV.at[:m,:m].set(dD1) |
130 |
| - dV = dV.at[:m,m:].set(-dD2) |
131 |
| - dV = dV.at[m:,:m].set(dD2.conj().T) |
132 |
| - dV = V @ dV |
133 |
| - return (U, s, Vt), (dU, ds, dV.conj().T) |
0 commit comments