-
Notifications
You must be signed in to change notification settings - Fork 25
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement Dynamical CPU/GPU/TPU Backends #142
Comments
Hi @Sampreet, the code changes look really good. I haven't been able to run things on my system yet, as Linux + AMD for JAX is experimental, and there are some barriers to building from source, but hopefully I'll be able to get that done soon. [On a related note it looks like JAX requires Python 3.10+; not an issue but will need to make a note of it in the README or similar as I think we are still hoping to support 3.6 as a minimum - #80] The dynamical switching looks fine, and I actually prefer the neatness of using I guess it's nice to be able to choose JAX for either a EDIT: Running CPU JAX, I've observed that the calls |
RE Potential Issues ShapeI don't see any issue with this. In fact, the current use of shape vs reshape in the code looks sporadic. I think we should switch to always using Vectorization SignaturesYes, signatures can be added and that sounds like a good addition. You are correct: Hamilonians for all Row Degeneracy Checks
Overall, this looks all looks very good and close to being PR-ready. It would be good to do widespread testing, because presumably there may well be other small compatibility issues such as you have caught in the numerical backend with |
RE Unit testing We have a large number of coverage and 'physical' (assessing physical consistency or accuracy) under tests/. What is the best way to extend these to test against the JAX backend? I guess (assuming we want the code to work 100% with both backends) we can have two testing environments, one for each backend (JAX also being Python 3.10), to run all the tests against. |
Hi @piperfw, thanks again for your detailed comments. I shall implement the suggestions for shape, signatures and the degeneracy arrays and let you know if I face any issue. Based on my understanding of the source code, I would like to share the following in relation to your queries: Data typesI do agree with certain checks on datatypes, specially since the precision requirements (single/double floating points and their corresponding complex datatypes) are interrelated in most of the modules. Also, as the operations for a single simulation (by a user) are likely to involve the same degree of precision, I feel it might not be useful to provide the dynamic alterability of precision (to the user) through the corresponding methods of And yes, JAX, by default, supports 32-bit floats (and correspondingly, 64-bit complex) datatypes and requires a call to the Unit testsI am still to update the test files with the classes from Regarding PRIt might be handy to have a separate branch (say Update on speedupIn the latest commit of the pr/feature-numerical-backend branch, I have implemented the generalized method to create deltas (with a NumPy-friendly About
|
Hi Sampreet, thanks for the detailed updates. I will respond to particular comments below. The code changes and generally the implementation look really good. For the bigger picture, there are two things we want to avoid:
These are potentially conflicting: if we require contributions to be JAX compatible, then that is a large barrier, but if we don't then contributed code may break for the JAX backend and someone will need to assess and fix any issues. It seems like we should be able to keep the core OQuPy and existing features JAX-compatible (to the extent of lacking integrate, eigen decomposition etc.). I suggest we add a note to developers (and forgetful maintainers!) e.g. in I'm not sure we should absolutely require developers to write modules for new features that are necessarily JAX compatible. I guess there might be some boilerplate code we can have ready to add to a module e.g. to raise Otherwise, I realised we did in fact switch to Python3.10 making my concern about versions above irrelevant. So that's good. Data typesI'll let @gefux tell us about the intended purpose, and if this is still relevant, as I don't know. Unit testsI'm currently having difficultly getting a working JAX install on my system within python testing environments (3.10 is not my OS python). I'll hopefully get this sorted in due course and provide an update. But regardless, the assertion and vectorized changes sound fine. I don't know of any code that would suffer from the change of My question here Sampreet is how we can make switching to the JAX testing neater (bear in mind we want to use GitHub's CI). I'm not so familiar with pytest, but do you think there is a simple command flag or plugin we can optionally specify from the test invocation command to optionally load the JAX PRThat's a good suggestion. For now I think just creating a PR from your current branch and linking here would be helpful so everyone can follow changes. create_deltaI or someone else will look into #140 when we get the chance, sounds good. integrate / linalgThese deficiencies may be good to note in addition to advise for developers. Maybe we need a page in the docs about the JAX backend, what do you think? Under 'Development', perhaps a new page on JAX after 'Contributing' (https://oqupy.readthedocs.io/en/latest/pages/contributing.html). In place of or addition to changes to I do not think we should use other libraries or projects. In particular, the bath eigendecomposition is only done once for small matrices if I'm not mistaken, so there is no cost in just avoiding the GPU here. |
Dear @Sampreet, I finally found some time to look into your code - I apologize for the delay. General commentAs piperfw writes, we need to be mindful of the maintainance difficulty and the contributing barriers. I think that using Technical comments (in addition to piperfw's)Dynamical switching and automated testing:The general idea of defining the default in Numpy data typesFixing the data types to one specific type is simply a kind of a defensive coding style which I embrased when starting this project because numpy tends to become really slow if it performs computations on mixed types. At least that is what I used to experience a few years ago -- maybe this has changed in the meantime. I also found it potentially useful to be able to switch all computations to a different type from one point.
|
Really good idea with the environmental variable @gefux, that clears up a lot of the mess with the testing I was concerned about. Completely agree about the experimental label for the backend; we can advertise this in an appropriate page in the docs. |
Thanks a lot for the detailed comments. I shall implement the suggested changes sometime this week and create a draft PR, where we can continue the code-related discussions. Below are a few points I would like to add here: Contribution and Maintenance
This sounds like a good idea and more user/contributor/maintainer-friendly for now. The proposed changes are also fully compatible with vanilla NumPy users without requiring any change in the existing examples. Same goes for contributions with vanilla NumPy. Additionally, interested contributors can follow an "optional" set of guidelines (detailed next) to make sure that their changes work seamlessly with JAX, thereby easing maintenance. Moreover, the current
The following guidelines can be added for "optional" contribution to the "experimental" features :
It might be a good idea to add a separate section/page in the documentation to highlight the "experimental" nature of this feature. In time, this can include installation guidelines, gotchas (some points from Continuous Integration and Unit Tests
From what I understand, the automated tests are performed with [testenv:pytest_jax]
description = run pytests with JAX and make coverage report
basepython = python3.10
deps = -rrequirements_ci_jax.txt
commands =
pytest --cov-report term-missing --cov=oqupy ./tests/{posargs}
python -m coverage xml Then, we can use Python's from importlib.util import find_spec
if find_spec('jax') is not None:
# JAX configuration here I have implemented this approach and it works well. What I am worried about is the time taken by the JAX-based tests are very high since we do not have any explicit JAX-based primitives in OQuPy's modules yet (further details in the next section). Dynamical Switching with Environment Variables
I believe what you are suggesting here is an explicit backend module for JAX, which is imported during the initialization of The above description is similar to what has been implemented inside On second thoughts, since
As of now, I don't see any major disadvantage in the environment-variable approach, other than the fact that we might need to add a few additional guidelines for "optional" contributions. |
Summary
With reference to the email conversation with @piperfw and @gefux, I am opening a dedicated issue pertaining to a dynamical numerical backend to support GPUs and TPUs. OQuPy utilizes TensorNetwork modules which support frameworks like TensorFlow, PyTorch and JAX. However, TensorNetwork's
Node
requires itstensor
parameter to be passed with the same numerical backend, and OQuPy's modules use NumPy explicitly to create these nodes. As such, using a dynamical backend to switch from NumPy/SciPy to corresponding libraries of (say) JAX will facilitate the usage of TensorNetwork's GPU/TPU-friendly frameworks, which may speed up certain methods. This issue proposes a way similar to what has been recently implemented in QuTiP to support auto-differentiation using its JAX backend.Implementation
The
oqupy.config
module is updated with:The specific choice of
scipy.linalg
instead ofscipy
is becausejax.scipy.integrate
doesn't yet have implementations ofdblquad
,quad
andquad_vec
as required by some of OQuPy's modules.A new
oqupy.backends.numerical_backend
module now takes care of the switching. The coarse contents of the module are:All the modules of OQuPy are modified as:
The corresponding changes can be viewed by comparing the pr/feature-numerical-backend branch.
Testing
I have tested the above implementations using
tox
for vanilla NumPy/SciPy backend and reproduced the plots of arXiv:2406.16650. To use the JAX backend, the following snippet can be added at any point in the scripts:Issues
The JAX backend tests have multiple issues that require suggestions before any modification. A few of these issues are mentioned below:
Immutability of
jax.numpy
ArraysThe
add_singleton
function inoqupy.utils
is often called by modules likeoqupy.backends.pt_tempo_backend
to create/update MPOs and MPSs. This function updates the shape of the array by adding an additional dimension to it at a specified index. Whereas NumPy supports altering theshape
attribute, JAX'sArrayImpl
datatypes are immutable and as suchshape
doesn't have a setter method leading to a recurrent errors. Similar mutations can be found inoqupy.gradient
,oqupy.mps_mpo
andoqupy.system_dynamics
modules. In this case, would reshaping the array instead of updating its shape break the functionality? Otherwise, a separate variable can be dedicated for the shape.Vectorization Signatures
Modules such as
oqupy.bath_correlations
andoqupy.system
use thevectorize
method of NumPy without explicitly mentioning the signatures for the input and output. This raises issues with JAX when the parameters contain non-scalar inputs. Depending on the type of implementation, can the signatures be added for each vectorized function? For example, the Hamiltonians forTimeDependentSystem
andTimeDependentSystemWithField
will have the signatures()->(m,m)
and(),()->(m,m)
respectively (kindly correct me if I have made a mistake here).Row Degeneracy Checks
Unlike the
numpy.unique
method, thejax.numpy.unique
method returns array indices with the same shape as the input array even with theaxis
parameter. Although this can be resolved by using theflatten
method, I am not sure if it will introduce additional errors.Kindly share your views as and when time permits. Thank you once again.
The text was updated successfully, but these errors were encountered: