Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Dynamical CPU/GPU/TPU Backends #142

Open
Sampreet opened this issue Sep 1, 2024 · 8 comments
Open

Implement Dynamical CPU/GPU/TPU Backends #142

Sampreet opened this issue Sep 1, 2024 · 8 comments
Labels
enhancement New feature or request

Comments

@Sampreet
Copy link

Sampreet commented Sep 1, 2024

Summary

With reference to the email conversation with @piperfw and @gefux, I am opening a dedicated issue pertaining to a dynamical numerical backend to support GPUs and TPUs. OQuPy utilizes TensorNetwork modules which support frameworks like TensorFlow, PyTorch and JAX. However, TensorNetwork's Node requires its tensor parameter to be passed with the same numerical backend, and OQuPy's modules use NumPy explicitly to create these nodes. As such, using a dynamical backend to switch from NumPy/SciPy to corresponding libraries of (say) JAX will facilitate the usage of TensorNetwork's GPU/TPU-friendly frameworks, which may speed up certain methods. This issue proposes a way similar to what has been recently implemented in QuTiP to support auto-differentiation using its JAX backend.

Implementation

The oqupy.config module is updated with:

# numerical backend
import numpy as default_np
import scipy.linalg as default_la
NUMERICAL_BACKEND_NUMPY = default_np
NUMERICAL_BACKEND_LINALG = default_la
NumPyDtypeComplex = default_np.complex128   # earlier NpDtype
NumPyDtypeFloat = default_np.float64        # earlier NpDtypeReal

The specific choice of scipy.linalg instead of scipy is because jax.scipy.integrate doesn't yet have implementations of dblquad, quad and quad_vec as required by some of OQuPy's modules.
A new oqupy.backends.numerical_backend module now takes care of the switching. The coarse contents of the module are:

class NumPy:
    """
    The NumPy backend for dynamic
    switching through `oqupy.config`.
    """
    @property
    def backend(self) -> default_numpy:
        """Getter for the backend."""
        return oqupy.config.NUMERICAL_BACKEND_NUMPY
    def __getattr__(self, name):
        """Return the backend's default attribute."""
        backend = object.__getattribute__(self, 'backend')
        return getattr(backend, name)
    # additional overridden methods for dataypes,
    # array updates, random number generators, etc

class LinAlg:
    # same as above for ``default_scipy.linalg``

# initialize for import
np = Numpy()
la = LinAlg()

All the modules of OQuPy are modified as:

# remove all `numpy` and `scipy.linalg` dependences other than
# `numpy.ndarray` before importing the dynamical backend
from oqupy.backends.numerical_backend import np, la
# change occurrences of NpDtype and NpDtypeReal
# to the corresponding methods of the new backend

The corresponding changes can be viewed by comparing the pr/feature-numerical-backend branch.

Testing

I have tested the above implementations using tox for vanilla NumPy/SciPy backend and reproduced the plots of arXiv:2406.16650. To use the JAX backend, the following snippet can be added at any point in the scripts:

import jax
import jax.numpy as jnp
import jax.scipy.linalg as jla
import os
import oqupy.config as oc
import tensornetwork as tn
jax.config.update('jax_enable_x64', True)
oc.NUMERICAL_BACKEND_NUMPY = jnp
oc.NumPyDtypeComplex = jnp.complex128
oc.NumPyDtypeFloat = jnp.float64
oc.NUMERICAL_BACKEND_LINALG = jla
tn.set_default_backend('jax')

Issues

The JAX backend tests have multiple issues that require suggestions before any modification. A few of these issues are mentioned below:

Immutability of jax.numpy Arrays

The add_singleton function in oqupy.utils is often called by modules like oqupy.backends.pt_tempo_backend to create/update MPOs and MPSs. This function updates the shape of the array by adding an additional dimension to it at a specified index. Whereas NumPy supports altering the shape attribute, JAX's ArrayImpl datatypes are immutable and as such shape doesn't have a setter method leading to a recurrent errors. Similar mutations can be found in oqupy.gradient, oqupy.mps_mpo and oqupy.system_dynamics modules. In this case, would reshaping the array instead of updating its shape break the functionality? Otherwise, a separate variable can be dedicated for the shape.

Vectorization Signatures

Modules such as oqupy.bath_correlations and oqupy.system use the vectorize method of NumPy without explicitly mentioning the signatures for the input and output. This raises issues with JAX when the parameters contain non-scalar inputs. Depending on the type of implementation, can the signatures be added for each vectorized function? For example, the Hamiltonians for TimeDependentSystem and TimeDependentSystemWithField will have the signatures ()->(m,m) and (),()->(m,m) respectively (kindly correct me if I have made a mistake here).

Row Degeneracy Checks

Unlike the numpy.unique method, the jax.numpy.unique method returns array indices with the same shape as the input array even with the axis parameter. Although this can be resolved by using the flatten method, I am not sure if it will introduce additional errors.

Kindly share your views as and when time permits. Thank you once again.

@piperfw piperfw added the enhancement New feature or request label Sep 1, 2024
@piperfw
Copy link
Collaborator

piperfw commented Sep 1, 2024

Hi @Sampreet, the code changes look really good. I haven't been able to run things on my system yet, as Linux + AMD for JAX is experimental, and there are some barriers to building from source, but hopefully I'll be able to get that done soon.

[On a related note it looks like JAX requires Python 3.10+; not an issue but will need to make a note of it in the README or similar as I think we are still hoping to support 3.6 as a minimum - #80]

The dynamical switching looks fine, and I actually prefer the neatness of using np.function calls instead of importing a load of specific functions in each module. I wonder if we should adjust the config setup slightly, at least e.g. to check DTYPE_COMPLEX and DTYPE_REAL are consistent with BACKEND_NUMPY, or even remove the freedom of being able to set the datatypes - @gefux what do you think? Similarly, if setting the default tensor network backend to 'jax' is important, maybe we want to check or take care of that.

I guess it's nice to be able to choose JAX for either a numpy of scipy.lianlg replacement, or both, and we can pin an issue to review the scipy import if the integrate functions are brought to JAX.

EDIT: Running CPU JAX, I've observed that the calls jax.config.update and tn.set_default_backend are necessary to avoid errors/warnings. On the other hand, JAX seems to understand np.complex128, np.float64 (although we may still want to advise their setting - depending on what Gerald intended for those config variables).

@piperfw
Copy link
Collaborator

piperfw commented Sep 1, 2024

RE Potential Issues

Shape

I don't see any issue with this. In fact, the current use of shape vs reshape in the code looks sporadic. I think we should switch to always using .reshape, which is safer (reshaping an array in-place will fail if a copy is required). See the deprecation warning https://numpy.org/doc/stable/reference/generated/numpy.ndarray.shape.html

Vectorization Signatures

Yes, signatures can be added and that sounds like a good addition. You are correct: Hamilonians for all System, TimeDependentSystem, TimeDependentSystemWithField must all return (m,m) where m is the Hilbert space dimension (._dimension), and take 0, 1 and 2 scalar parameters respectively (for the mean-field system, it is complex). See e.g. the parsing functions lines 986, 998 of system.py on your branch. I suggest you go ahead and add the signatures as you see fit, if there are others you are unsure on just ask or share a list and Gerald or I should be able to make those changes.

Row Degeneracy Checks

.flatten() there looks absolutely fine. Note that routine is only run a single time at the start of the computation, so there is no worry about overhead of unnecessary copies etc.

Overall, this looks all looks very good and close to being PR-ready. It would be good to do widespread testing, because presumably there may well be other small compatibility issues such as you have caught in the numerical backend with update. That being said, I think we can already (or soon) be confident it does not break the existing/default numpy backend, and can always advertise the JAX backend as more experimental or cutting edge.

@piperfw
Copy link
Collaborator

piperfw commented Sep 1, 2024

RE Unit testing

We have a large number of coverage and 'physical' (assessing physical consistency or accuracy) under tests/. What is the best way to extend these to test against the JAX backend? I guess (assuming we want the code to work 100% with both backends) we can have two testing environments, one for each backend (JAX also being Python 3.10), to run all the tests against.

@Sampreet
Copy link
Author

Sampreet commented Sep 2, 2024

Hi @piperfw, thanks again for your detailed comments. I shall implement the suggestions for shape, signatures and the degeneracy arrays and let you know if I face any issue. Based on my understanding of the source code, I would like to share the following in relation to your queries:

Data types

I do agree with certain checks on datatypes, specially since the precision requirements (single/double floating points and their corresponding complex datatypes) are interrelated in most of the modules. Also, as the operations for a single simulation (by a user) are likely to involve the same degree of precision, I feel it might not be useful to provide the dynamic alterability of precision (to the user) through the corresponding methods of numerical_backend once the NumPy and LinAlg classes are initialized. Kindly let me know your thoughts on the same. Likewise, I am also interested in knowing the intended purpose of the datatype variables in oqupy.config.

And yes, JAX, by default, supports 32-bit floats (and correspondingly, 64-bit complex) datatypes and requires a call to the jax.config.update method to utilize 64-bit floats. Also, using JAX with TensorNetwork requires a call to the tn.set_default_backend method.

Unit tests

I am still to update the test files with the classes from numerical_backend. As such, I have only tested the branch with vanilla NumPy. To run the tests on the different backends, can we utilize the __init__ files in tests to automatically detect the presence of JAX (with/without CUDA) and set up the backend configurations before the actual tests are run? I guess if those pass for JAX, they will pass for NumPy as well. To be sure about that, we can have two testing environments as you have suggested.

Regarding PR

It might be handy to have a separate branch (say dev-jax) in the main repository for JAX-related changes. In that case, once the necessary changes are done and the tests pass and everything else is finalized, a draft PR can be created to that branch and we can continue with code-specific changes in that draft.

Update on speedup

In the latest commit of the pr/feature-numerical-backend branch, I have implemented the generalized method to create deltas (with a NumPy-friendly oqupy.util.create_delta) which is discussed elaborately in #140. This works seamlessly with JAX and the update introduces a good speedup (benchmarks added inside the issue) for computation and application of PT-MPOs. Kindly share your comments on the same as and when time permits.

About scipy.integrate

The integrate module will probably not be incorporated in the main branch of JAX due to the presence of loopy algorithms. Although, there are probably other libraries that wrap scipy.integrate for JAX. The link might also be useful to plan future development.

Update 2024-09-04

I have made the changes corresponding to shape -> reshape, added vectorization signatures and flattened degeneracy check arrays in the pr/feature-numerical-backend branch, with some other minor changes (list -> ndarray, in-place updates, parameter type-casts and code-style-consistency fixes). I have also added the numerical backend to the test files in tests/coverage and tests/physics. The following additional changes are made in the tests:

  • np.testing.assert_almost_equal is changed to default_np.testing.assert_array_almost_equal since jax.numpy doesn't have a testing module and JAX-arrays don't go along without the _array addition. Same goes for np.testing.assert_equal.
  • type(<ndarray_element>) is changed to <ndarray_element>.dtype.

All tests pass with both NumPy/SciPy and their corresponding JAX-CPU backends (although slower). The configurations for JAX are added in the corresponding __init__ files (commented out) of tests/coverage and tests/physics. I would need suggestions regarding the following:

  • isinstance(<vectorized_function>, np.vectorize) in tests/coverage/system_test doesn't work with jax.numpy.vectorize. I'm not sure if there is a workaround for that. I have commented out the assertions pertaining to those for now.
  • the set_initial_tensor method of oqupy.process_tensor.SimpleProcessTensor throws a future warning regarding the creation of a JAX-array with the element None is treated as NaN. I think the tensor element can be set to NaN to avoid the warning or handled separately while calling the initial tensor.

Update 2024-09-05

The JAX-GPU backend doesn't yet support eigen decomposition for non-symmetric matrices although numpy.linalg lies in the scope for JAX. This leads to issue with initialization of oqupy.bath.Bath for non-diagonal coupling operators. Since it is a single operation, using the eigen decomposition in the CPU backend and passing the eigenvectors back to the GPU is a workaround. Otherwise, contributing libraries like jaxeigs can be used. A curated list of JAX-based projects can be found in awesome-jax. With the workaround mentioned, all tests pass on the JAX-GPU backend (taking a good amount of time) without any additional errors/warnings.

Thank you.

@piperfw
Copy link
Collaborator

piperfw commented Sep 8, 2024

Hi Sampreet, thanks for the detailed updates. I will respond to particular comments below.

The code changes and generally the implementation look really good. For the bigger picture, there are two things we want to avoid:

  1. increasing the maintenance difficultly of the codebase and
  2. increasing barriers for new contributions

These are potentially conflicting: if we require contributions to be JAX compatible, then that is a large barrier, but if we don't then contributed code may break for the JAX backend and someone will need to assess and fix any issues.

It seems like we should be able to keep the core OQuPy and existing features JAX-compatible (to the extent of lacking integrate, eigen decomposition etc.). I suggest we add a note to developers (and forgetful maintainers!) e.g. in CONTRIBUTING.MD to advise on the JAX backend, and best practices. The main ones I can see being to use the dynamical imported np. module rather than explicit np functions, updating array values using the new update method, and passing array-type arguments explicitly as arrays, not lists. Can you help provide any points I'm missing / advisory text?

I'm not sure we should absolutely require developers to write modules for new features that are necessarily JAX compatible. I guess there might be some boilerplate code we can have ready to add to a module e.g. to raise NotImplemented if the JAX backend is loaded for an incompatible module (similarly, one would need to skip tests of that module with the JAX backend). I don't think this point needs any action now.

Otherwise, I realised we did in fact switch to Python3.10 making my concern about versions above irrelevant. So that's good.

Data types

I'll let @gefux tell us about the intended purpose, and if this is still relevant, as I don't know.

Unit tests

I'm currently having difficultly getting a working JAX install on my system within python testing environments (3.10 is not my OS python). I'll hopefully get this sorted in due course and provide an update. But regardless, the assertion and vectorized changes sound fine. I don't know of any code that would suffer from the change of None to NaN in SimpleProcessTensor, but I'll highlight this for @gefux in case he does.

My question here Sampreet is how we can make switching to the JAX testing neater (bear in mind we want to use GitHub's CI). I'm not so familiar with pytest, but do you think there is a simple command flag or plugin we can optionally specify from the test invocation command to optionally load the JAX __init__.py code that is currently commented out?

PR

That's a good suggestion. For now I think just creating a PR from your current branch and linking here would be helpful so everyone can follow changes.

create_delta

I or someone else will look into #140 when we get the chance, sounds good.

integrate / linalg

These deficiencies may be good to note in addition to advise for developers. Maybe we need a page in the docs about the JAX backend, what do you think? Under 'Development', perhaps a new page on JAX after 'Contributing' (https://oqupy.readthedocs.io/en/latest/pages/contributing.html). In place of or addition to changes to CONTRIBUTING.md suggested above.

I do not think we should use other libraries or projects. In particular, the bath eigendecomposition is only done once for small matrices if I'm not mistaken, so there is no cost in just avoiding the GPU here.

@gefux
Copy link
Member

gefux commented Sep 8, 2024

Dear @Sampreet,

I finally found some time to look into your code - I apologize for the delay.
Overall, I agree with @piperfw. I think the changes look good.

General comment

As piperfw writes, we need to be mindful of the maintainance difficulty and the contributing barriers. I think that using from oqupy.backends.numerical_backend import np instead of import numpy as np is completely tollerable from the point of view of potential future contributors. On the other hand, demanding a JAX compatability for all current and future modules could be a discouraging factor for many potential contributors, or at least then leave more work for the maintainers that would then have to make the code JAX compatible. I'd thus suggest to make the changes you propose, but keeping the focus of OQuPy on vanilla numpy and declaring the JAX option as 'experimental' for users and as 'optional' for contributors (at least for now).

Technical comments (in addition to piperfw's)

Dynamical switching and automated testing:

The general idea of defining the default in config.py and then constructing the np and sp modules in backends/numerical_backend.py looks good to me. The switching, however, seems to be done by dynamically changing the variables defined in config.py. I think that the variable default_np should not be touched when switching, but instead the backends/numerical_backend.py should check a system environment variable such as "OQUPY_BACKEND". If it is set to "JAX" then np should be set to be the jax module and otherwise np should be set to the default. As JAX litterally requires a different system environment I think that a system environment variable is the right place to signal the switch. This would eliminate the __init__.py in the testing directory. The automated tests would then be done in a separate "JAX" environment that has everything neccessary installed and the "OQUPY_BACKEND=JAX" variable set. For functionality that doesn't work on JAX one could then add @skip decorators that are conditioned on the backend type.

Numpy data types

Fixing the data types to one specific type is simply a kind of a defensive coding style which I embrased when starting this project because numpy tends to become really slow if it performs computations on mixed types. At least that is what I used to experience a few years ago -- maybe this has changed in the meantime. I also found it potentially useful to be able to switch all computations to a different type from one point.

None vs NaN in SimpleProcessTensor.set_initial_tensor

If I am not mistaken it looks to me that the line self._initial_tensor = None there is completely useless as it is overwritten in the next line anyways. In the SimpleProcessTensor.__init__() the line self._initial_tensor = None however is important for the logic of the class as it signals that the process tensor doesn't have an initial tensor (which is still a valid form of a process tensor). It would be important that we could keep this logic of a tensor being either None or a tensor depending on it's existence because this appears in many places in the code, and is very convenient.

@piperfw
Copy link
Collaborator

piperfw commented Sep 8, 2024

Really good idea with the environmental variable @gefux, that clears up a lot of the mess with the testing I was concerned about. Completely agree about the experimental label for the backend; we can advertise this in an appropriate page in the docs.

@Sampreet
Copy link
Author

Dear @piperfw and @gefux,

Thanks a lot for the detailed comments. I shall implement the suggested changes sometime this week and create a draft PR, where we can continue the code-related discussions. Below are a few points I would like to add here:

Contribution and Maintenance

It seems like we should be able to keep the core OQuPy and existing features JAX-compatible (to the extent of lacking integrate, eigen decomposition etc.) - @piperfw
I'd thus suggest to make the changes you propose, but keeping the focus of OQuPy on vanilla numpy and declaring the JAX option as 'experimental' for users and as 'optional' for contributors (at least for now). - @gefux

This sounds like a good idea and more user/contributor/maintainer-friendly for now. The proposed changes are also fully compatible with vanilla NumPy users without requiring any change in the existing examples. Same goes for contributions with vanilla NumPy. Additionally, interested contributors can follow an "optional" set of guidelines (detailed next) to make sure that their changes work seamlessly with JAX, thereby easing maintenance. Moreover, the current oqupy.backends.numerical_backend module simply handles the breaking changes in JAX NumPy such as in-place updates and pseudo random number generation (which defaults to vanilla NumPy for now), and therefore, no explicit JAX-related imports are required in any of the modules. This keeps JAX-dependencies away from the core library, and the users can utilize the oqupy.config module to switch to JAX-based "experimental" features dynamically. Alternatively, a new function (say enable_experimental_features) can be added to ease the switching but will require explicit JAX-imports. I shall detail on this a bit more in the section for environment variables.

I suggest we add a note to developers (and forgetful maintainers!) e.g. in CONTRIBUTING.md to advise on the JAX backend, and best practices. Can you help provide any points I'm missing / advisory text? - @piperfw

The following guidelines can be added for "optional" contribution to the "experimental" features :

  • using from oqupy.backends.numerical_backend import np instead of import numpy as np and use the alias default_np for vanilla NumPy; avoiding wildcard imports.
  • using from oqupy.backends.numerical_backend import la instead of import scipy.linalg as la; except for non-symmetric eigen-decomposition (scipy.linalg.eig for non-symmetric matrices).
  • using one of np.dtype_complex (np.dtype_float) or oqupy.config.NumPyDtypeComplex (oqupy.config.NumPyDtypeFloat) instead of np.complex_ (np.float_).
  • converting lists or tuples to NumPy arrays when passing them as arguments inside NumPy functions.
  • using np.update(array, indices, values) instead of array[indices] = values.
  • using np.get_random_floats(seed, shape) instead of np.random.default_rng(seed).random(shape).
  • declaring signatures for np.vectorize explicitly.

Maybe we need a page in the docs about the JAX backend, what do you think? Under 'Development', perhaps a new page on JAX after 'Contributing'. In place of or addition to changes to CONTRIBUTING.md suggested above. - @piperfw

It might be a good idea to add a separate section/page in the documentation to highlight the "experimental" nature of this feature. In time, this can include installation guidelines, gotchas (some points from CONTRIBUTING.md), and configurational changes required to use JAX with OQuPy.

Continuous Integration and Unit Tests

My question here Sampreet is how we can make switching to the JAX testing neater (bear in mind we want to use GitHub's CI). - @piperfw

From what I understand, the automated tests are performed with tox using the GitHub CI routine described in .github\workflows\python-package-tests.yml. For the current changes, we can therefore add a separate testenv block in tox.ini:

[testenv:pytest_jax]
description = run pytests with JAX and make coverage report
basepython = python3.10
deps = -rrequirements_ci_jax.txt
commands =
    pytest --cov-report term-missing --cov=oqupy ./tests/{posargs}
    python -m coverage xml

Then, we can use Python's importlib to run the tests with and without JAX via the __init__.py modules of tests/coverage and tests/physics:

from importlib.util import find_spec

if find_spec('jax') is not None:
    # JAX configuration here

I have implemented this approach and it works well. What I am worried about is the time taken by the JAX-based tests are very high since we do not have any explicit JAX-based primitives in OQuPy's modules yet (further details in the next section).

Dynamical Switching with Environment Variables

I think that the variable default_np should not be touched when switching, but instead the backends/numerical_backend.py should check a system environment variable such as "OQUPY_BACKEND". If it is set to "JAX" then np should be set to be the jax module and otherwise np should be set to the default. As JAX literally requires a different system environment I think that a system environment variable is the right place to signal the switch. This would eliminate the __init__.py in the testing directory. - @gefux

I believe what you are suggesting here is an explicit backend module for JAX, which is imported during the initialization of numerical_backend.NumPy class only when the environment variable "OQUPY_BACKEND" is set to "JAX". Otherwise, the initialization defaults to the NumPy library defined by NUMERICAL_BACKEND_NUMPY in oqupy.config (which stays untouched). Kindly let me know if I am mistaken here.

The above description is similar to what has been implemented inside tensornetwork. I have also implemented something similar for my RL library quantrl by using an abstract base class inherited by each numerical backend (NumPy/JAX/PyTorch), with individual overridden abstract methods. This facilitates a fine-grained control (and thereby, speedups) over each backend-based implementation (such as precision, in-built methods or even custom wrappers for functions). The idea of utilizing oqupy.config here instead of this approach was to initially support GPU/TPU arrays without major changes in the original codebase (e.g. syntax of conditional statements, loops will be governed by the backend methods) and avoiding any explicit mention for JAX inside it.

On second thoughts, since tensornetwork is already a required module for OQuPy, we can utilize its JAX backend for this purpose and wrap it via numerical_backend based on the value of "OQUPY_BACKEND" (a few additional changes might make it compatible with PyTorch/Symmetric/TensorFlow too). This will allow:

  • the use of JIT-ted functions and jax.lax primitives inside the modules, resulting in speedups in computation and application.
  • testing without the requirement of explicit JAX configurations (setting "OQUPY_BACKEND" would still require importlib as mentioned in the previous section).
  • changing the backend dynamically by using methods like enable_experimental_features and disable_experimental_features (this will require reinitialization of numerical_backend.NumPy).
  • unavailable JAX functions to be automatically overridden by vanilla NumPy ones though the numerical backend.

As of now, I don't see any major disadvantage in the environment-variable approach, other than the fact that we might need to add a few additional guidelines for "optional" contributions.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

3 participants