Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updated test_graph_optims and test_graph_scaling_fused_optimizers to use new OptimizerInfo infrastructure #125127

Open
wants to merge 38 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
fe29e60
Adding TestCudaOptims class to move test_graph_optims function and te…
jayanthd04 Apr 18, 2024
f6cec68
Creating test_new_graph_optims under TestCudaOptims class and using O…
jayanthd04 Apr 23, 2024
9e86434
Creating test_new_graph_scaling_fused_optimizers under TestCudaOptims…
jayanthd04 Apr 25, 2024
8d6abfb
Resolving conflict with upstream/main
jayanthd04 Apr 25, 2024
3a86b06
Merge remote-tracking branch 'upstream/main' into test_cuda_OptimInfo
jayanthd04 Apr 26, 2024
c93214e
Deleting old test_graph_scaling_fused_optimizers and test_graph_optims
jayanthd04 Apr 26, 2024
2aca95f
Merge remote-tracking branch 'upstream/main' into test_cuda_OptimInfo
jayanthd04 Apr 27, 2024
501590b
Fixing some linting issues
jayanthd04 Apr 27, 2024
c10432e
Fixing more linting issues
jayanthd04 Apr 27, 2024
8e0945a
Merge remote-tracking branch 'upstream/main' into test_cuda_OptimInfo
jayanthd04 Apr 27, 2024
1e230e7
Fixing final linting issues
jayanthd04 Apr 27, 2024
c7b33f4
Fixing almost all linting issues
jayanthd04 Apr 27, 2024
af8653c
Fixing most of the linting issues
jayanthd04 Apr 27, 2024
35e1b6a
Fixing even more linting issues
jayanthd04 Apr 27, 2024
287f741
Merge remote-tracking branch 'upstream/main' into test_cuda_OptimInfo
jayanthd04 Apr 28, 2024
39d99b5
Adding kwargs to common_optimizers.py for added test coverability
jayanthd04 May 3, 2024
6d72aff
Testing updated kwargs in common_optimizers.py
jayanthd04 May 3, 2024
eac2bd0
Merge remote-tracking branch 'upstream/main' into test_cuda_OptimInfo
jayanthd04 May 3, 2024
b8f1ad0
Adding kwargs to common_optimizers.py and deleting dictionaries in te…
jayanthd04 May 5, 2024
b75098e
Merge remote-tracking branch 'upstream/main' into test_cuda_OptimInfo
jayanthd04 May 5, 2024
be64baf
Deleting references to optim_kwargs dictionary in test_graph_scaling_…
jayanthd04 May 6, 2024
79f542b
Merge remote-tracking branch 'upstream/main' into test_cuda_OptimInfo
jayanthd04 May 6, 2024
e901a0a
Fixing linting issues
jayanthd04 May 6, 2024
91264d4
Cleaning up test_cuda.py
jayanthd04 May 6, 2024
2a6c915
Resolving OptimizerInput config names and deleting redundant configs
jayanthd04 May 6, 2024
ad331e2
Cleaning up common_optimizers.py
jayanthd04 May 7, 2024
e80cde4
Adding has_capturable_arg to common_optimizers.py, deleting helper fu…
jayanthd04 May 8, 2024
37e6979
Fixing linting issues
jayanthd04 May 8, 2024
d4815a7
Rearranging common_optimizers.py configs, cleaning up test_cuda.py an…
jayanthd04 May 9, 2024
a0af2fe
Adding comments to test_cuda.py
jayanthd04 May 14, 2024
543bba1
Fixing linting issues
jayanthd04 May 14, 2024
b13d6b0
Trying to fix merge conflict
jayanthd04 May 14, 2024
cd6249f
Fixing linting issue
jayanthd04 May 14, 2024
5f92373
Merge remote-tracking branch 'upstream/main' into test_cuda_OptimInfo
jayanthd04 May 14, 2024
8d968d7
Adding markDynamoStrictTest decorator to TestCudaOptims class in test…
jayanthd04 May 14, 2024
236f2fe
Fixing linting issue
jayanthd04 May 14, 2024
83fd6fe
Adding test_sgd_weight_decay configs to test_compiled_optimizers.py
jayanthd04 May 18, 2024
041afac
Fixing linting issues
jayanthd04 May 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
25 changes: 2 additions & 23 deletions test/test_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -4459,23 +4459,13 @@ def test_graph_optims(self, device, dtype, optim_info):
all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
device, dtype, optim_info, skip=("differentiable",)
)
has_betas = any(
"betas" in error_inp.optimizer_error_input.kwargs
for error_inp in optim_info.optim_error_inputs_func(
device="cpu", dtype=dtype
)
)

