-
-
Notifications
You must be signed in to change notification settings - Fork 12
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
JAX backend fails for a simple pymc
linear regression model
#157
Comments
That you for the bug report. I had seen something possibly related recently, but didn't manage to find an example in a smaller model. This example should make it much easier to find the problem. Right now you can work around the issue by freezing the pymc model: from pymc.model.transform.optimization import freeze_dims_and_data
trace = nutpie.sample(nutpie.compile_pymc_model(freeze_dims_and_data(model), **kwargs)) |
Well, that was easier than I though, and won't be hard to fix. The problem is that the argument name for the point in parameter space is |
@aseyboldt : Thanks a lot for the very helpful reply. I renamed 'x' -> 'X' and now things are working. This is really good to know, but probably something that should either be fixed by ensuring that a unique name for the point in parameter space is used, or forbidding 'x' as name in the model (which would be a bit cumbersome, since predictors are often denoted by 'x'). On the upside using JAX gives a very nice speedup on my machine (Apple M1) :-). |
Yes, definitely needs a fix, I'll push one soon. Out of curiosity (I don't have a apple), could you do me a small favor and run this with jax and numba and tell me what the compile and the runtime is each time? jax
numba
|
jax
numba
|
I hope that helps. I have another follow up question: While I observe a great speed-up when using the JAX backend on my M1 Apple machine, I observe significantly slower sampling with the JAX backend compared to Numba/Pytensor when running on a Google Cloud VM with a lot more cores (32) and memory. This is for a hierarchical linear regression with thousands of groups and a couple of predictors. On the VM sampling with the "jax" backend is about 30% slower compared to the "numba" backend. Specifically I observe that for the "numba" backend I get a couple of (4-8) thread/CPU bars in If you have any ideas/insights what could cause this and also how to ensure best performance, then I'd be glad for any suggestions. |
Thanks for the numbers :-) First, I think it is important to distinguish compile time and sampling time. The numbers you just gave me show that the numba backend samples faster on the mac as well, only the compile time is much larger. If the model get's bigger the compile time will play less of a role, because it doesn't depend much on the data size. I think what you observe with the jax backend is an issue with how the jax backend currently works:
If the computation of the logp function takes a long time, and there aren't that many threads, then most of the time only one or even no thread will hold the gil, because each threads spends most of its time in the "do the actual logp function evaluation" phase, and all is good. There are two things that might make this situation better in the future:
In the meantime: If the cores of your machine aren't used well, you can at least try to limit the number of threads that run at the same time by setting the If you are willing to go to some extra lengths: You can start multiple separate processes that sample your model (with different seeds!) and then combine the traces. This is much more annoying, but should completely avoid the lock contention. In that case you can run into other issues however, for instance if each process tries to use all available cores on the machine. Fixing that would then require using I hope that helps to clear it up a bit :-) |
Thanks @aseyboldt this is really helpful. Do you know of a minimal example for the "start multilple separate processes" approach. I have seen https://discourse.pymc.io/t/harnessing-multiple-cores-to-speed-up-fits-with-small-number-of-chains/7669 where the idea is to concatenate multiple smaller chains to more efficiently harness the CPUs on a machine. Btw: with jax
numba
So JAX is a lot faster for sampling - which also matches my observation for a hierarchical linear model. |
For sampling in separate processes: # At the very start...
import os
os.environ["JOBLIB_START_METHOD"] = "forkserver"
import joblib
from joblib import parallel_config, Parallel, delayed
import arviz
def run_chain(data, idx, seed):
model = make_model(data)
seeds = np.random.SeedSequence(seed)
seed = np.random.default_rng(seeds.spawn(idx + 1)[-1]).integers(2 ** 63)
compiled = nutpie.compile_pymc_model(model, backend="jax", gradient_backend="jax")
trace = nutpie.sample(compiled, seed=seed, chains=1, progress_bar=False)
return trace.assign_coords(chain=[idx])
with parallel_config(n_jobs=10, prefer='processes'):
traces = Parallel()(delayed(run_chain)(frame, i, 123) for i in range(10))
trace = arviz.concat(traces, dim="chain") This comes with quite a bit of overhead (mostly constant though), so probably not worth it for smaller models. Funnily enough, I see big differences between
And jax and numba react quite differently. Maybe an issue with the blas config? What blas implementation are you using? (on conda-forge you can choose it as explained here: https://conda-forge.org/docs/maintainer/knowledge_base/#switching-blas-implementation) |
Thanks a lot for the again very helpful suggestions. I will benchmark the two versions of the "dot-product" to see whether I observe different performance. Regarding BLAS On Apple-M1 I have
and on the VM
I can try to use the accelerate BLAS. But I am more curious to speed up things on the VM now. |
Minimal example
Error message
Sampling with
backend="numba"
andgradient_backend="pytensor"
runs successfully.Version
The text was updated successfully, but these errors were encountered: