Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Allow parameter shifting on all parametric gates #31

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

atiyo
Copy link
Collaborator

@atiyo atiyo commented Sep 19, 2024

The previous PR (#27) implements the parameter shift rule (PSR) for parameters defined in the values argument of expectations. However, it suffered from some limitations:

This MR addresses the two above points. It also adds tests that the above can be jit-compiled and give the correct answers.

Some noteworthy points:

  • When jitting functions containing checkify.check points, the output type of the original function is changed to (error, output_of_original_function). This is not ideal for end users, so this has been removed. It would be great to have such checks in the code, so an issue investigating a promising alternative has been raised ([Enhancement] Investigate/use Chex for assertions and type handling #30).
  • Previously, the param attribute of a Parametric gate could be of type str | float. This was problematic when implementing custom JVP rules, since a float is a valid jax type, but a string is not (e.g. Allow str to be a valid JAX type jax-ml/jax#3045). Consequently, param has been explicitly split into param_name: str and param_val: float, so that param_val is always a valid jax type.

Closes #29

@atiyo atiyo self-assigned this Sep 19, 2024
@atiyo atiyo added bug Something isn't working feature labels Sep 19, 2024
Copy link
Collaborator

@RolandMacDoland RolandMacDoland left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this @atiyo. Not much to say, probably due to my current lack of knowledge in Jax. One thing that I would suggest is to align the syntax with PyQ. Either bu suggesting changes there or by importing them here owing the idiosyncrasies between torch and Jax ofc.

horqrux/parametric.py Outdated Show resolved Hide resolved
def parse_dict(values: dict[str, float] = dict()) -> float:
return values[self.param] # type: ignore[index]
def parse_dict(self: Parametric, values: dict[str, float] = dict()) -> float:
return values[self.param_name] + self.shift # type: ignore[index]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be possible to remove these type ignores ?

Co-authored-by: RolandMacDoland <[email protected]>
Copy link
Collaborator

@jpmoutinho jpmoutinho left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nothing to add I think, lgtm.

Separating the param_name and param_value is probably a good idea to do in PyQ as well, although there I don't think we have run into issues.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working feature
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[BUG] PSR with repeated parameters incorrect
3 participants