steps_warmup = 3
steps_train = 2

for optim_input in all_optim_inputs:
kwargs = optim_input.kwargs
if "lr" in kwargs:
del kwargs["lr"]
kwargs["lr"] = 0.1
if has_betas and optim_cls != torch.optim.Adamax:
kwargs["betas"] = (0.8, 0.7)
kwargs["lr"]=0.1
jayanthd04 marked this conversation as resolved.
Show resolved Hide resolved

for actually_do_graphs in (True, False):
params = [
Expand Down Expand Up @@ -4543,26 +4533,15 @@ def test_graph_scaling_fused_optimizers(self, device, dtype, optim_info):
steps_train = 2

optim_inputs = optim_info.optim_inputs_func(device=device)
has_betas = any(
"betas" in error_inp.optimizer_error_input.kwargs
for error_inp in optim_info.optim_error_inputs_func(
device="cpu", dtype=dtype
)
)

for optim_input in optim_inputs:
kwargs = optim_input.kwargs
kwargs["fused"] = True
if "lr" in kwargs:
del kwargs["lr"]
kwargs["lr"] = 0.1
if has_betas:
kwargs["betas"] = (0.8, 0.7)

for actually_do_graphs in (
(True, False) if optim_info.has_capturable_arg else (True,)
):
params = [torch.randn((i + 5, i + 5), device="cuda") for i in range(2)]
params = [torch.randn((i + 5, i + 5), device=device) for i in range(2)]
params_control = [p.clone().requires_grad_() for p in params]
params_graphed = [p.clone().requires_grad_() for p in params]

Expand Down
72 changes: 37 additions & 35 deletions torch/testing/_internal/common_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def __init__(
supported_impls: Tuple[str] = ("foreach", "differentiable"),
# the optim supports passing in sparse gradients as well as dense grads
supports_sparse: bool = False,
# the optim is capturable in a CUDA graph
# the optimizer constructor supports passing in capturable as a kwarg
has_capturable_arg: bool = False,
# the optim only supports one config: sparse grads w/ dense params, see SparseAdam
only_supports_sparse_grads: bool = False,
Expand Down Expand Up @@ -314,6 +314,7 @@ def optim_inputs_func_adadelta(device, dtype=None):
OptimizerInput(
params=None, kwargs={"weight_decay": 0.1}, desc="nonzero weight_decay"
),
OptimizerInput(params=None, kwargs={"maximize": True}, desc="maximize"),
OptimizerInput(
params=None,
kwargs={"weight_decay": 0.1, "maximize": True},
Expand All @@ -322,7 +323,7 @@ def optim_inputs_func_adadelta(device, dtype=None):
OptimizerInput(
params=None, kwargs={"rho": 0.95, "weight_decay": 0.9}, desc="rho"
),
OptimizerInput(params=None, kwargs={"maximize": True}, desc="maximize"),

] + (cuda_supported_configs if "cuda" in str(device) else [])


Expand Down Expand Up @@ -532,14 +533,15 @@ def optim_inputs_func_adamax(device, dtype=None):
),
OptimizerInput(
params=None,
kwargs={"weight_decay": 0.1, "maximize": True},
desc="maximize, weight_decay",
kwargs={"maximize": True},
desc="maximize",
),
OptimizerInput(
params=None,
kwargs={"maximize": True},
desc="maximize",
kwargs={"weight_decay": 0.1, "maximize": True},
desc="maximize, weight_decay",
),

] + (cuda_supported_configs if "cuda" in str(device) else [])


