Skip to content

Commit 3a2ce6f

Browse files
authored
Implement gradient clipping (#286)
* Implement gradient clipping * Fix tests * Support both clipping by norm and by value * Update base_trainer.py * Update test_trainer.py * Add gradient clipping tests
1 parent 1d36d80 commit 3a2ce6f

File tree

3 files changed

+167
-9
lines changed

3 files changed

+167
-9
lines changed
Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,22 @@
1-
# this is for basic training loop, i.e., forward/backward pass without any special task logic.
1+
# This is for basic training loop, i.e., forward/backward pass without any special task logic.
22
name: train
33
global_batch_size: 4
44
max_steps: 15
55

6+
# Optimizer configuration.
67
optimizer:
78
learning_rate: 5.e-5
89
type: adafactor
910

11+
# Defaults to clip the L2 norm of gradients to 1.0.
12+
# Set to null to disable gradient clipping by norm.
13+
max_grad_norm: 1.0
14+
15+
# Defaults to not clip gradients by their absolute value.
16+
# Set to a number to clip gradients by the specified max absolute value.
17+
max_grad_value: null
18+
19+
# Learning rate scheduler configuration.
1020
lr_scheduler:
1121
type: linear
1222
warmup_steps: 0

torchprime/torch_xla_models/tests/test_trainer.py

Lines changed: 127 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,8 @@ def dummy_config():
8383
"global_batch_size": 4,
8484
"max_steps": 2,
8585
"optimizer": {"type": "adafactor", "learning_rate": 1e-3},
86+
"max_grad_norm": None,
87+
"max_grad_value": None,
8688
"lr_scheduler": {"type": "constant", "warmup_steps": 0},
8789
},
8890
"run_name": None,
@@ -170,7 +172,131 @@ def test_trainer_train_step(monkeypatch, dummy_config):
170172
trainer = Trainer(model, dummy_config, dataset)
171173

172174
batch = {k: v.unsqueeze(0).to(device) for k, v in dataset[0].items()}
173-
loss = trainer.train_step(batch)
175+
loss, grad_norm = trainer.train_step(batch)
174176

