Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[nnx] add compat #3921

Merged
merged 1 commit into from
May 20, 2024
Merged

[nnx] add compat #3921

merged 1 commit into from
May 20, 2024

Conversation

cgarciae
Copy link
Collaborator

@cgarciae cgarciae commented May 13, 2024

What does this PR do?

Adds nnx.compat module with the goal of making it possible to port Linen codebases to an NNX system with as few changes as possible. It would contain the following functionality:

  • Module: inherits from nnx.Module and adds methods from linen.Module.
  • compact: allows defining submodules inline
  • wrappers: some types that simplify NNX <-> Linen interop.

@cgarciae cgarciae marked this pull request as ready for review May 13, 2024 15:42
@cgarciae cgarciae requested a review from chiamp May 13, 2024 15:42
@cgarciae cgarciae force-pushed the nnx-compact branch 2 times, most recently from 17e7fc2 to 3dc8838 Compare May 14, 2024 15:19
@cgarciae cgarciae marked this pull request as draft May 14, 2024 15:40
@cgarciae cgarciae marked this pull request as ready for review May 14, 2024 15:40
@cgarciae cgarciae force-pushed the nnx-object-refactor branch 3 times, most recently from c137583 to 24002bc Compare May 15, 2024 12:59
Base automatically changed from nnx-object-refactor to main May 16, 2024 16:38
@cgarciae cgarciae changed the title [nnx] add lazy [nnx] add compat May 16, 2024
@cgarciae cgarciae force-pushed the nnx-compact branch 5 times, most recently from e4af67e to 0da1116 Compare May 17, 2024 16:34
return x @ self.w + self.b[None]

@dataclasses.dataclass
class Foo(nnx.compat.Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
class Foo(nnx.compat.Module):
class Foo(compat.Module):

Change all nnx.compat references to compat for consistency?

>>> import jax.numpy as jnp
...
>>> class Linear(nnc.Module):
... def __init__(self, dout, rngs: nnx.Rngs):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For clarification:

  • if we define an __init__ method, we will be able to instantiate module parameters using the .init method?
  • if we define a setup method or wrap __call__ with compact, we will be able to instantiate the module parameters by calling the module on a sample input and invoking shape inference?
  • the module parameters that are instantiated are bound to the module so they can be dot-accessed, which is different from Linen where they are returned separately as a variable dict?
  • Instead of defining an __init__ method, can we define a setup method or wrap __call__ with compact to use the .init method as well?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the confusion here, this init method is the current init we have for nnx.Module but we are just moving it out to compat.Module, however its still need to create refactor the method so it follows the Linen API as closely as possible in a subsequent PR.

@copybara-service copybara-service bot merged commit 2bdcee7 into main May 20, 2024
21 checks passed
@copybara-service copybara-service bot deleted the nnx-compact branch May 20, 2024 19:01
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants