Skip to content

Commit 5e9a98a

Browse files
committed
added timing for model and optimizer saparately
1 parent b535005 commit 5e9a98a

File tree

1 file changed

+30
-11
lines changed

1 file changed

+30
-11
lines changed

dlio_benchmark/checkpointing/base_checkpointing.py

Lines changed: 30 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,17 @@
1414
See the License for the specific language governing permissions and
1515
limitations under the License.
1616
"""
17-
import os
17+
import logging
1818
import math
19+
import os
20+
import time
1921
from abc import ABC, abstractmethod
2022

2123
from dlio_benchmark.common.enumerations import CheckpointLocationType
2224
from dlio_benchmark.storage.storage_factory import StorageFactory
2325
from dlio_benchmark.utils.config import ConfigArguments
2426
from dlio_benchmark.utils.utility import DLIOMPI, utcnow
25-
import logging
27+
2628

2729
def get_datatype_size(datatype):
2830
if datatype == "int8" or datatype == "uint8":
@@ -137,9 +139,9 @@ def __init__(self, ext):
137139
model_checkpoint_size /= self.data_parallelism
138140
self.checkpoint_size = model_checkpoint_size + optimizer_checkpoint_size
139141
if self.args.my_rank == 0:
140-
logging.info(f"{utcnow()} Model size: {model_checkpoint_size} GB")
141-
logging.info(f"{utcnow()} Optimizer state size: {optimizer_checkpoint_size} GB")
142-
logging.info(f"{utcnow()} Total checkpoint size: {self.checkpoint_size} GB")
142+
logging.info(f"{utcnow()} Model size: {model_checkpoint_size:.4f} GB")
143+
logging.info(f"{utcnow()} Optimizer state size: {optimizer_checkpoint_size:.4f} GB")
144+
logging.info(f"{utcnow()} Total checkpoint size: {self.checkpoint_size:.4f} GB")
143145

144146
@abstractmethod
145147
def get_tensor(self, length, datatype="int8"):
@@ -262,10 +264,8 @@ def save_checkpoint(self, epoch, step_number):
262264
if self.model_state:
263265
self.save_state(suffix=f"{checkpoint_id}/model_states-{my_rank}", state=self.model_state, fsync = self.args.checkpoint_fsync)
264266

265-
if self.optimization_state:
266-
self.save_state(suffix=f"{checkpoint_id}/zero_pp_rank_{self.data_parallelism_rank}_mp_rank_{self.model_parallelism_rank}_optim_states", state=self.optimization_state, fsync = self.args.checkpoint_fsync)
267-
268267
if self.layer_state:
268+
start_time = time.time()
269269
if self.args.zero_stage < 3 and self.args.zero_stage > 0:
270270
# if pp is turned on, we assume that the model is sharded across the pipeline stages
271271
if self.data_parallelism_rank == 0 and self.args.num_layers > 0:
@@ -279,6 +279,16 @@ def save_checkpoint(self, epoch, step_number):
279279
# in this case, model is sharded across the data parallel ranks
280280
assert(self.args.pipeline_parallelism == 1)
281281
self.save_state(suffix=f"{checkpoint_id}/zero_pp_rank_{self.data_parallelism_rank}_mp_rank_{self.model_parallelism_rank}_model_states", state=self.layer_state, fsync = self.args.checkpoint_fsync)
282+
save_model_time = time.time() - start_time
283+
if my_rank == 0:
284+
logging.info(f"{utcnow()} Saved model checkpoint in {save_model_time:.4f} seconds")
285+
286+
if self.optimization_state:
287+
start_time = time.time()
288+
self.save_state(suffix=f"{checkpoint_id}/zero_pp_rank_{self.data_parallelism_rank}_mp_rank_{self.model_parallelism_rank}_optim_states", state=self.optimization_state, fsync = self.args.checkpoint_fsync)
289+
save_optimizer_time = time.time() - start_time
290+
if my_rank == 0:
291+
logging.info(f"{utcnow()} Saved optimizer checkpoint in {save_optimizer_time:.4f} seconds")
282292

283293
@abstractmethod
284294
def load_checkpoint(self, epoch, step_number):
@@ -288,13 +298,12 @@ def load_checkpoint(self, epoch, step_number):
288298
checkpoint_id = f"global_epoch{epoch}_step{step_number}"
289299
self.checkpoint_storage.create_node(checkpoint_id, exist_ok=True)
290300
if self.rank_to_checkpoint == my_rank:
301+
291302
if self.model_state:
292303
self.load_state(suffix=f"{checkpoint_id}/model_states-{my_rank}", state=self.model_state, fsync = self.args.checkpoint_fsync)
293-
294-
if self.optimization_state:
295-
self.load_state(suffix=f"{checkpoint_id}/zero_pp_rank_{self.data_parallelism_rank}_mp_rank_{self.model_parallelism_rank}_optim_states", state=self.optimization_state)
296304

297305
if self.layer_state:
306+
start_time = time.time()
298307
if self.args.zero_stage < 3 and self.args.zero_stage > 0:
299308
# if pp is turned on, we assume that the model is sharded across the pipeline stages
300309
if self.data_parallelism_rank == 0 and self.args.num_layers > 0:
@@ -308,6 +317,16 @@ def load_checkpoint(self, epoch, step_number):
308317
# in this case, model is sharded across the data parallel ranks
309318
assert(self.args.pipeline_parallelism == 1)
310319
self.load_state(suffix=f"{checkpoint_id}/zero_pp_rank_{self.data_parallelism_rank}_mp_rank_{self.model_parallelism_rank}_model_states", state=self.layer_state)
320+
load_model_time = time.time() - start_time
321+
if my_rank == 0:
322+
logging.info(f"{utcnow()} Loaded model checkpoint in {load_model_time:.4f} seconds")
323+
324+
if self.optimization_state:
325+
start_time = time.time()
326+
self.load_state(suffix=f"{checkpoint_id}/zero_pp_rank_{self.data_parallelism_rank}_mp_rank_{self.model_parallelism_rank}_optim_states", state=self.optimization_state)
327+
load_optimizer_time = time.time() - start_time
328+
if my_rank == 0:
329+
logging.info(f"{utcnow()} Loaded optimizer checkpoint in {load_optimizer_time:.4f} seconds")
311330

312331
@abstractmethod
313332
def finalize(self):

0 commit comments

Comments
 (0)