Skip to content

Commit a0c5869

Browse files
mathemakittenjaredcasper
authored andcommitted
ADLR/megatron-lm!1841 - Calibration, weight initialization, and inference in FP8
1 parent 203b463 commit a0c5869

File tree

3 files changed

+249
-2
lines changed

3 files changed

+249
-2
lines changed

tasks/finetune_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
import sys
77
import torch
88

9-
from megatron.training import get_args, get_num_microbatches
9+
from megatron.training import get_args
10+
from megatron.core.num_microbatches_calculator import get_num_microbatches
1011
from megatron.training import print_rank_0
1112
from megatron.training import get_timers
1213
from megatron.core import mpu

tasks/quantize/calibrate_gpt.py

Lines changed: 239 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,239 @@
1+
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2+
3+
"""Calibrate a GPT model for FP8 scaling factors."""
4+
import os
5+
import sys
6+
7+
sys.path.append(
8+
os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir, os.path.pardir))
9+
)
10+
import math
11+
12+
import torch
13+
import transformer_engine.pytorch as te
14+
15+
from megatron.core import parallel_state, tensor_parallel
16+
from megatron.core.models.gpt import GPTModel
17+
from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec
18+
from megatron.core.pipeline_parallel.p2p_communication import recv_forward, send_forward
19+
from megatron.core.transformer.spec_utils import import_module
20+
from megatron.training import get_args, get_model, is_last_rank, print_rank_0
21+
from megatron.training.arguments import core_transformer_config_from_args
22+
from megatron.training.checkpointing import load_checkpoint
23+
from megatron.training.initialize import initialize_megatron
24+
from megatron.training.training import save_checkpoint_and_time
25+
from megatron.training.utils import unwrap_model
26+
from megatron.training.yaml_arguments import core_transformer_config_from_yaml
27+
from tasks.finetune_utils import build_data_loader
28+
from tasks.zeroshot_gpt.datasets import build_dataset
29+
from tasks.zeroshot_gpt.evaluate import process_batch
30+
31+
32+
def model_provider(pre_process=True, post_process=True) -> GPTModel:
33+
"""Builds the model.
34+
35+
Args:
36+
pre_process (bool, optional): Set to true if you need to compute embeddings. Defaults to True.
37+
post_process (bool, optional): Set to true if you need to want to compute output logits/loss. Defaults to True.
38+
39+
Returns:
40+
GPTModel: The returned model. Only works for Transformer Engine implementations.
41+
"""
42+
43+
args = get_args()
44+
45+
print_rank_0('building GPT model ...')
46+
47+
# Experimental loading arguments from yaml
48+
if args.yaml_cfg is not None:
49+
config = core_transformer_config_from_yaml(args, "language_model")
50+
else:
51+
config = core_transformer_config_from_args(args)
52+
53+
if args.use_legacy_models or args.transformer_impl != "transformer_engine":
54+
raise NotImplementedError(
55+
'Calibration is only supported for models using TransformerEngine.'
56+
)
57+
else:
58+
if args.spec is not None:
59+
transformer_layer_spec = import_module(args.spec)
60+
else:
61+
transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec(
62+
args.num_experts, args.moe_grouped_gemm
63+
)
64+
model = GPTModel(
65+
config=config,
66+
transformer_layer_spec=transformer_layer_spec,
67+
vocab_size=args.padded_vocab_size,
68+
max_sequence_length=args.max_position_embeddings,
69+
pre_process=pre_process,
70+
post_process=post_process,
71+
fp16_lm_cross_entropy=args.fp16_lm_cross_entropy,
72+
parallel_output=True,
73+
share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights,
74+
position_embedding_type=args.position_embedding_type,
75+
rotary_percent=args.rotary_percent
76+
)
77+
78+
return model
79+
80+
81+
def forward_step(batch, model, config):
82+
"""Forward step."""
83+
84+
# Get the batch.
85+
tokens, labels, attention_mask, position_ids, loss_mask = process_batch(batch)
86+
87+
args = get_args()
88+
args.micro_batch_size = len(labels)
89+
90+
tensor_shape = (args.seq_length, args.micro_batch_size, args.hidden_size)
91+
input_tensor = recv_forward(tensor_shape, config)
92+
93+
# Forward pass through the model.
94+
unwrapped_model = unwrap_model(model)
95+
unwrapped_model.set_input_tensor(input_tensor)
96+
output = model(tokens, position_ids, attention_mask)
97+
98+
send_forward(output, config)
99+
100+
if parallel_state.is_pipeline_last_stage():
101+
losses = tensor_parallel.vocab_parallel_cross_entropy(
102+
output.contiguous().float(), labels.contiguous()
103+
)
104+
loss = torch.sum(losses.view(-1) * loss_mask.contiguous().view(-1).float())
105+
return loss
106+
107+
return None
108+
109+
110+
def calibrate(data_loader, model):
111+
args = get_args()
112+
config = core_transformer_config_from_args(args)
113+
114+
# Turn on evaluation mode which disables dropout.
115+
model.eval()
116+
117+
total_output = 0.0
118+
num_examples = min(len(data_loader), args.calib_size)
119+
data_loader = iter(data_loader)
120+
121+
with torch.no_grad():
122+
iteration = 0
123+
while iteration < num_examples - 1:
124+
batch = next(data_loader)
125+
if iteration % args.log_interval == 0:
126+
print_rank_0('> working on iteration: {}'.format(iteration))
127+
with te.fp8_autocast(enabled=False, calibrating=True), torch.autocast(
128+
device_type='cuda', dtype=torch.bfloat16
129+
):
130+
output = forward_step(batch, model, config)
131+
132+
# Reduce across processes.
133+
if parallel_state.is_pipeline_last_stage():
134+
torch.distributed.all_reduce(
135+
output, group=parallel_state.get_data_parallel_group()
136+
)
137+
138+
total_output += output
139+
iteration += 1
140+
141+
print_rank_0(f"Compute scaling factors with FP8 autocast ...")
142+
with te.fp8_autocast(enabled=True), torch.autocast(
143+
device_type='cuda', dtype=torch.bfloat16
144+
):
145+
forward_step(batch, model, config)
146+
147+
if parallel_state.is_pipeline_last_stage():
148+
torch.distributed.all_reduce(output, group=parallel_state.get_data_parallel_group())
149+
150+
total_output += output
151+
152+
print_rank_0(f"Saving calibrated checkpoint ...")
153+
save_checkpoint_and_time(
154+
iteration,
155+
[model],
156+
optimizer=None,
157+
opt_param_scheduler=None,
158+
num_floating_point_operations_so_far=0,
159+
checkpointing_context=None,
160+
)
161+
162+
return total_output
163+
164+
165+
def calibrate_and_print_results(task, data_loader, model):
166+
"""Calibrate and print results on screen."""
167+
168+
# Calibrate and save scaling factors
169+
output = calibrate(data_loader, model)
170+
171+
string = ' validation results on {} | '.format(task)
172+
if is_last_rank():
173+
num_tokenized_tokens = data_loader.dataset.num_tokenized_tokens
174+
num_original_tokens = data_loader.dataset.num_original_tokens
175+
val_loss = output / (num_tokenized_tokens - 1)
176+
ppl = math.exp(min(20, val_loss))
177+
token_ratio = (num_tokenized_tokens - 1) / (num_original_tokens - 1)
178+
adjusted_ppl = math.exp(min(20, val_loss * token_ratio))
179+
string += 'avg loss: {:.4E} | '.format(val_loss)
180+
string += 'ppl: {:.4E} | '.format(ppl)
181+
string += 'adjusted ppl: {:.4E} | '.format(adjusted_ppl)
182+
string += 'token ratio: {} |'.format(token_ratio)
183+
184+
length = len(string) + 1
185+
print('-' * length)
186+
print(string)
187+
print('-' * length)
188+
189+
190+
def add_calib_args(parser):
191+
group = parser.add_argument_group(title='calibration')
192+
group.add_argument("--task", type=str, help="Calibration task to run. Defaults to WIKITEXT103.")
193+
group.add_argument('--valid-data', nargs='*', default=None, help='Calibration dataset')
194+
group.add_argument(
195+
'--overlapping-eval',
196+
type=int,
197+
default=32, # Required for reusing _build_wikitext103_dataset()
198+
help='Sliding window for overlapping evaluation.',
199+
)
200+
group.add_argument(
201+
"--calib-size", type=int, default=512, help="Number of samples to use for calibration."
202+
)
203+
return parser
204+
205+
206+
if __name__ == "__main__":
207+
initialize_megatron(
208+
extra_args_provider=add_calib_args,
209+
args_defaults={
210+
'tokenizer_type': 'GPT2BPETokenizer',
211+
'no_load_rng': True,
212+
'no_load_optim': True,
213+
},
214+
)
215+
216+
args = get_args()
217+
218+
if args.num_layers_per_virtual_pipeline_stage is not None:
219+
print("Interleaved pipeline schedule is not yet supported for calibration.")
220+
exit()
221+
222+
# Set up model and load checkpoint.
223+
model = get_model(model_provider, wrap_with_ddp=False)
224+
if args.load is not None:
225+
_ = load_checkpoint(model, None, None)
226+
227+
assert len(model) == 1, "Above condition should have caught this"
228+
model = model[0]
229+
230+
# Setup data loader.
231+
dataset = build_dataset(args.task)
232+
dataloader = build_data_loader(
233+
dataset, args.micro_batch_size, args.num_workers, drop_last=False
234+
)
235+
236+
# Run calibration.
237+
calibrate_and_print_results(args.task, dataloader, model)
238+
239+
print_rank_0('Calibration successfully completed.')

tools/run_text_generation_server.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
get_gpt_layer_with_transformer_engine_spec,
2424
)
2525

26+
from contextlib import nullcontext
2627
import torch
2728
from typing import Union
2829
import megatron
@@ -106,8 +107,14 @@ def add_text_generate_args(parser):
106107
print_rank_0("WARNING: Forcing exit_on_missing_checkpoint to True for text "
107108
"generation.")
108109
args.exit_on_missing_checkpoint = True
110+
109111
# Set up model and load checkpoint
110-
model = get_model(model_provider, wrap_with_ddp=False)
112+
load_context = nullcontext()
113+
if args.fp8:
114+
from transformer_engine.pytorch.fp8 import fp8_model_init
115+
load_context = fp8_model_init()
116+
with load_context:
117+
model = get_model(model_provider, wrap_with_ddp=False)
111118

112119
if args.load is not None:
113120
_ = load_checkpoint(model, None, None)

0 commit comments

Comments
 (0)