@@ -83,6 +83,8 @@ def dummy_config():
83
83
"global_batch_size" : 4 ,
84
84
"max_steps" : 2 ,
85
85
"optimizer" : {"type" : "adafactor" , "learning_rate" : 1e-3 },
86
+ "max_grad_norm" : None ,
87
+ "max_grad_value" : None ,
86
88
"lr_scheduler" : {"type" : "constant" , "warmup_steps" : 0 },
87
89
},
88
90
"run_name" : None ,
@@ -170,7 +172,131 @@ def test_trainer_train_step(monkeypatch, dummy_config):
170
172
trainer = Trainer (model , dummy_config , dataset )
171
173
172
174
batch = {k : v .unsqueeze (0 ).to (device ) for k , v in dataset [0 ].items ()}
173
- loss = trainer .train_step (batch )
175
+ loss , grad_norm = trainer .train_step (batch )
174
176
175
177
assert isinstance (loss , torch .Tensor )
176
178
assert loss .ndim == 0 # scalar loss
179
+ assert isinstance (grad_norm , torch .Tensor )
180
+ assert grad_norm .ndim == 0 # scalar gradient norm
181
+
182
+
183
+ def test_trainer_clip_gradients_by_norm (monkeypatch , dummy_config ):
184
+ """Test correctness of gradient clipping by norm in a train step."""
185
+ import torch_xla
186
+
187
+ from torchprime .torch_xla_models .model_rewriting import sharding_initialization
188
+
189
+ # Arrange
190
+ monkeypatch .setattr (
191
+ sharding_initialization , "get_mesh" , lambda * args , ** kwargs : FakeMesh ()
192
+ )
193
+ monkeypatch .setattr (
194
+ sharding_initialization ,
195
+ "shard_torch_xla_model_from_config" ,
196
+ lambda model , * args , ** kwargs : model ,
197
+ )
198
+
199
+ class SumModel (nn .Module ):
200
+ def __init__ (self ):
201
+ super ().__init__ ()
202
+ self .linear = nn .Linear (4 , 1 , bias = False )
203
+
204
+ def forward (self , input_ids = None , attention_mask = None , ** kwargs ):
205
+ logits = self .linear (input_ids )
206
+ loss = logits .mean ()
207
+ return logits , loss
208
+
209
+ dummy_config .task .max_grad_norm = 1.0
210
+ dummy_config .task .max_grad_value = None
211
+ model = SumModel ().to ("xla" )
212
+ with torch .no_grad ():
213
+ model .linear .weight .fill_ (1.0 )
214
+ dataset = DummyDataset ()
215
+ trainer = Trainer (model , dummy_config , dataset )
216
+ torch_xla .sync ()
217
+
218
+ # Act
219
+ batch = {k : v .unsqueeze (0 ).to ("xla" ) for k , v in dataset [0 ].items ()}
220
+ loss , grad_norm = trainer .train_step (batch )
221
+
222
+ # Assert
223
+ # Loss should be exactly 4.0 since we are summing 4 inputs of 1.0.
224
+ assert loss .item () == 4.0
225
+
226
+ # ∂L/∂W = 1.0 for each weight in the linear layer
227
+ # Expected gradient norm before clipping: sqrt(4 * 1^2) = 2.0
228
+ assert pytest .approx (grad_norm .item (), rel = 1e-5 ) == 2.0
229
+
230
+ # Verify the actual gradients on the model
231
+ # The original gradient for each weight would be 1.0
232
+ # With clipping factor 0.5 (1.0/2.0), each gradient becomes 0.5
233
+ if hasattr (model .linear .weight , "grad" ) and model .linear .weight .grad is not None :
234
+ expected_clipped_grad = torch .full_like (model .linear .weight , 0.5 )
235
+ torch .testing .assert_close (
236
+ model .linear .weight .grad , expected_clipped_grad , rtol = 1e-5 , atol = 1e-5
237
+ )
238
+
239
+ # Also verify the gradient norm matches what we expect
240
+ actual_grad_norm = torch .norm (model .linear .weight .grad )
241
+ assert pytest .approx (actual_grad_norm .item (), rel = 1e-5 ) == 1.0
242
+
243
+
244
+ def test_trainer_clip_gradients_by_value (monkeypatch , dummy_config ):
245
+ """Test correctness of gradient clipping by max absolute value in a train step."""
246
+ import torch_xla
247
+
248
+ from torchprime .torch_xla_models .model_rewriting import sharding_initialization
249
+
250
+ # Arrange
251
+ monkeypatch .setattr (
252
+ sharding_initialization , "get_mesh" , lambda * args , ** kwargs : FakeMesh ()
253
+ )
254
+ monkeypatch .setattr (
255
+ sharding_initialization ,
256
+ "shard_torch_xla_model_from_config" ,
257
+ lambda model , * args , ** kwargs : model ,
258
+ )
259
+
260
+ class SumModel (nn .Module ):
261
+ def __init__ (self ):
262
+ super ().__init__ ()
263
+ self .linear = nn .Linear (4 , 1 , bias = False )
264
+
265
+ def forward (self , input_ids = None , attention_mask = None , ** kwargs ):
266
+ logits = self .linear (input_ids )
267
+ loss = logits .mean ()
268
+ return logits , loss
269
+
270
+ dummy_config .task .max_grad_value = 0.5
271
+ dummy_config .task .max_grad_norm = None
272
+ model = SumModel ().to ("xla" )
273
+ with torch .no_grad ():
274
+ model .linear .weight .fill_ (1.0 )
275
+ dataset = DummyDataset ()
276
+ trainer = Trainer (model , dummy_config , dataset )
277
+ torch_xla .sync ()
278
+
279
+ # Act
280
+ batch = {k : v .unsqueeze (0 ).to ("xla" ) for k , v in dataset [0 ].items ()}
281
+ loss , grad_norm = trainer .train_step (batch )
282
+
283
+ # Assert
284
+ # Loss should be exactly 4.0 since we are summing 4 inputs of 1.0.
285
+ assert loss .item () == 4.0
286
+
287
+ # ∂L/∂W = 1.0 for each weight in the linear layer
288
+ # Expected gradient norm before clipping: sqrt(4 * 1^2) = 2.0
289
+ assert pytest .approx (grad_norm .item (), rel = 1e-5 ) == 2.0
290
+
291
+ # Verify the actual gradients on the model
292
+ # The original gradient for each weight would be 1.0
293
+ # With value clipping at 0.5, each gradient becomes 0.5
294
+ if hasattr (model .linear .weight , "grad" ) and model .linear .weight .grad is not None :
295
+ expected_clipped_grad = torch .full_like (model .linear .weight , 0.5 )
296
+ torch .testing .assert_close (
297
+ model .linear .weight .grad , expected_clipped_grad , rtol = 1e-5 , atol = 1e-5
298
+ )
299
+
300
+ # Verify all gradient values are within [-max_grad_value, max_grad_value]
301
+ assert torch .all (model .linear .weight .grad <= dummy_config .task .max_grad_value )
302
+ assert torch .all (model .linear .weight .grad >= - dummy_config .task .max_grad_value )
0 commit comments