Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementation of AdamW differs from PyTorch #2433

Open
dpaetzel opened this issue May 3, 2024 · 10 comments
Open

Implementation of AdamW differs from PyTorch #2433

dpaetzel opened this issue May 3, 2024 · 10 comments

Comments

@dpaetzel
Copy link

dpaetzel commented May 3, 2024

Hi, thank you for developing and maintaining this awesome library and ecosystem!

I'm not entirely sure but could it be that the documentation for the AdamW optimizer is a bit misleading? If I understand correctly, then its definition of

AdamW(η = 0.001, β = (0.9, 0.999), decay = 0) = Optimiser(Adam(η, β), WeightDecay(decay))

means that it performs this update (where $-\eta A$ is Adam's update):

$$ \begin{align*} \theta_t \leftarrow \theta_{t-1} - \eta A + \texttt{decay} \ \theta_{t-1} \end{align*} $$

However, the paper on AdamW (which is linked to by the docs) parametrizes this differently as:

$$ \begin{align*} \theta_t \leftarrow \theta_{t-1} - \eta (\alpha A + \lambda \theta_{t-1}) \end{align*} $$

I.e. Flux's eta corresponds to the paper's $\eta\alpha$ and Flux's decay corresponds to the paper's $\eta \lambda$.

This is probably super unimportant (in that case, sorry for the noise) but since I just noticed this during bug hunting in an implementation of mine (which uses AdamW), I thought I'd report it.

@CarloLucibello
Copy link
Member

CarloLucibello commented May 3, 2024

We fixed this
#1612
and then we unfixed it again
#1868
due to FluxML/Optimisers.jl#46 (comment)

There is some ambiguity in the paper. They call $\alpha$ the learning rate and $\eta_t$ some scheduled coefficient.
So assuming $\eta_t=1$ the current implementation seems correct.

On the other hand, the pytorch implementation seems equal to #1612, so I think we should fix AdamW again.

@dpaetzel
Copy link
Author

dpaetzel commented May 3, 2024

Thank you for unravelling that for me and sorry that I didn't notice those issues/PRs in the first place.

Short elaboration for future reference:

The paper on AdamW uncouples what it calls the “schedule multiplier $\eta_t$“ from the learning rate $\alpha$ (applied to the Adam update) and the regularization factor/weight decay $\lambda$. $\eta_t$ is applied as an additional factor to both of the other two factors.

Pytorch only exposes two parameters, $\gamma_\text{torch}$ (the learning rate) and $\lambda_\text{torch}$ (regularization factor/weight decay) with paper correspondence $\gamma_\text{torch} = \eta_t \alpha$ and $\gamma_\text{torch} \lambda_\text{torch} = \eta_t \lambda$. The implementation in #1612 hardcodes $\alpha$ to 1 which is exactly the same parametrization.

I can't quite tell how important the additional control of an uncoupled $\alpha$ would be; probably not important enough to diverge from the pytorch implementation, I guess?

@CarloLucibello
Copy link
Member

We should adhere to pytorch's implementation for sure. Would you mind filing PRs here and in Optimisers.jl?

@ToucheSir
Copy link
Member

ToucheSir commented May 3, 2024

I don't have time to comment on this in detail now (will do so later), but the decision to diverge from PyTorch was not made lightly. IIRC it was something about how their handling and interpretation of the learning rate was unintuitive and would trip up people moving from other optimizers -> AdamW. I also didn't find their justification particularly compelling.

@ToucheSir
Copy link
Member

ToucheSir commented May 3, 2024

Ok, I did some more digging into why PyTorch decided to couple the learning rate and weight decay coefficient for their AdamW implementation. My best guess is that this comment on one of the AdamW PRs triggered changes which cascaded all the way to the ultimate AdamW PR. I don't find the point super compelling here because Flux lacks a Adam + coupled L2 norm constructor unlike PyTorch. Moreover, changing the calculation would be a breaking change for Flux and Optimisers.jl.

Now for an argument on semantics and usability. I agree that separate scheduling alone is not enough to justify a separate learning rate and weight decay rate. The problem lies more with tweaking hyperparameters. The AdamW paper makes a big point about being able to control both independently. With both coupled as PyTorch does, you have to always remember to tweak the weight decay every time you tweak the learning rate, otherwise you will be increasing/decreasing both simultaneously. We may even have public examples of people not realizing this, e.g. fastai/fastai#1806 (funnily enough, FastAI's AdamW used to not couple the two hyperparams).

There's also a practical concern if we do introduce hyperparam scheduling (i.e. controlling $\eta_t$ using ParameterSchedulers.jl). The previous implementation in #1612 chained together two rules with a learning rate (eta) field, but one of them must remain fixed at eta = 1 in order for the algorithm to be correct. Optimisers.adjust! will by default adjust both learning rates, and trying to get it to only adjust one would require a good amount more code.

As such, I think the best path forward would be to add a keyword arg to the AdamW constructor. Call it couple_lr or something, and have it return something closer to #1612 if couple_lr=true. As I noted, we'd likely also need to add a wrapper type for AdamW instead of using OptimiserChain directly.

@ToucheSir ToucheSir changed the title Documentation of AdamW may be misleading Implementation of AdamW differs from PyTorch May 3, 2024
@CarloLucibello
Copy link
Member

I have to think a bit about it. Another datapoint is that also optax couples the two
https://optax.readthedocs.io/en/latest/api/optimizers.html#optax.adamw

@ToucheSir
Copy link
Member

I actually opened an issue on the Optax repo about this, and their more or less said they wanted to copy PyTorch...

@CarloLucibello
Copy link
Member

CarloLucibello commented May 8, 2024

There's also a practical concern if we do introduce hyperparam scheduling

I think we should simply implement AdamW by copy-pasting the code from Adam.

We can add the couple_lr thing. The default should be couple_lr=true though, we should do what everybody else is doing. AdamW has become very popular in recent years and we want experiments in papers to be reproducible, finding an optimizer's flag as the source of divergence would be a very frustrating experience.

@ToucheSir
Copy link
Member

No objections, we'd just have to make a breaking release with it. Anything else we'd want to get in said release?

@dpaetzel
Copy link
Author

Thank you for the extended discussion!

Just to make sure I understand correctly (I'll try to find time to submit a PR):

couple_lr == false would mean to use the parametrization from the AdamW paper, i.e.

$$ \begin{align*} \theta_t \leftarrow \theta_{t-1} - \eta (\alpha \ A + \lambda \theta_{t-1}) \end{align*} $$

(where $A$ is the update coming from Adam and we expose to the user $\eta$, $\alpha$, $\lambda$) whereas couple_lr == true would mean to use this parametrization:

$$ \begin{align*} \theta_{t-1} - \eta (\alpha \ A + \lambda \theta_{t-1}) & = \theta_{t-1} - \eta \alpha \ A + \eta \lambda \theta_{t-1} \\ & = \theta_{t-1} - \gamma_\text{torch} \ A + \gamma_\text{torch} \lambda_\text{torch} \theta_{t-1} \\ & = \theta_{t-1} - \gamma_\text{torch} (1 A + \lambda_\text{torch} \theta_{t-1}) \\ \end{align*} $$

where we expose to the user $\eta = \gamma_\text{torch}$ and $\lambda = \lambda_\text{torch}$ (i.e. changing $\alpha$ doesn't do anything as long as couple_lr == true).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

4 participants