Expand Down Expand Up @@ -689,6 +691,13 @@ def optim_inputs_func_nadam(device, dtype=None):
kwargs={"momentum_decay": 6e-3},
desc="non-zero momentum_decay",
),
OptimizerInput(
params=None,
kwargs={
"weight_decay": 0.1,
},
desc="weight_decay",
),
OptimizerInput(
params=None,
kwargs={"weight_decay": 0.1, "momentum_decay": 6e-3},
Expand All @@ -702,13 +711,6 @@ def optim_inputs_func_nadam(device, dtype=None):
},
desc="decoupled_weight_decay",
),
OptimizerInput(
params=None,
kwargs={
"weight_decay": 0.1,
},
desc="weight_decay",
),
] + (cuda_supported_configs if "cuda" in str(device) else [])


Expand Down Expand Up @@ -834,38 +836,38 @@ def optim_inputs_func_rmsprop(device, dtype=None):
),
OptimizerInput(
params=None,
kwargs={"weight_decay": 0.1, "centered": True},
desc="centered",
kwargs={
"maximize": True,
},
desc="maximize",
),
OptimizerInput(
params=None,
kwargs={"weight_decay": 0.1, "centered": True, "momentum": 0.1},
desc="momentum",
kwargs={"weight_decay": 0.1, "centered": True},
desc="centered",
),
OptimizerInput(
params=None,
kwargs={
"weight_decay": 0.1,
"centered": True,
"momentum": 0.1,
"maximize": True,
"weight_decay": 0.1,
},
desc="maximize, centered, weight_decay, w/ momentum",
desc="maximize, weight_decay",
),
OptimizerInput(
jayanthd04 marked this conversation as resolved.
Show resolved Hide resolved
params=None,
kwargs={
"maximize": True,
},
desc="maximize",
kwargs={"weight_decay": 0.1, "centered": True, "momentum": 0.1},
desc="momentum",
),
OptimizerInput(
jayanthd04 marked this conversation as resolved.
Show resolved Hide resolved
params=None,
kwargs={
"maximize": True,
"weight_decay": 0.1,
"centered": True,
"momentum": 0.1,
"maximize": True,
},
desc="maximize, weight_decay",
desc="maximize, centered, weight_decay, w/ momentum",
),
] + (cuda_supported_configs if "cuda" in str(device) else [])

Expand Down Expand Up @@ -936,7 +938,15 @@ def optim_inputs_func_sgd(device, dtype=None):
OptimizerInput(
params=None, kwargs={"lr": torch.tensor(0.001)}, desc="tensor lr"
),
OptimizerInput(
params=None, kwargs={"weight_decay": 0.5}, desc="non-zero weight_decay"
),
OptimizerInput(params=None, kwargs={"momentum": 0.9}, desc="momentum"),
OptimizerInput(
params=None,
kwargs={"weight_decay": 0.1, "maximize": True},
desc="maximize",
),
OptimizerInput(
params=None,
kwargs={"momentum": 0.9, "dampening": 0.5},
Expand All @@ -952,14 +962,6 @@ def optim_inputs_func_sgd(device, dtype=None):
kwargs={"momentum": 0.9, "nesterov": True, "weight_decay": 0.1},
desc="nesterov",
),
OptimizerInput(
params=None,
kwargs={"weight_decay": 0.1, "maximize": True},
desc="maximize",
),
OptimizerInput(
params=None, kwargs={"weight_decay": 0.5}, desc="non-zero weight_decay"
),
]


Expand Down