Skip to content

Commit

Permalink
[Documentation] Add DQC Tutorial (#12)
Browse files Browse the repository at this point in the history
  • Loading branch information
dominikandreasseitz authored Apr 8, 2024
1 parent 491cdb3 commit 0db46bd
Show file tree
Hide file tree
Showing 9 changed files with 382 additions and 173 deletions.
39 changes: 18 additions & 21 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,46 +1,43 @@
# horqrux
[![Linting / Tests/ Documentation](https://github.com/pasqal-io/horqrux/actions/workflows/run-tests-and-mypy.yml/badge.svg)](https://github.com/pasqal-io/horqrux/actions/workflows/run-tests-and-mypy.yml)
[![License](https://img.shields.io/badge/License-Apache_2.0-blue.svg)](https://opensource.org/licenses/Apache-2.0)
[![Pypi](https://badge.fury.io/py/horqrux.svg)](https://pypi.org/project/horqrux/)

**horqrux** is a [JAX](https://jax.readthedocs.io/en/latest/)-based state vector simulator designed for quantum machine learning.
It acts as a backend for [`Qadence`](https://github.com/pasqal-io/qadence), a digital-analog quantum programming interface.
`horqrux` is a [JAX](https://jax.readthedocs.io/en/latest/)-based state vector simulator designed for quantum machine learning and acts as a backend for [`Qadence`](https://github.com/pasqal-io/qadence), a digital-analog quantum programming interface.

## Installation

`horqrux` (CPU-only) can be installed from PyPI with `pip` as follows:
To install the CPU-only version, simply use `pip`:
```bash
pip install horqrux
```
If you want to install the GPU version, simply do:
If you intend to use GPU:

```bash
pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_releases.html
```

[![Linting / Tests/ Documentation](https://github.com/pasqal-io/horqrux/actions/workflows/run-tests-and-mypy.yml/badge.svg)](https://github.com/pasqal-io/horqrux/actions/workflows/run-tests-and-mypy.yml)
[![License](https://img.shields.io/badge/License-Apache_2.0-blue.svg)](https://opensource.org/licenses/Apache-2.0)
[![Pypi](https://badge.fury.io/py/horqrux.svg)](https://pypi.org/project/horqrux/)
## Getting started
`horqrux` adopts a minimalistic and functional interface however the [docs](https://pasqal-io.github.io/horqrux/latest/) provide a comprehensive A-Z guide ranging from how to apply simple primitive and parametric gates, to using [adjoint differentiation](https://arxiv.org/abs/2009.02823) to fit a nonlinear function and implementing [DQC](https://arxiv.org/abs/2011.10395) to solve a partial differential equation.

## Contributing

## Install from source
To learn how to contribute, please visit the [CONTRIBUTING](docs/CONTRIBUTING.md) page.

We recommend to use the [`hatch`](https://hatch.pypa.io/latest/) environment manager to install `horqrux` from source:
When developing within `horqrux`, you can either use the python environment manager [`hatch`](https://hatch.pypa.io/latest/):

```bash
python -m pip install hatch
pip install hatch

# get into a shell with all the dependencies
python -m hatch shell
# enter a shell with containing all the dependencies
hatch shell

# run a command within the virtual environment with all the dependencies
python -m hatch run python my_script.py
hatch run python my_script.py
```

Please note that `hatch` will not combine nicely with other environment managers such Conda. If you want to use Conda, install `horqrux` from source using `pip`:
When using any other environment manager like `venv` or `conda`, simply do:

```bash
# within the Conda environment
python -m pip install -e .
# within the virtual environment
pip install -e .
```

## Contributing

Please refer to [CONTRIBUTING](docs/CONTRIBUTING.md) to learn how to contribute to `horqrux`.
233 changes: 220 additions & 13 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ choice and install it normally with `pip`:
pip install horqrux
```

## Gates
## Digital operations

`horqrux` implements a large selection of both primitive and parametric single to n-qubit, digital quantum gates.

Expand Down Expand Up @@ -68,10 +68,34 @@ param_value = 1 / 4 * jnp.pi
new_state = apply_gate(state, RX(param_value, target_qubit, control_qubit))
```

## Analog Operations

`horqrux` also allows for global state evolution via the `HamiltonianEvolution` operation.
Note that it expects a hamiltonian and a time evolution parameter passed as `numpy` or `jax.numpy` arrays. To build arbitrary Pauli hamiltonians, we recommend using [Qadence](https://github.com/pasqal-io/qadence/blob/main/examples/backends/low_level/horqrux_analog.py).

```python exec="on" source="material-block"
from jax.numpy import pi, array, diag, kron, cdouble
from horqrux.analog import HamiltonianEvolution
from horqrux.apply import apply_gate
from horqrux.utils import uniform_state

sigmaz = diag(array([1.0, -1.0], dtype=cdouble))
Hbase = kron(sigmaz, sigmaz)

Hamiltonian = kron(Hbase, Hbase)
n_qubits = 4
t_evo = pi / 4
hamevo = HamiltonianEvolution(tuple([i for i in range(n_qubits)]))
psi = uniform_state(n_qubits)
psi_star = apply_gate(psi, hamevo, {"hamiltonian": Hamiltonian, "time_evolution": t_evo})
```

## Fitting a nonlinear function using adjoint differentiation

We can now build a fully differentiable variational circuit by simply defining a sequence of gates
and a set of initial parameter values we want to optimize.
Horqrux provides an implementation of [adjoint differentiation](https://arxiv.org/abs/2009.02823),
which we can use to fit a function using a simple circuit class wrapper.
`horqrux` provides an implementation of [adjoint differentiation](https://arxiv.org/abs/2009.02823),
which we can use to fit a function using a simple `Circuit` class.

```python exec="on" source="material-block" html="1"
from __future__ import annotations
Expand All @@ -87,7 +111,7 @@ from typing import Any, Callable
from uuid import uuid4

from horqrux.adjoint import adjoint_expectation
from horqrux.abstract import Primitive
from horqrux.primitive import Primitive
from horqrux import Z, RX, RY, NOT, zero_state, apply_gate


Expand Down Expand Up @@ -121,18 +145,16 @@ class Circuit:

def __post_init__(self) -> None:
# We will use a featuremap of RX rotations to encode some classical data
self.feature_map: list[Primitive] = [RX('phi', i) for i in range(n_qubits)]
self.feature_map: list[Primitive] = [RX('phi', i) for i in range(self.n_qubits)]
self.ansatz, self.param_names = ansatz_w_params(self.n_qubits, self.n_layers)
self.observable: list[Primitive] = [Z(0)]

@partial(vmap, in_axes=(None, None, 0))
def forward(self, param_values: Array, x: Array) -> Array:
def __call__(self, param_values: Array, x: Array) -> Array:
state = zero_state(self.n_qubits)
param_dict = {name: val for name, val in zip(self.param_names, param_values)}
return adjoint_expectation(state, self.feature_map + self.ansatz, self.observable, {**param_dict, **{'phi': x}})

def __call__(self, param_values: Array, x: Array) -> Array:
return self.forward(param_values, x)

@property
def n_vparams(self) -> int:
Expand All @@ -154,15 +176,15 @@ def loss_fn(param_vals: Array, x: Array, y: Array) -> Array:
return jnp.mean(optax.l2_loss(y_pred, y))


def optimize_step(params: dict[str, Array], opt_state: Array, grads: dict[str, Array]) -> tuple:
def optimize_step(param_vals: Array, opt_state: Array, grads: Array) -> tuple:
updates, opt_state = optimizer.update(grads, opt_state)
params = optax.apply_updates(params, updates)
return params, opt_state
param_vals = optax.apply_updates(param_vals, updates)
return param_vals, opt_state

@jit
def train_step(i: int, inputs: tuple
def train_step(i: int, paramvals_w_optstate: tuple
) -> tuple:
param_vals, opt_state = inputs
param_vals, opt_state = paramvals_w_optstate
loss, grads = value_and_grad(loss_fn)(param_vals, x, y)
param_vals, opt_state = optimize_step(param_vals, opt_state, grads)
return param_vals, opt_state
Expand All @@ -188,3 +210,188 @@ def fig_to_html(fig: Figure) -> str: # markdown-exec: hide
# from docs import docutils # markdown-exec: hide
print(fig_to_html(plt.gcf())) # markdown-exec: hide
```
## Fitting a partial differential equation using DQC

Finally, we show how [DQC](https://arxiv.org/abs/2011.10395) can be implemented in `horqrux` and solve a partial differential equation.

```python exec="on" source="material-block" html="1"
from __future__ import annotations

from dataclasses import dataclass
from functools import reduce
from itertools import product
from operator import add
from uuid import uuid4

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import optax
from jax import Array, jit, value_and_grad, vmap
from numpy.random import uniform

from horqrux import NOT, RX, RY, Z, apply_gate, zero_state
from horqrux.primitive import Primitive
from horqrux.utils import inner

LEARNING_RATE = 0.01
N_QUBITS = 4
DEPTH = 3
VARIABLES = ("x", "y")
X_POS = 0
Y_POS = 1
N_POINTS = 150
N_EPOCHS = 1000


def ansatz_w_params(n_qubits: int, n_layers: int) -> tuple[list, list]:
all_ops = []
param_names = []
rots_fns = [RX, RY, RX]
for _ in range(n_layers):
for i in range(n_qubits):
ops = [
fn(str(uuid4()), qubit)
for fn, qubit in zip(rots_fns, [i for _ in range(len(rots_fns))])
]
param_names += [op.param for op in ops]
ops += [NOT((i + 1) % n_qubits, i % n_qubits) for i in range(n_qubits)]
all_ops += ops

return all_ops, param_names


@dataclass
class TotalMagnetization:
n_qubits: int

def __post_init__(self) -> None:
self.paulis = [Z(i) for i in range(self.n_qubits)]

def __call__(self, state: Array, values: dict) -> Array:
return reduce(add, [apply_gate(state, pauli, values) for pauli in self.paulis])


@dataclass
class Circuit:
n_qubits: int
n_layers: int

def __post_init__(self) -> None:
self.feature_map: list[Primitive] = [RX("x", i) for i in range(self.n_qubits // 2)] + [
RX("y", i) for i in range(self.n_qubits // 2, self.n_qubits)
]
self.ansatz, self.param_names = ansatz_w_params(self.n_qubits, self.n_layers)
self.observable = TotalMagnetization(self.n_qubits)

def __call__(self, param_vals: Array, x: Array, y: Array) -> Array:
state = zero_state(self.n_qubits)
param_dict = {name: val for name, val in zip(self.param_names, param_vals)}
out_state = apply_gate(
state, self.feature_map + self.ansatz, {**param_dict, **{"x": x, "y": y}}
)
projected_state = self.observable(state, param_dict)
return jnp.real(inner(out_state, projected_state))

@property
def n_vparams(self) -> int:
return len(self.param_names)


circ = Circuit(N_QUBITS, DEPTH)
# Create random initial values for the parameters
key = jax.random.PRNGKey(42)
param_vals = jax.random.uniform(key, shape=(circ.n_vparams,))

optimizer = optax.adam(learning_rate=0.01)
opt_state = optimizer.init(param_vals)


def exp_fn(param_vals: Array, x: Array, y: Array) -> Array:
return circ(param_vals, x, y)


def loss_fn(param_vals: Array, x: Array, y: Array) -> Array:
def pde_loss(x: float, y: float) -> Array:
l_b, r_b, t_b, b_b = list(
map(
lambda xy: exp_fn(param_vals, *xy),
[
[jnp.zeros((1, 1)), y], # u(0,y)=0
[jnp.ones((1, 1)), y], # u(L,y)=0
[x, jnp.ones((1, 1))], # u(x,H)=0
[x, jnp.zeros((1, 1))], # u(x,0)=f(x)
],
)
)
b_b -= jnp.sin(jnp.pi * x)
hessian = jax.hessian(lambda xy: exp_fn(param_vals, xy[0], xy[1]))(
jnp.concatenate(
[
x.reshape(
1,
),
y.reshape(
1,
),
]
)
)
interior = hessian[X_POS][X_POS] + hessian[Y_POS][Y_POS] # uxx+uyy=0
return reduce(add, list(map(lambda term: jnp.power(term, 2), [l_b, r_b, t_b, b_b, interior])))

return jnp.mean(vmap(pde_loss, in_axes=(0, 0))(x, y))


def optimize_step(param_vals: Array, opt_state: Array, grads: dict[str, Array]) -> tuple:
updates, opt_state = optimizer.update(grads, opt_state, param_vals)
param_vals = optax.apply_updates(param_vals, updates)
return param_vals, opt_state


# collocation points sampling and training
def sample_points(n_in: int, n_p: int) -> Array:
return uniform(0, 1.0, (n_in, n_p))


@jit
def train_step(i: int, paramvals_w_optstate: tuple) -> tuple:
param_vals, opt_state = paramvals_w_optstate
x, y = sample_points(2, N_POINTS)
loss, grads = value_and_grad(loss_fn)(param_vals, x, y)
return optimize_step(param_vals, opt_state, grads)


param_vals, opt_state = jax.lax.fori_loop(0, N_EPOCHS, train_step, (param_vals, opt_state))
# compare the solution to known ground truth
single_domain = jnp.linspace(0, 1, num=N_POINTS)
domain = jnp.array(list(product(single_domain, single_domain)))
# analytical solution
analytic_sol = (
(np.exp(-np.pi * domain[:, 0]) * np.sin(np.pi * domain[:, 1])).reshape(N_POINTS, N_POINTS).T
)
# DQC solution

dqc_sol = vmap(lambda domain: exp_fn(param_vals, domain[0], domain[1]), in_axes=(0,))(domain).reshape(
N_POINTS, N_POINTS
)
# # plot results
fig, ax = plt.subplots(1, 2, figsize=(7, 7))
ax[0].imshow(analytic_sol, cmap="turbo")
ax[0].set_xlabel("x")
ax[0].set_ylabel("y")
ax[0].set_title("Analytical solution u(x,y)")
ax[1].imshow(dqc_sol, cmap="turbo")
ax[1].set_xlabel("x")
ax[1].set_ylabel("y")
ax[1].set_title("DQC solution u(x,y)")
from io import StringIO # markdown-exec: hide
from matplotlib.figure import Figure # markdown-exec: hide
def fig_to_html(fig: Figure) -> str: # markdown-exec: hide
buffer = StringIO() # markdown-exec: hide
fig.savefig(buffer, format="svg") # markdown-exec: hide
return buffer.getvalue() # markdown-exec: hide
# from docs import docutils # markdown-exec: hide
print(fig_to_html(plt.gcf())) # markdown-exec: hide
```
Loading

0 comments on commit 0db46bd

Please sign in to comment.