Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make using JAX (or any accelerator) an Instanced Python class (and toggle) #509

Open
wants to merge 14 commits into
base: master
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
First attempted implementation of the API for toggling JAX. Needs lot…
…s of work still, including how to delay JIT

Rebased from master
Lnaden committed Jun 16, 2023
commit 5b121c959a180abb1f0e1caf0e604bb98f1a231c
12 changes: 12 additions & 0 deletions pymbar/mbar.py
Original file line number Diff line number Diff line change
@@ -96,6 +96,7 @@ def __init__(
n_bootstraps=0,
bootstrap_solver_protocol=None,
rseed=None,
accelerator="numpy"
):
"""Initialize multistate Bennett acceptance ratio (MBAR) on a set of simulation data.

@@ -186,6 +187,13 @@ def __init__(
We usually just do steps of adaptive sampling without. "robust" would be the backup.
Default: dict(method="adaptive", options=dict(min_sc_iter=0)),

accelerator: str, optional, default="jax"
Set the accelerator method to try. Attempts to use the named accelerator for the solvers, and then
stores the output accelerator after trying to set. Not case-sensitive. "numpy" is no-accelerators,
and will work fine.
(Valid options: jax, numpy)


Notes
-----
The reduced potential energy ``u_kn[k,n] = u_k(x_{ln})``, where the reduced potential energy ``u_l(x)`` is
@@ -225,6 +233,10 @@ def __init__(

"""

# Set the accelerator methods for the solvers
mbar_solvers.set_accelerator(accelerator)
self.accelerator = mbar_solvers.accelerator

# Store local copies of necessary data.
# N_k[k] is the number of samples from state k, some of which might be zero.
self.N_k = np.array(N_k, dtype=np.int64)
199 changes: 126 additions & 73 deletions pymbar/mbar_solvers.py
Original file line number Diff line number Diff line change
@@ -11,15 +11,91 @@
logger = logging.getLogger(__name__)

use_jit = False
force_no_jax = False # Temporary until we can make a proper setting to enable/disable by choice
try:
#### JAX related imports
if force_no_jax:
# Capture user-disabled JAX instead "JAX not found"
raise ImportError("Jax disabled by force_no_jax in mbar_solvers.py")
accelerator = "numpy"


# Import the methods functionally
# This is admittedly non-standard, but solves the following use case:
# * Has JAX
# * Wants to use PyMBAR
# * Does NOT want JAX to be set to 64-bit mode
# Also solves the future use case of different accelerator,
# but want to selectively use them
def init_numpy():
"""Set the imports for the basic numpy methods"""
# Fallback/default solver methods
# NOTE: ALL ACCELERATORS MUST SHADOW THIS NAMESPACE EXACTLY
global exp, sum, newaxis, diag, dot, s_, npad, lstsq, scipy_optimize, logsumexp
global jit, precondition_jit
global accelerator, use_jit
from numpy import exp, sum, newaxis, diag, dot, s_
from numpy import pad as npad
from numpy.linalg import lstsq
import scipy.optimize as scipy_optimize # pylint: disable=reimported
from scipy.special import logsumexp

# No jit, so make a passthrough decorator
def jit(fn):
return fn

# Precondition if you need to do something different
def precondition_jit(fn):
return jit(fn)

use_jit = False
accelerator = "numpy"
logger.info("JAX was either not detected or disabled, using standard NumPy and SciPy")


def init_jax():
"""Set the imports for the JAX accelerated methods"""
# NOTE: ALL ACCELERATORS MUST SHADOW THIS NAMESPACE EXACTLY
global exp, sum, newaxis, diag, dot, s_, npad, lstsq, scipy_optimize, logsumexp
global jit, precondition_jit
global accelerator, use_jit
global config
try:
from jax.config import config

from jax.numpy import exp, sum, newaxis, diag, dot, s_
from jax.numpy import pad as npad
from jax.numpy.linalg import lstsq
import jax.scipy.optimize as scipy_optimize
from jax.scipy.special import logsumexp

from jax import jit
def precondition_jit(jitable_fn):
"""
Attempt to set JAX precision if present. This does nothing if JAX is not present

Parameters
----------
jitable_fn: function
A function which can be jit'd
"""

@wraps(
jitable_fn
) # Helper to ensure the decorated function still registers for docs and inspection
def staggered_jit(*args, **kwargs):
# This will only trigger if JAX is set
if use_jit and not config.x64_enabled:
# Warn that JAX 64-bit will being turned on
logger.warning(
"\n"
"******* JAX 64-bit mode is now on! *******\n"
"* JAX is now set to 64-bit mode! *\n"
"* This MAY cause problems with other *\n"
"* uses of JAX in the same code. *\n"
"******************************************\n"
)
config.update("jax_enable_x64", True)
jited_fn = jit(jitable_fn)
return jited_fn(*args, **kwargs)

return staggered_jit

# Throw warning only if the whole of JAX is found
if not config.x64_enabled:
# Warn that we're going to be setting 64 bit jax
logger.warning(
@@ -36,15 +112,9 @@
"******************************************\n"
)

from jax.numpy import exp, sum, newaxis, diag, dot, s_
from jax.numpy import pad as npad
from jax.numpy.linalg import lstsq
import jax.scipy.optimize as optimize_maybe_jax
from jax.scipy.special import logsumexp

from jax import jit as jit_or_passthrough

use_jit = True
accelerator = "jax"
logger.info("JAX detected. Using JAX acceleration.")
except ImportError:
# Catch no JAX and throw a warning
logger.warning(
@@ -58,31 +128,46 @@
" conda install pymbar \n"
"*********************************"
)
raise # Continue with the raised Import Error
# Fall back to NumPy import
init_numpy()

except ImportError:
# No JAX found, overlap imports
# These imports MUST align exactly
from numpy import exp, sum, newaxis, diag, dot, s_
from numpy import pad as npad
from numpy.linalg import lstsq
import scipy.optimize as optimize_maybe_jax # pylint: disable=reimported
from scipy.special import logsumexp
# Accelerator map for the set method below
ACCELERATOR_MAP = {
"numpy": init_numpy,
"jax": init_jax
}

# Try to set the initial/default accelerator
init_jax()

# No jit, so make a passthrough decorator
def jit_or_passthrough(fn):
return fn

# Helper function for toggling the solver method
def set_accelerator(accelerator_name: str):
"""
Set the accelerator in the namespace for this module
"""
global accelerator # We want to modify the current accelerator
# Saving it to new tag does not change since we're saving the immutable string object
accel = accelerator_name.lower()
if accel not in ACCELERATOR_MAP:
raise ValueError(f"No accelerator implementation for {accel}, please use one of the following:\n" +
"".join((f"* {a}\n" for a in ACCELERATOR_MAP.keys())) +
f"(case-insentive)"
)
logger.info(f"Attempting to change accelerator to {accel}...")
old_accelerator = accelerator
ACCELERATOR_MAP[accelerator_name.lower()]()
new_accelerator = accelerator
if new_accelerator == old_accelerator:
logger.warning(f"Attempted to change accelerator from {old_accelerator} to {accel},"
f" but something went wrong. Please check the log outputs above.")
return
logger.info(f"Successfully changed to accelerator {accel}!")

# Note on "pylint: disable=invalid-unary-operand-type"
# Known issue with astroid<2.12 and numpy array returns, but 2.12 doesn't fix it due to returns being jax.
# Can be mostly ignored

if use_jit is False:
logger.info("JAX was either not detected or disabled, using standard NumPy and SciPy")
else:
logger.info("JAX detected. Using JAX acceleration.")

# Below are the recommended default protocols (ordered sequence of minimization algorithms / NLE solvers) for solving
# the MBAR equations.
# Note: we use tuples instead of lists to avoid accidental mutability.
@@ -126,38 +211,6 @@ def jit_or_passthrough(fn):
scipy_root_options = ["hybr", "lm"] # only use root options with the hessian included


def jit_or_pass_after_bitsize(jitable_fn):
"""
Attempt to set JAX precision if present. This does nothing if JAX is not present

Parameters
----------
jitable_fn: function
A function which can be jit'd
"""

@wraps(
jitable_fn
) # Helper to ensure the decorated function still registers for docs and inspection
def staggered_jit(*args, **kwargs):
# This will only trigger if JAX is set
if use_jit and not config.x64_enabled:
# Warn that JAX 64-bit will being turned on
logger.warning(
"\n"
"******* JAX 64-bit mode is now on! *******\n"
"* JAX is now set to 64-bit mode! *\n"
"* This MAY cause problems with other *\n"
"* uses of JAX in the same code. *\n"
"******************************************\n"
)
config.update("jax_enable_x64", True)
jited_fn = jit_or_passthrough(jitable_fn)
return jited_fn(*args, **kwargs)

