Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature-Request: Matrix Exponentiation #1779

Open
N8python opened this issue Jan 19, 2025 · 6 comments
Open

Feature-Request: Matrix Exponentiation #1779

N8python opened this issue Jan 19, 2025 · 6 comments

Comments

@N8python
Copy link

Useful for orthogonal optimization and other algorithms where calculating unitary matrices or dealing with complex numbers is necessary.

Pytorch impl:
https://discuss.pytorch.org/t/what-implementation-is-used-for-matrix-exp/159608/5

https://pytorch.org/docs/stable/generated/torch.linalg.matrix_exp.html

@N8python
Copy link
Author

In-python implementation, yoinked from torch and ported w/ Claude - appears to work in training, though:

def _compute_T1(A):
    """I + A"""
    return mx.eye(A.shape[-1]) + A

def _compute_T2(A):
    """I + A + A^2/2"""
    A2 = A @ A
    return mx.eye(A.shape[-1]) + A + A2/2

def _compute_T4(A):
    """I + A + A^2 * (I/2 + A/6 + A^2/24)"""
    A2 = A @ A
    inner_term = (mx.eye(A.shape[-1])/2 + A/6 + A2/24)
    return mx.eye(A.shape[-1]) + A + (A2 @ inner_term)

def _compute_T8(A):
    sqrt_177 = 0.1330413469565007072504e+2
    x3 = 2/3
    x1 = x3 * ((1 + sqrt_177) / 88)
    x2 = x3 * ((1 + sqrt_177) / 352)
    x4 = (-271 + 29 * sqrt_177) / (315 * x3)
    x5 = (-11 + 11 * sqrt_177) / (1260 * x3)
    x6 = (-99 + 11 * sqrt_177) / (5040 * x3)
    x7 = (89 - sqrt_177) / (5040 * x3)
    y2 = (857 - 58 * sqrt_177) / 630

    A2 = A @ A
    A4 = A2 @ (x1*A + x2*A2)
    A8 = (x3*A2 + A4) @ (x4*mx.eye(A.shape[-1]) + x5*A + x6*A2 + x7*A4)
    
    return mx.eye(A.shape[-1]) + A + y2*A2 + A8

def matrix_exp(A):
    """
    Computes matrix exponential using optimized Taylor series.
    Based on PyTorch's implementation from the paper:
    Bader, P.; Blanes, S.; Casas, F.
    Computing the Matrix Exponential with an Optimized Taylor Polynomial Approximation.
    """
    if A.shape[-2:] == (0, 0):
        return A.clone()
    elif A.shape[-2:] == (1, 1):
        return mx.exp(A)

    # Compute the matrix norm to choose degree
    matrix_norm = mx.max(mx.sum(mx.abs(A), axis=-2), axis=-1)
    
    # These thresholds are from PyTorch's implementation
    # They're carefully chosen based on the paper
    if A.dtype == mx.float32:
        thresholds = [
            1.192092800768788e-07,  # deg 1
            5.978858893805233e-04,  # deg 2
            5.116619363445086e-02,  # deg 4
            5.800524627688768e-01,  # deg 8
            1.461661507209034e+00,  # deg 12
            3.010066362817634e+00   # deg 18
        ]
    else:  # float64
        thresholds = [
            2.220446049250313e-16,  # deg 1
            2.580956802971767e-08,  # deg 2
            3.397168839976962e-04,  # deg 4
            4.991228871115323e-02,  # deg 8
            2.996158913811580e-01,  # deg 12
            1.090863719290036e+00   # deg 18
        ]

    # For small norms use lower degree approximations
    if matrix_norm <= thresholds[0]:
        return _compute_T1(A)
    elif matrix_norm <= thresholds[1]:
        return _compute_T2(A)
    elif matrix_norm <= thresholds[2]:
        return _compute_T4(A)
    elif matrix_norm <= thresholds[3]:
        return _compute_T8(A)

    # For larger norms use scaling and squaring with T8
    s = mx.maximum(
        mx.zeros_like(matrix_norm),
        mx.ceil(mx.log2(matrix_norm / thresholds[3]))
    )
    s = s.astype(mx.int32)
    A_scaled = A / mx.expand_dims(mx.expand_dims(2.0**s, -1), -1)
    
    # Compute exponential of scaled matrix
    X = _compute_T8(A_scaled)
    
    # Square back up
    max_s = int(mx.max(s).item())
    for _ in range(max_s):
        X = mx.where(s > 0, X @ X, X)
        s = s - 1
        
    return X

@N8python
Copy link
Author

Full implementation w/ optimization and custom vjp for smol kernels, all, again, yoinked from pytorch: https://gist.github.com/N8python/b3e24a4f88efa52bdd81a8762b7a7238.

For two 1024x1024 matrices initialized with randn(0, 0.1), the provided matrix_exp implementation diverges from that of pytorch by a maximum absolute different of 0.000975.

@N8python
Copy link
Author

For two 1024x1024 matrices initialized with randn(0, 1) - intentionally made to have diverging eigenvalues - the average difference, as a percent of the maxmimum element in the torch computation, is 0.03124% in absolute terms - unnormalized, it is roughly 1311244288.0 due to diverging eigenvalues.

@N8python
Copy link
Author

@angeloskath
Copy link
Member

This is very nicely done :-). I especially love the gist with the custom function...

I think it is not quite ready to be made into an op yet, but it definitely raises some nice issues for us to solve. The main problem I see is the implicit graph evaluation in a couple of places (matrix norm conditionals and looping).

I think this may be a good example for an if and while that do not cause a graph evaluation.

@N8python
Copy link
Author

Agreed. But how would I go about doing that - given that the if on the norm is neccessary to choose the correct polynomial, and the scale factor - which also depends on the norm.

Or are you saying this needs to be implemented:
I think this may be a good example for an if and while that do not cause a graph evaluation.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants