diff --git a/.github/workflows/CI.yaml b/.github/workflows/CI.yaml index 0e377a21..62410f9f 100644 --- a/.github/workflows/CI.yaml +++ b/.github/workflows/CI.yaml @@ -62,7 +62,7 @@ jobs: - name: Run tests (pytest) shell: bash -l {0} run: | - pytest -v --cov=$PACKAGE --cov-report=xml --color=yes --doctest-modules $PACKAGE/ + pytest -v --cov=$PACKAGE --cov-report=xml --color=yes --doctest-modules --doctest-ignore-import-errors $PACKAGE/ - name: Run examples shell: bash -l {0} diff --git a/README.md b/README.md index 2005c383..fa5b4217 100644 --- a/README.md +++ b/README.md @@ -91,6 +91,13 @@ PyMBAR needs 64-bit floats to provide reliable answers. JAX by default uses PyMBAR will turn on JAX's 64-bit mode, which may cause issues with some separate uses of JAX in the same code as PyMBAR, such as existing Neural Network (NN) Models for machine learning. +If you would like JAX in 32-bit mode, and PyMBAR in the same script, instance your MBAR with the `accelerator=numpy` +option, e.g. +```python +mbar = MBAR(..., accelerator="numpy") +``` +replacing `...` with your other options. + Authors ------- * Kyle A. Beauchamp diff --git a/devtools/conda-envs/test_env_jax.yaml b/devtools/conda-envs/test_env_jax.yaml index d466eeae..c2dd795e 100644 --- a/devtools/conda-envs/test_env_jax.yaml +++ b/devtools/conda-envs/test_env_jax.yaml @@ -22,4 +22,5 @@ dependencies: - xlrd # Docs - numpydoc + - sphinx <7 - sphinxcontrib-bibtex diff --git a/pymbar/mbar.py b/pymbar/mbar.py index 893a46a9..6c705d75 100644 --- a/pymbar/mbar.py +++ b/pymbar/mbar.py @@ -96,6 +96,7 @@ def __init__( n_bootstraps=0, bootstrap_solver_protocol=None, rseed=None, + accelerator=None, ): """Initialize multistate Bennett acceptance ratio (MBAR) on a set of simulation data. @@ -186,6 +187,13 @@ def __init__( We usually just do steps of adaptive sampling without. "robust" would be the backup. Default: dict(method="adaptive", options=dict(min_sc_iter=0)), + accelerator: str, optional, default=None + Set the accelerator library. Attempts to use the named accelerator for the solvers, and then + stores the output accelerator after trying to set. Not case-sensitive. "numpy" is no-accelerators, + and will work fine. Default accelerator is JAX if nothing specified and JAX installed, else NumPy + (Valid options: jax, numpy) + + Notes ----- The reduced potential energy ``u_kn[k,n] = u_k(x_{ln})``, where the reduced potential energy ``u_l(x)`` is @@ -225,6 +233,9 @@ def __init__( """ + # Set the accelerator methods for the solvers + self.solver = mbar_solvers.get_accelerator(accelerator) + # Store local copies of necessary data. # N_k[k] is the number of samples from state k, some of which might be zero. self.N_k = np.array(N_k, dtype=np.int64) @@ -407,7 +418,7 @@ def __init__( else: np.random.seed(rseed) - self.f_k = mbar_solvers.solve_mbar_for_all_states( + self.f_k = self.solver.solve_mbar_for_all_states( self.u_kn, self.N_k, self.f_k, self.states_with_samples, solver_protocol ) @@ -431,7 +442,7 @@ def __init__( # If we initialized with BAR, then BAR, starting from the provided initial_f_k as well. if initialize == "BAR": f_k_init = self._initialize_with_bar(self.u_kn[:, rints], f_k_init=self.f_k) - self.f_k_boots[b, :] = mbar_solvers.solve_mbar_for_all_states( + self.f_k_boots[b, :] = self.solver.solve_mbar_for_all_states( self.u_kn[:, rints], self.N_k, f_k_init, @@ -449,7 +460,7 @@ def __init__( # bootstrapped weight matrices not generated here, but when expectations are needed # otherwise, it's too much memory to keep - self.Log_W_nk = mbar_solvers.mbar_log_W_nk(self.u_kn, self.N_k, self.f_k) + self.Log_W_nk = self.solver.mbar_log_W_nk(self.u_kn, self.N_k, self.f_k) # Print final dimensionless free energies. if self.verbose: @@ -904,7 +915,7 @@ def compute_expectations_inner( f_k[0:K] = self.f_k_boots[n - 1, :] ri = self.bootstrap_rints[n - 1] u_kn = self.u_kn[:, ri] - Log_W_nk[:, 0:K] = mbar_solvers.mbar_log_W_nk(u_kn, self.N_k, f_k[0:K]) + Log_W_nk[:, 0:K] = self.solver.mbar_log_W_nk(u_kn, self.N_k, f_k[0:K]) # Pre-calculate the log denominator: Eqns 13, 14 in MBAR paper states_with_samples = self.N_k > 0 diff --git a/pymbar/mbar_solvers.py b/pymbar/mbar_solvers.py deleted file mode 100644 index 6c36e654..00000000 --- a/pymbar/mbar_solvers.py +++ /dev/null @@ -1,1004 +0,0 @@ -import logging -import warnings -from functools import wraps - -import numpy as np - -# Optimize imported here and below as the jax-optimized one is jax or passthrough, but this is required regardless -import scipy.optimize -from pymbar.utils import ensure_type, check_w_normalized, ParameterError - -logger = logging.getLogger(__name__) - -use_jit = False -force_no_jax = False # Temporary until we can make a proper setting to enable/disable by choice -try: - #### JAX related imports - if force_no_jax: - # Capture user-disabled JAX instead "JAX not found" - raise ImportError("Jax disabled by force_no_jax in mbar_solvers.py") - try: - from jax.config import config - - if not config.x64_enabled: - # Warn that we're going to be setting 64 bit jax - logger.warning( - "\n" - "****** PyMBAR will use 64-bit JAX! *******\n" - "* JAX is currently set to 32-bit bitsize *\n" - "* which is its default. *\n" - "* *\n" - "* PyMBAR requires 64-bit mode and WILL *\n" - "* enable JAX's 64-bit mode when called. *\n" - "* *\n" - "* This MAY cause problems with other *\n" - "* Uses of JAX in the same code. *\n" - "******************************************\n" - ) - - from jax.numpy import exp, sum, newaxis, diag, dot, s_ - from jax.numpy import pad as npad - from jax.numpy.linalg import lstsq - import jax.scipy.optimize as optimize_maybe_jax - from jax.scipy.special import logsumexp - - from jax import jit as jit_or_passthrough - - use_jit = True - except ImportError: - # Catch no JAX and throw a warning - logger.warning( - "\n" - "********* JAX NOT FOUND *********\n" - " PyMBAR can run faster with JAX \n" - " But will work fine without it \n" - "Either install with pip or conda:\n" - " pip install pybar[jax] \n" - " OR \n" - " conda install pymbar \n" - "*********************************" - ) - raise # Continue with the raised Import Error - -except ImportError: - # No JAX found, overlap imports - # These imports MUST align exactly - from numpy import exp, sum, newaxis, diag, dot, s_ - from numpy import pad as npad - from numpy.linalg import lstsq - import scipy.optimize as optimize_maybe_jax # pylint: disable=reimported - from scipy.special import logsumexp - - # No jit, so make a passthrough decorator - def jit_or_passthrough(fn): - return fn - - -# Note on "pylint: disable=invalid-unary-operand-type" -# Known issue with astroid<2.12 and numpy array returns, but 2.12 doesn't fix it due to returns being jax. -# Can be mostly ignored - -if use_jit is False: - logger.info("JAX was either not detected or disabled, using standard NumPy and SciPy") -else: - logger.info("JAX detected. Using JAX acceleration.") - -# Below are the recommended default protocols (ordered sequence of minimization algorithms / NLE solvers) for solving -# the MBAR equations. -# Note: we use tuples instead of lists to avoid accidental mutability. -JAX_SOLVER_PROTOCOL = ( - dict(method="BFGS", continuation=True), - dict(method="adaptive", options=dict(min_sc_iter=0)), -) - -DEFAULT_SOLVER_PROTOCOL = ( - dict(method="hybr", continuation=True), - dict(method="adaptive", options=dict(min_sc_iter=0)), -) - -ROBUST_SOLVER_PROTOCOL = ( - dict(method="adaptive", options=dict(maxiter=1000)), - dict(method="L-BFGS-B", options=dict(maxiter=1000)), -) - -BOOTSTRAP_SOLVER_PROTOCOL = (dict(method="adaptive", options=dict(min_sc_iter=0)),) - -# Allows all of the gradient based methods, but not the non-gradient methods ["Nelder-Mead", "Powell", "COBYLA"]", -scipy_minimize_options = [ - "L-BFGS-B", - "dogleg", - "CG", - "BFGS", - "Newton-CG", - "TNC", - "trust-ncg", - "trust-krylov", - "trust-exact", - "SLSQP", -] -scipy_nohess_options = [ - "L-BFGS-B", - "BFGS", - "CG", - "TNC", - "SLSQP", -] # don't pass a hessian to these to avoid warnings to these. -scipy_root_options = ["hybr", "lm"] # only use root options with the hessian included - - -def jit_or_pass_after_bitsize(jitable_fn): - """ - Attempt to set JAX precision if present. This does nothing if JAX is not present - - Parameters - ---------- - jitable_fn: function - A function which can be jit'd - """ - - @wraps( - jitable_fn - ) # Helper to ensure the decorated function still registers for docs and inspection - def staggered_jit(*args, **kwargs): - # This will only trigger if JAX is set - if use_jit and not config.x64_enabled: - # Warn that JAX 64-bit will being turned on - logger.warning( - "\n" - "******* JAX 64-bit mode is now on! *******\n" - "* JAX is now set to 64-bit mode! *\n" - "* This MAY cause problems with other *\n" - "* uses of JAX in the same code. *\n" - "******************************************\n" - ) - config.update("jax_enable_x64", True) - jited_fn = jit_or_passthrough(jitable_fn) - return jited_fn(*args, **kwargs) - - return staggered_jit - - -def validate_inputs(u_kn, N_k, f_k): - """Check types and return inputs for MBAR calculations. - - Parameters - ---------- - u_kn or q_kn : np.ndarray, shape=(n_states, n_samples), dtype='float' - The reduced potential energies or unnormalized probabilities - N_k : np.ndarray, shape=(n_states), dtype='int' - The number of samples in each state - f_k : np.ndarray, shape=(n_states), dtype='float' - The reduced free energies of each state - - Returns - ------- - u_kn or q_kn : np.ndarray, shape=(n_states, n_samples), dtype='float' - The reduced potential energies or unnormalized probabilities - N_k : np.ndarray, shape=(n_states), dtype='float' - The number of samples in each state. Converted to float because this cast is required when log is calculated. - f_k : np.ndarray, shape=(n_states), dtype='float' - The reduced free energies of each state - """ - n_states, n_samples = u_kn.shape - - u_kn = ensure_type(u_kn, "float", 2, "u_kn or Q_kn", shape=(n_states, n_samples)) - N_k = ensure_type( - N_k, "float", 1, "N_k", shape=(n_states,), warn_on_cast=False - ) # Autocast to float because will be eventually used in float calculations. - f_k = ensure_type(f_k, "float", 1, "f_k", shape=(n_states,)) - - return u_kn, N_k, f_k - - -def self_consistent_update(u_kn, N_k, f_k, states_with_samples=None): - """Return an improved guess for the dimensionless free energies - - Parameters - ---------- - u_kn : np.ndarray, shape=(n_states, n_samples), dtype='float' - The reduced potential energies, i.e. -log unnormalized probabilities - N_k : np.ndarray, shape=(n_states), dtype='int' - The number of samples in each state - f_k : np.ndarray, shape=(n_states), dtype='float' - The reduced free energies of each state - - Returns - ------- - f_k : np.ndarray, shape=(n_states), dtype='float' - Updated estimate of f_k - - Notes - ----- - Equation C3 in MBAR JCP paper. - """ - - return jax_self_consistent_update(u_kn, N_k, f_k, states_with_samples=states_with_samples) - - -@jit_or_pass_after_bitsize -def _jit_self_consistent_update(u_kn, N_k, f_k): - """JAX version of self_consistent update. For parameters, see self_consistent_update. - N_k must be float (should be cast at a higher level) - - """ - # Asteroid - log_denominator_n = logsumexp(f_k - u_kn.T, b=N_k, axis=1) - # All states can contribute to the numerator term. Check transpose - return -1.0 * logsumexp( - -log_denominator_n - u_kn, axis=1 - ) # pylint: disable=invalid-unary-operand-type - - -def jax_self_consistent_update(u_kn, N_k, f_k, states_with_samples=None): - """JAX version of self_consistent update. For parameters, see self_consistent_update. - N_k must be float (should be cast at a higher level) - - """ - # Only the states with samples can contribute to the denominator term. - # Precondition before feeding the op to the JIT'd function - # In theory, this can be computed with jax.lax.cond, but trying to reuse code for non-jax paths - states_with_samples = s_[:] if states_with_samples is None else states_with_samples - # Feed to the JIT'd function. Can't pass slice types, so slice here - return _jit_self_consistent_update( - u_kn[states_with_samples], N_k[states_with_samples], f_k[states_with_samples] - ) - - -def mbar_gradient(u_kn, N_k, f_k): - """Gradient of MBAR objective function. - - Parameters - ---------- - u_kn : np.ndarray, shape=(n_states, n_samples), dtype='float' - The reduced potential energies, i.e. -log unnormalized probabilities - N_k : np.ndarray, shape=(n_states), dtype='int' - The number of samples in each state - f_k : np.ndarray, shape=(n_states), dtype='float' - The reduced free energies of each state - - Returns - ------- - grad : np.ndarray, dtype=float, shape=(n_states) - Gradient of mbar_objective - - Notes - ----- - This is equation C6 in the JCP MBAR paper. - """ - return jax_mbar_gradient(u_kn, N_k, f_k) - - -@jit_or_pass_after_bitsize -def jax_mbar_gradient(u_kn, N_k, f_k): - """JAX version of MBAR gradient function. See documentation of mbar_gradient. - N_k must be float (should be cast at a higher level) - """ - - log_denominator_n = logsumexp(f_k - u_kn.T, b=N_k, axis=1) - log_numerator_k = logsumexp(-log_denominator_n - u_kn, axis=1) - return -1 * N_k * (1.0 - exp(f_k + log_numerator_k)) - - -def mbar_objective(u_kn, N_k, f_k): - """Calculates objective function for MBAR. - - Parameters - ---------- - u_kn : np.ndarray, shape=(n_states, n_samples), dtype='float' - The reduced potential energies, i.e. -log unnormalized probabilities - N_k : np.ndarray, shape=(n_states), dtype='int' - The number of samples in each state - f_k : np.ndarray, shape=(n_states), dtype='float' - The reduced free energies of each state - - - Returns - ------- - obj : float - Objective function - - Notes - ----- - This objective function is essentially a doubly-summed partition function and is - quite sensitive to precision loss from both overflow and underflow. For optimal - results, u_kn can be preconditioned by subtracting out a `n` dependent - vector. - - More optimal precision, the objective function uses math.fsum for the - outermost sum and logsumexp for the inner sum. - """ - - return jax_mbar_objective(u_kn, N_k, f_k) - - -@jit_or_pass_after_bitsize -def jax_mbar_objective(u_kn, N_k, f_k): - """JAX version of mbar_objective. - For parameters, mbar_objective_and_Gradient - N_k must be float (should be cast at a higher level) - - """ - - log_denominator_n = logsumexp(f_k - u_kn.T, b=N_k, axis=1) - obj = sum(log_denominator_n) - dot(N_k, f_k) - - return obj - - -@jit_or_pass_after_bitsize -def jax_mbar_objective_and_gradient(u_kn, N_k, f_k): - """JAX version of mbar_objective_and_gradient. - For parameters, mbar_objective_and_Gradient - N_k must be float (should be cast at a higher level) - - """ - - log_denominator_n = logsumexp(f_k - u_kn.T, b=N_k, axis=1) - log_numerator_k = logsumexp(-log_denominator_n - u_kn, axis=1) - grad = -1 * N_k * (1.0 - exp(f_k + log_numerator_k)) - - obj = sum(log_denominator_n) - dot(N_k, f_k) - - return obj, grad - - -def mbar_objective_and_gradient(u_kn, N_k, f_k): - """Calculates both objective function and gradient for MBAR. - - Parameters - ---------- - u_kn : np.ndarray, shape=(n_states, n_samples), dtype='float' - The reduced potential energies, i.e. -log unnormalized probabilities - N_k : np.ndarray, shape=(n_states), dtype='int' - The number of samples in each state - f_k : np.ndarray, shape=(n_states), dtype='float' - The reduced free energies of each state - - - Returns - ------- - obj : float - Objective function - grad : np.ndarray, dtype=float, shape=(n_states) - Gradient of objective function - - Notes - ----- - This objective function is essentially a doubly-summed partition function and is - quite sensitive to precision loss from both overflow and underflow. For optimal - results, u_kn can be preconditioned by subtracting out a `n` dependent - vector. - - More optimal precision, the objective function uses math.fsum for the - outermost sum and logsumexp for the inner sum. - - The gradient is equation C6 in the JCP MBAR paper; the objective - function is its integral. - """ - - return jax_mbar_objective_and_gradient(u_kn, N_k, f_k) - - -@jit_or_pass_after_bitsize -def jax_mbar_hessian(u_kn, N_k, f_k): - """JAX version of mbar_hessian. - For parameters, see mbar_hessian - N_k must be float (should be cast at a higher level) - - """ - - log_denominator_n = logsumexp(f_k - u_kn.T, b=N_k, axis=1) - logW = f_k - u_kn.T - log_denominator_n[:, newaxis] - W = exp(logW) - - H = dot(W.T, W) - H *= N_k - H *= N_k[:, newaxis] - H -= diag(W.sum(0) * N_k) - return -1.0 * H - - -def mbar_hessian(u_kn, N_k, f_k): - """Hessian of MBAR objective function. - - Parameters - ---------- - u_kn : np.ndarray, shape=(n_states, n_samples), dtype='float' - The reduced potential energies, i.e. -log unnormalized probabilities - N_k : np.ndarray, shape=(n_states), dtype='int' - The number of samples in each state - f_k : np.ndarray, shape=(n_states), dtype='float' - The reduced free energies of each state - - Returns - ------- - H : np.ndarray, dtype=float, shape=(n_states, n_states) - Hessian of mbar objective function. - - Notes - ----- - Equation (C9) in JCP MBAR paper. - """ - - return jax_mbar_hessian(u_kn, N_k, f_k) - - -@jit_or_pass_after_bitsize -def jax_mbar_log_W_nk(u_kn, N_k, f_k): - """JAX version of mbar_log_W_nk. - For parameters, see mbar_log_W_nk - N_k must be float (should be cast at a higher level) - - """ - - log_denominator_n = logsumexp(f_k - u_kn.T, b=N_k, axis=1) - logW = f_k - u_kn.T - log_denominator_n[:, newaxis] - return logW - - -def mbar_log_W_nk(u_kn, N_k, f_k): - """Calculate the log weight matrix. - - Parameters - ---------- - u_kn : np.ndarray, shape=(n_states, n_samples), dtype='float' - The reduced potential energies, i.e. -log unnormalized probabilities - N_k : np.ndarray, shape=(n_states), dtype='int' - The number of samples in each state - f_k : np.ndarray, shape=(n_states), dtype='float' - The reduced free energies of each state - - Returns - ------- - logW_nk : np.ndarray, dtype='float', shape=(n_samples, n_states) - The normalized log weights. - - Notes - ----- - Equation (9) in JCP MBAR paper. - """ - return jax_mbar_log_W_nk(u_kn, N_k, f_k) - - -@jit_or_pass_after_bitsize -def jax_mbar_W_nk(u_kn, N_k, f_k): - """JAX version of mbar_W_nk. - For parameters, see mbar_W_nk - N_k must be float (should be cast at a higher level) - - """ - return exp(jax_mbar_log_W_nk(u_kn, N_k, f_k)) - - -def mbar_W_nk(u_kn, N_k, f_k): - """Calculate the weight matrix. - - Parameters - ---------- - u_kn : np.ndarray, shape=(n_states, n_samples), dtype='float' - The reduced potential energies, i.e. -log unnormalized probabilities - N_k : np.ndarray, shape=(n_states), dtype='int' - The number of samples in each state - f_k : np.ndarray, shape=(n_states), dtype='float' - The reduced free energies of each state - - Returns - ------- - W_nk : np.ndarray, dtype='float', shape=(n_samples, n_states) - The normalized weights. - - Notes - ----- - Equation (9) in JCP MBAR paper. - """ - return jax_mbar_W_nk(u_kn, N_k, f_k) - - -def adaptive(u_kn, N_k, f_k, tol=1.0e-8, options=None): - """ - Determine dimensionless free energies by a combination of Newton-Raphson iteration and self-consistent iteration. - Picks whichever method gives the lowest gradient. - Is slower than NR since it calculates the log norms twice each iteration. - - OPTIONAL ARGUMENTS - tol (float between 0 and 1) - relative tolerance for convergence (default 1.0e-12) - - options : dictionary of options - gamma (float between 0 and 1) - incrementor for NR iterations (default 1.0). Usually not changed now, since adaptively switch. - maxiter (int) - maximum number of Newton-Raphson iterations (default 10000: either NR converges or doesn't, pretty quickly) - verbose (boolean) - verbosity level for debug output - - NOTES - - This method determines the dimensionless free energies by - minimizing a convex function whose solution is the desired - estimator. The original idea came from the construction of a - likelihood function that independently reproduced the work of - Geyer (see [1] and Section 6 of [2]). This can alternatively be - formulated as a root-finding algorithm for the Z-estimator. More - details of this procedure will follow in a subsequent paper. Only - those states with nonzero counts are include in the estimation - procedure. - - REFERENCES - See Appendix C.2 of [1]. - - """ - # put the defaults here in case we get passed an 'options' dictionary that is only partial - options.setdefault("verbose", False) - options.setdefault("maxiter", 10000) - options.setdefault("print_warning", False) - options.setdefault("gamma", 1.0) - options.setdefault("min_sc_iter", 2) # set a minimum number of self-consistent iterations - - gamma = options["gamma"] - - doneIterating = False - if options["verbose"] == True: - logger.info( - "Determining dimensionless free energies by Newton-Raphson / self-consistent iteration." - ) - - if tol < 4.0 * np.finfo(float).eps: - logger.info("Tolerance may be too close to machine precision to converge.") - - success = False # fail unless solution is found. - # keep track of Newton-Raphson and self-consistent iterations - nr_iter = 0 - sci_iter = 0 - - f_sci = np.zeros(len(f_k), dtype=np.float64) - f_nr = np.zeros(len(f_k), dtype=np.float64) - - # Perform Newton-Raphson iterations (with sci computed on the way) - - # usually calculated at the end of the loop and saved, but we need - # to calculate the first time. - g = mbar_gradient(u_kn, N_k, f_k) # Objective function gradient. - - maxiter = options["maxiter"] - min_sc_iter = options["min_sc_iter"] - warn = "Did not converge." - for iteration in range(0, maxiter): - if use_jit: - (f_sci, g_sci, gnorm_sci, f_nr, g_nr, gnorm_nr) = jax_core_adaptive( - u_kn, N_k, f_k, options["gamma"] - ) - else: - H = mbar_hessian(u_kn, N_k, f_k) # Objective function hessian - Hinvg = np.linalg.lstsq(H, g, rcond=-1)[0] - Hinvg -= Hinvg[0] - f_nr = f_k - gamma * Hinvg - - # self-consistent iteration gradient norm and saved log sums. - f_sci = self_consistent_update(u_kn, N_k, f_k) - f_sci = f_sci - f_sci[0] # zero out the minimum - g_sci = mbar_gradient(u_kn, N_k, f_sci) - gnorm_sci = dot(g_sci, g_sci) - - # newton raphson gradient norm and saved log sums. - g_nr = mbar_gradient(u_kn, N_k, f_nr) - gnorm_nr = dot(g_nr, g_nr) - - # we could save the gradient, for the next round, but it's not too expensive to - # compute since we are doing the Hessian anyway. - - if options["verbose"]: - logger.info( - "self consistent iteration gradient norm is %10.5g, Newton-Raphson gradient norm is %10.5g" - % (np.sqrt(gnorm_sci), np.sqrt(gnorm_nr)) - ) - # decide which directon to go depending on size of gradient norm - f_old = f_k - - if gnorm_sci < gnorm_nr or sci_iter < min_sc_iter: - f_k = f_sci - g = g_sci - sci_iter += 1 - if options["verbose"]: - if sci_iter < min_sc_iter: - logger.info( - f"Choosing self-consistent iteration on iteration {iteration:d} because min_sci_iter={min_sc_iter:d}" - ) - else: - logger.info( - f"Choosing self-consistent iteration for lower gradient on iteration {iteration:d}" - ) - else: - f_k = f_nr - g = g_nr - nr_iter += 1 - if options["verbose"]: - logger.info(f"Newton-Raphson used on iteration {iteration:}") - - div = np.abs(f_k[1:]) # what we will divide by to get relative difference - zeroed = np.abs(f_k[1:]) < np.min( - [10**-8, tol] - ) # check which values are near enough to zero, hard coded max for now. - div[zeroed] = 1.0 # for these values, use absolute values. - max_delta = np.max(np.abs(f_k[1:] - f_old[1:]) / div) - max_diff = np.max(np.abs(f_sci[1:] - f_nr[1:]) / div) - # add this just to make sure they are not too different. - # if we start with bad states, the f_k - f_k_old might be far off. - if np.isnan(max_delta) or ((max_delta < tol) and max_diff < np.sqrt(tol)): - doneIterating = True - success = True - warn = "Convergence achieved by change in f with respect to previous guess." - break - - if doneIterating: - if options["verbose"]: - logger.info(f"Converged to tolerance of {max_delta:e} in {iteration+1:d} iterations.") - logger.info( - f"Of {iteration+1:d} iterations, {nr_iter:d} were Newton-Raphson iterations and {sci_iter:d} were self-consistent iterations" - ) - if np.all(f_k == 0.0): - logger.info("WARNING: All f_k appear to be zero.") - else: - logger.warning("WARNING: Did not converge to within specified tolerance.") - - if maxiter <= 0: - logger.warning( - f"No iterations ran be cause maximum_iterations was <= 0 ({maxiter:s})!" - ) - else: - logger.warning( - f"max_delta = {max_delta:e}, tol = {tol:e}, maximum_iterations = {maxiter:d}, iterations completed = {iteration:d}" - ) - - results = dict() - results["success"] = success - results["message"] = warn - results["x"] = f_k - - return results - - -@jit_or_pass_after_bitsize -def jax_core_adaptive(u_kn, N_k, f_k, gamma): - """JAX version of adaptive inner loop. - N_k must be float (should be cast at a higher level) - - """ - - # Perform Newton-Raphson iterations (with sci computed on the way) - g = mbar_gradient(u_kn, N_k, f_k) # Objective function gradient - H = mbar_hessian(u_kn, N_k, f_k) # Objective function hessian - Hinvg = lstsq(H, g, rcond=-1)[0] - Hinvg -= Hinvg[0] - f_nr = f_k - gamma * Hinvg - - # self-consistent iteration gradient norm and saved log sums. - f_sci = self_consistent_update(u_kn, N_k, f_k) - f_sci = f_sci - f_sci[0] # zero out the minimum - g_sci = mbar_gradient(u_kn, N_k, f_sci) - gnorm_sci = dot(g_sci, g_sci) - - # newton raphson gradient norm and saved log sums. - g_nr = mbar_gradient(u_kn, N_k, f_nr) - gnorm_nr = dot(g_nr, g_nr) - - return f_sci, g_sci, gnorm_sci, f_nr, g_nr, gnorm_nr - - -@jit_or_pass_after_bitsize -def jax_precondition_u_kn(u_kn, N_k, f_k): - """JAX version of precondition_u_kn - for parameters, see precondition_u_kn - N_k must be float (should be cast at a higher level) - - """ - - u_kn = u_kn - u_kn.min(0) - u_kn += (logsumexp(f_k - u_kn.T, b=N_k, axis=1)) - dot(N_k, f_k) / N_k.sum() - return u_kn - - -def precondition_u_kn(u_kn, N_k, f_k): - """Subtract a sample-dependent constant from u_kn to improve precision - - Parameters - ---------- - u_kn : np.ndarray, shape=(n_states, n_samples), dtype='float' - The reduced potential energies, i.e. -log unnormalized probabilities - N_k : np.ndarray, shape=(n_states), dtype='int' - The number of samples in each state - f_k : np.ndarray, shape=(n_states), dtype='float' - The reduced free energies of each state - - Returns - ------- - u_kn : np.ndarray, shape=(n_states, n_samples), dtype='float' - The reduced potential energies, i.e. -log unnormalized probabilities - - Notes - ----- - Returns u_kn - x_n, where x_n is based on the current estimate of f_k. - Upon subtraction of x_n, the MBAR objective function changes by an - additive constant, but its derivatives remain unchanged. We choose - x_n such that the current objective function value is zero, which - should give maximum precision in the objective function. - """ - return jax_precondition_u_kn(u_kn, N_k, f_k) - - -def solve_mbar_once( - u_kn_nonzero, - N_k_nonzero, - f_k_nonzero, - method="adaptive", - tol=1e-12, - continuation=None, - options=None, -): - """Solve MBAR self-consistent equations using some form of equation solver. - - Parameters - ---------- - u_kn_nonzero : np.ndarray, shape=(n_states, n_samples), dtype='float' - The reduced potential energies, i.e. -log unnormalized probabilities - for the nonempty states - N_k_nonzero : np.ndarray, shape=(n_states), dtype='int' - The number of samples in each state for the nonempty states - f_k_nonzero : np.ndarray, shape=(n_states), dtype='float' - The reduced free energies for the nonempty states - method : str, optional, default="hybr" - The optimization routine to use. This can be any of the methods - available via scipy.optimize.minimize() or scipy.optimize.root(). - tol : float, optional, default=1E-14 - The convergance tolerance for minimize() or root() - verbose: bool - Whether to print information about the solution method. - options: dict, optional, default=None - Optional dictionary of algorithm-specific parameters. See - scipy.optimize.root or scipy.optimize.minimize for details. - - Returns - ------- - f_k : np.ndarray - The converged reduced free energies. - results : dict - Dictionary containing entire results of optimization routine, may - be useful when debugging convergence. - - Notes - ----- - This function requires that N_k_nonzero > 0--that is, you should have - already dropped all the states for which you have no samples. - Internally, this function works in a reduced coordinate system defined - by subtracting off the first component of f_k and fixing that component - to be zero. - - For fast but precise convergence, we recommend calling this function - multiple times to polish the result. `solve_mbar()` facilitates this. - """ - - # we only validate at the outside of the call - u_kn_nonzero, N_k_nonzeo, f_k_nonzero = validate_inputs(u_kn_nonzero, N_k_nonzero, f_k_nonzero) - f_k_nonzero = f_k_nonzero - f_k_nonzero[0] # Work with reduced dimensions with f_k[0] := 0 - N_k_nonzero = 1.0 * N_k_nonzero # convert to float for acceleration. - u_kn_nonzero = precondition_u_kn(u_kn_nonzero, N_k_nonzero, f_k_nonzero) - - pad = lambda x: np.pad( - x, (1, 0), mode="constant" - ) # Helper function inserts zero before first element - unpad_second_arg = lambda obj, grad: ( - obj, - grad[1:], - ) # Helper function drops first element of gradient - - # Create objective functions / nonlinear equations to send to scipy.optimize, fixing f_0 = 0 - grad = lambda x: mbar_gradient(u_kn_nonzero, N_k_nonzero, pad(x))[ - 1: - ] # Objective function gradient - - grad_and_obj = lambda x: unpad_second_arg( - *mbar_objective_and_gradient(u_kn_nonzero, N_k_nonzero, pad(x)) - ) # Objective function gradient and objective function - - de_jax_grad_and_obj = lambda x: ( - *map(np.array, grad_and_obj(x)), # (...,) Casts to tuple instead of object - ) # Force any jax-based array output to normal numpy for scipy.optimize.minimize. np.asarray does not work. - - hess = lambda x: mbar_hessian(u_kn_nonzero, N_k_nonzero, pad(x))[1:][ - :, 1: - ] # Hessian of objective function - with warnings.catch_warnings(record=True) as w: - if use_jit and method == "BFGS": - fpad = lambda x: npad(x, (1, 0)) - obj = lambda x: mbar_objective(u_kn_nonzero, N_k_nonzero, fpad(x)) - # objective function to be minimized (for derivative free methods, mostly jit) - jax_results = optimize_maybe_jax.minimize( - obj, - f_k_nonzero[1:], - method=method, - tol=tol, - options=dict(maxiter=options["maxiter"]), - ) - results = dict() # there should be a way to copy this. - results["x"] = jax_results[0] - f_k_nonzero = pad(results["x"]) - results["success"] = jax_results[1] - elif method in scipy_minimize_options: - if method in scipy_nohess_options: - hess = None # To suppress warning from passing a hessian function. - results = scipy.optimize.minimize( - de_jax_grad_and_obj, - f_k_nonzero[1:], - jac=True, - hess=hess, - method=method, - tol=tol, - options=options, - ) - f_k_nonzero = pad(results["x"]) - elif method == "adaptive": - results = adaptive(u_kn_nonzero, N_k_nonzero, f_k_nonzero, tol=tol, options=options) - f_k_nonzero = results["x"] - elif method in scipy_root_options: - # find the root in the gradient. - results = scipy.optimize.root( - grad, f_k_nonzero[1:], jac=hess, method=method, tol=tol, options=options - ) - f_k_nonzero = pad(results["x"]) - else: - raise ParameterError(f"Method {method} for solution of free energies not recognized") - - # If there were runtime warnings, show the messages - if len(w) > 0: - can_ignore = True - for warn_msg in w: - if "Unknown solver options" in str(warn_msg.message): - continue - warnings.showwarning( - warn_msg.message, - warn_msg.category, - warn_msg.filename, - warn_msg.lineno, - warn_msg.file, - "", - ) - can_ignore = False # If any warning is not just unknown options, can not skip check - if not can_ignore: - # Ensure MBAR solved correctly - w_nk_check = mbar_W_nk(u_kn_nonzero, N_k_nonzero, f_k_nonzero) - check_w_normalized(w_nk_check, N_k_nonzero) - logger.warning( - "MBAR weights converged within tolerance, despite the SciPy Warnings. Please validate your results." - ) - - return f_k_nonzero, results - - -def solve_mbar(u_kn_nonzero, N_k_nonzero, f_k_nonzero, solver_protocol=None): - """Solve MBAR self-consistent equations using some sequence of equation solvers. - - Parameters - ---------- - u_kn_nonzero : np.ndarray, shape=(n_states, n_samples), dtype='float' - The reduced potential energies, i.e. -log unnormalized probabilities - for the nonempty states - N_k_nonzero : np.ndarray, shape=(n_states), dtype='int' - The number of samples in each state for the nonempty states - f_k_nonzero : np.ndarray, shape=(n_states), dtype='float' - The reduced free energies for the nonempty states - solver_protocol : tuple(dict()), optional, default=None - Optional list of dictionaries of steps in solver protocol. - If None, a default protocol will be used. - - Returns - ------- - f_k : np.ndarray - The converged reduced free energies. - all_results : list(dict()) - List of results from each step of solver_protocol. Each element in - list contains the results dictionary from solve_mbar_once() - for the corresponding step. - - Notes - ----- - This function requires that N_k_nonzero > 0--that is, you should have - already dropped all the states for which you have no samples. - Internally, this function works in a reduced coordinate system defined - by subtracting off the first component of f_k and fixing that component - to be zero. - - This function calls `solve_mbar_once()` multiple times to achieve - converged results. Generally, a single call to solve_mbar_once() - will not give fully converged answers because of limited numerical precision. - Each call to `solve_mbar_once()` re-conditions the nonlinear - equations using the current guess. - """ - - if solver_protocol is None: - solver_protocol = DEFAULT_SOLVER_PROTOCOL - - all_fks = [] - all_gnorms = [] - all_results = [] - - for solver in solver_protocol: - f_k_nonzero_result, results = solve_mbar_once( - u_kn_nonzero, N_k_nonzero, f_k_nonzero, **solver - ) - all_fks.append(f_k_nonzero_result) - all_gnorms.append( - np.linalg.norm(mbar_gradient(u_kn_nonzero, N_k_nonzero, f_k_nonzero_result)) - ) - all_results.append(results) - - if results["success"]: - success = True - best_gnorm = all_gnorms[-1] - logger.info(f"Reached a solution to within tolerance with {solver['method']}") - break - else: - logger.warning( - f"Failed to reach a solution to within tolerance with {solver['method']}: trying next method" - ) - logger.info(f"Ending gnorm of method {solver['method']} = {all_gnorms[-1]:e}") - if solver["continuation"]: - f_k_nonzero = f_k_nonzero_result - logger.info("Will continue with results from previous method") - - if results["success"]: - logger.info("Solution found within tolerance!") - else: - i_best_gnorm = np.argmin(all_gnorms) - logger.warning("No solution found to within tolerance.") - best_method = solver_protocol[i_best_gnorm]["method"] - best_gnorm = all_gnorms[i_best_gnorm] - logger.warning( - f"The solution with the smallest gradient {best_gnorm:e} norm is {best_method}" - ) - f_k_nonzero_result = all_fks[i_best_gnorm] - logger.warning( - "Please exercise caution with this solution and consider alternative methods or a different tolerance." - ) - - logger.info(f"Final gradient norm: {best_gnorm:.3g}") - - return f_k_nonzero_result, all_results - - -def solve_mbar_for_all_states(u_kn, N_k, f_k, states_with_samples, solver_protocol): - """Solve for free energies of states with samples, then calculate for - empty states. - - Parameters - ---------- - u_kn : np.ndarray, shape=(n_states, n_samples), dtype='float' - The reduced potential energies, i.e. -log unnormalized probabilities - N_k : np.ndarray, shape=(n_states), dtype='int' - The number of samples in each state - f_k : np.ndarray, shape=(n_states), dtype='float' - The reduced free energies of each state - solver_protocol : tuple(dict()), optional, default=None - Sequence of dictionaries of steps in solver protocol for final - stage of refinement. - - Returns - ------- - f_k : np.ndarray, shape=(n_states), dtype='float' - The free energies of states - """ - - if len(states_with_samples) == 1: - f_k_nonzero = np.array([0.0]) - else: - f_k_nonzero, all_results = solve_mbar( - u_kn[states_with_samples], - N_k[states_with_samples], - f_k[states_with_samples], - solver_protocol=solver_protocol, - ) - - f_k[states_with_samples] = np.array(f_k_nonzero) - - # Update all free energies because those from states with zero samples are not correctly computed by solvers. - f_k = self_consistent_update(u_kn, N_k, f_k) - # This is necessary because state 0 might have had zero samples, - # but we still want that state to be the reference with free energy 0. - f_k -= f_k[0] - - return f_k diff --git a/pymbar/mbar_solvers/__init__.py b/pymbar/mbar_solvers/__init__.py new file mode 100644 index 00000000..861ee14f --- /dev/null +++ b/pymbar/mbar_solvers/__init__.py @@ -0,0 +1,117 @@ +############################################################################## +# pymbar: A Python Library for MBAR +# +# Copyright 2017-2022 University of Colorado Boulder +# Copyright 2010-2017 Memorial Sloan-Kettering Cancer Center +# Portions of this software are Copyright (c) 2010-2016 University of Virginia +# Portions of this software are Copyright (c) 2006-2007 The Regents of the University of California. All Rights Reserved. +# Portions of this software are Copyright (c) 2007-2008 Stanford University and Columbia University. +# +# Authors: Michael Shirts, John Chodera +# Contributors: Kyle Beauchamp, Levi Naden +# +# pymbar is free software: you can redistribute it and/or modify +# it under the terms of the MIT License as +# +# This library is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# MIT License for more details. +# +# You should have received a copy of the MIT License along with pymbar. +############################################################################## + +""" +########### +pymbar.mbar_solvers +########### + +A module implementing the solvers array operations for the MBAR solvers with various code bases for acceleration. + +All methods have the same calls and returns, independent of their underlying codes for solution. + +Please reference the following if you use this code in your research: + +[1] Shirts MR and Chodera JD. Statistically optimal analysis of samples from multiple equilibrium states. +J. Chem. Phys. 129:124105, 2008. http://dx.doi.org/10.1063/1.2978177 + +""" + +import logging +from typing import Union + +from .mbar_solver import ( + validate_inputs, + JAX_SOLVER_PROTOCOL, + DEFAULT_SOLVER_PROTOCOL, + ROBUST_SOLVER_PROTOCOL, + BOOTSTRAP_SOLVER_PROTOCOL, +) +from .mbar_solver import MBARSolver +from .numpy_solver import MBARSolverNumpy + +logger = logging.getLogger(__name__) + +INSTANCED_ACCELERATORS = {} # Cache the accelerators to avoid re-jit on instancing +ACCELERATOR_MAP = {"numpy": MBARSolverNumpy} +default_solver = "numpy" # Set fallback solver + +try: + from .jax_solver import MBARSolverJAX + + ACCELERATOR_MAP["jax"] = MBARSolverJAX + default_solver = "jax" + logger.info("JAX detected. Using JAX acceleration by default.") +except ImportError: + logger.warning( + "\n" + "********* JAX NOT FOUND *********\n" + " PyMBAR can run faster with JAX \n" + " But will work fine without it \n" + "Either install with pip or conda:\n" + " pip install pybar[jax] \n" + " OR \n" + " conda install pymbar \n" + "*********************************" + ) + + +# Helper function for toggling the solver method +def get_accelerator(accelerator_name: Union[str, None]) -> MBARSolver: + """ + get the accelerator in the namespace for this module + """ + if accelerator_name is None: + accelerator_name = default_solver + # Saving accelerator to new tag does not change since we're saving the immutable string object + accel = accelerator_name.lower() + if accel in INSTANCED_ACCELERATORS: + return INSTANCED_ACCELERATORS[accel] + if accel not in ACCELERATOR_MAP: + raise ValueError( + f"Accelerator {accel} is not implemented or did not load correctly. Please use one of the following:\n" + + "".join((f"* {a}\n" for a in ACCELERATOR_MAP.keys())) + + f"(case-insentive)\n" + + f"If you expected {accel} to load, please check the logs above for details." + ) + logger.info(f"Instancing accelerator {accel}...") + INSTANCED_ACCELERATORS[accel] = ACCELERATOR_MAP[accel]() + return INSTANCED_ACCELERATORS[accel] + + +# Imports done, handle initialization +module_solver = get_accelerator(default_solver) + +# Establish API methods for 4.x consistency +self_consistent_update = module_solver.self_consistent_update +mbar_gradient = module_solver.mbar_gradient +mbar_objective = module_solver.mbar_objective +mbar_objective_and_gradient = module_solver.mbar_objective_and_gradient +mbar_hessian = module_solver.mbar_hessian +mbar_log_W_nk = module_solver.mbar_log_W_nk +mbar_W_nk = module_solver.mbar_W_nk +adaptive = module_solver.adaptive +precondition_u_kn = module_solver.precondition_u_kn +solve_mbar_once = module_solver.solve_mbar_once +solve_mbar = module_solver.solve_mbar +solve_mbar_for_all_states = module_solver.solve_mbar_for_all_states diff --git a/pymbar/mbar_solvers/jax_solver.py b/pymbar/mbar_solvers/jax_solver.py new file mode 100644 index 00000000..b9853b02 --- /dev/null +++ b/pymbar/mbar_solvers/jax_solver.py @@ -0,0 +1,160 @@ +"""Set the imports for the JAX accelerated methods""" + +import logging +from functools import wraps + +try: + from jax.config import config + + import jax.numpy as jnp + from jax.numpy.linalg import lstsq + import jax.scipy.optimize + from jax.scipy.special import logsumexp + + from jax.tree_util import register_pytree_node_class + + from jax import jit +except ImportError: + raise ImportError("JAX not found!") + +from pymbar.mbar_solvers.mbar_solver import MBARSolver + +logger = logging.getLogger(__name__) + + +@register_pytree_node_class +class MBARSolverJAX(MBARSolver): + """ + Solver methods for MBAR. Implementations use specific libraries/accelerators to solve the code paths. + + Default solver is the numpy solution + """ + + def __init__(self): + # Throw warning only if the whole of JAX is found + if not config.x64_enabled: + # Warn that we're going to be setting 64 bit jax + logger.warning( + "\n" + "****** PyMBAR will use 64-bit JAX! *******\n" + "* JAX is currently set to 32-bit bitsize *\n" + "* which is its default. *\n" + "* *\n" + "* PyMBAR requires 64-bit mode and WILL *\n" + "* enable JAX's 64-bit mode when called. *\n" + "* *\n" + "* This MAY cause problems with other *\n" + "* Uses of JAX in the same code. *\n" + "* *\n" + "* If you want 32-bit JAX and PyMBAR *\n" + "* please set: *\n" + "* accelerator=numpy *\n" + "* when you instance the MBAR object *\n" + "******************************************\n" + ) + super().__init__() + + @property + def exp(self): + return jnp.exp + + @property + def sum(self): + return jnp.sum + + @property + def diag(self): + return jnp.diag + + @property + def newaxis(self): + return jnp.newaxis + + @property + def dot(self): + return jnp.dot + + @property + def s_(self): + return jnp.s_ + + @property + def pad(self): + return jnp.pad + + @property + def lstsq(self): + return lstsq + + @property + def optimize(self): + return jax.scipy.optimize + + @property + def logsumexp(self): + return logsumexp + + @property + def jit(self): + return jit + + @property + def real_jit(self): + return True + + def _precondition_jit(self, jitable_fn): + @wraps( + jitable_fn + ) # Helper to ensure the decorated function still registers for docs and inspection + def staggered_jit(*args, **kwargs): + # This will only trigger if JAX is set + if not config.x64_enabled: + # Warn that JAX 64-bit will being turned on + logger.warning( + "\n" + "******* JAX 64-bit mode is now on! *******\n" + "* JAX is now set to 64-bit mode! *\n" + "* This MAY cause problems with other *\n" + "* uses of JAX in the same code. *\n" + "******************************************\n" + ) + config.update("jax_enable_x64", True) + jited_fn = self.jit(jitable_fn) + return jited_fn(*args, **kwargs) + + return staggered_jit + + def _adaptive_core(self, u_kn, N_k, f_k, g, gamma): + """JAX version of adaptive inner loop. + N_k must be float (should be cast at a higher level) + + """ + # Perform Newton-Raphson iterations (with sci computed on the way) + g = self.mbar_gradient(u_kn, N_k, f_k) # Objective function gradient + H = self.mbar_hessian(u_kn, N_k, f_k) # Objective function hessian + Hinvg = lstsq(H, g, rcond=-1)[0] + Hinvg -= Hinvg[0] + f_nr = f_k - gamma * Hinvg + + # self-consistent iteration gradient norm and saved log sums. + f_sci = self.self_consistent_update(u_kn, N_k, f_k) + f_sci = f_sci - f_sci[0] # zero out the minimum + g_sci = self.mbar_gradient(u_kn, N_k, f_sci) + gnorm_sci = self.dot(g_sci, g_sci) + + # newton raphson gradient norm and saved log sums. + g_nr = self.mbar_gradient(u_kn, N_k, f_nr) + gnorm_nr = self.dot(g_nr, g_nr) + + return f_sci, g_sci, gnorm_sci, f_nr, g_nr, gnorm_nr + + def tree_flatten(self): + """Required method for PyTree registration with JAX""" + children = () + aux_data = {} + return children, aux_data + + @classmethod + def tree_unflatten(cls, aux_data, children): + """Required method for PyTree registration with JAX""" + return cls() diff --git a/pymbar/mbar_solvers/mbar_solver.py b/pymbar/mbar_solvers/mbar_solver.py new file mode 100644 index 00000000..5a760fb1 --- /dev/null +++ b/pymbar/mbar_solvers/mbar_solver.py @@ -0,0 +1,815 @@ +import logging +import warnings + +import numpy as np + +# Optimize imported here and below as the jax-optimized one is jax or passthrough, but this is required regardless +import scipy.optimize +from pymbar.utils import ensure_type, check_w_normalized, ParameterError +from pymbar.mbar_solvers.solver_api import MBARSolverAPI, MBARSolverAcceleratorMethods + +logger = logging.getLogger(__name__) + +# Note on "pylint: disable=invalid-unary-operand-type" +# Known issue with astroid<2.12 and numpy array returns, but 2.12 doesn't fix it due to returns being jax. +# Can be mostly ignored + +# Below are the recommended default protocols (ordered sequence of minimization algorithms / NLE solvers) for solving +# the MBAR equations. +# Note: we use tuples instead of lists to avoid accidental mutability. +JAX_SOLVER_PROTOCOL = ( + dict(method="BFGS", continuation=True), + dict(method="adaptive", options=dict(min_sc_iter=0)), +) + +DEFAULT_SOLVER_PROTOCOL = ( + dict(method="hybr", continuation=True), + dict(method="adaptive", options=dict(min_sc_iter=0)), +) + +ROBUST_SOLVER_PROTOCOL = ( + dict(method="adaptive", options=dict(maxiter=1000)), + dict(method="L-BFGS-B", options=dict(maxiter=1000)), +) + +BOOTSTRAP_SOLVER_PROTOCOL = (dict(method="adaptive", options=dict(min_sc_iter=0)),) + +# Allows all of the gradient based methods, but not the non-gradient methods ["Nelder-Mead", "Powell", "COBYLA"]", +scipy_minimize_options = [ + "L-BFGS-B", + "dogleg", + "CG", + "BFGS", + "Newton-CG", + "TNC", + "trust-ncg", + "trust-krylov", + "trust-exact", + "SLSQP", +] +scipy_nohess_options = [ + "L-BFGS-B", + "BFGS", + "CG", + "TNC", + "SLSQP", +] # don't pass a hessian to these to avoid warnings to these. +scipy_root_options = ["hybr", "lm"] # only use root options with the hessian included + + +class MBARSolver(MBARSolverAPI, MBARSolverAcceleratorMethods): + """ + Solver methods for MBAR. Implementations use specific libraries/accelerators to solve the code paths. + + Default solver is the numpy solution + """ + + JITABLE_IMPLEMENTATION_METHODS = ("jit_self_consistent_update",) + + def __init__(self): + """ + JIT the methods on instancing to avoid doing this at runtime. + + In theory, you want all JIT methods to be static (at least in JAX) because otherwise you can suffer a massive + performance loss if you try to JIT a bound method of a class (i.e. anything with a reference to 'self'). + See: https://github.com/google/jax/discussions/16020#discussioncomment-5915882 + + For this use case however, we do not appear to suffer a performance loss of note due to the simplicity + of this class, and the use of the exact methods we need in all the @property decorators and the PyTree + recommendation of JAX itself. + See: https://jax.readthedocs.io/en/latest/faq.html#strategy-3-making-customclass-a-pytree + + Testing the timing of test_protocols test using static-generated methods as a relative baseline: + The test is 99% as fast on average with PyTree registration. + The test is 95% as fast on average without the PyTree registration. + See commit hash d65e882 to view the static-generated methods for this code. + + Marking self as static with a partial doesn't work because we're wrapping the function already once, and we + still need the functions/properties found in the class to make this an extensible class for other accelerators + in the future. + See https://jax.readthedocs.io/en/latest/faq.html#strategy-2-marking-self-as-static + + """ + # Apply the precondition to each of the JITABLE_METHODS + for method in ( + self.JITABLE_ACCELERATOR_METHODS + + self.JITABLE_API_METHODS + + self.JITABLE_IMPLEMENTATION_METHODS + ): + # Jit + setattr(self, method, self._precondition_jit(getattr(self, method))) + + def self_consistent_update(self, u_kn, N_k, f_k, states_with_samples=None): + """Return an improved guess for the dimensionless free energies + + Parameters + ---------- + u_kn : np.ndarray, shape=(n_states, n_samples), dtype='float' + The reduced potential energies, i.e. -log unnormalized probabilities + N_k : np.ndarray, shape=(n_states), dtype='int' + The number of samples in each state + f_k : np.ndarray, shape=(n_states), dtype='float' + The reduced free energies of each state + + Returns + ------- + f_k : np.ndarray, shape=(n_states), dtype='float' + Updated estimate of f_k + + Notes + ----- + Equation C3 in MBAR JCP paper. + """ + + # Only the states with samples can contribute to the denominator term. + # Precondition before feeding the op to the JIT'd function + # In theory, this can be computed with jax.lax.cond, but trying to reuse code for non-jax paths + states_with_samples = self.s_[:] if states_with_samples is None else states_with_samples + return self.jit_self_consistent_update( + u_kn[states_with_samples], N_k[states_with_samples], f_k[states_with_samples] + ) + + def jit_self_consistent_update(self, u_kn, N_k, f_k): + """JAX version of self_consistent update. For parameters, see self_consistent_update. + N_k must be float (should be cast at a higher level) + + """ + # Asteroid + log_denominator_n = self.logsumexp(f_k - u_kn.T, b=N_k, axis=1) + # All states can contribute to the numerator term. Check transpose + return -1.0 * self.logsumexp( + -log_denominator_n - u_kn, axis=1 + ) # pylint: disable=invalid-unary-operand-type + + def mbar_gradient(self, u_kn, N_k, f_k): + """Gradient of MBAR objective function. + + Parameters + ---------- + u_kn : np.ndarray, shape=(n_states, n_samples), dtype='float' + The reduced potential energies, i.e. -log unnormalized probabilities + N_k : np.ndarray, shape=(n_states), dtype='int' + The number of samples in each state + f_k : np.ndarray, shape=(n_states), dtype='float' + The reduced free energies of each state + + Returns + ------- + grad : np.ndarray, dtype=float, shape=(n_states) + Gradient of mbar_objective + + Notes + ----- + This is equation C6 in the JCP MBAR paper. + """ + # N_k must be float (should be cast at a higher level) + log_denominator_n = self.logsumexp(f_k - u_kn.T, b=N_k, axis=1) + log_numerator_k = self.logsumexp(-log_denominator_n - u_kn, axis=1) + return -1 * N_k * (1.0 - self.exp(f_k + log_numerator_k)) + + def mbar_objective(self, u_kn, N_k, f_k): + """Calculates objective function for MBAR. + + Parameters + ---------- + u_kn : np.ndarray, shape=(n_states, n_samples), dtype='float' + The reduced potential energies, i.e. -log unnormalized probabilities + N_k : np.ndarray, shape=(n_states), dtype='int' + The number of samples in each state + f_k : np.ndarray, shape=(n_states), dtype='float' + The reduced free energies of each state + + + Returns + ------- + obj : float + Objective function + + Notes + ----- + This objective function is essentially a doubly-summed partition function and is + quite sensitive to precision loss from both overflow and underflow. For optimal + results, u_kn can be preconditioned by subtracting out a `n` dependent + vector. + + More optimal precision, the objective function uses math.fsum for the + outermost sum and logsumexp for the inner sum. + """ + log_denominator_n = self.logsumexp(f_k - u_kn.T, b=N_k, axis=1) + obj = self.sum(log_denominator_n) - self.dot(N_k, f_k) + + return obj + + def mbar_objective_and_gradient(self, u_kn, N_k, f_k): + """Calculates both objective function and gradient for MBAR. + + Parameters + ---------- + u_kn : np.ndarray, shape=(n_states, n_samples), dtype='float' + The reduced potential energies, i.e. -log unnormalized probabilities + N_k : np.ndarray, shape=(n_states), dtype='int' + The number of samples in each state + f_k : np.ndarray, shape=(n_states), dtype='float' + The reduced free energies of each state + + + Returns + ------- + obj : float + Objective function + grad : np.ndarray, dtype=float, shape=(n_states) + Gradient of objective function + + Notes + ----- + This objective function is essentially a doubly-summed partition function and is + quite sensitive to precision loss from both overflow and underflow. For optimal + results, u_kn can be preconditioned by subtracting out a `n` dependent + vector. + + More optimal precision, the objective function uses math.fsum for the + outermost sum and logsumexp for the inner sum. + + The gradient is equation C6 in the JCP MBAR paper; the objective + function is its integral. + """ + log_denominator_n = self.logsumexp(f_k - u_kn.T, b=N_k, axis=1) + log_numerator_k = self.logsumexp(-log_denominator_n - u_kn, axis=1) + grad = -1 * N_k * (1.0 - self.exp(f_k + log_numerator_k)) + + obj = self.sum(log_denominator_n) - self.dot(N_k, f_k) + + return obj, grad + + def mbar_hessian(self, u_kn, N_k, f_k) -> np.ndarray: + """Hessian of Mmbar_hessianBAR objective function. + + Parameters + ---------- + u_kn : np.ndarray, shape=(n_states, n_samples), dtype='float' + The reduced potential energies, i.e. -log unnormalized probabilities + N_k : np.ndarray, shape=(n_states), dtype='int' + The number of samples in each state + f_k : np.ndarray, shape=(n_states), dtype='float' + The reduced free energies of each state + + Returns + ------- + H : np.ndarray, dtype=float, shape=(n_states, n_states) + Hessian of mbar objective function. + + Notes + ----- + Equation (C9) in JCP MBAR paper. + """ + log_denominator_n = self.logsumexp(f_k - u_kn.T, b=N_k, axis=1) + logW = f_k - u_kn.T - log_denominator_n[:, self.newaxis] + W = self.exp(logW) + + H = self.dot(W.T, W) + H *= N_k + H *= N_k[:, self.newaxis] + H -= self.diag(W.sum(0) * N_k) + return -1.0 * H + + def mbar_log_W_nk(self, u_kn, N_k, f_k): + """Calculate the log weight matrix. + + Parameters + ---------- + u_kn : np.ndarray, shape=(n_states, n_samples), dtype='float' + The reduced potential energies, i.e. -log unnormalized probabilities + N_k : np.ndarray, shape=(n_states), dtype='int' + The number of samples in each state + f_k : np.ndarray, shape=(n_states), dtype='float' + The reduced free energies of each state + + Returns + ------- + logW_nk : np.ndarray, dtype='float', shape=(n_samples, n_states) + The normalized log weights. + + Notes + ----- + Equation (9) in JCP MBAR paper. + """ + log_denominator_n = self.logsumexp(f_k - u_kn.T, b=N_k, axis=1) + logW = f_k - u_kn.T - log_denominator_n[:, self.newaxis] + return logW + + def mbar_W_nk(self, u_kn, N_k, f_k): + """Calculate the weight matrix. + + Parameters + ---------- + u_kn : np.ndarray, shape=(n_states, n_samples), dtype='float' + The reduced potential energies, i.e. -log unnormalized probabilities + N_k : np.ndarray, shape=(n_states), dtype='int' + The number of samples in each state + f_k : np.ndarray, shape=(n_states), dtype='float' + The reduced free energies of each state + + Returns + ------- + W_nk : np.ndarray, dtype='float', shape=(n_samples, n_states) + The normalized weights. + + Notes + ----- + Equation (9) in JCP MBAR paper. + """ + return self.exp(self.mbar_log_W_nk(u_kn, N_k, f_k)) + + def precondition_u_kn(self, u_kn, N_k, f_k): + """Subtract a sample-dependent constant from u_kn to improve precision + + Parameters + ---------- + u_kn : np.ndarray, shape=(n_states, n_samples), dtype='float' + The reduced potential energies, i.e. -log unnormalized probabilities + N_k : np.ndarray, shape=(n_states), dtype='int' + The number of samples in each state + f_k : np.ndarray, shape=(n_states), dtype='float' + The reduced free energies of each state + + Returns + ------- + u_kn : np.ndarray, shape=(n_states, n_samples), dtype='float' + The reduced potential energies, i.e. -log unnormalized probabilities + + Notes + ----- + Returns u_kn - x_n, where x_n is based on the current estimate of f_k. + Upon subtraction of x_n, the MBAR objective function changes by an + additive constant, but its derivatives remain unchanged. We choose + x_n such that the current objective function value is zero, which + should give maximum precision in the objective function. + """ + u_kn = u_kn - u_kn.min(0) + u_kn += (self.logsumexp(f_k - u_kn.T, b=N_k, axis=1)) - self.dot(N_k, f_k) / N_k.sum() + return u_kn + + def adaptive(self, u_kn, N_k, f_k, tol=1.0e-8, options=None): + """ + Determine dimensionless free energies by a combination of Newton-Raphson iteration and self-consistent iteration. + Picks whichever method gives the lowest gradient. + Is slower than NR since it calculates the log norms twice each iteration. + + OPTIONAL ARGUMENTS + tol (float between 0 and 1) - relative tolerance for convergence (default 1.0e-12) + + options : dictionary of options + gamma (float between 0 and 1) - incrementor for NR iterations (default 1.0). Usually not changed now, since adaptively switch. + maxiter (int) - maximum number of Newton-Raphson iterations (default 10000: either NR converges or doesn't, pretty quickly) + verbose (boolean) - verbosity level for debug output + + NOTES + + This method determines the dimensionless free energies by + minimizing a convex function whose solution is the desired + estimator. The original idea came from the construction of a + likelihood function that independently reproduced the work of + Geyer (see [1] and Section 6 of [2]). This can alternatively be + formulated as a root-finding algorithm for the Z-estimator. More + details of this procedure will follow in a subsequent paper. Only + those states with nonzero counts are include in the estimation + procedure. + + REFERENCES + See Appendix C.2 of [1]. + + """ + # put the defaults here in case we get passed an 'options' dictionary that is only partial + options.setdefault("verbose", False) + options.setdefault("maxiter", 10000) + options.setdefault("print_warning", False) + options.setdefault("gamma", 1.0) + options.setdefault("min_sc_iter", 2) # set a minimum number of self-consistent iterations + + doneIterating = False + if options["verbose"] == True: + logger.info( + "Determining dimensionless free energies by Newton-Raphson / self-consistent iteration." + ) + + if tol < 4.0 * np.finfo(float).eps: + logger.info("Tolerance may be too close to machine precision to converge.") + + success = False # fail unless solution is found. + # keep track of Newton-Raphson and self-consistent iterations + nr_iter = 0 + sci_iter = 0 + + f_sci = np.zeros(len(f_k), dtype=np.float64) + f_nr = np.zeros(len(f_k), dtype=np.float64) + + # Perform Newton-Raphson iterations (with sci computed on the way) + + # usually calculated at the end of the loop and saved, but we need + # to calculate the first time. + g = self.mbar_gradient(u_kn, N_k, f_k) # Objective function gradient. + + maxiter = options["maxiter"] + min_sc_iter = options["min_sc_iter"] + warn = "Did not converge." + for iteration in range(0, maxiter): + f_sci, g_sci, gnorm_sci, f_nr, g_nr, gnorm_nr = self._adaptive_core( + u_kn, N_k, f_k, g, options["gamma"] + ) + # we could save the gradient, for the next round, but it's not too expensive to + # compute since we are doing the Hessian anyway. + if options["verbose"]: + logger.info( + "self consistent iteration gradient norm is %10.5g, Newton-Raphson gradient norm is %10.5g" + % (np.sqrt(gnorm_sci), np.sqrt(gnorm_nr)) + ) + # decide which direction to go depending on size of gradient norm + f_old = f_k + + if gnorm_sci < gnorm_nr or sci_iter < min_sc_iter: + f_k = f_sci + g = g_sci + sci_iter += 1 + if options["verbose"]: + if sci_iter < min_sc_iter: + logger.info( + f"Choosing self-consistent iteration on iteration {iteration:d} because min_sci_iter={min_sc_iter:d}" + ) + else: + logger.info( + f"Choosing self-consistent iteration for lower gradient on iteration {iteration:d}" + ) + else: + f_k = f_nr + g = g_nr + nr_iter += 1 + if options["verbose"]: + logger.info(f"Newton-Raphson used on iteration {iteration:}") + + div = np.abs(f_k[1:]) # what we will divide by to get relative difference + zeroed = np.abs(f_k[1:]) < np.min( + [10**-8, tol] + ) # check which values are near enough to zero, hard coded max for now. + div[zeroed] = 1.0 # for these values, use absolute values. + max_delta = np.max(np.abs(f_k[1:] - f_old[1:]) / div) + max_diff = np.max(np.abs(f_sci[1:] - f_nr[1:]) / div) + # add this just to make sure they are not too different. + # if we start with bad states, the f_k - f_k_old might be far off. + if np.isnan(max_delta) or ((max_delta < tol) and max_diff < np.sqrt(tol)): + doneIterating = True + success = True + warn = "Convergence achieved by change in f with respect to previous guess." + break + + if doneIterating: + if options["verbose"]: + logger.info( + f"Converged to tolerance of {max_delta:e} in {iteration+1:d} iterations." + ) + logger.info( + f"Of {iteration+1:d} iterations, {nr_iter:d} were Newton-Raphson iterations and {sci_iter:d} were self-consistent iterations" + ) + if np.all(f_k == 0.0): + logger.info("WARNING: All f_k appear to be zero.") + else: + logger.warning("WARNING: Did not converge to within specified tolerance.") + + if maxiter <= 0: + logger.warning( + f"No iterations ran be cause maximum_iterations was <= 0 ({maxiter:s})!" + ) + else: + logger.warning( + f"max_delta = {max_delta:e}, tol = {tol:e}, maximum_iterations = {maxiter:d}, iterations completed = {iteration:d}" + ) + + results = dict() + results["success"] = success + results["message"] = warn + results["x"] = f_k + + return results + + def solve_mbar_once( + self, + u_kn_nonzero, + N_k_nonzero, + f_k_nonzero, + method="adaptive", + tol=1e-12, + continuation=None, + options=None, + ): + """Solve MBAR self-consistent equations using some form of equation solver. + + Parameters + ---------- + u_kn_nonzero : np.ndarray, shape=(n_states, n_samples), dtype='float' + The reduced potential energies, i.e. -log unnormalized probabilities + for the nonempty states + N_k_nonzero : np.ndarray, shape=(n_states), dtype='int' + The number of samples in each state for the nonempty states + f_k_nonzero : np.ndarray, shape=(n_states), dtype='float' + The reduced free energies for the nonempty states + method : str, optional, default="hybr" + The optimization routine to use. This can be any of the methods + available via scipy.optimize.minimize() or scipy.optimize.root(). + tol : float, optional, default=1E-14 + The convergance tolerance for minimize() or root() + verbose: bool + Whether to print information about the solution method. + options: dict, optional, default=None + Optional dictionary of algorithm-specific parameters. See + scipy.optimize.root or scipy.optimize.minimize for details. + + Returns + ------- + f_k : np.ndarray + The converged reduced free energies. + results : dict + Dictionary containing entire results of optimization routine, may + be useful when debugging convergence. + + Notes + ----- + This function requires that N_k_nonzero > 0--that is, you should have + already dropped all the states for which you have no samples. + Internally, this function works in a reduced coordinate system defined + by subtracting off the first component of f_k and fixing that component + to be zero. + + For fast but precise convergence, we recommend calling this function + multiple times to polish the result. `solve_mbar()` facilitates this. + """ + + # we only validate at the outside of the call + u_kn_nonzero, N_k_nonzeo, f_k_nonzero = validate_inputs( + u_kn_nonzero, N_k_nonzero, f_k_nonzero + ) + f_k_nonzero = f_k_nonzero - f_k_nonzero[0] # Work with reduced dimensions with f_k[0] := 0 + N_k_nonzero = 1.0 * N_k_nonzero # convert to float for acceleration. + u_kn_nonzero = self.precondition_u_kn(u_kn_nonzero, N_k_nonzero, f_k_nonzero) + + pad = lambda x: np.pad( + x, (1, 0), mode="constant" + ) # Helper function inserts zero before first element + unpad_second_arg = lambda obj, grad: ( + obj, + grad[1:], + ) # Helper function drops first element of gradient + + # Create objective functions / nonlinear equations to send to scipy.optimize, fixing f_0 = 0 + grad = lambda x: self.mbar_gradient(u_kn_nonzero, N_k_nonzero, pad(x))[ + 1: + ] # Objective function gradient + + grad_and_obj = lambda x: unpad_second_arg( + *self.mbar_objective_and_gradient(u_kn_nonzero, N_k_nonzero, pad(x)) + ) # Objective function gradient and objective function + + de_jax_grad_and_obj = lambda x: ( + *map(np.array, grad_and_obj(x)), # (...,) Casts to tuple instead of object + ) # Force any jax-based array output to normal numpy for scipy.optimize.minimize. np.asarray does not work. + + hess = lambda x: self.mbar_hessian(u_kn_nonzero, N_k_nonzero, pad(x))[1:][ + :, 1: + ] # Hessian of objective function + with warnings.catch_warnings(record=True) as w: + if ( + self.real_jit and method == "BFGS" + ): # Might be a way to fold this in now that accelerators are class-ified + fpad = lambda x: self.pad(x, (1, 0)) + # Make sure to use the static method here + obj = lambda x: self.mbar_objective(u_kn_nonzero, N_k_nonzero, fpad(x)) + # objective function to be minimized (for derivative free methods, mostly jit) + minimize_results = self.optimize.minimize( + obj, + f_k_nonzero[1:], + method=method, + tol=tol, + options=dict(maxiter=options["maxiter"]), + ) + results = dict() # there should be a way to copy this. + results["x"] = minimize_results.x + f_k_nonzero = pad(results["x"]) + results["success"] = minimize_results[1] + elif method in scipy_minimize_options: + if method in scipy_nohess_options: + hess = None # To suppress warning from passing a hessian function. + # This needs to be stock scipy.optimize (at least it won't work for JAX) + results = scipy.optimize.minimize( + de_jax_grad_and_obj, + f_k_nonzero[1:], + jac=True, + hess=hess, + method=method, + tol=tol, + options=options, + ) + f_k_nonzero = pad(results["x"]) + elif method == "adaptive": + results = self.adaptive( + u_kn_nonzero, N_k_nonzero, f_k_nonzero, tol=tol, options=options + ) + f_k_nonzero = results["x"] + elif method in scipy_root_options: + # find the root in the gradient. + results = scipy.optimize.root( + grad, f_k_nonzero[1:], jac=hess, method=method, tol=tol, options=options + ) + f_k_nonzero = pad(results["x"]) + else: + raise ParameterError( + f"Method {method} for solution of free energies not recognized" + ) + + # If there were runtime warnings, show the messages + if len(w) > 0: + can_ignore = True + for warn_msg in w: + if "Unknown solver options" in str(warn_msg.message): + continue + warnings.showwarning( + warn_msg.message, + warn_msg.category, + warn_msg.filename, + warn_msg.lineno, + warn_msg.file, + "", + ) + can_ignore = ( + False # If any warning is not just unknown options, can not skip check + ) + if not can_ignore: + # Ensure MBAR solved correctly + w_nk_check = self.mbar_W_nk(u_kn_nonzero, N_k_nonzero, f_k_nonzero) + check_w_normalized(w_nk_check, N_k_nonzero) + logger.warning( + "MBAR weights converged within tolerance, despite the SciPy Warnings. Please validate your results." + ) + + return f_k_nonzero, results + + def solve_mbar(self, u_kn_nonzero, N_k_nonzero, f_k_nonzero, solver_protocol=None): + """Solve MBAR self-consistent equations using some sequence of equation solvers. + + Parameters + ---------- + u_kn_nonzero : np.ndarray, shape=(n_states, n_samples), dtype='float' + The reduced potential energies, i.e. -log unnormalized probabilities + for the nonempty states + N_k_nonzero : np.ndarray, shape=(n_states), dtype='int' + The number of samples in each state for the nonempty states + f_k_nonzero : np.ndarray, shape=(n_states), dtype='float' + The reduced free energies for the nonempty states + solver_protocol : tuple(dict()), optional, default=None + Optional list of dictionaries of steps in solver protocol. + If None, a default protocol will be used. + + Returns + ------- + f_k : np.ndarray + The converged reduced free energies. + all_results : list(dict()) + List of results from each step of solver_protocol. Each element in + list contains the results dictionary from solve_mbar_once() + for the corresponding step. + + Notes + ----- + This function requires that N_k_nonzero > 0--that is, you should have + already dropped all the states for which you have no samples. + Internally, this function works in a reduced coordinate system defined + by subtracting off the first component of f_k and fixing that component + to be zero. + + This function calls `solve_mbar_once()` multiple times to achieve + converged results. Generally, a single call to solve_mbar_once() + will not give fully converged answers because of limited numerical precision. + Each call to `solve_mbar_once()` re-conditions the nonlinear + equations using the current guess. + """ + + if solver_protocol is None: + solver_protocol = DEFAULT_SOLVER_PROTOCOL + + all_fks = [] + all_gnorms = [] + all_results = [] + + for solver in solver_protocol: + f_k_nonzero_result, results = self.solve_mbar_once( + u_kn_nonzero, N_k_nonzero, f_k_nonzero, **solver + ) + all_fks.append(f_k_nonzero_result) + all_gnorms.append( + np.linalg.norm(self.mbar_gradient(u_kn_nonzero, N_k_nonzero, f_k_nonzero_result)) + ) + all_results.append(results) + + if results["success"]: + success = True + best_gnorm = all_gnorms[-1] + logger.info(f"Reached a solution to within tolerance with {solver['method']}") + break + else: + logger.warning( + f"Failed to reach a solution to within tolerance with {solver['method']}: trying next method" + ) + logger.info(f"Ending gnorm of method {solver['method']} = {all_gnorms[-1]:e}") + if solver["continuation"]: + f_k_nonzero = f_k_nonzero_result + logger.info("Will continue with results from previous method") + + if results["success"]: + logger.info("Solution found within tolerance!") + else: + i_best_gnorm = np.argmin(all_gnorms) + logger.warning("No solution found to within tolerance.") + best_method = solver_protocol[i_best_gnorm]["method"] + best_gnorm = all_gnorms[i_best_gnorm] + logger.warning( + f"The solution with the smallest gradient {best_gnorm:e} norm is {best_method}" + ) + f_k_nonzero_result = all_fks[i_best_gnorm] + logger.warning( + "Please exercise caution with this solution and consider alternative methods or a different tolerance." + ) + + logger.info(f"Final gradient norm: {best_gnorm:.3g}") + + return f_k_nonzero_result, all_results + + def solve_mbar_for_all_states(self, u_kn, N_k, f_k, states_with_samples, solver_protocol): + """Solve for free energies of states with samples, then calculate for + empty states. + + Parameters + ---------- + u_kn : np.ndarray, shape=(n_states, n_samples), dtype='float' + The reduced potential energies, i.e. -log unnormalized probabilities + N_k : np.ndarray, shape=(n_states), dtype='int' + The number of samples in each state + f_k : np.ndarray, shape=(n_states), dtype='float' + The reduced free energies of each state + solver_protocol : tuple(dict()), optional, default=None + Sequence of dictionaries of steps in solver protocol for final + stage of refinement. + + Returns + ------- + f_k : np.ndarray, shape=(n_states), dtype='float' + The free energies of states + """ + + if len(states_with_samples) == 1: + f_k_nonzero = np.array([0.0]) + else: + f_k_nonzero, all_results = self.solve_mbar( + u_kn[states_with_samples], + N_k[states_with_samples], + f_k[states_with_samples], + solver_protocol=solver_protocol, + ) + + f_k[states_with_samples] = np.array(f_k_nonzero) + + # Update all free energies because those from states with zero samples are not correctly computed by solvers. + f_k = self.self_consistent_update(u_kn, N_k, f_k) + # This is necessary because state 0 might have had zero samples, + # but we still want that state to be the reference with free energy 0. + f_k -= f_k[0] + + return f_k + + +def validate_inputs(u_kn, N_k, f_k): + """Check types and return inputs for MBAR calculations. + + Parameters + ---------- + u_kn or q_kn : np.ndarray, shape=(n_states, n_samples), dtype='float' + The reduced potential energies or unnormalized probabilities + N_k : np.ndarray, shape=(n_states), dtype='int' + The number of samples in each state + f_k : np.ndarray, shape=(n_states), dtype='float' + The reduced free energies of each state + + Returns + ------- + u_kn or q_kn : np.ndarray, shape=(n_states, n_samples), dtype='float' + The reduced potential energies or unnormalized probabilities + N_k : np.ndarray, shape=(n_states), dtype='float' + The number of samples in each state. Converted to float because this cast is required when log is calculated. + f_k : np.ndarray, shape=(n_states), dtype='float' + The reduced free energies of each state + """ + n_states, n_samples = u_kn.shape + + u_kn = ensure_type(u_kn, "float", 2, "u_kn or Q_kn", shape=(n_states, n_samples)) + N_k = ensure_type( + N_k, "float", 1, "N_k", shape=(n_states,), warn_on_cast=False + ) # Autocast to float because will be eventually used in float calculations. + f_k = ensure_type(f_k, "float", 1, "f_k", shape=(n_states,)) + + return u_kn, N_k, f_k diff --git a/pymbar/mbar_solvers/numpy_solver.py b/pymbar/mbar_solvers/numpy_solver.py new file mode 100644 index 00000000..bb000240 --- /dev/null +++ b/pymbar/mbar_solvers/numpy_solver.py @@ -0,0 +1,94 @@ +# Import the methods functionally +# This is admittedly non-standard, but solves the following use case: +# * Has JAX +# * Wants to use PyMBAR +# * Does NOT want JAX to be set to 64-bit mode +# Also solves the future use case of different accelerator, +# but want to selectively use them + +# Fallback/default solver methods +# NOTE: ALL ACCELERATORS MUST SHADOW THIS NAMESPACE EXACTLY +import numpy as np +from numpy.linalg import lstsq +import scipy.optimize +from scipy.special import logsumexp + +from pymbar.mbar_solvers.mbar_solver import MBARSolver + + +class MBARSolverNumpy(MBARSolver): + """ + Solver methods for MBAR. Implementations use specific libraries/accelerators to solve the code paths. + + Default solver is the numpy solution + """ + + @property + def exp(self): + return np.exp + + @property + def sum(self): + return np.sum + + @property + def diag(self): + return np.diag + + @property + def newaxis(self): + return np.newaxis + + @property + def dot(self): + return np.dot + + @property + def s_(self): + return np.s_ + + @property + def pad(self): + return np.pad + + @property + def lstsq(self): + return lstsq + + @property + def optimize(self): + return scipy.optimize + + @property + def logsumexp(self): + return logsumexp + + @staticmethod + def _passthrough_jit(fn): + return fn + + @property + def jit(self): + """Passthrough JIT""" + return self._passthrough_jit + + def _adaptive_core(self, u_kn, N_k, f_k, g, gamma): + """ + Core function to execute per iteration of a method. + """ + H = self.mbar_hessian(u_kn, N_k, f_k) # Objective function hessian + Hinvg = np.linalg.lstsq(H, g, rcond=-1)[0] + Hinvg -= Hinvg[0] + f_nr = f_k - gamma * Hinvg + + # self-consistent iteration gradient norm and saved log sums. + f_sci = self.self_consistent_update(u_kn, N_k, f_k) + f_sci = f_sci - f_sci[0] # zero out the minimum + g_sci = self.mbar_gradient(u_kn, N_k, f_sci) + gnorm_sci = self.dot(g_sci, g_sci) + + # newton raphson gradient norm and saved log sums. + g_nr = self.mbar_gradient(u_kn, N_k, f_nr) + gnorm_nr = self.dot(g_nr, g_nr) + + return f_sci, g_sci, gnorm_sci, f_nr, g_nr, gnorm_nr diff --git a/pymbar/mbar_solvers/solver_api.py b/pymbar/mbar_solvers/solver_api.py new file mode 100644 index 00000000..06a5da54 --- /dev/null +++ b/pymbar/mbar_solvers/solver_api.py @@ -0,0 +1,162 @@ +""" +API Definitions of the solver module to be consistent with PyMBAR 4.0 +and for subclassing any solvers for implementation. +""" + +from functools import wraps +from abc import ABC, abstractmethod + + +class MBARSolverAPI(ABC): + """ + API for MBAR solvers + """ + + JITABLE_API_METHODS = ( + "mbar_gradient", + "mbar_objective", + "mbar_objective_and_gradient", + "mbar_hessian", + "mbar_log_W_nk", + "mbar_W_nk", + "precondition_u_kn", + ) + + @abstractmethod + def self_consistent_update(self, u_kn, N_k, f_k, states_with_samples=None): + pass + + @abstractmethod + def mbar_gradient(self, u_kn, N_k, f_k): + pass + + @abstractmethod + def mbar_objective(self, u_kn, N_k, f_k): + pass + + @abstractmethod + def mbar_objective_and_gradient(self, u_kn, N_k, f_k): + pass + + @abstractmethod + def mbar_hessian(self, u_kn, N_k, f_k): + pass + + @abstractmethod + def mbar_log_W_nk(self, u_kn, N_k, f_k): + pass + + @abstractmethod + def mbar_W_nk(self, u_kn, N_k, f_k): + pass + + @abstractmethod + def adaptive(self, u_kn, N_k, f_k, tol=1.0e-8, options=None): + pass + + @abstractmethod + def precondition_u_kn(self, u_kn, N_k, f_k): + pass + + @abstractmethod + def solve_mbar_once( + self, + u_kn_nonzero, + N_k_nonzero, + f_k_nonzero, + method="adaptive", + tol=1e-12, + continuation=None, + options=None, + ): + pass + + @abstractmethod + def solve_mbar(self, u_kn_nonzero, N_k_nonzero, f_k_nonzero, solver_protocol=None): + pass + + @abstractmethod + def solve_mbar_for_all_states(self, u_kn, N_k, f_k, states_with_samples, solver_protocol): + pass + + +class MBARSolverAcceleratorMethods(ABC): + """ + Methods which have to be implemented by MBAR solver accelerators + """ + + JITABLE_ACCELERATOR_METHODS = ("_adaptive_core",) + + @property + @abstractmethod + def exp(self): + pass + + @property + @abstractmethod + def sum(self): + pass + + @property + @abstractmethod + def diag(self): + pass + + @property + @abstractmethod + def newaxis(self): + pass + + @property + @abstractmethod + def dot(self): + pass + + @property + @abstractmethod + def s_(self): + pass + + @property + @abstractmethod + def pad(self): + pass + + @property + @abstractmethod + def lstsq(self): + pass + + @property + @abstractmethod + def optimize(self): + pass + + @property + @abstractmethod + def logsumexp(self): + pass + + @property + @abstractmethod + def jit(self): + pass + + @property + def real_jit(self): + return False + + def _precondition_jit(self, jitable_fn): + @wraps( + jitable_fn + ) # Helper to ensure the decorated function still registers for docs and inspection + def wrapped_precog_jit(*args, **kwargs): + # Uses "self" here as intercepted first arg for instance of the decorated class + jited_fn = self.jit(jitable_fn) + return jited_fn(*args, **kwargs) + + return wrapped_precog_jit + + @abstractmethod + def _adaptive_core(self, u_kn, N_k, f_k, g, options): + pass diff --git a/pymbar/tests/test_accelerators.py b/pymbar/tests/test_accelerators.py new file mode 100644 index 00000000..4be8a002 --- /dev/null +++ b/pymbar/tests/test_accelerators.py @@ -0,0 +1,118 @@ +"""Test MBAR accelerators by ensuring they yield comperable results to the default (numpy) and can cycle between them +""" + +import numpy as np +import pytest + +from pymbar import MBAR +from pymbar.mbar_solvers import get_accelerator, default_solver +from pymbar.utils_for_testing import assert_equal, assert_allclose + +# Pylint doesn't like the interplay between pytest and importing fixtures. disabled the one problem. +from pymbar.tests.test_mbar import ( # pylint: disable=unused-import + system_generators, + N_k, + free_energies_almost_equal, + fixed_harmonic_sample, +) + +# Setup skip if conditions +has_jax = False +try: + # pylint: disable=unused-import + from jax import jit + from jax.numpy import ndarray as jax_ndarray + + has_jax = True +except ImportError: + pass + +# Establish marks +needs_jax = pytest.mark.skipif(not has_jax, reason="Needs Jax Accelerator") + + +# Required test function for testing that the accelerator worked correctly. +def check_numpy(mbar: MBAR): + assert isinstance(mbar.f_k, np.ndarray) + + +def check_jax(mbar: MBAR): + assert isinstance(mbar.f_k, jax_ndarray) + + +# Setup accelerator list. Each parameter is (string_of_accelerator, accelerator_check) +numpy_accel = pytest.param(("numpy", check_numpy), id="numpy") +jax_accel = pytest.param(("jax", check_jax), marks=needs_jax, id="jax") +accelerators = [numpy_accel, jax_accel] + + +@pytest.fixture +def fallback_accelerator(): + return "numpy", check_numpy + + +@pytest.fixture(scope="module", params=system_generators) +def only_test_data(request): + _, test = request.param() + x_n, u_kn, N_k_output, s_n = test.sample(N_k, mode="u_kn") + assert_equal(N_k, N_k_output) + yield_bundle = {"test": test, "x_n": x_n, "u_kn": u_kn} + yield yield_bundle + + +@pytest.fixture() +def static_ukn_nk(fixed_harmonic_sample): + _, u_kn, N_k_output, _ = fixed_harmonic_sample.sample(N_k, mode="u_kn") + assert_equal(N_k, N_k_output) + return u_kn, N_k_output + + +@pytest.mark.parametrize("accelerator", accelerators) +def test_mbar_accelerators_are_accurate(only_test_data, accelerator): + """Test that each accelerator is scientifically accurate""" + accelerator_name, accelerator_check = accelerator + test, x_n, u_kn = only_test_data["test"], only_test_data["x_n"], only_test_data["u_kn"] + x_n, u_kn, N_k_output, s_n = test.sample(N_k, mode="u_kn") + mbar = build_out_an_mbar(u_kn, N_k, accelerator_name, accelerator_check, boostraps=200) + results = mbar.compute_free_energy_differences() + fe = results["Delta_f"] + fe_sigma = results["dDelta_f"] + free_energies_almost_equal(fe, fe_sigma, test.analytical_free_energies()) + accelerator_check(mbar) + + +def build_out_an_mbar(u_kn, N_k, accelerator_name, accelerator_check, boostraps=0): + """Helper function to build an MBAR object""" + mbar = MBAR(u_kn, N_k, verbose=True, accelerator=accelerator_name, n_bootstraps=boostraps) + assert mbar.solver == get_accelerator(accelerator_name) + accelerator_check(mbar) + return mbar + + +@pytest.mark.parametrize("accelerator", accelerators) +def test_mbar_accelerators_can_toggle(static_ukn_nk, accelerator, fallback_accelerator): + """ + Test that accelerator can toggle and the act of doing so doesn't corrupt each other's output. + """ + u_kn, N_k_output = static_ukn_nk + # Setup and check the accelerator + accelerator_name, accelerator_check = accelerator + mbar = build_out_an_mbar(u_kn, N_k, accelerator_name, accelerator_check) + # Setup and check the fallback + fall_back_name, fall_back_check = fallback_accelerator + mbar_fallback = build_out_an_mbar(u_kn, N_k, fall_back_name, fall_back_check) + # Ensure fallback and accelerator match + assert_allclose(mbar.f_k, mbar_fallback.f_k) + # Rebuild the accelerated version again. + mbar_rebuild = build_out_an_mbar(u_kn, N_k, accelerator_name, accelerator_check) + assert_allclose(mbar.f_k, mbar_rebuild.f_k) + + +def test_default_acclerator_is_correct(static_ukn_nk): + u_kn, N_k_output = static_ukn_nk + + def blank_check(*args): + return True + + mbar = build_out_an_mbar(u_kn, N_k, default_solver, blank_check) + assert mbar.solver == get_accelerator(default_solver) diff --git a/pymbar/tests/test_timeseries.py b/pymbar/tests/test_timeseries.py index e5455ba7..1a57fd11 100644 --- a/pymbar/tests/test_timeseries.py +++ b/pymbar/tests/test_timeseries.py @@ -130,7 +130,12 @@ def test_compare_detectEquil(show_hist=False): bs_de = timeseries.detect_equilibration_binary_search(D_t, bs_nodes=10) std_de = timeseries.detect_equilibration(D_t, fast=False, nskip=1) t_res.append(bs_de[0] - std_de[0]) - t_res_mode = float(stats.mode(t_res)[0][0]) + try: + # scipy<1.9 + t_res_mode = float(stats.mode(t_res)[0][0]) + except IndexError: + # scipy>=1.9 + t_res_mode = float(stats.mode(t_res, keepdims=True)[0][0]) assert_almost_equal(t_res_mode, 0.0, decimal=1) if show_hist: import matplotlib.pyplot as plt