Replies: 1 comment 2 replies
-
That is exactly our thinking as well! We have been working on the JAX backend for quite some time now actually -- see this branch : https://github.com/infer-actively/pymdp/tree/agent_jax We haven't benchmarked it yet against that CPP paper but that would be nice to do at some point |
Beta Was this translation helpful? Give feedback.
2 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Hello,
According to this recent publication about the Cpp-AIF github project, the main painpoint of pymdp is its speed. Looking at pymdp source code, I see that it could greatly benefit from more vectorized code and hardware acceleration. With its Numpy-like API, JAX could be a great choice here. Are there any plans to rewrite this library using JAX? I'm pretty confident that it could bring several orders of magnitude faster computations while keeping the simplicity of python over C++.
Beta Was this translation helpful? Give feedback.
All reactions