return staggered_jit


def validate_inputs(u_kn, N_k, f_k):
"""Check types and return inputs for MBAR calculations.

@@ -215,7 +268,7 @@ def self_consistent_update(u_kn, N_k, f_k, states_with_samples=None):
return jax_self_consistent_update(u_kn, N_k, f_k, states_with_samples=states_with_samples)


@jit_or_pass_after_bitsize
@precondition_jit
def _jit_self_consistent_update(u_kn, N_k, f_k):
"""JAX version of self_consistent update. For parameters, see self_consistent_update.
N_k must be float (should be cast at a higher level)
@@ -268,7 +321,7 @@ def mbar_gradient(u_kn, N_k, f_k):
return jax_mbar_gradient(u_kn, N_k, f_k)


@jit_or_pass_after_bitsize
@precondition_jit
def jax_mbar_gradient(u_kn, N_k, f_k):
"""JAX version of MBAR gradient function. See documentation of mbar_gradient.
N_k must be float (should be cast at a higher level)
@@ -311,7 +364,7 @@ def mbar_objective(u_kn, N_k, f_k):
return jax_mbar_objective(u_kn, N_k, f_k)


@jit_or_pass_after_bitsize
@precondition_jit
def jax_mbar_objective(u_kn, N_k, f_k):
"""JAX version of mbar_objective.
For parameters, mbar_objective_and_Gradient
@@ -325,7 +378,7 @@ def jax_mbar_objective(u_kn, N_k, f_k):
return obj


@jit_or_pass_after_bitsize
@precondition_jit
def jax_mbar_objective_and_gradient(u_kn, N_k, f_k):
"""JAX version of mbar_objective_and_gradient.
For parameters, mbar_objective_and_Gradient
@@ -379,7 +432,7 @@ def mbar_objective_and_gradient(u_kn, N_k, f_k):
return jax_mbar_objective_and_gradient(u_kn, N_k, f_k)


@jit_or_pass_after_bitsize
@precondition_jit
def jax_mbar_hessian(u_kn, N_k, f_k):
"""JAX version of mbar_hessian.
For parameters, see mbar_hessian
@@ -423,7 +476,7 @@ def mbar_hessian(u_kn, N_k, f_k):
return jax_mbar_hessian(u_kn, N_k, f_k)


@jit_or_pass_after_bitsize
@precondition_jit
def jax_mbar_log_W_nk(u_kn, N_k, f_k):
"""JAX version of mbar_log_W_nk.
For parameters, see mbar_log_W_nk
@@ -460,7 +513,7 @@ def mbar_log_W_nk(u_kn, N_k, f_k):
return jax_mbar_log_W_nk(u_kn, N_k, f_k)


@jit_or_pass_after_bitsize
@precondition_jit
def jax_mbar_W_nk(u_kn, N_k, f_k):
"""JAX version of mbar_W_nk.
For parameters, see mbar_W_nk
@@ -654,7 +707,7 @@ def adaptive(u_kn, N_k, f_k, tol=1.0e-8, options=None):
return results


@jit_or_pass_after_bitsize
@precondition_jit
def jax_core_adaptive(u_kn, N_k, f_k, gamma):
"""JAX version of adaptive inner loop.
N_k must be float (should be cast at a higher level)
@@ -681,7 +734,7 @@ def jax_core_adaptive(u_kn, N_k, f_k, gamma):
return f_sci, g_sci, gnorm_sci, f_nr, g_nr, gnorm_nr


@jit_or_pass_after_bitsize
@precondition_jit
def jax_precondition_u_kn(u_kn, N_k, f_k):
"""JAX version of precondition_u_kn
for parameters, see precondition_u_kn
@@ -808,7 +861,7 @@ def solve_mbar_once(
fpad = lambda x: npad(x, (1, 0))
obj = lambda x: mbar_objective(u_kn_nonzero, N_k_nonzero, fpad(x))
# objective function to be minimized (for derivative free methods, mostly jit)
jax_results = optimize_maybe_jax.minimize(
jax_results = scipy_optimize.minimize(
obj,
f_k_nonzero[1:],
method=method,