Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a centered variance option to the ClippedAdam optimizer #3415

Merged
merged 7 commits into from
Jan 25, 2025
5 changes: 4 additions & 1 deletion examples/lda.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,9 @@ def main(args):
guide = functools.partial(parametrized_guide, predictor)
Elbo = JitTraceEnum_ELBO if args.jit else TraceEnum_ELBO
elbo = Elbo(max_plate_nesting=2)
optim = ClippedAdam({"lr": args.learning_rate})
optim = ClippedAdam(
{"lr": args.learning_rate, "centered_variance": args.centered_variance}
)
svi = SVI(model, guide, optim, elbo)
logging.info("Step\tLoss")
for step in range(args.num_steps):
Expand All @@ -160,6 +162,7 @@ def main(args):
parser.add_argument("-n", "--num-steps", default=1000, type=int)
parser.add_argument("-l", "--layer-sizes", default="100-100")
parser.add_argument("-lr", "--learning-rate", default=0.01, type=float)
parser.add_argument("-cv", "--centered-variance", default=False, type=bool)
parser.add_argument("-b", "--batch-size", default=32, type=int)
parser.add_argument("--jit", action="store_true")
args = parser.parse_args()
Expand Down
16 changes: 13 additions & 3 deletions pyro/optim/clipped_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,21 @@ class ClippedAdam(Optimizer):
:param weight_decay: weight decay (L2 penalty) (default: 0)
:param clip_norm: magnitude of norm to which gradients are clipped (default: 10.0)
:param lrd: rate at which learning rate decays (default: 1.0)
:param centered_variance: use centered variance (default: False)

Small modification to the Adam algorithm implemented in torch.optim.Adam
to include gradient clipping and learning rate decay.
to include gradient clipping and learning rate decay and an option to use
the centered variance.
BenZickel marked this conversation as resolved.
Show resolved Hide resolved

Reference
References

`A Method for Stochastic Optimization`, Diederik P. Kingma, Jimmy Ba
https://arxiv.org/abs/1412.6980