175177
assert isinstance(loss, torch.Tensor)
176178
assert loss.ndim == 0 # scalar loss
179+
assert isinstance(grad_norm, torch.Tensor)
180+
assert grad_norm.ndim == 0 # scalar gradient norm
181+
182+
183+
def test_trainer_clip_gradients_by_norm(monkeypatch, dummy_config):
184+
"""Test correctness of gradient clipping by norm in a train step."""
185+
import torch_xla
186+
187+
from torchprime.torch_xla_models.model_rewriting import sharding_initialization
188+
189+
# Arrange
190+
monkeypatch.setattr(
191+
sharding_initialization, "get_mesh", lambda *args, **kwargs: FakeMesh()
192+
)
193+
monkeypatch.setattr(
194+
sharding_initialization,
195+
"shard_torch_xla_model_from_config",
196+
lambda model, *args, **kwargs: model,
197+
)
198+
199+
class SumModel(nn.Module):
200+
def __init__(self):
201+
super().__init__()
202+
self.linear = nn.Linear(4, 1, bias=False)
203+
204+
def forward(self, input_ids=None, attention_mask=None, **kwargs):
205+
logits = self.linear(input_ids)
206+
loss = logits.mean()
207+
return logits, loss
208+
209+
dummy_config.task.max_grad_norm = 1.0
210+
dummy_config.task.max_grad_value = None
211+
model = SumModel().to("xla")
212+
with torch.no_grad():
213+
model.linear.weight.fill_(1.0)
214+
dataset = DummyDataset()
215+
trainer = Trainer(model, dummy_config, dataset)
216+
torch_xla.sync()
217+
218+
# Act
219+
batch = {k: v.unsqueeze(0).to("xla") for k, v in dataset[0].items()}
220+
loss, grad_norm = trainer.train_step(batch)
221+
222+
# Assert
223+
# Loss should be exactly 4.0 since we are summing 4 inputs of 1.0.
224+
assert loss.item() == 4.0
225+
226+
# ∂L/∂W = 1.0 for each weight in the linear layer
227+
# Expected gradient norm before clipping: sqrt(4 * 1^2) = 2.0
228+
assert pytest.approx(grad_norm.item(), rel=1e-5) == 2.0
229+
230+
# Verify the actual gradients on the model
231+
# The original gradient for each weight would be 1.0
232+
# With clipping factor 0.5 (1.0/2.0), each gradient becomes 0.5
233+
if hasattr(model.linear.weight, "grad") and model.linear.weight.grad is not None:
234+
expected_clipped_grad = torch.full_like(model.linear.weight, 0.5)
235+
torch.testing.assert_close(
236+
model.linear.weight.grad, expected_clipped_grad, rtol=1e-5, atol=1e-5
237+
)
238+
239+
# Also verify the gradient norm matches what we expect
240+
actual_grad_norm = torch.norm(model.linear.weight.grad)
241+
assert pytest.approx(actual_grad_norm.item(), rel=1e-5) == 1.0
242+
243+
244+
def test_trainer_clip_gradients_by_value(monkeypatch, dummy_config):
245+
"""Test correctness of gradient clipping by max absolute value in a train step."""
246+
import torch_xla
247+
248+
from torchprime.torch_xla_models.model_rewriting import sharding_initialization
249+
250+
# Arrange
251+
monkeypatch.setattr(
252+
sharding_initialization, "get_mesh", lambda *args, **kwargs: FakeMesh()
253+
)
254+
monkeypatch.setattr(
255+
sharding_initialization,
256+
"shard_torch_xla_model_from_config",
257+
lambda model, *args, **kwargs: model,
258+
)
259+
260+
class SumModel(nn.Module):
261+
def __init__(self):
262+
super().__init__()
263+
self.linear = nn.Linear(4, 1, bias=False)
264+
265+
def forward(self, input_ids=None, attention_mask=None, **kwargs):
266+
logits = self.linear(input_ids)
267+
loss = logits.mean()
268+
return logits, loss
269+
270+
dummy_config.task.max_grad_value = 0.5
271+
dummy_config.task.max_grad_norm = None
272+
model = SumModel().to("xla")
273+
with torch.no_grad():
274+
model.linear.weight.fill_(1.0)
275+
dataset = DummyDataset()
276+
trainer = Trainer(model, dummy_config, dataset)
277+
torch_xla.sync()
278+
279+
# Act
280+
batch = {k: v.unsqueeze(0).to("xla") for k, v in dataset[0].items()}
281+
loss, grad_norm = trainer.train_step(batch)
282+
283+
# Assert
284+
# Loss should be exactly 4.0 since we are summing 4 inputs of 1.0.
285+
assert loss.item() == 4.0
286+
287+
# ∂L/∂W = 1.0 for each weight in the linear layer
288+
# Expected gradient norm before clipping: sqrt(4 * 1^2) = 2.0
289+
assert pytest.approx(grad_norm.item(), rel=1e-5) == 2.0
290+
291+
# Verify the actual gradients on the model
292+
# The original gradient for each weight would be 1.0
293+
# With value clipping at 0.5, each gradient becomes 0.5
294+
if hasattr(model.linear.weight, "grad") and model.linear.weight.grad is not None:
295+
expected_clipped_grad = torch.full_like(model.linear.weight, 0.5)
296+
torch.testing.assert_close(
297+
model.linear.weight.grad, expected_clipped_grad, rtol=1e-5, atol=1e-5
298+
)
299+
300+
# Verify all gradient values are within [-max_grad_value, max_grad_value]
301+
assert torch.all(model.linear.weight.grad <= dummy_config.task.max_grad_value)
302+
assert torch.all(model.linear.weight.grad >= -dummy_config.task.max_grad_value)

torchprime/torch_xla_models/trainer/base_trainer.py

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from timeit import default_timer as timer
1919

2020
import torch
21+
import torch.nn.utils as nn_utils
2122
import torch_xla
2223
import torch_xla.core.xla_model as xm
2324
import torch_xla.debug.profiler as xp
@@ -182,12 +183,13 @@ def _get_train_dataloader(self) -> pl.MpDeviceLoader:
182183
return loader
183184

