Python provides a variety of functionality for parallelization, including threaded operations (in particular for linear algebra), parallel looping and map statements, and parallelization across multiple machines. For the CPU, this material focuses on Python’s ipyparallel package and JAX, with some discussion of Dask and Ray. For the GPU, the material focuses on PyTorch and JAX, with a bit of discussion of CuPy.
Note that all of the looping-based functionality discussed here applies only if the iterations/loops of your calculations can be done completely separately and do not depend on one another. This scenario is called an embarrassingly parallel computation. So coding up the evolution of a time series or a Markov chain is not possible using these tools. However, bootstrapping, random forests, simulation studies, cross-validation and many other statistical methods can be handled in this way.
The BLAS is the library of basic linear algebra operations (written in Fortran or C). A fast BLAS can greatly speed up linear algebra relative to the default BLAS on a machine. Some fast BLAS libraries are
- Intel’s MKL; may be available for educational use for free
- OpenBLAS; open source and free
- vecLib for Macs; provided with your Mac
In addition to being fast when used on a single core, all of these BLAS libraries are threaded - if your computer has multiple cores and there are free resources, your linear algebra will use multiple cores, provided your installed Python is linked against the threaded BLAS installed on your machine.
To use a fast, threaded BLAS, one approach is to use the Anaconda/Miniconda Python distribution. When you install numpy and scipy, these should be automatically linked against a fast, threaded BLAS (MKL). More generally, simply installing numpy from PyPI should make use of OpenBLAS.
Threading in Python is limited to linear algebra (provided Python is linked against a threaded BLAS, except if using Dask or JAX or various other packages). Python has something called the Global Interpreter Lock that interferes with threading in Python (but not in threaded linear algebra packages called by Python).
Here’s some linear algebra in Python that will use threading if numpy is linked against a threaded BLAS, though I don’t compare the timing for different numbers of threads here.
import numpy as np
n = 5000
x = np.random.normal(0, 1, size=(n, n))
x = x.T @ x
U = np.linalg.cholesky(x)
If you watch the Python process via the top command, you should see CPU usage above 100% if Python is linking to a threaded BLAS.
In general, threaded code will detect the number of cores available on a machine and make use of them. However, you can also explicitly control the number of threads available to a process.
For most threaded code (that based on the openMP protocol), the number of threads can be set by setting the OMP_NUM_THREADS environment variable. Note that under some circumstances you may need to use VECLIB_MAXIMUM_THREADS if on an (older, Intel-based) Mac or MKL_NUM_THREADS if numpy/scipy are linked against MKL.
For example, to set it for four threads in bash, do this before starting your Python session.
export OMP_NUM_THREADS=4
Alternatively, you can set OMP_NUM_THREADS as you invoke your job, e.g.,
OMP_NUM_THREADS=4 python job.py > job.out
First we’ll cover IPython Parallel (i.e., the ipyparallel
package)
functionality, which allows one to parallelize on a single machine
(discussed here) or across multiple machines (see next section). In
later sections, I’ll discuss other packages that can be used for
parallelization.
First we need to start our workers. As of ipyparallel version 7, we can start the workers from within Python.
## In newer versions of ipyparallel (v. 7 and later)
import ipyparallel as ipp
# Check the version:
ipp.__version__
n = 4
cluster = ipp.Cluster(n = n)
c = cluster.start_and_connect_sync()
Starting 4 engines with <class 'ipyparallel.cluster.launcher.LocalEngineSetLauncher'>
0%| | 0/4 [00:00<?, ?engine/s]
Let’s verify that things seem set up ok and we can interact with all our workers:
## Check that we have the number of workers expected:
c.ids
[0, 1, 2, 3]
## Set up a direct view to interface with all the workers
dview = c[:]
dview
<DirectView [0, 1, 2, 3]>
## Set blocking so that we wait for the result of the parallel execution
dview.block = True
dview.apply(lambda : "Hello, World")
['Hello, World', 'Hello, World', 'Hello, World', 'Hello, World']
dview
stands for a ‘direct view’, which is an interface to our cluster
that allows us to ‘manually’ send tasks to the workers.
Now let’s see an example of how we can use our workers to run code in parallel.
We’ll carry out a statistics/machine learning prediction method (random forest regression) with leave-one-out cross-validation, parallelizing over different held out data.
First let’s set up packages, data and our main function on the workers:
dview.execute('from sklearn.ensemble import RandomForestRegressor as rfr')
dview.execute('import numpy as np')
def looFit(index, Ylocal, Xlocal):
rf = rfr(n_estimators=100)
fitted = rf.fit(np.delete(Xlocal, index, axis = 0), np.delete(Ylocal, index))
pred = rf.predict(np.array([Xlocal[index, :]]))
return(pred[0])
import numpy as np
np.random.seed(0)
n = 200
p = 20
X = np.random.normal(0, 1, size = (n, p))
Y = X[: , 0] + pow(abs(X[:,1] * X[:,2]), 0.5) + X[:,1] - X[:,2] + \
np.random.normal(0, 1, n)
mydict = dict(X = X, Y = Y, looFit = looFit)
dview.push(mydict)
[None, None, None, None]
Now let’s set up a “load-balanced view”. With this type of interface, one submits the tasks and the controller decides how to divide up the tasks, ideally achieving good load balancing. A load-balanced computation is one that keeps all the workers busy throughout the computation
lview = c.load_balanced_view()
lview.block = True
# need a wrapper function because map() only operates on one argument
def wrapper(i):
return(looFit(i, Y, X))
# Now run the fitting, predicting on each held-out observation:
pred = lview.map(wrapper, range(n))
# Check a few predictions:
pred[0:3]
[2.0945225368269256, -0.8696741139958911, -0.32442762057816776]
One can also start the workers outside of Python. This was required in older versions of ipyparallel, before version 7.
# In the bash shell:
export NWORKERS=4
ipcluster start -n ${NWORKERS} &
Now in Python, we can connect to the running workers:
# In python
import os
import ipyparallel as ipp
c = ipp.Client()
c.wait_for_engines(n = int(os.environ['NWORKERS']))
c.ids
# Now do your parallel computations
Finally, stop the workers.
ipcluster stop
One can use ipyparallel in a context with multiple nodes, though the setup to get the worker processes started is a bit more involved when you have multiple nodes.
If we are using the SLURM scheduling software, here’s how we start up the worker processes:
# In the bash shell (e.g., in your Slurm job script)
ipcontroller --ip='*' &
sleep 60
# Next start as many ipengines (workers) as we have Slurm tasks.
# This works because srun is a Slurm command,
# so it knows it is running within a Slurm allocation
srun ipengine &
At this point you should be able to connect to the running cluster using the syntax seen for single-node usage.
Warning: Be careful to set the sleep period long enough that the controller starts before trying to start the workers and the workers start before trying to connect to the workers from within Python.
After doing your computations and quitting your main Python session, shut down the cluster of workers:
ipcluster stop
To start the engines in a context outside of using Slurm (provided all
machines share a filesystem), you should be able ssh to each machine and
run ipengine &
for as many worker processes as you want to start as
follows. In some, but not all cases (depending on how the network is set
up) you may not need the --location
flag, but if you do, it should be
set to the name of the machine you’re working on, e.g., by using the
HOST environment variable. Here we start all the workers on a single
other machine, “other_host”:
ipcontroller --ip='*' --location=${HOST} &
sleep 60
NWORKERS=4
ssh other_host "for (( i = 0; i < ${NWORKERS}; i++ )); do ipengine &; done"
Dask and Ray are powerful packages for parallelization that allow one to parallelize tasks in similar fashion to ipyparallel. But they also provide additional useful functionality: Dask allows one to work with large datasets that are split up across multiple processes on (potentially) multiple nodes, providing Spark/Hadoop-like functionality. Ray allows one to develop complicated apps that execute in parallel using the notion of actors.
For more details on using distributed dataset with Dask, see this Dask tutorial. For more details on Ray’s actors, please see the Ray documentation.
There are various ways to do parallel loops in Dask, as discussed in detail in this Dask tutorial.
Here’s an example of doing it with “delayed” calculations set up via list comprehension. First we’ll start workers on a single machine. One can also start workers on multiple machines, as discussed in the tutorial linked to just above.
import dask.multiprocessing
dask.config.set(scheduler='processes', num_workers = 4)
Now we’ll execute a set of tasks in parallel by wrapping the function of
interest in dask.delayed
to set up lazy evaluation that will be done
in parallel using the workers already set up with the ‘processes’
scheduler above.
def calc_mean(i, n):
import numpy as np
rng = np.random.default_rng(i)
data = rng.normal(size = n)
return([np.mean(data), np.std(data)])
n = 1000
p = 10
futures = [dask.delayed(calc_mean)(i, n) for i in range(p)]
futures # This is an array of placeholders for the tasks to be carried out.
# [Delayed('calc_mean-b07564ff-149a-4db7-ac3c-1cc89b898fe5'),
# Delayed('calc_mean-f602cd67-97ad-4293-aeb8-e58be55a89d6'),
# Delayed('calc_mean-d9448f54-b1db-46aa-b367-93a46e1c202a'), ...
# Now ask for the output to trigger the lazy evaluation.
results = dask.compute(futures)
Execution only starts when we call dask.compute
.
Note that we set a separate seed for each task to try to ensure indepenedent random numbers between tasks, but Section 5 discusses better ways to do this.
We’ll start up workers on a single machine. To run across multiple workers, see this tutorial or the Ray documentation.
import ray
ray.init(num_cpus = 4)
To run a computation in parallel, we decorate the function of interest
with the remote
tag:
@ray.remote
def calc_mean(i, n):
import numpy as np
rng = np.random.default_rng(i)
data = rng.normal(size = n)
return([np.mean(data), np.std(data)])
n = 1000
p = 10
futures = [calc_mean.remote(i, n) for i in range(p)]
futures # This is an array of placeholders for the tasks to be carried out.
# [ObjectRef(a67dc375e60ddd1affffffffffffffffffffffff0100000001000000),
# ObjectRef(63964fa4841d4a2effffffffffffffffffffffff0100000001000000), ...
# Now trigger the computation
ray.get(futures)
The key thing when thinking about random numbers in a parallel context is that you want to avoid having the same ‘random’ numbers occur on multiple processes. On a computer, random numbers are not actually random but are generated as a sequence of pseudo-random numbers designed to mimic true random numbers. The sequence is finite (but very long) and eventually repeats itself. When one sets a seed, one is choosing a position in that sequence to start from. Subsequent random numbers are based on that subsequence. All random numbers can be generated from one or more random uniform numbers, so we can just think about a sequence of values between 0 and 1.
The worst thing that could happen is that one sets things up in such a
way that every process is using the same sequence of random numbers.
This could happen if you mistakenly set the same seed in each process,
e.g., using rng = np.random.default_rng(1)
or np.random.seed(1)
in
Python for every worker.
The naive approach is to use a different seed for each process. E.g., if
your processes are numbered id = 1,2,...,p
with a variable id that
is unique to a process, setting the seed to be the value of id on each
process. This is likely not to cause problems, but raises the danger
that two (or more) subsequences might overlap. For an algorithm with
dependence on the full subsequence, such as an MCMC, this probably won’t
cause big problems (though you likely wouldn’t know if it did), but for
something like simple simulation studies, some of your ‘independent’
samples could be exact replicates of a sample on another process. Given
the period length of the default generator in Python, this is actually
quite unlikely, but it is a bit sloppy.
To avoid this problem, the key is to use an algorithm that ensures sequences that do not overlap.
In recent versions of numpy there has been attention paid to this problem and there are now multiple approaches to getting high-quality random number generation for parallel code.
One approach is to generate one random seed per task such that the
blocks of random numbers avoid overlapping with high probability, as
implemented in numpy’s SeedSequence
approach.
Here we use that approach within the context of an ipyparallel load-balanced view.
import numpy as np
import ipyparallel as ipp
n = 4
cluster = ipp.Cluster(n = n)
cluster.start_cluster_sync()
c = cluster.connect_client_sync()
c.wait_for_engines(n)
c.ids
lview = c.load_balanced_view()
lview.block = True
n = 1000
p = 10
seed = 1
ss = np.random.SeedSequence(seed)
child_seeds = ss.spawn(p)
def calc_mean(i, n, seed_i):
import numpy as np
rng = np.random.default_rng(seed_i)
data = rng.normal(size = n)
return([np.mean(data), np.std(data)])
# need a wrapper function because map() only operates on one argument
def wrapper(i):
return(calc_mean(i, n, child_seeds[i]))
dview = c[:]
dview.block = True
mydict = dict(calc_mean = calc_mean, n = n, child_seeds = child_seeds)
dview.push(mydict)
results = lview.map(wrapper, range(p))
A second approach is to advance the state of the random number generator as if a large number of random numbers had been drawn.
seed = 1
pcg64 = np.random.PCG64(seed)
def calc_mean(i, n, rng):
import numpy as np
rng = np.random.Generator(pcg64.jumped(i)) ## jump in large steps, one jump per task
data = rng.normal(size = n)
return([np.mean(data), np.std(data)])
# need a wrapper function because map() only operates on one argument
def wrapper(i):
return(calc_mean(i, n, rng))
dview = c[:]
dview.block = True
mydict = dict(calc_mean = calc_mean, n = n, rng = rng)
dview.push(mydict)
results = lview.map(wrapper, range(p))
Note that above, I’ve done everything at the level of the computational tasks. One could presumably do this at the level of the workers, but one would need to figure out how to maintain the state of the generator from one task to the next for any given worker.
Python is the go-to language used to run computations on a GPU. Some of the packages that can easily offload computations to the GPU include PyTorch, Tensorflow, JAX, and CuPy. (Of course PyTorch and Tensorflow are famously used for deep learning, but they’re also general numerical computing packages.) We’ll discuss some of these.
There are a couple key things to remember about using a GPU:
- The GPU memory is separate from CPU memory, and transferring data
from the CPU to GPU (or back) is often more costly than doing the
computation on the GPU.
- If possible, generate the data on the GPU or keep the data on the GPU when carrying out a sequence of operations.
- By default GPU calculations are often doing using 32-bit (4-byte)
floating point numbers rather than the standard of 64-bit (8-byte)
when on the CPU.
- This can affect speed comparisons between CPU and GPU if one doesn’t compare operations with the same types of floating point numbers.
- GPU operations are often asynchronous – they’ll continue in the
background after they start, returning control of your Python
session to you and potentially making it seem like the computation
happened more quickly than it did.
- In the examples below, note syntax that ensures the operation is
done before timing concludes (e.g.,
cuda.synchronize
for PyTorch andblock_until_ready
for JAX).
- In the examples below, note syntax that ensures the operation is
done before timing concludes (e.g.,
Note that for this section, I’m pasting in the output when running the code separately on a machine with a GPU because this document is generated on a machine without a GPU.
Here's an example of doing some linear algebra (simply matrix multiplication) on the GPU using PyTorch.
By default PyTorch will use 32-bit numbers.
import torch
import time
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
gpu = torch.device("cuda:0")
n = 7000
def matmul_wrap(x, y):
z = torch.matmul(x, y)
return(z)
## Generate data on the CPU.
x = torch.randn(n,n)
y = torch.randn(n,n)
## Copy the objects to the GPU.
x_gpu = x.cuda() # or: `x.to("cuda")`
y_gpu = y.cuda()
torch.set_num_threads(1)
t0 = time.time()
z = matmul_wrap(x, y)
print(time.time() - t0) # 6.8 sec.
start.record()
z_gpu = matmul_wrap(x_gpu, y_gpu)
torch.cuda.synchronize()
end.record()
print(start.elapsed_time(end)) # 70 milliseconds (ms)
So we achieved a speedup of about 100-fold over a single CPU core using an A100 GPU in this case.
Let’s consider the time for copying data to the GPU:
x = torch.randn(n,n)
start.record()
x_gpu = x.cuda()
torch.cuda.synchronize()
end.record()
print(start.elapsed_time(end)) # 60 ms
This suggests that the time in copying the data is similar to that for doing the matrix multiplication.
We can generate data on the GPU like this:
x_gpu = torch.randn(n,n, device=gpu)
Here we’ll consider using the GPU for vectorized calculations. We’ll compare using numpy, CPU-based PyTorch, and GPU-based PyTorch, again with 32-bit numbers.
import torch
import numpy as np
import time
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
gpu = torch.device("cuda:0")
def myfun_np(x):
y = np.exp(x) + 3 * np.sin(x)
return(y)
def myfun_torch(x):
y = torch.exp(x) + 3 * torch.sin(x)
return(y)
n = 250000000
x = torch.randn(n)
x_gpu = x.cuda() # or: `x.to("cuda")`
tmp = np.random.normal(size = n)
x_np = tmp.astype(np.float32) # for fair comparison
## numpy
t0 = time.time()
y_np = myfun_np(x_np)
time.time()-t0 # 1.2 sec.
## CPU-based torch (1 thread)
torch.set_num_threads(1)
start.record()
y = myfun_torch(x)
end.record()
print(start.elapsed_time(end)) # 2200 ms (2.2 sec.)
## GPU-based torch
start.record()
y_gpu = myfun_torch(x_gpu)
torch.cuda.synchronize()
end.record()
print(start.elapsed_time(end)) # 9 ms
So using the GPU speeds things up by 150-fold (compared to numpy) and 250-fold (compared to CPU-based PyTorch).
One can also have PyTorch “fuse” the operations in the loop, which
avoids having the different vectorized operations in myfun
being done
in separate loops under the hood. For an overview of loop fusion, see
this discussion in the
context of Julia.
To fuse the operations, we need to have the function in a module. In
this case I defined myfun_torch
in myfun_torch.py
, and we need to
compile the code using torch.jit.script
.
from myfun_torch import myfun_torch as myfun_torch_tmp
myfun_torch_compiled = torch.jit.script(myfun_torch_tmp)
## CPU plus loop fusion
start.record()
y = myfun_torch_compiled(x)
end.record()
print(start.elapsed_time(end)) # 1000 ms (1 sec.)
## GPU plus loop fusion
start.record()
y_gpu = myfun_torch_compiled(x_gpu)
torch.cuda.synchronize()
end.record()
print(start.elapsed_time(end)) # 3.5 ms
So that seems to give a 2-3 fold speedup compared to without loop fusion.
One can also use PyTorch to run computations on the GPU that comes with Apple’s M2 chips.
The “backend” is called “MPS”, where “M” stands for “Metal”, which is what Apple calls its GPU framework.
import torch
import time
start = torch.mps.Event(enable_timing=True)
end = torch.mps.Event(enable_timing=True)
mps_device = torch.device("mps")
n = 10000
x = torch.randn(n,n)
y = torch.randn(n,n)
x_mps = x.to("mps")
y_mps = y.to("mps")
## On the CPU
torch.set_num_threads(1)
t0 = time.time()
z = matmul_wrap(x, y)
print(time.time() - t0) # 1.8 sec (1800 ms)
## On the M2 GPU
start.record()
z_mps = matmul_wrap(x_mps, y_mps)
torch.mps.synchronize()
end.record()
print(start.elapsed_time(end)) # 950 ms
So there is about a two-fold speed up, which isn’t impressive compared to the speedup on a standard GPU.
Let’s see how much time is involved in transferring the data.
x = torch.randn(n,n)
start.record()
x_mps = x.to("mps")
torch.mps.synchronize()
end.record()
print(start.elapsed_time(end)) # 35 ms.
So it looks like the transfer time is pretty small compared to the computation time (and to the savings involved in using the M2 GPU).
We can generate data on the GPU like this:
x_mps = torch.randn(n,n, device=mps_device)
You can think of JAX as a version of numpy enabled to use the GPU (or automatically parallelize on CPU threads) and provide automatic differentiation.
One can also use just-in-time (JIT) compilation with JAX. Behind the scenes, the instructions are compiled to machine code for different backends (e.g., CPU and GPU) using XLA.
Let’s first consider running a vectorized calculation using JAX on the CPU, which will use multiple threads, each thread running on a separate CPU core on our computer.
import time
import numpy as np
import jax.numpy as jnp
def myfun_np(x):
y = np.exp(x) + 3 * np.sin(x)
return(y)
def myfun_jnp(x):
y = jnp.exp(x) + 3 * jnp.sin(x)
return(y)
n = 250000000
x = np.random.normal(size = n).astype(np.float32) # for consistency
x_jax = jnp.array(x) # 32-bit by default
print(x_jax.platform())
cpu
t0 = time.time()
z = myfun_np(x)
t1 = time.time() - t0
t0 = time.time()
z_jax = myfun_jnp(x_jax).block_until_ready()
t2 = time.time() - t0
print(f"numpy time: {round(t1,3)}\njax time: {round(t2,3)}")
numpy time: 4.49
jax time: 1.643
There’s a nice speedup compared to numpy.
Since JAX will often execute computations asynchronously (in particular
when using the GPU), the block_until_ready
invocation ensures that the
computation finishes before we stop timing.
By default the JAX floating point type is 32-bit so we forced the use of 32-bit numbers for numpy for comparability. One could have JAX use 64-bit numbers like this:
import jax.config
jax.config.update("jax_enable_x64", True)
Next let’s consider JIT compiling it, which should fuse the vectorized operations and avoid temporary objects. The JAX docs have a nice discussion of when JIT compilation will be beneficial.
import jax
myfun_jnp_jit = jax.jit(myfun_jnp)
t0 = time.time()
z_jax_jit = myfun_jnp_jit(x_jax).block_until_ready()
t3 = time.time() - t0
print(f"jitted jax time: {round(t3,3)}")
jitted jax time: 0.793
So that gives another almost 2x speedup.
Linear algebra in JAX will use multiple threads (as discussed for numpy). Here we’ll compare 64-bit calculation, since matrix decompositions sometimes need more precision.
n = 7000
x = np.random.normal(0, 1, size=(n, n))
t0 = time.time()
mat = x.T @ x
print("numpy time:")
print(round(time.time() - t0,3))
t0 = time.time()
U = np.linalg.cholesky(mat)
print(round(time.time() - t0,3))
numpy time:
5.261
3.024
import jax.config
jax.config.update("jax_enable_x64", True)
x_jax = jnp.array(x, dtype = jnp.float64)
print(f"JAX dtype is {x_jax.dtype}")
t0 = time.time()
mat_jax = jnp.matmul(x_jax.transpose(), x_jax)
print("jax time:")
print(round(time.time() - t0,3))
t0 = time.time()
U_jax = jnp.linalg.cholesky(mat_jax)
print(round(time.time() - t0,3))
JAX dtype is float64
jax time:
8.288
1.835
So here the matrix multiplication is slower using JAX with 64-bit numbers but the Cholesky is a bit faster. If one uses 32-bit numbers, JAX is faster for both (not shown).
In general, the JAX speedups are not huge, which is not surprising given both approaches are using multiple threads to carry out the linear algebra. At the least it indicates one can move a numpy workflow to JAX without worrying about losing the threaded BLAS speed of numpy.
Getting threaded CPU computation automatically is nice, but the real benefit of JAX comes in offloading computations to the GPU (and in providing automatic differentiation, not discussed in this tutorial). If a GPU is available and a GPU-enabled JAX is installed, JAX will generally try to use the GPU.
Note my general comments about using the GPU in the PyTorch section.
Note that for this section, I’m pasting in the output when running the code separately on a machine with a GPU because this document is generated on a machine without a GPU.
We’ll just repeat the experiments we ran earlier comparing numpy- and JAX-based calculations, but on a machine with an A100 GPU.
import time
import numpy as np
import jax.numpy as jnp
def myfun_np(x):
y = np.exp(x) + 3 * np.sin(x)
return(y)
def myfun_jnp(x):
y = jnp.exp(x) + 3 * jnp.sin(x)
return(y)
n = 250000000
x = np.random.normal(size = n).astype(np.float32) # for consistency
x_jax = jnp.array(x) # 32-bit by default
print(x_jax.platform()) # gpu
t0 = time.time()
z = myfun_np(x)
print(time.time() - t0) # 1.15 s.
t0 = time.time()
z_jax = myfun_jnp(x_jax).block_until_ready()
print(time.time() - t0) # 0.0099 s.
So that gives a speedup of more than 100x.
import jax
myfun_jnp_jit = jax.jit(myfun_jnp)
t0 = time.time()
z_jax_jit = myfun_jnp_jit(x_jax).block_until_ready() # 0.0052 s.
print(time.time() - t0)
JIT compilation helps a bit (about 2x).
Finally, here’s the linear algebra example on the GPU.
n = 7000
x = np.random.normal(0, 1, size=(n, n)).astype(np.float32) # for consistency
t0 = time.time()
mat = x.T @ x
print(time.time() - t0) # 3.7 s.
t0 = time.time()
U = np.linalg.cholesky(mat) # 3.3 s.
print(time.time() - t0)
x_jax = jnp.array(x)
t0 = time.time()
mat_jax = jnp.matmul(x_jax.transpose(), x_jax).block_until_ready()
print(time.time() - t0) # 0.025 sec.
t0 = time.time()
U_jax = jnp.linalg.cholesky(mat_jax).block_until_ready()
print(time.time() - t0) # 0.08 s.
Again we get a very impressive speedup.
As discussed elsewhere in this tutorial, it takes time to transfer data to and from the GPU, so it’s best to generate values on the GPU and keep objects on the GPU when possible.
Also, JAX objects are designed to be manipulated as objects, rather than manipulating individual values.
x_jax[0,0] = 3.17
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/scratch/users/paciorek/conda/envs/jax-test/lib/python3.11/site-packages/jax/_src/numpy/array_methods.py", line 278, in _unimplemented_setitem
raise TypeError(msg.format(type(self)))
TypeError: '<class 'jaxlib.xla_extension.ArrayImpl'>' object does not support item assignment. JAX arrays are immutable. Instead of ``x[idx] = y``, use ``x = x.at[idx].set(y)`` or another .at[] method: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html
We can use JAX’s vmap
to automatically vectorize a map operation.
Unlike numpy’s vectorize
or apply_along_axis
, which are just handy
syntax (“syntactic sugar”) and don’t actually speed anything up (because
the looping is still done in Python), vmap
actually vectorizes the
loop. Behind the scenes it generates a vectorized version of the code
that can run in parallel on CPU or GPU.
In general, one would use this to automatically iterate over the dimension(s) of one or more arrays. This is convenient from a coding perspective (compared to explicitly writing a loop) and potentially speeds up the computation based on parallelization and by avoiding the overhead of looping at the Python level.
Here we’ll standardize each column of an array using vmap
rather than
writing a loop over the columns.
import jax
import jax.numpy as jnp
import time
nr = 10000
nc = 10000
x = np.random.normal(size = (nr,nc)).astype(np.float32) # for consistency
x_jax = jnp.array(x)
def f(x):
## Standardize a vector by its range.
return x / (np.max(x) - np.min(x))
def f_jax(x):
return x / (jnp.max(x) - jnp.min(x))
# Standardize each column.
t0 = time.time()
out = np.apply_along_axis(f, 0, x)
t1 = time.time() - t0
# JAX vmap numbers axes in reverse order of numpy, apparently.
f_jax_vmap = jax.vmap(f_jax, in_axes=1, out_axes=1)
t0 = time.time()
out_jax = f_jax_vmap(x_jax).block_until_ready()
t2 = time.time() - t0
print(f"numpy time: {round(t1,3)}\njax vmap time: {round(t2,3)}")
numpy time: 3.654
jax vmap time: 1.566
That gives a nice speedup. Let’s also try JIT’ing it. That gives a further speedup.
f_jax_vmap_jit = jax.jit(f_jax_vmap)
t0 = time.time()
out_jax_jit = f_jax_vmap_jit(x_jax).block_until_ready()
t3 = time.time() - t0
print(f"jitted jax vmap time: {round(t3,3)}")
jitted jax vmap time: 0.322
It would make sense to explore the benefits of using a GPU here, though I haven’t done so.
vmap
has a lot of flexibility to operate on various axes of its input
arguments (and structure the output axes). Suppose we want to do the
same standardization but using the columns of a different array as what
to standardize based on.
y = np.random.normal(size = (nr,nc)).astype(np.float32)
y_jax = jnp.array(y)
def f2_jax(x, y):
return x / (jnp.max(y) - jnp.min(y))
out2 = jax.vmap(f2_jax, in_axes=(1,1), out_axes=1)(x_jax, y_jax)
f2_jax_jit = jax.jit(jax.vmap(f2_jax, in_axes=(1,1), out_axes=1))
out3 = f2_jax_jit(x_jax, y_jax)
Finally, note that pmap
is a function with a similar-sounding name
that allows one to parallelize a map operation over multiple devices
(e.g., multiple GPUs).
CuPy is another package allowing one to execute numpy-type calculations on the GPU (Nvidia only). It has some similarity to JAX.
Here’s a basic illustration, where we get a 175x speedup for generating a random matrix and matrix multiplication when using an A100 GPU.
import cupy
import numpy as np
import time
def matmul_np(n):
x = np.random.normal(size=(n,n))
z = np.matmul(x,x)
return(z)
def matmul_cupy(n):
x = cupy.random.normal(size=(n,n))
z = cupy.matmul(x,x)
return(z)
n = 7000
t0 = time.time()
z = matmul_np(n)
print(time.time() - t0) # 8.8 s.
t0 = time.time()
z_cupy = matmul_cupy(n)
cupy.cuda.stream.get_current_stream().synchronize()
print(time.time() - t0) # .05 s.
You can also use cupy.RawKernel
to execute a GPU kernel written in
CUDA C/C++ directly from Python. That’s a bit beyond our scope here, so
I won’t show an example.