Skip to content

How to load PyTorch checkpoints into JAX/Flax? #927

Answered by marcvanzee
marcvanzee asked this question in Q&A
Discussion options

You must be logged in to vote

Pytorch checkpoints contain a state_dict with all the weights/parameters for the models, and converting it to Flax involves:

  1. Defining the model using Flax modules
  2. Renaming the dictionary items to line up, and use the NCHW dimensions for conv weights.

Often flax.traverse_util.flatten_dict is useful, because you only need to operate on a flat dict instead of a nested dict. Once they align you use unflatten_dict to get the normal form back.

@nikitakit wrote the following code for importing PyTorch BERT checkpoints into a Flax model: https://github.com/nikitakit/flax_bert/blob/master/import_weights.py

Replies: 1 comment 6 replies

Comment options

You must be logged in to vote
6 replies
@avital
Comment options

@GCP20
Comment options

@GCP20
Comment options

@davisyoshida
Comment options

@GinRawin
Comment options

Answer selected by marcvanzee
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
5 participants
Converted from issue

This discussion was converted from issue #866 on January 22, 2021 12:42.