Skip to content

Commit

Permalink
Add more fine-grained performance metrics (#11619)
Browse files Browse the repository at this point in the history
* add support for megatron timers

Signed-off-by: ashors1 <[email protected]>

* add option to log tokens/sec/gpu

Signed-off-by: ashors1 <[email protected]>

* add flops callback to nemo2

Signed-off-by: ashors1 <[email protected]>

* remove print

Signed-off-by: ashors1 <[email protected]>

* fixes for multi-gpu

Signed-off-by: ashors1 <[email protected]>

* Apply isort and black reformatting

Signed-off-by: ashors1 <[email protected]>

* log min/max time across ranks

Signed-off-by: ashors1 <[email protected]>

* Apply isort and black reformatting

Signed-off-by: ashors1 <[email protected]>

* log flops to wandb/tb

Signed-off-by: ashors1 <[email protected]>

* Apply isort and black reformatting

Signed-off-by: ashors1 <[email protected]>

* cleanup

Signed-off-by: ashors1 <[email protected]>

* Apply isort and black reformatting

Signed-off-by: ashors1 <[email protected]>

* add copyright

Signed-off-by: ashors1 <[email protected]>

* bug fix

Signed-off-by: ashors1 <[email protected]>

* fix tokens/sec/gpu formula

Signed-off-by: ashors1 <[email protected]>

* Apply isort and black reformatting

Signed-off-by: ashors1 <[email protected]>

* update documentation

Signed-off-by: ashors1 <[email protected]>

* Apply isort and black reformatting

Signed-off-by: ashors1 <[email protected]>

* formatting

Signed-off-by: ashors1 <[email protected]>

* use dataclass for flops hparams

Signed-off-by: ashors1 <[email protected]>

* Apply isort and black reformatting

Signed-off-by: ashors1 <[email protected]>

* reorganize

Signed-off-by: ashors1 <[email protected]>

* fix

Signed-off-by: ashors1 <[email protected]>

* add typehint

Signed-off-by: ashors1 <[email protected]>

* fix

Signed-off-by: ashors1 <[email protected]>

* Apply isort and black reformatting

Signed-off-by: ashors1 <[email protected]>

---------

Signed-off-by: ashors1 <[email protected]>
Signed-off-by: ashors1 <[email protected]>
Co-authored-by: ashors1 <[email protected]>
  • Loading branch information
ashors1 and ashors1 authored Jan 23, 2025
1 parent 6aeef75 commit f0bf77b
Show file tree
Hide file tree
Showing 6 changed files with 378 additions and 121 deletions.
149 changes: 30 additions & 119 deletions nemo/collections/common/metrics/perf_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
import numpy as np
from lightning.pytorch.callbacks import Callback

from nemo.collections.common.parts.perf_metrics_utils import LLM_VOCAB_SIZE_MAP, read_tb_log
from nemo.utils import logging
from nemo.collections.common.parts.perf_metrics_utils import read_tb_log
from nemo.utils import flops_formulas, logging

__all__ = ["FLOPsMeasurementCallback"]

Expand Down Expand Up @@ -68,18 +68,29 @@ def __init__(
self.num_nodes = self.train_cfg.get('num_nodes', None)
self.num_gpus_per_node = self.train_cfg.get('devices', None)

self.gbs = self.model_cfg.get('global_batch_size', None)
self.enc_seq_len = self.model_cfg.get('encoder_seq_length', None)
self.hs = self.model_cfg.get('hidden_size', None)
self.layers = self.model_cfg.get('num_layers', None)
self.ffn_hs = self.model_cfg.get('ffn_hidden_size', None)
self.attention_heads = self.model_cfg.get('num_attention_heads', None)
self.moe_router_topk = self.model_cfg.get('moe_router_topk', None)
gbs = self.model_cfg.get('global_batch_size', None)
enc_seq_len = self.model_cfg.get('encoder_seq_length', None)
hs = self.model_cfg.get('hidden_size', None)
layers = self.model_cfg.get('num_layers', None)
ffn_hs = self.model_cfg.get('ffn_hidden_size', None)
attention_heads = self.model_cfg.get('num_attention_heads', None)
moe_router_topk = self.model_cfg.get('moe_router_topk', None)

# this handles both- 1. key is present, value is None; 2. key is absent
self.query_groups = self.model_cfg.get('num_query_groups', None)
if self.query_groups is None:
self.query_groups = self.attention_heads
query_groups = self.model_cfg.get('num_query_groups', None)
if query_groups is None:
query_groups = attention_heads

self.flops_config = flops_formulas.FLOPSConfig(
gbs=gbs,
enc_seq_len=enc_seq_len,
hs=hs,
layers=layers,
ffn_hs=ffn_hs,
attention_heads=attention_heads,
moe_router_topk=moe_router_topk,
query_groups=query_groups,
)

self.model = self.model.lower() if self.model is not None else self.model

Expand Down Expand Up @@ -128,12 +139,12 @@ def eval_model_flops(self):
"""

model_flops_map = {
"gpt3": self._gpt3,
"llama2": self._llama2,
"llama3": self._llama3,
"nemotron": self._nemotron,
"mixtral": self._mixtral,
"bert": self._bert,
"gpt3": flops_formulas.gpt3,
"llama2": flops_formulas.llama2,
"llama3": flops_formulas.llama3,
"nemotron": flops_formulas.nemotron,
"mixtral": flops_formulas.mixtral,
"bert": flops_formulas.bert,
}

if self.model is not None:
Expand All @@ -143,107 +154,7 @@ def eval_model_flops(self):
logging.info(f"FLOPs measurement supported for {list(model_flops_map.keys())}")
raise KeyError(f"Failed to extract valid model name from or missing FLOPs calculations for {self.model}")

total_flops = model_flops_map[self.model]()
total_flops = model_flops_map[self.model](self.flops_config)
flops_per_gpu = total_flops / (self.num_nodes * self.num_gpus_per_node)

return total_flops, flops_per_gpu

def _gpt3(self):
"""Model FLOPs for GPT3 family"""

vocab_size = LLM_VOCAB_SIZE_MAP["gpt3"]

return (
24 * self.gbs * self.enc_seq_len * self.hs * self.hs
+ 4 * self.gbs * self.enc_seq_len * self.enc_seq_len * self.hs
) * (3 * self.layers) + (6 * self.gbs * self.enc_seq_len * self.hs * vocab_size)

def _llama2(self):
"""Model FLOPs for llama2 family"""
vocab_size = LLM_VOCAB_SIZE_MAP["llama2"]

return (
self.gbs
* self.enc_seq_len
* self.layers
* self.hs
* self.hs
* (
12
+ (12 * self.query_groups / self.attention_heads)
+ (18 * self.ffn_hs / self.hs)
+ (12 * self.enc_seq_len / self.hs)
+ (6 * vocab_size / (self.layers * self.hs))
)
)

def _llama3(self):
"""Model FLOPs for llama3 family"""
vocab_size = LLM_VOCAB_SIZE_MAP["llama3"]

return (
self.gbs
* self.enc_seq_len
* self.layers
* self.hs
* self.hs
* (
12
+ (12 * self.query_groups / self.attention_heads)
+ (18 * self.ffn_hs / self.hs)
+ (12 * self.enc_seq_len / self.hs)
+ (6 * vocab_size / (self.layers * self.hs))
)
)

def _nemotron(self):
"""Model FLOPs for nemotron family"""
vocab_size = LLM_VOCAB_SIZE_MAP["nemotron"]

return (
self.gbs
* self.enc_seq_len
* self.layers
* self.hs
* self.hs
* (
12
+ (12 * self.query_groups / self.attention_heads)
+ (12 * self.ffn_hs / self.hs)
+ (12 * self.enc_seq_len / self.hs)
+ (6 * vocab_size / (self.layers * self.hs))
)
)

def _mixtral(self):
"""Model FLOPs for mixtral family"""
vocab_size = LLM_VOCAB_SIZE_MAP["mixtral"]

return (
self.gbs
* self.enc_seq_len
* self.layers
* self.hs
* self.hs
* (
12
+ (12 * self.query_groups / self.attention_heads)
+ (18 * self.moe_router_topk * self.ffn_hs / self.hs)
+ (12 * self.enc_seq_len / self.hs)
+ (6 * vocab_size / (self.layers * self.hs))
)
)

def _bert(self):
"""Model FLOPs for BERT family"""
vocab_size = LLM_VOCAB_SIZE_MAP["bert"]

return (
72
* self.gbs
* self.layers
* self.enc_seq_len
* self.hs
* self.hs
* (1 + (self.enc_seq_len / (6 * self.hs)) + (vocab_size / (12 * self.hs * self.layers)))
)
164 changes: 164 additions & 0 deletions nemo/lightning/pytorch/callbacks/flops_callback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List

import lightning.pytorch as pl
import numpy as np
import torch
from lightning.pytorch.callbacks import Callback

from nemo.lightning.pytorch.callbacks import PEFT
from nemo.utils import flops_formulas, logging

__all__ = ["FLOPsMeasurementCallback"]


class FLOPsMeasurementCallback(Callback):
"""
Calculate and log FLOPs per second after every ``trainer.log_every_n_steps`` steps.
Args:
model_config (GPTConfig): Model parameters.
data_config (pl.LightningDataModule): Data module being used in the experiment.
model_name (str): Name of the model being run. The following models are supported:
gpt3, llama2, llama3, nemotron, mixtral, bert.
"""

higher_is_better = True

def __init__(
self,
model_config: "GPTConfig",
data_config: pl.LightningDataModule,
model_name: str,
):
self.model_cfg = model_config
self.data_cfg = data_config

# use config params only when NOT provided explicitly
self.model = model_name

gbs = self.data_cfg.global_batch_size
enc_seq_len = self.model_cfg.seq_length
hs = self.model_cfg.hidden_size
layers = self.model_cfg.num_layers
ffn_hs = self.model_cfg.ffn_hidden_size
attention_heads = self.model_cfg.num_attention_heads
moe_router_topk = self.model_cfg.moe_router_topk

# this handles both- 1. key is present, value is None; 2. key is absent
query_groups = self.model_cfg.num_query_groups
if query_groups is None:
query_groups = attention_heads

self.flops_config = flops_formulas.FLOPSConfig(
gbs=gbs,
enc_seq_len=enc_seq_len,
hs=hs,
layers=layers,
ffn_hs=ffn_hs,
attention_heads=attention_heads,
moe_router_topk=moe_router_topk,
query_groups=query_groups,
)

self.model = self.model.lower() if self.model is not None else self.model

self.avg_train_step_time = 0

def on_train_start(self, trainer, pl_module):
"""
PyTorch Lightning callback hook. Ensures that user is not using PEFT
as FLOPS callback does not support it.
"""
for callback in trainer.callbacks:
if isinstance(callback, PEFT):
raise NotImplementedError("FLOPs measurement not supported for finetuning jobs")

def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx: int):
"""
PyTorch Lightning callback hook to calculate TFLOPs per sec per GPU after training
"""
try:
self.avg_train_step_time += trainer.progress_bar_metrics['train_step_timing in s']
except KeyError:
print("'train_step_timing in s' not found. Make sure to use TimingCallback with FLOPsMeasurementCallback.")

n = trainer.strategy.current_epoch_step
if n % trainer.log_every_n_steps == 0:
## skip calculation if we haven't accumulated any timing data
if self.avg_train_step_time == 0:
return
tflops_per_sec_per_gpu = self.eval_tflops_per_sec_per_gpu(
self.avg_train_step_time / trainer.log_every_n_steps
)
self.avg_train_step_time = 0
pl_module.log(
"tflops_per_sec_per_gpu",
tflops_per_sec_per_gpu,
on_step=True,
on_epoch=False,
batch_size=1,
prog_bar=True,
)

def eval_tflops_per_sec_per_gpu(self, train_step_time: List | float | int) -> float:
"""
Args:
train_step_time (Any[List, float, int]): Train step time (in seconds).
Step time will be less stable for initial steps (~10 steps)- less
accurate measurement
Use average step time over several steps for higher accuracy
Returns:
(float): Model TFLOPs per sec per gpu
"""
total_flops, flops_per_gpu = self.eval_model_flops()

if not isinstance(train_step_time, list):
train_step_time = [train_step_time]
# efficient mean computation if num train steps is very large
step_time_arr = np.array(train_step_time)
train_step_time = np.mean(step_time_arr[len(step_time_arr) // 2 :])

return flops_per_gpu / (1e12 * train_step_time)

def eval_model_flops(self):
"""
Calculate model FLOPs for a given model
"""

model_flops_map = {
"gpt3": flops_formulas.gpt3,
"llama2": flops_formulas.llama2,
"llama3": flops_formulas.llama3,
"nemotron": flops_formulas.nemotron,
"mixtral": flops_formulas.mixtral,
"bert": flops_formulas.bert,
}

if self.model is not None:
model_matches = [model for model in model_flops_map if model in self.model]
self.model = model_matches[0] if len(model_matches) > 0 else self.model
if self.model not in model_flops_map:
logging.info(f"FLOPs measurement supported for {list(model_flops_map.keys())}")
raise KeyError(f"Failed to extract valid model name from or missing FLOPs calculations for {self.model}")

total_flops = model_flops_map[self.model](self.flops_config)
num_devices = torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1
flops_per_gpu = total_flops / num_devices

return total_flops, flops_per_gpu
16 changes: 15 additions & 1 deletion nemo/lightning/pytorch/callbacks/progress_printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,11 @@ def on_train_start(self, trainer, *_):
## TODO(ashors): handle nan losses
@override
def on_train_batch_end(self, trainer, pl_module, *_, **__):
n = trainer.strategy.current_epoch_step

if self.is_disabled:
return
n = trainer.strategy.current_epoch_step

metrics = self.get_metrics(trainer, pl_module)
for key in metrics:
if key in self.exclude_metrics:
Expand All @@ -138,6 +140,12 @@ def on_train_batch_end(self, trainer, pl_module, *_, **__):
prefix = self.train_description + f" epoch {trainer.current_epoch}, iteration {n-1}/{self.total-1}"
log_string = self.format_string(prefix, self.average_metrics_dict)
print(log_string)
if getattr(trainer.strategy, "timers", None):
timers = trainer.strategy.timers
megatron_log_string = self.log_megatron_timers(timers)

if megatron_log_string:
print(megatron_log_string, flush=True)

self.total_metrics_dict = defaultdict(lambda: 0.0)

Expand Down Expand Up @@ -201,3 +209,9 @@ def on_test_batch_end(

def should_log(self, n):
return n % self.log_interval == 0

def log_megatron_timers(self, timers):
output_string = timers.get_all_timers_string(names=None, normalizer=self.log_interval)
if output_string is not None:
return output_string + "\n"
return None
Loading

0 comments on commit f0bf77b

Please sign in to comment.