`A Two-Step Machine Learning Method for Predicting the Formation Energy of Ternary Compounds`,
Varadarajan Rengaraj, Sebastian Jost, Franz Bethke, Christian Plessl,
Hossein Mirhosseini, Andrea Walther, Thomas D. Kühne
https://doi.org/10.3390/computation11050095
"""

def __init__(
Expand All @@ -38,6 +45,7 @@ def __init__(
weight_decay=0,
clip_norm: float = 10.0,
lrd: float = 1.0,
centered_variance: bool = False,
):
defaults = dict(
lr=lr,
Expand All @@ -46,6 +54,7 @@ def __init__(
weight_decay=weight_decay,
clip_norm=clip_norm,
lrd=lrd,
centered_variance=centered_variance,
)
super().__init__(params, defaults)

Expand Down Expand Up @@ -87,7 +96,8 @@ def step(self, closure: Optional[Callable] = None) -> Optional[Any]:

# Decay the first and second moment running average coefficient
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
grad_var = (grad - exp_avg) if group["centered_variance"] else grad
BenZickel marked this conversation as resolved.
Show resolved Hide resolved
exp_avg_sq.mul_(beta2).addcmul_(grad_var, grad_var, value=1 - beta2)

denom = exp_avg_sq.sqrt().add_(group["eps"])

Expand Down
102 changes: 102 additions & 0 deletions tests/optim/test_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,3 +435,105 @@ def step(svi, optimizer):
actual.append(step(svi, optimizer))

assert_equal(actual, expected)


def test_centered_clipped_adam(plot_results=False):
BenZickel marked this conversation as resolved.
Show resolved Hide resolved
w = torch.Tensor([1, 500])

def loss_fn(p):
return (1 + w * p * p).sqrt().sum() - len(w)

def fit(lr, centered_variance, num_iter=5000):
loss_vec = []
p = torch.nn.Parameter(torch.Tensor([10, 1]))
optim = pyro.optim.clipped_adam.ClippedAdam(
lr=lr, params=[p], centered_variance=centered_variance
)
for count in range(num_iter):
optim.zero_grad()
loss = loss_fn(p)
loss.backward()
optim.step()
loss_vec.append(loss)
return torch.Tensor(loss_vec)

def calc_convergence(loss_vec, tail_len=100, threshold=0.01):
BenZickel marked this conversation as resolved.
Show resolved Hide resolved
ultimate_loss = loss_vec[-tail_len:].mean()
convergence_iter = (loss_vec < (ultimate_loss + threshold)).nonzero().min()
convergence_vec = loss_vec[:convergence_iter] - ultimate_loss
convergence_rate = (convergence_vec[:-1] / convergence_vec[1:]).log().mean()
return ultimate_loss, convergence_rate, convergence_iter

def get_convergence_vec(lr_vec, centered_variance):
BenZickel marked this conversation as resolved.
Show resolved Hide resolved
ultimate_loss_vec, convergence_rate_vec, convergence_iter_vec = [], [], []
for lr in lr_vec:
loss_vec = fit(lr=lr, centered_variance=centered_variance)
ultimate_loss, convergence_rate, convergence_iter = calc_convergence(
loss_vec
)
ultimate_loss_vec.append(ultimate_loss)
convergence_rate_vec.append(convergence_rate)
convergence_iter_vec.append(convergence_iter)
print(lr, centered_variance, ultimate_loss, convergence_rate)
return (
torch.Tensor(ultimate_loss_vec),
torch.Tensor(convergence_rate_vec),
convergence_iter_vec,
)

lr_vec = [0.1, 0.05, 0.02, 0.01, 0.005, 0.002, 0.001]
(
centered_ultimate_loss_vec,
centered_convergence_rate_vec,
centered_convergence_iter_vec,
) = get_convergence_vec(lr_vec=lr_vec, centered_variance=True)
ultimate_loss_vec, convergence_rate_vec, convergence_iter_vec = get_convergence_vec(
lr_vec=lr_vec, centered_variance=False
)

# ALl centered variance results should converge
assert (centered_ultimate_loss_vec < 0.01).all()
# Some uncentered variance results do not converge
assert (ultimate_loss_vec > 0.01).any()
# Verify convergence rate improvement
assert (
(centered_convergence_rate_vec / convergence_rate_vec)
> (torch.Tensor([1.2] * len(lr_vec)).cumprod(0))
).all()

if plot_results:
from matplotlib import pyplot as plt

plt.figure(figsize=(6, 8))
plt.subplot(3, 1, 1)
plt.loglog(
lr_vec, centered_convergence_iter_vec, "b.-", label="Centered Variance"
)
plt.loglog(lr_vec, convergence_iter_vec, "r.-", label="Uncentered Variance")
plt.xlabel("Learning Rate")
plt.ylabel("Convergence Iteration")
plt.title("Convergence Iteration vs Learning Rate")
plt.grid()
plt.legend(loc="best")
plt.subplot(3, 1, 2)
plt.loglog(
lr_vec, centered_convergence_rate_vec, "b.-", label="Centered Variance"
)
plt.loglog(lr_vec, convergence_rate_vec, "r.-", label="Uncentered Variance")
plt.xlabel("Learning Rate")
plt.ylabel("Convergence Rate")
plt.title("Convergence Rate vs Learning Rate")
plt.grid()
plt.legend(loc="best")
plt.subplot(3, 1, 3)
plt.semilogx(
lr_vec, centered_ultimate_loss_vec, "b.-", label="Centered Variance"
)
plt.semilogx(lr_vec, ultimate_loss_vec, "r.-", label="Uncentered Variance")
plt.xlabel("Learning Rate")
plt.ylabel("Ultimate Loss")
plt.title("Ultimate Loss vs Learning Rate")
plt.grid()
plt.legend(loc="best")
plt.tight_layout()
plt.savefig("test_centered_variance.png")