-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] Allow parameter shifting on all parametric gates #31
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for this @atiyo. Not much to say, probably due to my current lack of knowledge in Jax. One thing that I would suggest is to align the syntax with PyQ. Either bu suggesting changes there or by importing them here owing the idiosyncrasies between torch and Jax ofc.
def parse_dict(values: dict[str, float] = dict()) -> float: | ||
return values[self.param] # type: ignore[index] | ||
def parse_dict(self: Parametric, values: dict[str, float] = dict()) -> float: | ||
return values[self.param_name] + self.shift # type: ignore[index] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it be possible to remove these type ignores ?
Co-authored-by: RolandMacDoland <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nothing to add I think, lgtm.
Separating the param_name
and param_value
is probably a good idea to do in PyQ as well, although there I don't think we have run into issues.
The previous PR (#27) implements the parameter shift rule (PSR) for parameters defined in the
values
argument of expectations. However, it suffered from some limitations:values
, and a user couldn't pass a parameter directly into a gate.This MR addresses the two above points. It also adds tests that the above can be jit-compiled and give the correct answers.
Some noteworthy points:
checkify.check
points, the output type of the original function is changed to(error, output_of_original_function)
. This is not ideal for end users, so this has been removed. It would be great to have such checks in the code, so an issue investigating a promising alternative has been raised ([Enhancement] Investigate/use Chex for assertions and type handling #30).param
attribute of aParametric
gate could be of typestr | float
. This was problematic when implementing custom JVP rules, since afloat
is a valid jax type, but a string is not (e.g. Allow str to be a valid JAX type jax-ml/jax#3045). Consequently,param
has been explicitly split intoparam_name: str
andparam_val: float
, so thatparam_val
is always a valid jax type.Closes #29