Skip to content

Commit

Permalink
Remake (#9)
Browse files Browse the repository at this point in the history
Remake of horqrux in the style of pyqtorch, i.e., geared towards Qadence.
  • Loading branch information
dominikandreasseitz authored Jan 25, 2024
1 parent dbb49cb commit d54c02d
Show file tree
Hide file tree
Showing 13 changed files with 800 additions and 603 deletions.
174 changes: 127 additions & 47 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,92 +18,172 @@ pip install horqrux
Let's have a look at primitive gates first.

```python exec="on" source="material-block"
from horqrux.gates import X
from horqrux.utils import prepare_state
from horqrux.ops import apply_gate
from horqrux import X, random_state, apply_gate

state = prepare_state(2)
state = random_state(2)
new_state = apply_gate(state, X(0))
```

We can also make any gate controlled, in the case of X, we have to pass the target qubit first!

```python exec="on" source="material-block"
import jax.numpy as jnp
from horqrux.gates import X
from horqrux.utils import prepare_state, equivalent_state
from horqrux.ops import apply_gate
from horqrux import X, product_state, equivalent_state, apply_gate

n_qubits = 2
state = prepare_state(n_qubits, '11')
state = product_state('11')
control = 0
target = 1
# This is equivalent to performing CNOT(0,1)
new_state= apply_gate(state, X(target,control))
assert jnp.allclose(new_state, prepare_state(n_qubits, '10'))
assert jnp.allclose(new_state, product_state('10'))
```

When applying parametric gates, we pass the numeric value for the parameter first
When applying parametric gates, we can either pass a numeric value or a parameter name for the parameter as the first argument.

```python exec="on" source="material-block"
import jax.numpy as jnp
from horqrux.gates import Rx
from horqrux.utils import prepare_state
from horqrux.ops import apply_gate
from horqrux import RX, random_state, apply_gate

target_qubit = 1
state = prepare_state(target_qubit+1)
state = random_state(target_qubit+1)
param_value = 1 / 4 * jnp.pi
new_state = apply_gate(state, Rx(param_value, target_qubit))
new_state = apply_gate(state, RX(param_value, target_qubit))
# Parametric horqrux gates also accept parameter names in the form of strings.
# Simply pass a dictionary of parameter names and values to the 'apply_gate' function
new_state = apply_gate(state, RX('theta', target_qubit), {'theta': jnp.pi})
```

We can also make any parametric gate controlled simply by passing a control qubit.

```python exec="on" source="material-block"
import jax.numpy as jnp
from horqrux.gates import Rx
from horqrux.utils import prepare_state
from horqrux.ops import apply_gate
from horqrux import RX, product_state, apply_gate

n_qubits = 2
target_qubit = 1
control_qubit = 0
state = prepare_state(2, '11')
state = product_state('11')
param_value = 1 / 4 * jnp.pi
new_state = apply_gate(state, Rx(param_value, target_qubit, control_qubit))
new_state = apply_gate(state, RX(param_value, target_qubit, control_qubit))
```

A fully differentiable variational circuit is simply a sequence of gates which are applied to a state.
We can now build a fully differentiable variational circuit by simply defining a sequence of gates
and a set of initial parameter values we want to optimize.
Lets fit a function using a simple circuit class wrapper.

```python exec="on" source="material-block" html="1"
from __future__ import annotations

```python exec="on" source="material-block"
import jax
from jax import grad, jit, Array, value_and_grad, vmap
from dataclasses import dataclass
import jax.numpy as jnp
from horqrux import gates
from horqrux.utils import prepare_state, overlap
from horqrux.ops import apply_gate
import optax
from functools import reduce, partial
from operator import add
from typing import Any, Callable
from uuid import uuid4

from horqrux.abstract import Operator
from horqrux import Z, RX, RY, NOT, zero_state, apply_gate, overlap


n_qubits = 5
n_params = 3
n_layers = 3

n_qubits = 2
state = prepare_state(2, '00')
# Lets define a sequence of rotations
ops = [gates.Rx, gates.Ry, gates.Rx]
def ansatz_w_params(n_qubits: int, n_layers: int) -> tuple[list, list]:
all_ops = []
param_names = []
rots_fns = [RX ,RY, RX]
for _ in range(n_layers):
for i in range(n_qubits):
ops = [fn(str(uuid4()), qubit) for fn, qubit in zip(rots_fns, [i for _ in range(len(rots_fns))])]
param_names += [op.param for op in ops]
ops += [NOT((i+1) % n_qubits, i % n_qubits) for i in range(n_qubits)]
all_ops += ops

return all_ops, param_names

# We need a function to fit and use it to produce training data
fn = lambda x, degree: .05 * reduce(add, (jnp.cos(i*x) + jnp.sin(i*x) for i in range(degree)), 0)
x = jnp.linspace(0, 10, 100)
y = fn(x, 5)

@dataclass
class Circuit:
n_qubits: int
n_layers: int

def __post_init__(self) -> None:
# We will use a featuremap of RX rotations to encode some classical data
self.feature_map: list[Operator] = [RX('phi', i) for i in range(n_qubits)]
self.ansatz, self.param_names = ansatz_w_params(self.n_qubits, self.n_layers)
self.observable: list[Operator] = [Z(0)]

@partial(vmap, in_axes=(None, None, 0))
def forward(self, param_values: Array, x: Array) -> Array:
state = zero_state(self.n_qubits)
param_dict = {name: val for name, val in zip(self.param_names, param_values)}
state = apply_gate(state, self.feature_map + self.ansatz, {**param_dict, **{'phi': x}})
return overlap(state, apply_gate(state, self.observable))

def __call__(self, param_values: Array, x: Array) -> Array:
return self.forward(param_values, x)

@property
def n_vparams(self) -> int:
return len(self.param_names)

circ = Circuit(n_qubits, n_layers)
# Create random initial values for the parameters
key = jax.random.PRNGKey(0)
params = jax.random.uniform(key, shape=(n_qubits * len(ops),))

def circ(state) -> jax.Array:
for qubit in range(n_qubits):
for gate,param in zip(ops, params):
state = apply_gate(state, gate(param, qubit))
state = apply_gate(state,gates.NOT(1, 0))
projection = apply_gate(state, gates.Z(0))
return overlap(state, projection)

# Let's compute both values and gradients for a set of parameters and compile the circuit.
circ = jax.jit(jax.value_and_grad(circ))
# Run it on a state.
expval_and_grads = circ(state)
expval = expval_and_grads[0]
grads = expval_and_grads[1:]
print(f'Expval: {expval};'
f'Grads: {grads}')
key = jax.random.PRNGKey(42)
param_vals = jax.random.uniform(key, shape=(circ.n_vparams,))
# Check the initial predictions using randomly initialized parameters
y_init = circ(param_vals, x)

optimizer = optax.adam(learning_rate=0.01)
opt_state = optimizer.init(param_vals)

# Define a loss function
def loss_fn(param_vals: Array, x: Array, y: Array) -> Array:
y_pred = circ(param_vals, x)
return jnp.mean(optax.l2_loss(y_pred, y))


def optimize_step(params: dict[str, Array], opt_state: Array, grads: dict[str, Array]) -> tuple:
updates, opt_state = optimizer.update(grads, opt_state)
params = optax.apply_updates(params, updates)
return params, opt_state

@jit
def train_step(i: int, inputs: tuple
) -> tuple:
param_vals, opt_state = inputs
loss, grads = value_and_grad(loss_fn)(param_vals, x, y)
param_vals, opt_state = optimize_step(param_vals, opt_state, grads)
return param_vals, opt_state


n_epochs = 200
param_vals, opt_state = jax.lax.fori_loop(0, n_epochs, train_step, (param_vals, opt_state))
y_final = circ(param_vals, x)

# Lets plot the results
import matplotlib.pyplot as plt
plt.plot(x, y, label="truth")
plt.plot(x, y_init, label="initial")
plt.plot(x, y_final, "--", label="final", linewidth=3)
plt.legend()

from io import StringIO # markdown-exec: hide
from matplotlib.figure import Figure # markdown-exec: hide
def fig_to_html(fig: Figure) -> str: # markdown-exec: hide
buffer = StringIO() # markdown-exec: hide
fig.savefig(buffer, format="svg") # markdown-exec: hide
return buffer.getvalue() # markdown-exec: hide
# from docs import docutils # markdown-exec: hide
print(fig_to_html(plt.gcf())) # markdown-exec: hide
```
15 changes: 12 additions & 3 deletions horqrux/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,14 @@
from __future__ import annotations

from jax import config

config.update("jax_enable_x64", True) # you should really really do this
from .apply import apply_gate, apply_operator
from .parametric import PHASE, RX, RY, RZ
from .primitive import NOT, SWAP, H, I, S, T, X, Y, Z
from .utils import (
equivalent_state,
hilbert_reshape,
overlap,
product_state,
random_state,
uniform_state,
zero_state,
)
125 changes: 125 additions & 0 deletions horqrux/abstract.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import Any, Iterable, Tuple

import numpy as np
from jax import Array
from jax.tree_util import register_pytree_node_class

from .matrices import OPERATIONS_DICT
from .utils import (
ControlQubits,
QubitSupport,
TargetQubits,
_dagger,
_jacobian,
_unitary,
is_controlled,
none_like,
)


@register_pytree_node_class
@dataclass
class Operator:
"""Abstract class which stores information about generators target and control qubits
of a particular quantum operator."""

generator_name: str
target: QubitSupport
control: QubitSupport

@staticmethod
def parse_idx(
idx: Tuple,
) -> Tuple:
if isinstance(idx, (int, np.int64)):
return ((idx,),)
elif isinstance(idx, tuple):
return (idx,)
else:
return (idx.astype(int),)

def __post_init__(self) -> None:
self.target = Operator.parse_idx(self.target)
if self.control is None:
self.control = none_like(self.target)
else:
self.control = Operator.parse_idx(self.control)

def __iter__(self) -> Iterable:
return iter((self.generator_name, self.target, self.control))

def tree_flatten(self) -> Tuple[Tuple, Tuple[str, TargetQubits, ControlQubits]]:
children = ()
aux_data = (self.generator_name, self.target, self.control)
return (children, aux_data)

@classmethod
def tree_unflatten(cls, aux_data: Any, children: Any) -> Any:
return cls(*children, *aux_data)

def unitary(self, values: dict[str, float] = dict()) -> Array:
return OPERATIONS_DICT[self.generator_name]

def dagger(self, values: dict[str, float] = dict()) -> Array:
return _dagger(self.unitary(values))

@property
def name(self) -> str:
return "C" + self.generator_name if is_controlled(self.control) else self.generator_name

def __repr__(self) -> str:
return self.name + f"(target={self.target[0]}, control={self.control[0]})"


Primitive = Operator


@register_pytree_node_class
@dataclass
class Parametric(Primitive):
"""Extension of the Primitive class adding the option to pass a parameter."""

generator_name: str
target: QubitSupport
control: QubitSupport
param: str | float = ""

def __post_init__(self) -> None:
super().__post_init__()

def parse_dict(values: dict[str, float] = dict()) -> float:
return values[self.param] # type: ignore[index]

def parse_val(values: dict[str, float] = dict()) -> float:
return self.param # type: ignore[return-value]

self.parse_values = parse_dict if isinstance(self.param, str) else parse_val

def tree_flatten(self) -> Tuple[Tuple, Tuple[str, Tuple, Tuple, str | float]]: # type: ignore[override]
children = ()
aux_data = (
self.name,
self.target,
self.control,
self.param,
)
return (children, aux_data)

def unitary(self, values: dict[str, float] = dict()) -> Array:
return _unitary(OPERATIONS_DICT[self.generator_name], self.parse_values(values))

def jacobian(self, values: dict[str, float] = dict()) -> Array:
return _jacobian(OPERATIONS_DICT[self.generator_name], self.parse_values(values))

@property
def name(self) -> str:
base_name = "R" + self.generator_name
return "C" + base_name if is_controlled(self.control) else base_name

def __repr__(self) -> str:
return (
self.name + f"(target={self.target[0]}, control={self.control[0]}, param={self.param})"
)
Loading

0 comments on commit d54c02d

Please sign in to comment.