Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

*Module Parameters* section of docs is outdated. #3761

Open
PaulScemama opened this issue Mar 14, 2024 · 4 comments
Open

*Module Parameters* section of docs is outdated. #3761

PaulScemama opened this issue Mar 14, 2024 · 4 comments
Assignees
Labels
Priority: P1 - soon Response within 5 business days. Resolution within 30 days. (Assignee required)

Comments

@PaulScemama
Copy link

Hi, first off thanks for a great library -- flax is awesome.

I wanted to revisit the documentation to gain a better understanding of flax. In basics there is a section on module parameters.

I wanted to point out that it would appear as though the code seems to not work at the moment.

Here is a stripped version of what is currently in the docs

import flax.linen as nn
import jax.numpy as jnp
import jax.random as random


class SimpleDense(nn.Module):
  features: int
  kernel_init = nn.initializers.lecun_normal()

  @nn.compact
  def __call__(self, inputs):
    kernel = self.param('kernel',
                        self.kernel_init, # Initialization function
                        (inputs.shape[-1], self.features))  # init_args
    y = jnp.dot(inputs, kernel)
    return y

x = jnp.ones((1, 7))
model = SimpleDense(features=3)
key, init_key = random.split(random.key(123))

params = model.init(init_key, x)
# Error: TypeError: Cannot interpret '7' as a data type

Seems to be something to do with how *init_args is being unpacked. I tried reproducing similar behaviour with the following

initializer = nn.initializers.glorot_normal()

def foo(rng_key, args):
    
    def initialize():
        return nn.initializers.glorot_normal()(rng_key, *args)

    return initialize()

foo(random.key(1), (4,5))
# TypeError: Cannot interpret '5' as a data type

But I had trouble navigating the flax codebase as I am unfamiliar with it. Thanks again!

@chiamp chiamp added the Priority: P1 - soon Response within 5 business days. Resolution within 30 days. (Assignee required) label Mar 19, 2024
@chiamp chiamp self-assigned this Mar 19, 2024
@chiamp
Copy link
Collaborator

chiamp commented Mar 19, 2024

You need to specify a type annotation to the dataclass field:

class SimpleDense(nn.Module):
  features: int
  kernel_init: Callable = nn.initializers.lecun_normal()
  ...

@PaulScemama
Copy link
Author

@chiamp thanks!

I also think maybe an error message for not type annotating the dataclass field may be good, since the error message that came from it was a bit cryptic.

@chiamp
Copy link
Collaborator

chiamp commented Mar 25, 2024

Not adding a type annotation turns kernel_init into a class method:

class SimpleDense(nn.Module):
  features: int
  kernel_init = nn.initializers.lecun_normal()

SimpleDense.kernel_init(jax.random.key(0), (1, 1)) == nn.initializers.lecun_normal()(jax.random.key(0), (1, 1))

I believe there are use-cases for these, but @cgarciae can speak more to this.

@PaulScemama
Copy link
Author

Ahh I see @chiamp. So when we don't type annotate kernel_init, it becomes a bound method. E.g.

from typing import Callable

import flax.linen as nn
import jax.numpy as jnp
import jax.random as random


class SimpleDense(nn.Module):
  features: int
  kernel_init= nn.initializers.lecun_normal()


x = jnp.ones((1, 7))
model = SimpleDense(features=3)
print(model.kernel_init)
# <bound method variance_scaling.<locals>.init of SimpleDense(
#    # attributes
#    features = 3
# )>

And then when we type annotate, it is only an attribute of the class (not bound).

from typing import Callable

import flax.linen as nn
import jax.numpy as jnp
import jax.random as random


class SimpleDense(nn.Module):
  features: int
  kernel_init: Callable = nn.initializers.lecun_normal()


x = jnp.ones((1, 7))
model = SimpleDense(features=3)
print(model.kernel_init)
# <function variance_scaling.<locals>.init at 0x7f498fb66200>

In the former case, this boundedness messed up the order of the passing in the arguments to it during the initialization of self.param (see top of thread).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Priority: P1 - soon Response within 5 business days. Resolution within 30 days. (Assignee required)
Projects
None yet
Development

No branches or pull requests

2 participants