Skip to content

Updated test_graph_optims and test_graph_scaling_fused_optimizers to use new OptimizerInfo infrastructure #125127

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 38 commits into from

Conversation

jayanthd04
Copy link
Contributor

@jayanthd04 jayanthd04 commented Apr 28, 2024

This PR is meant to address issue #123451, more specifically, the test_graph_optims and test_graph_scaling_fused_optimizers functions in test_cuda.py have been updated so that they now use the new OptimizerInfo infrastructure.

Lintrunner passed:

$ lintrunner test/test_cuda.py
ok No lint issues.

Tests passed:

>python test_cuda.py -k test_graph_optims
Ran 19 tests in 7.463s

OK (skipped=9)

>python test_cuda.py -k test_graph_scaling_fused_optimizers
Ran 6 tests in 2.800s

OK (skipped=3)

Both the functions have been moved to the newly created TestCase class TestCudaOptims. The test is mostly the same except the @optims decorator is used at the top of the function to implicitly call the function using each of the optimizers mentioned in the decorator instead of explicitly using a for loop to iterate through each of the optimizers.

I was unable to use the _get_optim_inputs_including_global_cliquey_kwargs to get all kwargs for each of the optimizers since some of the kwargs that are used in the original test_graph_optims function are not being returned by the new OptimizerInfo infrastructure, more specifically, for the torch.optim.rmsprop.RMSprop optimizer, the following kwargs are not returned whenever _get_optim_inputs_including_global_cliquey_kwargs is called:

{'foreach': False, 'maximize': True, 'weight_decay': 0}
{ 'foreach': True, 'maximize': True, 'weight_decay': 0}

I ran into the same issue for test_graph_scaling_fused_optimizers, for the torch.optim.adamw.AdamW optimizer, whenever optim_info.optim_inputs_func(device=device) was called, the following kwarg was not returned:

{'amsgrad': True}

Due to this issue, I resorted to using a dictionary to store the kwargs for each of the optimizers, I am aware that this is less than ideal. I was wondering whether I should use the OptimizerInfo infrastructure to get all the kwargs regardless of the fact that it lacks some kwargs.

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang

Copy link

pytorch-bot bot commented Apr 28, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/125127

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (3 Unrelated Failures)

As of commit 041afac with merge base 1a28f73 (image):

FLAKY - The following jobs failed but were likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

Copy link

linux-foundation-easycla bot commented Apr 28, 2024

@pytorch-bot pytorch-bot bot added the topic: not user facing topic category label Apr 28, 2024
@cpuhrsch cpuhrsch requested a review from janeyx99 April 30, 2024 19:52
@cpuhrsch cpuhrsch added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Apr 30, 2024
Copy link
Contributor

@janeyx99 janeyx99 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your thorough look at the kwargs. For RMSprop, feel free to add a maximize with no weight_decay option here:

supports_param_groups: bool = True,

For Adam/W, it's okay to not have amsgrad alone--we test it sufficiently with the "capturable, amsgrad" and the "amsgrad" described inputs.

This way we can use the helper function to get the kwargs instead of needing a separate dictionary.

@jayanthd04
Copy link
Contributor Author

Thank you for the review! To add a maximize with no weight_decay option in RMSprop, should I just edit optim_inputs_func_rmsprop? Also I've noticed some other optimizers such as torch.optim.Adadelta are also missing some kwargs such as no weight_decay and maximize, should I also add those options in common_optimizers.py or should I just use the current options for those optimizers?

@janeyx99
Copy link
Contributor

janeyx99 commented May 2, 2024

Thank you for the review! To add a maximize with no weight_decay option in RMSprop, should I just edit optim_inputs_func_rmsprop?

Yes

Also I've noticed some other optimizers such as torch.optim.Adadelta are also missing some kwargs such as no weight_decay and maximize, should I also add those options in common_optimizers.py

Yes

@janeyx99
Copy link
Contributor

Yes, if the compiled optimizer test fails again, you may need to add a new entry to

"test_sgd_weight_decay_maximize_cpu": 4,

@jayanthd04
Copy link
Contributor Author

jayanthd04 commented May 14, 2024

I believe 'test/inductor/test_compiled_optimizers.py::CompiledOptimizerTests::test_sgd_weight_decay_cpu' is the test that is failing.

@janeyx99
Copy link
Contributor

Yep, looks like you need to add an entry to the code snippet I linked above as you added new configs.

@jayanthd04
Copy link
Contributor Author

Should I just add this to pytorch/test/inductor/test_compiled_optimizers.py

test_sgd_weight_decay_cpu=4,
test_sgd_weight_decay_cuda=4,

@janeyx99
Copy link
Contributor

That looks reasonable--does that pass the test?

@jayanthd04
Copy link
Contributor Author

It looks like it works.

@janeyx99
Copy link
Contributor

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@DanilBaibak
Copy link
Contributor

@pytorchbot revert -m "Broken trunk" -c nosignal

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a revert job. Check the current status here.
Questions? Feedback? Please reach out to the PyTorch DevX Team

pytorchmergebot added a commit that referenced this pull request May 20, 2024
…izers to use new OptimizerInfo infrastructure (#125127)"

This reverts commit cf35a59.

Reverted #125127 on behalf of https://github.com/DanilBaibak due to Broken trunk ([comment](#125127 (comment)))
@pytorchmergebot
Copy link
Collaborator

@jayanthd04 your PR has been successfully reverted.

@pytorch-bot pytorch-bot bot dismissed janeyx99’s stale review May 20, 2024 12:14

This PR was reopened (likely due to being reverted), so your approval was removed. Please request another review.

@DanilBaibak
Copy link
Contributor

Hi @jayanthd04! Sorry, I need to revert your PR because it broke cuda tests. Here you can find more details.

Copy link
Contributor

Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as Stale.
Feel free to remove the Stale label if you feel this was a mistake.
If you are unable to remove the Stale label please contact a maintainer in order to do so.
If you want the bot to never mark this PR stale again, add the no-stale label.
Stale pull requests will automatically be closed after 30 days of inactivity.

@github-actions github-actions bot added the Stale label Jul 19, 2024
@github-actions github-actions bot closed this Aug 18, 2024
pytorchmergebot pushed a commit that referenced this pull request Aug 28, 2024
…rs (#133749)

Fixes #123451

This is a rework of a reverted pull request, #125127.
The test failure is fixed.

Pull Request resolved: #133749
Approved by: https://github.com/janeyx99
Chao1Han pushed a commit to Chao1Han/pytorch that referenced this pull request Sep 20, 2024
…rs (pytorch#133749)

Fixes pytorch#123451

This is a rework of a reverted pull request, pytorch#125127.
The test failure is fixed.

Pull Request resolved: pytorch#133749
Approved by: https://github.com/janeyx99
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request Merged module: inductor open source Reverted Stale topic: not user facing topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants