-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Feature] Add HamiltonianEvolution, Support py3.12 (#11)
* [Feature] Add HamiltonianEvolution, Support py3.12
- Loading branch information
1 parent
a64c900
commit 491cdb3
Showing
9 changed files
with
112 additions
and
28 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
from __future__ import annotations | ||
|
||
from dataclasses import dataclass | ||
|
||
from jax import Array | ||
from jax.scipy.linalg import expm | ||
from jax.tree_util import register_pytree_node_class | ||
|
||
from .abstract import Primitive, QubitSupport | ||
|
||
|
||
@register_pytree_node_class | ||
@dataclass | ||
class _HamiltonianEvolution(Primitive): | ||
""" | ||
A slim wrapper class which evolves a 'hamiltonian' | ||
given a 'time_evolution' parameter and applies it to 'state' psi by doing: matrixexp(-iHt)|psi> | ||
""" | ||
|
||
generator_name: str | ||
target: QubitSupport | ||
control: QubitSupport | ||
|
||
def unitary(self, values: dict[str, Array] = dict()) -> Array: | ||
return expm(values["hamiltonian"] * (-1j * values["time_evolution"])) | ||
|
||
|
||
def HamiltonianEvolution( | ||
target: QubitSupport, control: QubitSupport = (None,) | ||
) -> _HamiltonianEvolution: | ||
return _HamiltonianEvolution("I", target, control) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,10 +9,10 @@ authors = [ | |
{ name = "Gert-Jan Both" , email = "[email protected]" }, | ||
{ name = "Dominik Seitz", email = "[email protected]" }, | ||
] | ||
requires-python = ">=3.9,<3.12" | ||
requires-python = ">=3.8,<3.13" | ||
license = {text = "Apache 2.0"} | ||
|
||
version = "0.5.0" | ||
version = "0.6.0" | ||
|
||
classifiers=[ | ||
"License :: Other/Proprietary License", | ||
|
@@ -21,6 +21,7 @@ classifiers=[ | |
"Programming Language :: Python :: 3.9", | ||
"Programming Language :: Python :: 3.10", | ||
"Programming Language :: Python :: 3.11", | ||
"Programming Language :: Python :: 3.12", | ||
"Programming Language :: Python :: Implementation :: CPython", | ||
"Programming Language :: Python :: Implementation :: PyPy", | ||
] | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
from __future__ import annotations | ||
|
||
import jax.numpy as jnp | ||
import numpy as np | ||
import pytest | ||
from jax import jit, vmap | ||
|
||
from horqrux.analog import HamiltonianEvolution | ||
from horqrux.apply import apply_gate | ||
from horqrux.utils import is_normalized, overlap, random_state, uniform_state | ||
|
||
sigmaz = jnp.diag(jnp.array([1.0, -1.0], dtype=jnp.cdouble)) | ||
Hbase = jnp.kron(sigmaz, sigmaz) | ||
|
||
Hamiltonian = jnp.kron(Hbase, Hbase) | ||
|
||
|
||
def test_hamevo_single() -> None: | ||
n_qubits = 4 | ||
t_evo = jnp.pi / 4 | ||
hamevo = HamiltonianEvolution(tuple([i for i in range(n_qubits)])) | ||
psi = uniform_state(n_qubits) | ||
psi_star = apply_gate(psi, hamevo, {"hamiltonian": Hamiltonian, "time_evolution": t_evo}) | ||
result = overlap(psi_star, psi) | ||
assert jnp.isclose(result, 0.5) | ||
|
||
|
||
def Hamiltonian_general(n_qubits: int = 2, batch_size: int = 1) -> jnp.array: | ||
H_batch = jnp.zeros((batch_size, 2**n_qubits, 2**n_qubits), dtype=jnp.cdouble) | ||
for i in range(batch_size): | ||
H_0 = np.random.uniform(0.0, 1.0, (2**n_qubits, 2**n_qubits)).astype(np.cdouble) | ||
H = H_0 + jnp.conj(H_0.transpose(0, 1)) | ||
H_batch.at[i, :, :].set(H) | ||
return H_batch | ||
|
||
|
||
@pytest.mark.parametrize("n_qubits, batch_size", [(2, 1), (4, 2)]) | ||
def test_hamevo_general(n_qubits: int, batch_size: int) -> None: | ||
H = Hamiltonian_general(n_qubits, batch_size) | ||
t_evo = np.random.uniform(0, 1, (batch_size, 1)) | ||
hamevo = HamiltonianEvolution(tuple([i for i in range(n_qubits)])) | ||
psi = random_state(n_qubits) | ||
psi_star = jit(vmap(apply_gate, in_axes=(None, None, {"hamiltonian": 0, "time_evolution": 0})))( | ||
psi, hamevo, {"hamiltonian": H, "time_evolution": t_evo} | ||
) | ||
assert jnp.all(vmap(is_normalized, in_axes=(0,))(psi_star)) |