184185
def _log_to_tensorboard(
185-
self, epoch: float, step: int, loss: float, learning_rate: float
186+
self, epoch: float, step: int, loss: float, learning_rate: float, grad_norm: float
186187
):
187188
"""Log metrics to TensorBoard."""
188189
self.summary_writer.add_scalar("train/epoch", epoch, step)
189190
self.summary_writer.add_scalar("train/loss", loss, step)
190191
self.summary_writer.add_scalar("train/learning_rate", learning_rate, step)
192+
self.summary_writer.add_scalar("train/grad_norm", grad_norm, step)
191193
self.summary_writer.flush()
192194

193195
def train_loop(self, metrics_logger) -> None:
@@ -214,22 +216,26 @@ def train_loop(self, metrics_logger) -> None:
214216
batch = next(train_iterator)
215217

216218
trace_start_time = timer()
217-
loss = self.train_step(batch)
219+
loss, grad_norm = self.train_step(batch)
218220
trace_end_time = timer()
219221

220222
if step % self.config.logging_steps == 0:
221223

222-
def step_closure(epoch, step, loss, trace_start_time, trace_end_time, lr):
224+
def step_closure(
225+
epoch, step, loss, grad_norm, trace_start_time, trace_end_time, lr
226+
):
223227
loss = loss.detach().item()
228+
grad_norm = grad_norm.detach().item()
224229
logger.info(
225-
"Epoch: %d, step: %d, loss: %.4f, lr: %.2e, trace time: %.2f ms",
230+
"Epoch: %d, step: %d, loss: %.4f, grad_norm: %.4f, lr: %.2e, trace time: %.2f ms",
226231
epoch,
227232
step,
228233
loss,
234+
grad_norm,
229235
lr,
230236
(trace_end_time - trace_start_time) * 1000,
231237
)
232-
self._log_to_tensorboard(epoch, step, loss, lr)
238+
self._log_to_tensorboard(epoch, step, loss, lr, grad_norm)
233239
if math.isnan(loss):
234240
raise ValueError(f"Loss is NaN at step {step}")
235241

@@ -239,6 +245,7 @@ def step_closure(epoch, step, loss, trace_start_time, trace_end_time, lr):
239245
epoch,
240246
step,
241247
loss,
248+
grad_norm,
242249
trace_start_time,
243250
trace_end_time,
244251
self.lr_scheduler.get_last_lr()[0],
@@ -301,10 +308,25 @@ def step_closure(epoch, step, loss, trace_start_time, trace_end_time, lr):
301308
OmegaConf.save(config=self.config, f=config_save_path)
302309

303310
@torch_xla.compile(full_graph=True)
304-
def train_step(self, batch: dict) -> torch.Tensor:
311+
def train_step(self, batch: dict) -> tuple[torch.Tensor, torch.Tensor]:
305312
_logits, loss = self.model(**batch)
306313
loss.backward()
314+
grad_norm = self.clip_gradients()
307315
self.optimizer.step()
308316
self.lr_scheduler.step()
309317
self.model.zero_grad()
310-
return loss
318+
return loss, grad_norm
319+
320+
def clip_gradients(self):
321+
"""Clip gradients by the specified max norm and/or max absolute value."""
322+
max_grad_norm = self.config.task.max_grad_norm
323+
if max_grad_norm is None or max_grad_norm <= 0:
324+
grad_norm = nn_utils.get_total_norm(self.model.parameters(), norm_type=2)
325+
else:
326+
grad_norm = nn_utils.clip_grad_norm_(
327+
self.model.parameters(), max_norm=max_grad_norm, norm_type=2
328+
)
329+
max_grad_value = self.config.task.max_grad_value
330+
if max_grad_value is not None and max_grad_value > 0:
331+
nn_utils.clip_grad_value_(self.model.parameters(), clip_value=max_grad_value)
332+
return grad_norm

0 commit comments

Comments
 (0)