14
14
See the License for the specific language governing permissions and
15
15
limitations under the License.
16
16
"""
17
- import os
17
+ import logging
18
18
import math
19
+ import os
20
+ import time
19
21
from abc import ABC , abstractmethod
20
22
21
23
from dlio_benchmark .common .enumerations import CheckpointLocationType
22
24
from dlio_benchmark .storage .storage_factory import StorageFactory
23
25
from dlio_benchmark .utils .config import ConfigArguments
24
26
from dlio_benchmark .utils .utility import DLIOMPI , utcnow
25
- import logging
27
+
26
28
27
29
def get_datatype_size (datatype ):
28
30
if datatype == "int8" or datatype == "uint8" :
@@ -137,9 +139,9 @@ def __init__(self, ext):
137
139
model_checkpoint_size /= self .data_parallelism
138
140
self .checkpoint_size = model_checkpoint_size + optimizer_checkpoint_size
139
141
if self .args .my_rank == 0 :
140
- logging .info (f"{ utcnow ()} Model size: { model_checkpoint_size } GB" )
141
- logging .info (f"{ utcnow ()} Optimizer state size: { optimizer_checkpoint_size } GB" )
142
- logging .info (f"{ utcnow ()} Total checkpoint size: { self .checkpoint_size } GB" )
142
+ logging .info (f"{ utcnow ()} Model size: { model_checkpoint_size :.4f } GB" )
143
+ logging .info (f"{ utcnow ()} Optimizer state size: { optimizer_checkpoint_size :.4f } GB" )
144
+ logging .info (f"{ utcnow ()} Total checkpoint size: { self .checkpoint_size :.4f } GB" )
143
145
144
146
@abstractmethod
145
147
def get_tensor (self , length , datatype = "int8" ):
@@ -262,10 +264,8 @@ def save_checkpoint(self, epoch, step_number):
262
264
if self .model_state :
263
265
self .save_state (suffix = f"{ checkpoint_id } /model_states-{ my_rank } " , state = self .model_state , fsync = self .args .checkpoint_fsync )
264
266
265
- if self .optimization_state :
266
- self .save_state (suffix = f"{ checkpoint_id } /zero_pp_rank_{ self .data_parallelism_rank } _mp_rank_{ self .model_parallelism_rank } _optim_states" , state = self .optimization_state , fsync = self .args .checkpoint_fsync )
267
-
268
267
if self .layer_state :
268
+ start_time = time .time ()
269
269
if self .args .zero_stage < 3 and self .args .zero_stage > 0 :
270
270
# if pp is turned on, we assume that the model is sharded across the pipeline stages
271
271
if self .data_parallelism_rank == 0 and self .args .num_layers > 0 :
@@ -279,6 +279,16 @@ def save_checkpoint(self, epoch, step_number):
279
279
# in this case, model is sharded across the data parallel ranks
280
280
assert (self .args .pipeline_parallelism == 1 )
281
281
self .save_state (suffix = f"{ checkpoint_id } /zero_pp_rank_{ self .data_parallelism_rank } _mp_rank_{ self .model_parallelism_rank } _model_states" , state = self .layer_state , fsync = self .args .checkpoint_fsync )
282
+ save_model_time = time .time () - start_time
283
+ if my_rank == 0 :
284
+ logging .info (f"{ utcnow ()} Saved model checkpoint in { save_model_time :.4f} seconds" )
285
+
286
+ if self .optimization_state :
287
+ start_time = time .time ()
288
+ self .save_state (suffix = f"{ checkpoint_id } /zero_pp_rank_{ self .data_parallelism_rank } _mp_rank_{ self .model_parallelism_rank } _optim_states" , state = self .optimization_state , fsync = self .args .checkpoint_fsync )
289
+ save_optimizer_time = time .time () - start_time
290
+ if my_rank == 0 :
291
+ logging .info (f"{ utcnow ()} Saved optimizer checkpoint in { save_optimizer_time :.4f} seconds" )
282
292
283
293
@abstractmethod
284
294
def load_checkpoint (self , epoch , step_number ):
@@ -288,13 +298,12 @@ def load_checkpoint(self, epoch, step_number):
288
298
checkpoint_id = f"global_epoch{ epoch } _step{ step_number } "
289
299
self .checkpoint_storage .create_node (checkpoint_id , exist_ok = True )
290
300
if self .rank_to_checkpoint == my_rank :
301
+
291
302
if self .model_state :
292
303
self .load_state (suffix = f"{ checkpoint_id } /model_states-{ my_rank } " , state = self .model_state , fsync = self .args .checkpoint_fsync )
293
-
294
- if self .optimization_state :
295
- self .load_state (suffix = f"{ checkpoint_id } /zero_pp_rank_{ self .data_parallelism_rank } _mp_rank_{ self .model_parallelism_rank } _optim_states" , state = self .optimization_state )
296
304
297
305
if self .layer_state :
306
+ start_time = time .time ()
298
307
if self .args .zero_stage < 3 and self .args .zero_stage > 0 :
299
308
# if pp is turned on, we assume that the model is sharded across the pipeline stages
300
309
if self .data_parallelism_rank == 0 and self .args .num_layers > 0 :
@@ -308,6 +317,16 @@ def load_checkpoint(self, epoch, step_number):
308
317
# in this case, model is sharded across the data parallel ranks
309
318
assert (self .args .pipeline_parallelism == 1 )
310
319
self .load_state (suffix = f"{ checkpoint_id } /zero_pp_rank_{ self .data_parallelism_rank } _mp_rank_{ self .model_parallelism_rank } _model_states" , state = self .layer_state )
320
+ load_model_time = time .time () - start_time
321
+ if my_rank == 0 :
322
+ logging .info (f"{ utcnow ()} Loaded model checkpoint in { load_model_time :.4f} seconds" )
323
+
324
+ if self .optimization_state :
325
+ start_time = time .time ()
326
+ self .load_state (suffix = f"{ checkpoint_id } /zero_pp_rank_{ self .data_parallelism_rank } _mp_rank_{ self .model_parallelism_rank } _optim_states" , state = self .optimization_state )
327
+ load_optimizer_time = time .time () - start_time
328
+ if my_rank == 0 :
329
+ logging .info (f"{ utcnow ()} Loaded optimizer checkpoint in { load_optimizer_time :.4f} seconds" )
311
330
312
331
@abstractmethod
313
332
def finalize (self ):
0 commit comments