Skip to content

Commit 8dd75af

Browse files
committed
[nnx] remove transforms.deprecated
1 parent 5f7af86 commit 8dd75af

File tree

7 files changed

+56
-2790
lines changed

7 files changed

+56
-2790
lines changed

docs_nnx/nnx_basics.ipynb

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -375,9 +375,9 @@
375375
"In the code below notice the following:\n",
376376
"\n",
377377
"1. The custom `create_model` function takes in a key and returns an `MLP` object, since you create five keys and use [`nnx.vmap`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.vmap) over `create_model` a stack of 5 `MLP` objects is created.\n",
378-
"2. The [`nnx.scan`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.Scan) is used to iteratively apply each `MLP` in the stack to the input `x`.\n",
379-
"3. The [`nnx.scan`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.Scan) (consciously) deviates from [`jax.lax.scan`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html#jax.lax.scan) and instead mimics [`nnx.vmap`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.vmap), which is more expressive. [`nnx.scan`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.Scan) allows specifying multiple inputs, the scan axes of each input/output, and the position of the carry.\n",
380-
"4. [`State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State) updates for the [`nnx.BatchNorm`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/normalization.html#flax.nnx.BatchNorm) and [`nnx.Dropout`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/stochastic.html#flax.nnx.Dropout) layers are automatically propagated by [`nnx.scan`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.Scan)."
378+
"2. The [`nnx.scan`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.scan) is used to iteratively apply each `MLP` in the stack to the input `x`.\n",
379+
"3. The [`nnx.scan`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.scan) (consciously) deviates from [`jax.lax.scan`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html#jax.lax.scan) and instead mimics [`nnx.vmap`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.vmap), which is more expressive. [`nnx.scan`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.scan) allows specifying multiple inputs, the scan axes of each input/output, and the position of the carry.\n",
380+
"4. [`State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State) updates for the [`nnx.BatchNorm`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/normalization.html#flax.nnx.BatchNorm) and [`nnx.Dropout`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/stochastic.html#flax.nnx.Dropout) layers are automatically propagated by [`nnx.scan`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.scan)."
381381
]
382382
},
383383
{

docs_nnx/nnx_basics.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -193,9 +193,9 @@ The next example uses Flax [`nnx.vmap`](https://flax.readthedocs.io/en/latest/ap
193193
In the code below notice the following:
194194

195195
1. The custom `create_model` function takes in a key and returns an `MLP` object, since you create five keys and use [`nnx.vmap`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.vmap) over `create_model` a stack of 5 `MLP` objects is created.
196-
2. The [`nnx.scan`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.Scan) is used to iteratively apply each `MLP` in the stack to the input `x`.
197-
3. The [`nnx.scan`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.Scan) (consciously) deviates from [`jax.lax.scan`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html#jax.lax.scan) and instead mimics [`nnx.vmap`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.vmap), which is more expressive. [`nnx.scan`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.Scan) allows specifying multiple inputs, the scan axes of each input/output, and the position of the carry.
198-
4. [`State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State) updates for the [`nnx.BatchNorm`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/normalization.html#flax.nnx.BatchNorm) and [`nnx.Dropout`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/stochastic.html#flax.nnx.Dropout) layers are automatically propagated by [`nnx.scan`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.Scan).
196+
2. The [`nnx.scan`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.scan) is used to iteratively apply each `MLP` in the stack to the input `x`.
197+
3. The [`nnx.scan`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.scan) (consciously) deviates from [`jax.lax.scan`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html#jax.lax.scan) and instead mimics [`nnx.vmap`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.vmap), which is more expressive. [`nnx.scan`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.scan) allows specifying multiple inputs, the scan axes of each input/output, and the position of the carry.
198+
4. [`State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State) updates for the [`nnx.BatchNorm`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/normalization.html#flax.nnx.BatchNorm) and [`nnx.Dropout`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/stochastic.html#flax.nnx.Dropout) layers are automatically propagated by [`nnx.scan`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.scan).
199199

200200
```{code-cell} ipython3
201201
@nnx.vmap(in_axes=0, out_axes=0)

flax/nnx/__init__.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -136,11 +136,6 @@
136136
from .training.metrics import Metric as Metric
137137
from .training.metrics import MultiMetric as MultiMetric
138138
from .training.optimizer import Optimizer as Optimizer
139-
from .transforms.deprecated import Jit as Jit
140-
from .transforms.deprecated import Remat as Remat
141-
from .transforms.deprecated import Scan as Scan
142-
from .transforms.deprecated import Vmap as Vmap
143-
from .transforms.deprecated import Pmap as Pmap
144139
from .transforms.autodiff import DiffState as DiffState
145140
from .transforms.autodiff import grad as grad
146141
from .transforms.autodiff import value_and_grad as value_and_grad

flax/nnx/transforms/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,5 +11,3 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
15-
from . import deprecated

0 commit comments

Comments
 (0)