Skip to content

Commit

Permalink
[Feature] Add HamiltonianEvolution, Support py3.12 (#11)
Browse files Browse the repository at this point in the history
* [Feature] Add HamiltonianEvolution, Support py3.12
  • Loading branch information
dominikandreasseitz authored Feb 9, 2024
1 parent a64c900 commit 491cdb3
Show file tree
Hide file tree
Showing 9 changed files with 112 additions and 28 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/run-tests-and-mypy.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.9", "3.10", "3.11"]
python-version: ["3.9", "3.10", "3.11", "3.12"]
steps:
- name: Checkout main code and submodules
uses: actions/checkout@v4
Expand Down
8 changes: 4 additions & 4 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,8 @@ from typing import Any, Callable
from uuid import uuid4

from horqrux.adjoint import adjoint_expectation
from horqrux.abstract import Operator
from horqrux import Z, RX, RY, NOT, zero_state, apply_gate, overlap
from horqrux.abstract import Primitive
from horqrux import Z, RX, RY, NOT, zero_state, apply_gate


n_qubits = 5
Expand Down Expand Up @@ -121,9 +121,9 @@ class Circuit:

def __post_init__(self) -> None:
# We will use a featuremap of RX rotations to encode some classical data
self.feature_map: list[Operator] = [RX('phi', i) for i in range(n_qubits)]
self.feature_map: list[Primitive] = [RX('phi', i) for i in range(n_qubits)]
self.ansatz, self.param_names = ansatz_w_params(self.n_qubits, self.n_layers)
self.observable: list[Operator] = [Z(0)]
self.observable: list[Primitive] = [Z(0)]

@partial(vmap, in_axes=(None, None, 0))
def forward(self, param_values: Array, x: Array) -> Array:
Expand Down
11 changes: 4 additions & 7 deletions horqrux/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@

@register_pytree_node_class
@dataclass
class Operator:
"""Abstract class which stores information about generators target and control qubits
class Primitive:
"""Primitive gate class which stores information about generators target and control qubits
of a particular quantum operator."""

generator_name: str
Expand All @@ -42,11 +42,11 @@ def parse_idx(
return (idx.astype(int),)

def __post_init__(self) -> None:
self.target = Operator.parse_idx(self.target)
self.target = Primitive.parse_idx(self.target)
if self.control is None:
self.control = none_like(self.target)
else:
self.control = Operator.parse_idx(self.control)
self.control = Primitive.parse_idx(self.control)

def __iter__(self) -> Iterable:
return iter((self.generator_name, self.target, self.control))
Expand Down Expand Up @@ -74,9 +74,6 @@ def __repr__(self) -> str:
return self.name + f"(target={self.target[0]}, control={self.control[0]})"


Primitive = Operator


@register_pytree_node_class
@dataclass
class Parametric(Primitive):
Expand Down
21 changes: 11 additions & 10 deletions horqrux/adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,45 +3,46 @@
from typing import Tuple

from jax import Array, custom_vjp
from jax.numpy import real as jnpreal

from horqrux.abstract import Operator, Parametric
from horqrux.abstract import Parametric, Primitive
from horqrux.apply import apply_gate
from horqrux.utils import OperationType, overlap
from horqrux.utils import OperationType, inner


def expectation(
state: Array, gates: list[Operator], observable: list[Operator], values: dict[str, float]
state: Array, gates: list[Primitive], observable: list[Primitive], values: dict[str, float]
) -> Array:
out_state = apply_gate(state, gates, values, OperationType.UNITARY)
projected_state = apply_gate(out_state, observable, values, OperationType.UNITARY)
return overlap(out_state, projected_state)
return jnpreal(inner(out_state, projected_state))


@custom_vjp
def adjoint_expectation(
state: Array, gates: list[Operator], observable: list[Operator], values: dict[str, float]
state: Array, gates: list[Primitive], observable: list[Primitive], values: dict[str, float]
) -> Array:
return expectation(state, gates, observable, values)


def adjoint_expectation_fwd(
state: Array, gates: list[Operator], observable: list[Operator], values: dict[str, float]
) -> Tuple[Array, Tuple[Array, Array, list[Operator], dict[str, float]]]:
state: Array, gates: list[Primitive], observable: list[Primitive], values: dict[str, float]
) -> Tuple[Array, Tuple[Array, Array, list[Primitive], dict[str, float]]]:
out_state = apply_gate(state, gates, values, OperationType.UNITARY)
projected_state = apply_gate(out_state, observable, values, OperationType.UNITARY)
return overlap(out_state, projected_state), (out_state, projected_state, gates, values)
return jnpreal(inner(out_state, projected_state)), (out_state, projected_state, gates, values)


def adjoint_expectation_bwd(
res: Tuple[Array, Array, list[Operator], dict[str, float]], tangent: Array
res: Tuple[Array, Array, list[Primitive], dict[str, float]], tangent: Array
) -> tuple:
out_state, projected_state, gates, values = res
grads = {}
for gate in gates[::-1]:
out_state = apply_gate(out_state, gate, values, OperationType.DAGGER)
if isinstance(gate, Parametric):
mu = apply_gate(out_state, gate, values, OperationType.JACOBIAN)
grads[gate.param] = tangent * 2 * overlap(mu, projected_state)
grads[gate.param] = tangent * 2 * jnpreal(inner(mu, projected_state))
projected_state = apply_gate(projected_state, gate, values, OperationType.DAGGER)
return (None, None, None, grads)

Expand Down
31 changes: 31 additions & 0 deletions horqrux/analog.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from __future__ import annotations

from dataclasses import dataclass

from jax import Array
from jax.scipy.linalg import expm
from jax.tree_util import register_pytree_node_class

from .abstract import Primitive, QubitSupport


@register_pytree_node_class
@dataclass
class _HamiltonianEvolution(Primitive):
"""
A slim wrapper class which evolves a 'hamiltonian'
given a 'time_evolution' parameter and applies it to 'state' psi by doing: matrixexp(-iHt)|psi>
"""

generator_name: str
target: QubitSupport
control: QubitSupport

def unitary(self, values: dict[str, Array] = dict()) -> Array:
return expm(values["hamiltonian"] * (-1j * values["time_evolution"]))


def HamiltonianEvolution(
target: QubitSupport, control: QubitSupport = (None,)
) -> _HamiltonianEvolution:
return _HamiltonianEvolution("I", target, control)
6 changes: 3 additions & 3 deletions horqrux/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import numpy as np
from jax import Array

from horqrux.abstract import Operator
from horqrux.abstract import Primitive

from .utils import OperationType, State, _controlled, is_controlled

Expand Down Expand Up @@ -53,7 +53,7 @@ def apply_operator(

def apply_gate(
state: State,
gate: Operator | Iterable[Operator],
gate: Primitive | Iterable[Primitive],
values: dict[str, float] = dict(),
op_type: OperationType = OperationType.UNITARY,
) -> State:
Expand All @@ -68,7 +68,7 @@ def apply_gate(
State after applying 'gate'.
"""
operator: Tuple[Array, ...]
if isinstance(gate, Operator):
if isinstance(gate, Primitive):
operator_fn = getattr(gate, op_type)
operator, target, control = (operator_fn(values),), gate.target, gate.control
else:
Expand Down
10 changes: 9 additions & 1 deletion horqrux/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,12 @@ def equivalent_state(s0: Array, s1: Array) -> bool:
return jnp.allclose(overlap(s0, s1), 1.0, atol=ATOL) # type: ignore[no-any-return]


def inner(state: Array, projection: Array) -> Array:
return jnp.dot(jnp.conj(state.flatten()), projection.flatten())


def overlap(state: Array, projection: Array) -> Array:
return jnp.real(jnp.dot(jnp.conj(state.flatten()), projection.flatten()))
return jnp.real(jnp.power(inner(state, projection), 2))


def uniform_state(
Expand Down Expand Up @@ -150,3 +154,7 @@ def _normalize(wf: Array) -> Array:
return _normalize(
(jnp.sqrt(x / sumx) * jnp.exp(1j * phases)).reshape(tuple(2 for _ in range(n_qubits)))
)


def is_normalized(state: Array) -> bool:
return equivalent_state(state, state)
5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@ authors = [
{ name = "Gert-Jan Both" , email = "[email protected]" },
{ name = "Dominik Seitz", email = "[email protected]" },
]
requires-python = ">=3.9,<3.12"
requires-python = ">=3.8,<3.13"
license = {text = "Apache 2.0"}

version = "0.5.0"
version = "0.6.0"

classifiers=[
"License :: Other/Proprietary License",
Expand All @@ -21,6 +21,7 @@ classifiers=[
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: Implementation :: CPython",
"Programming Language :: Python :: Implementation :: PyPy",
]
Expand Down
46 changes: 46 additions & 0 deletions tests/test_analog.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from __future__ import annotations

import jax.numpy as jnp
import numpy as np
import pytest
from jax import jit, vmap

from horqrux.analog import HamiltonianEvolution
from horqrux.apply import apply_gate
from horqrux.utils import is_normalized, overlap, random_state, uniform_state

sigmaz = jnp.diag(jnp.array([1.0, -1.0], dtype=jnp.cdouble))
Hbase = jnp.kron(sigmaz, sigmaz)

Hamiltonian = jnp.kron(Hbase, Hbase)


def test_hamevo_single() -> None:
n_qubits = 4
t_evo = jnp.pi / 4
hamevo = HamiltonianEvolution(tuple([i for i in range(n_qubits)]))
psi = uniform_state(n_qubits)
psi_star = apply_gate(psi, hamevo, {"hamiltonian": Hamiltonian, "time_evolution": t_evo})
result = overlap(psi_star, psi)
assert jnp.isclose(result, 0.5)


def Hamiltonian_general(n_qubits: int = 2, batch_size: int = 1) -> jnp.array:
H_batch = jnp.zeros((batch_size, 2**n_qubits, 2**n_qubits), dtype=jnp.cdouble)
for i in range(batch_size):
H_0 = np.random.uniform(0.0, 1.0, (2**n_qubits, 2**n_qubits)).astype(np.cdouble)
H = H_0 + jnp.conj(H_0.transpose(0, 1))
H_batch.at[i, :, :].set(H)
return H_batch


@pytest.mark.parametrize("n_qubits, batch_size", [(2, 1), (4, 2)])
def test_hamevo_general(n_qubits: int, batch_size: int) -> None:
H = Hamiltonian_general(n_qubits, batch_size)
t_evo = np.random.uniform(0, 1, (batch_size, 1))
hamevo = HamiltonianEvolution(tuple([i for i in range(n_qubits)]))
psi = random_state(n_qubits)
psi_star = jit(vmap(apply_gate, in_axes=(None, None, {"hamiltonian": 0, "time_evolution": 0})))(
psi, hamevo, {"hamiltonian": H, "time_evolution": t_evo}
)
assert jnp.all(vmap(is_normalized, in_axes=(0,))(psi_star))

0 comments on commit 491cdb3

Please sign in to comment.