-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feature-Request: Matrix Exponentiation #1779
Comments
In-python implementation, yoinked from torch and ported w/ Claude - appears to work in training, though: def _compute_T1(A):
"""I + A"""
return mx.eye(A.shape[-1]) + A
def _compute_T2(A):
"""I + A + A^2/2"""
A2 = A @ A
return mx.eye(A.shape[-1]) + A + A2/2
def _compute_T4(A):
"""I + A + A^2 * (I/2 + A/6 + A^2/24)"""
A2 = A @ A
inner_term = (mx.eye(A.shape[-1])/2 + A/6 + A2/24)
return mx.eye(A.shape[-1]) + A + (A2 @ inner_term)
def _compute_T8(A):
sqrt_177 = 0.1330413469565007072504e+2
x3 = 2/3
x1 = x3 * ((1 + sqrt_177) / 88)
x2 = x3 * ((1 + sqrt_177) / 352)
x4 = (-271 + 29 * sqrt_177) / (315 * x3)
x5 = (-11 + 11 * sqrt_177) / (1260 * x3)
x6 = (-99 + 11 * sqrt_177) / (5040 * x3)
x7 = (89 - sqrt_177) / (5040 * x3)
y2 = (857 - 58 * sqrt_177) / 630
A2 = A @ A
A4 = A2 @ (x1*A + x2*A2)
A8 = (x3*A2 + A4) @ (x4*mx.eye(A.shape[-1]) + x5*A + x6*A2 + x7*A4)
return mx.eye(A.shape[-1]) + A + y2*A2 + A8
def matrix_exp(A):
"""
Computes matrix exponential using optimized Taylor series.
Based on PyTorch's implementation from the paper:
Bader, P.; Blanes, S.; Casas, F.
Computing the Matrix Exponential with an Optimized Taylor Polynomial Approximation.
"""
if A.shape[-2:] == (0, 0):
return A.clone()
elif A.shape[-2:] == (1, 1):
return mx.exp(A)
# Compute the matrix norm to choose degree
matrix_norm = mx.max(mx.sum(mx.abs(A), axis=-2), axis=-1)
# These thresholds are from PyTorch's implementation
# They're carefully chosen based on the paper
if A.dtype == mx.float32:
thresholds = [
1.192092800768788e-07, # deg 1
5.978858893805233e-04, # deg 2
5.116619363445086e-02, # deg 4
5.800524627688768e-01, # deg 8
1.461661507209034e+00, # deg 12
3.010066362817634e+00 # deg 18
]
else: # float64
thresholds = [
2.220446049250313e-16, # deg 1
2.580956802971767e-08, # deg 2
3.397168839976962e-04, # deg 4
4.991228871115323e-02, # deg 8
2.996158913811580e-01, # deg 12
1.090863719290036e+00 # deg 18
]
# For small norms use lower degree approximations
if matrix_norm <= thresholds[0]:
return _compute_T1(A)
elif matrix_norm <= thresholds[1]:
return _compute_T2(A)
elif matrix_norm <= thresholds[2]:
return _compute_T4(A)
elif matrix_norm <= thresholds[3]:
return _compute_T8(A)
# For larger norms use scaling and squaring with T8
s = mx.maximum(
mx.zeros_like(matrix_norm),
mx.ceil(mx.log2(matrix_norm / thresholds[3]))
)
s = s.astype(mx.int32)
A_scaled = A / mx.expand_dims(mx.expand_dims(2.0**s, -1), -1)
# Compute exponential of scaled matrix
X = _compute_T8(A_scaled)
# Square back up
max_s = int(mx.max(s).item())
for _ in range(max_s):
X = mx.where(s > 0, X @ X, X)
s = s - 1
return X |
Full implementation w/ optimization and custom vjp for smol kernels, all, again, yoinked from pytorch: https://gist.github.com/N8python/b3e24a4f88efa52bdd81a8762b7a7238. For two 1024x1024 matrices initialized with randn(0, 0.1), the provided matrix_exp implementation diverges from that of pytorch by a maximum absolute different of 0.000975. |
For two 1024x1024 matrices initialized with randn(0, 1) - intentionally made to have diverging eigenvalues - the average difference, as a percent of the maxmimum element in the torch computation, is 0.03124% in absolute terms - unnormalized, it is roughly 1311244288.0 due to diverging eigenvalues. |
This is very nicely done :-). I especially love the gist with the custom function... I think it is not quite ready to be made into an op yet, but it definitely raises some nice issues for us to solve. The main problem I see is the implicit graph evaluation in a couple of places (matrix norm conditionals and looping). I think this may be a good example for an |
Agreed. But how would I go about doing that - given that the if on the norm is neccessary to choose the correct polynomial, and the scale factor - which also depends on the norm. Or are you saying this needs to be implemented: |
Useful for orthogonal optimization and other algorithms where calculating unitary matrices or dealing with complex numbers is necessary.
Pytorch impl:
https://discuss.pytorch.org/t/what-implementation-is-used-for-matrix-exp/159608/5
https://pytorch.org/docs/stable/generated/torch.linalg.matrix_exp.html
The text was updated successfully, but these errors were encountered: