Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add more fine-grained performance metrics #11619

Merged
merged 30 commits into from
Jan 23, 2025
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
0b6dff6
add support for megatron timers
ashors1 Dec 14, 2024
22af9f5
add option to log tokens/sec/gpu
ashors1 Dec 16, 2024
90f2964
add flops callback to nemo2
ashors1 Dec 16, 2024
ffcc521
remove print
ashors1 Dec 16, 2024
04a3a98
fixes for multi-gpu
ashors1 Dec 17, 2024
f963bf6
Apply isort and black reformatting
ashors1 Dec 17, 2024
466f9c3
log min/max time across ranks
ashors1 Dec 18, 2024
1d1bbc1
Apply isort and black reformatting
ashors1 Dec 18, 2024
6d6376a
log flops to wandb/tb
ashors1 Dec 19, 2024
834cc29
Apply isort and black reformatting
ashors1 Dec 19, 2024
71dd945
Merge branch 'main' of github.com:NVIDIA/NeMo into ashors/megatron-ti…
ashors1 Jan 8, 2025
771a772
Merge branch 'ashors/megatron-timers' of github.com:NVIDIA/NeMo into …
ashors1 Jan 8, 2025
9b37cfe
cleanup
ashors1 Jan 8, 2025
987d5f2
Apply isort and black reformatting
ashors1 Jan 8, 2025
62d9f7e
add copyright
ashors1 Jan 9, 2025
0c11c99
Merge branch 'main' of github.com:NVIDIA/NeMo into ashors/megatron-ti…
ashors1 Jan 12, 2025
33b7977
bug fix
ashors1 Jan 12, 2025
9f27a34
fix tokens/sec/gpu formula
ashors1 Jan 14, 2025
21be58c
Apply isort and black reformatting
ashors1 Jan 14, 2025
da00f60
update documentation
ashors1 Jan 14, 2025
ded5e3c
Apply isort and black reformatting
ashors1 Jan 14, 2025
e99c381
formatting
ashors1 Jan 14, 2025
1415fac
use dataclass for flops hparams
ashors1 Jan 15, 2025
9211aa9
Apply isort and black reformatting
ashors1 Jan 15, 2025
4fb4414
reorganize
ashors1 Jan 15, 2025
8b49c07
fix
ashors1 Jan 15, 2025
8c81e07
Merge branch 'main' of github.com:NVIDIA/NeMo into ashors/megatron-ti…
ashors1 Jan 16, 2025
48b0ed9
add typehint
ashors1 Jan 17, 2025
df5c313
fix
ashors1 Jan 17, 2025
b89b2fb
Apply isort and black reformatting
ashors1 Jan 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
264 changes: 264 additions & 0 deletions nemo/lightning/pytorch/callbacks/flops_callback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,264 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import glob
Fixed Show fixed Hide fixed
import os
Fixed Show fixed Hide fixed
from typing import Any, Dict, List, Optional

import lightning.pytorch as pl
import numpy as np
import torch
from lightning.pytorch.callbacks import Callback

from nemo.collections.common.parts.perf_metrics_utils import LLM_VOCAB_SIZE_MAP
from nemo.lightning.pytorch.callbacks import PEFT
from nemo.utils import logging

__all__ = ["FLOPsMeasurementCallback"]


class FLOPsMeasurementCallback(Callback):
"""
Calculate FLOPs per second after last train step for a given job run.

Args:
model_config (Dict[str, Any]): params for running the experiment/job.
Expects a nested dictionary with parent keys
1. run- for assessing model name (Eg. 'gpt3', 'llama2', etc.) from sub-key 'name'.
'name' usually has value like- train_gpt3_5b_*, which is matched to model name 'gpt3'.
2. exp_manager- for accessing 'explicit_log_dir'. tensorboard log file is stored here,
used for accessing step time needed for calculating TFLOPs per sec per GPU
3. trainer- for accessing 'num_nodes' and 'devices' needed for calculating
TFLOPs per sec per GPU
4. model- Hyperparams for the model. Specifically- global batch size, sequence length,
hidden size, ffn hidden size, num_layers, num_attention_heads, num_query_groups,
moe_router_topk. (list might increase with new models as required)
log_dir (Optional[str]): Directory with tenbsorboard log file. If present, will overrride
'explicit_log_dir' in model_config. Defaults to None.
model_name (Optional[str]): If present, will override 'name' under 'run' in model_config.
Defaults to None.
"""

higher_is_better = True

def __init__(
self,
model_config: Dict[str, Any],
data_config: pl.LightningDataModule,
model_name: Optional[str],
ashors1 marked this conversation as resolved.
Show resolved Hide resolved
):
self.model_cfg = model_config
self.data_cfg = data_config

# use config params only when NOT provided explicitly
self.model = model_name

self.gbs = self.data_cfg.global_batch_size
self.enc_seq_len = self.model_cfg.seq_length
self.hs = self.model_cfg.hidden_size
self.layers = self.model_cfg.num_layers
self.ffn_hs = self.model_cfg.ffn_hidden_size
self.attention_heads = self.model_cfg.num_attention_heads
self.moe_router_topk = self.model_cfg.moe_router_topk
ashors1 marked this conversation as resolved.
Show resolved Hide resolved

# this handles both- 1. key is present, value is None; 2. key is absent
self.query_groups = self.model_cfg.num_query_groups
if self.query_groups is None:
self.query_groups = self.attention_heads

self.model = self.model.lower() if self.model is not None else self.model

self.avg_train_step_time = 0

def on_train_start(self, trainer, pl_module):
has_lora = False
Fixed Show fixed Hide fixed
for callback in trainer.callbacks:
if isinstance(callback, PEFT):
raise NotImplementedError("FLOPs measurement not supported for finetuning jobs")

def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx: int):
"""
PyTorch Lightning callback hook to calculate TFLOPs per sec per GPU after training
"""
tflops_per_sec_per_gpu = -1
Fixed Show fixed Hide fixed

try:
self.avg_train_step_time += trainer.progress_bar_metrics['train_step_timing in s']
except KeyError:
print("'train_step_timing in s' not found. Make sure to use TimingCallback with FLOPsMeasurementCallback.")

n = trainer.strategy.current_epoch_step
if n % trainer.log_every_n_steps == 0:
## skip calculation if we haven't accumulated any timing data
if self.avg_train_step_time == 0:
return
tflops_per_sec_per_gpu = self.eval_tflops_per_sec_per_gpu(
self.avg_train_step_time / trainer.log_every_n_steps
)
self.avg_train_step_time = 0
pl_module.log(
"tflops_per_sec_per_gpu",
tflops_per_sec_per_gpu,
on_step=True,
on_epoch=False,
batch_size=1,
prog_bar=True,
)

def eval_tflops_per_sec_per_gpu(self, train_step_time: List | float | int) -> float:
"""
Args:
train_step_time (Any[List, float, int]): Train step time (in seconds).
Step time will be less stable for initial steps (~10 steps)- less
accurate measurement
Use average step time over several steps for higher accuracy
Returns:
(float): Model TFLOPs per sec per gpu
"""
total_flops, flops_per_gpu = self.eval_model_flops()

if not isinstance(train_step_time, list):
train_step_time = [train_step_time]
# efficient mean computation if num train steps is very large
step_time_arr = np.array(train_step_time)
train_step_time = np.mean(step_time_arr[len(step_time_arr) // 2 :])

return flops_per_gpu / (1e12 * train_step_time)

def eval_model_flops(self):
"""
Calculate model FLOPs for a given model
"""

model_flops_map = {
"gpt3": self._gpt3,
"llama2": self._llama2,
"llama3": self._llama3,
"nemotron": self._nemotron,
"mixtral": self._mixtral,
"bert": self._bert,
}

if self.model is not None:
model_matches = [model for model in model_flops_map if model in self.model]
self.model = model_matches[0] if len(model_matches) > 0 else self.model
if self.model not in model_flops_map:
logging.info(f"FLOPs measurement supported for {list(model_flops_map.keys())}")
raise KeyError(f"Failed to extract valid model name from or missing FLOPs calculations for {self.model}")

total_flops = model_flops_map[self.model]()
num_devices = torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1
flops_per_gpu = total_flops / num_devices

return total_flops, flops_per_gpu

def _gpt3(self):
"""Model FLOPs for GPT3 family"""

vocab_size = LLM_VOCAB_SIZE_MAP["gpt3"]

return (
24 * self.gbs * self.enc_seq_len * self.hs * self.hs
+ 4 * self.gbs * self.enc_seq_len * self.enc_seq_len * self.hs
) * (3 * self.layers) + (6 * self.gbs * self.enc_seq_len * self.hs * vocab_size)
ashors1 marked this conversation as resolved.
Show resolved Hide resolved

def _llama2(self):
"""Model FLOPs for llama2 family"""
vocab_size = LLM_VOCAB_SIZE_MAP["llama2"]

return (
self.gbs
* self.enc_seq_len
* self.layers
* self.hs
* self.hs
* (
12
+ (12 * self.query_groups / self.attention_heads)
+ (18 * self.ffn_hs / self.hs)
+ (12 * self.enc_seq_len / self.hs)
+ (6 * vocab_size / (self.layers * self.hs))
)
)

def _llama3(self):
"""Model FLOPs for llama3 family"""
vocab_size = LLM_VOCAB_SIZE_MAP["llama3"]

return (
self.gbs
* self.enc_seq_len
* self.layers
* self.hs
* self.hs
* (
12
+ (12 * self.query_groups / self.attention_heads)
+ (18 * self.ffn_hs / self.hs)
+ (12 * self.enc_seq_len / self.hs)
+ (6 * vocab_size / (self.layers * self.hs))
)
)

def _nemotron(self):
"""Model FLOPs for nemotron family"""
vocab_size = LLM_VOCAB_SIZE_MAP["nemotron"]

return (
self.gbs
* self.enc_seq_len
* self.layers
* self.hs
* self.hs
* (
12
+ (12 * self.query_groups / self.attention_heads)
+ (12 * self.ffn_hs / self.hs)
+ (12 * self.enc_seq_len / self.hs)
+ (6 * vocab_size / (self.layers * self.hs))
)
)

def _mixtral(self):
"""Model FLOPs for mixtral family"""
vocab_size = LLM_VOCAB_SIZE_MAP["mixtral"]

return (
self.gbs
* self.enc_seq_len
* self.layers
* self.hs
* self.hs
* (
12
+ (12 * self.query_groups / self.attention_heads)
+ (18 * self.moe_router_topk * self.ffn_hs / self.hs)
+ (12 * self.enc_seq_len / self.hs)
+ (6 * vocab_size / (self.layers * self.hs))
)
)

def _bert(self):
"""Model FLOPs for BERT family"""
vocab_size = LLM_VOCAB_SIZE_MAP["bert"]

return (
72
* self.gbs
* self.layers
* self.enc_seq_len
* self.hs
* self.hs
* (1 + (self.enc_seq_len / (6 * self.hs)) + (vocab_size / (12 * self.hs * self.layers)))
)
14 changes: 13 additions & 1 deletion nemo/lightning/pytorch/callbacks/progress_printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,15 @@
## TODO(ashors): handle nan losses
@override
def on_train_batch_end(self, trainer, pl_module, *_, **__):
n = trainer.strategy.current_epoch_step

if self.should_log(n) and getattr(trainer.strategy, "timers", None):
timers = trainer.strategy._mcore_config.timers # pointer to timers used in megatron
megatron_log_string = self.log_megatron_timers(timers)

if self.is_disabled:
return
n = trainer.strategy.current_epoch_step

metrics = self.get_metrics(trainer, pl_module)
for key in metrics:
if key in self.exclude_metrics:
Expand All @@ -138,6 +144,7 @@
prefix = self.train_description + f" epoch {trainer.current_epoch}, iteration {n-1}/{self.total-1}"
log_string = self.format_string(prefix, self.average_metrics_dict)
print(log_string)
print(megatron_log_string, flush=True)

self.total_metrics_dict = defaultdict(lambda: 0.0)

Expand Down Expand Up @@ -201,3 +208,8 @@

def should_log(self, n):
return n % self.log_interval == 0

def log_megatron_timers(self, timers):
Fixed Show fixed Hide fixed
output_string = timers.get_all_timers_string(names=None, normalizer=self.log_interval)
if output_string is not None:
return output_string + "\n"
ashors1 marked this conversation as resolved.
Show resolved Hide resolved
12 changes: 12 additions & 0 deletions nemo/lightning/pytorch/strategies/megatron_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
from lightning.pytorch.strategies.ddp import DDPStrategy
from lightning.pytorch.trainer.states import RunningStage, TrainerFn
from lightning.pytorch.utilities.types import STEP_OUTPUT
from megatron.core import Timers
from megatron.core.distributed import DistributedDataParallelConfig
from megatron.core.optimizer import OptimizerConfig
from torch import nn
Expand Down Expand Up @@ -168,6 +169,12 @@ class MegatronStrategy(DDPStrategy, io.IOMixin):
that prints the metrics to stdout. Suitable for non-interactive settings.
progress_interval (int): How frequently to print progress to stdout. Only used when
replace_progress_bar is True.
megatron_log_level (int): Granularity level to measure and report timing.
0: report only iteration time and make sure timing does not introduce extra overhead.
1: report timing for operations that are executed very limited times (basically once) during
each iteration (such as gradient all-reduce)
2: report timing for operations that migh be executed numerous times during each iteration.
Note that setting the level to 1 or 2 might cause increase in iteration time.
**kwargs: Additional keyword arguments.

Note:
Expand Down Expand Up @@ -214,6 +221,7 @@ def __init__(
replace_progress_bar: bool = True,
progress_interval: int = 1,
restore_config: Optional[RestoreConfig] = None,
megatron_log_level: int = 0,
**kwargs,
) -> None:
super().__init__(
Expand Down Expand Up @@ -264,6 +272,7 @@ def __init__(
self.progress_interval = progress_interval

self.restore_config = restore_config
self.timers = Timers(megatron_log_level, "minmax") ## could also set this for optimizer if we want

self._ddp = ddp
if ddp == "megatron":
Expand Down Expand Up @@ -309,6 +318,9 @@ def connect(self, model: pl.LightningModule) -> None:

model.config = update_config_with_dtype_overrides(dtype_config, model.config)

## add megatron timer to config
model.config.timers = self.timers

has_optim = getattr(model, "optim", None)
if has_optim and self._setup_optimizers:
opt_config = getattr(model.optim, "config", None)
Expand Down
19 changes: 18 additions & 1 deletion nemo/utils/exp_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,17 @@
from nemo.utils.env_var_parsing import get_envbool
from nemo.utils.exceptions import NeMoBaseException
from nemo.utils.get_rank import is_global_rank_zero
from nemo.utils.import_utils import safe_import_from
from nemo.utils.lightning_logger_patch import add_filehandlers_to_pl_logger
from nemo.utils.loggers import ClearMLLogger, ClearMLParams, DLLogger, DLLoggerParams, MLFlowParams
from nemo.utils.mcore_logger import add_handlers_to_mcore_logger
from nemo.utils.model_utils import uninject_model_parallel_rank

get_num_microbatches, HAVE_MCORE_MBATCH_CALCULATOR = safe_import_from(
"megatron.core.num_microbatches_calculator", "get_num_microbatches"
)


try:
# `ptl_resiliency` is included in `gwe_resiliency_pkg` package
from ptl_resiliency import StragglerDetectionCallback
Expand Down Expand Up @@ -242,7 +248,8 @@ class TimingCallback(Callback):
Logs execution time of train/val/test steps
"""

def __init__(self, timer_kwargs={}):
def __init__(self, log_tokens_per_sec: bool = False, timer_kwargs={}):
self.log_tokens_per_sec = log_tokens_per_sec
self.timer = timers.NamedTimer(**timer_kwargs)

def _on_batch_start(self, name):
Expand Down Expand Up @@ -276,6 +283,16 @@ def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):

def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
self._on_batch_end("train_step_timing", pl_module)
if self.log_tokens_per_sec:
tokens_per_gpu = batch["tokens"].shape[0] * batch["tokens"].shape[1] * get_num_microbatches()
pl_module.log(
"tokens_per_sec_per_gpu",
tokens_per_gpu / (torch.as_tensor(self.timer["train_step_timing"])),
on_step=True,
on_epoch=False,
batch_size=1,
prog_bar=True,
)

def on_validation_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx=0):
self._on_batch_start("validation_step_timing")
Expand Down
Loading