-
Notifications
You must be signed in to change notification settings - Fork 596
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
*Module Parameters* section of docs is outdated. #3761
Comments
You need to specify a type annotation to the dataclass field:
|
@chiamp thanks! I also think maybe an error message for not type annotating the dataclass field may be good, since the error message that came from it was a bit cryptic. |
Not adding a type annotation turns
I believe there are use-cases for these, but @cgarciae can speak more to this. |
Ahh I see @chiamp. So when we don't type annotate from typing import Callable
import flax.linen as nn
import jax.numpy as jnp
import jax.random as random
class SimpleDense(nn.Module):
features: int
kernel_init= nn.initializers.lecun_normal()
x = jnp.ones((1, 7))
model = SimpleDense(features=3)
print(model.kernel_init)
# <bound method variance_scaling.<locals>.init of SimpleDense(
# # attributes
# features = 3
# )> And then when we type annotate, it is only an attribute of the class (not bound). from typing import Callable
import flax.linen as nn
import jax.numpy as jnp
import jax.random as random
class SimpleDense(nn.Module):
features: int
kernel_init: Callable = nn.initializers.lecun_normal()
x = jnp.ones((1, 7))
model = SimpleDense(features=3)
print(model.kernel_init)
# <function variance_scaling.<locals>.init at 0x7f498fb66200> In the former case, this boundedness messed up the order of the passing in the arguments to it during the initialization of |
Hi, first off thanks for a great library -- flax is awesome.
I wanted to revisit the documentation to gain a better understanding of flax. In basics there is a section on module parameters.
I wanted to point out that it would appear as though the code seems to not work at the moment.
Here is a stripped version of what is currently in the docs
Seems to be something to do with how
*init_args
is being unpacked. I tried reproducing similar behaviour with the followingBut I had trouble navigating the flax codebase as I am unfamiliar with it. Thanks again!
The text was updated successfully, but these errors were encountered: