-
Notifications
You must be signed in to change notification settings - Fork 109
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Misleading type signature of UpdateFn #707
Comments
not sure if adding but you can try opening a PR with the change to see if it works @JGameCreation, and we can see from there. |
I created a draft pull request but it doesn't seem to run any tests for some reason |
I updated the setting so the tests should run for you now. |
Related to #700 |
Similar problem holds for RunFn in AdaptationAlgorithm. I also don't understand why blackjax uses the pattern of NamedTuple of Callable/Protocol instead of just abstract classes. |
As specified in base.py, the UpdateFn should only take in two arguments: an RNG key and the previous state. However many samplers take in additional arguments. For example, sgld also takes a minibatch of data and a step size after the rng and state. Maybe let's add
*args
to the signature ofUpdateFn.__call__
?The text was updated successfully, but these errors were encountered: