|
| 1 | +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. |
| 2 | + |
| 3 | +"""Calibrate a GPT model for FP8 scaling factors.""" |
| 4 | +import os |
| 5 | +import sys |
| 6 | + |
| 7 | +sys.path.append( |
| 8 | + os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir, os.path.pardir)) |
| 9 | +) |
| 10 | +import math |
| 11 | + |
| 12 | +import torch |
| 13 | +import transformer_engine.pytorch as te |
| 14 | + |
| 15 | +from megatron.core import parallel_state, tensor_parallel |
| 16 | +from megatron.core.models.gpt import GPTModel |
| 17 | +from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec |
| 18 | +from megatron.core.pipeline_parallel.p2p_communication import recv_forward, send_forward |
| 19 | +from megatron.core.transformer.spec_utils import import_module |
| 20 | +from megatron.training import get_args, get_model, is_last_rank, print_rank_0 |
| 21 | +from megatron.training.arguments import core_transformer_config_from_args |
| 22 | +from megatron.training.checkpointing import load_checkpoint |
| 23 | +from megatron.training.initialize import initialize_megatron |
| 24 | +from megatron.training.training import save_checkpoint_and_time |
| 25 | +from megatron.training.utils import unwrap_model |
| 26 | +from megatron.training.yaml_arguments import core_transformer_config_from_yaml |
| 27 | +from tasks.finetune_utils import build_data_loader |
| 28 | +from tasks.zeroshot_gpt.datasets import build_dataset |
| 29 | +from tasks.zeroshot_gpt.evaluate import process_batch |
| 30 | + |
| 31 | + |
| 32 | +def model_provider(pre_process=True, post_process=True) -> GPTModel: |
| 33 | + """Builds the model. |
| 34 | +
|
| 35 | + Args: |
| 36 | + pre_process (bool, optional): Set to true if you need to compute embeddings. Defaults to True. |
| 37 | + post_process (bool, optional): Set to true if you need to want to compute output logits/loss. Defaults to True. |
| 38 | +
|
| 39 | + Returns: |
| 40 | + GPTModel: The returned model. Only works for Transformer Engine implementations. |
| 41 | + """ |
| 42 | + |
| 43 | + args = get_args() |
| 44 | + |
| 45 | + print_rank_0('building GPT model ...') |
| 46 | + |
| 47 | + # Experimental loading arguments from yaml |
| 48 | + if args.yaml_cfg is not None: |
| 49 | + config = core_transformer_config_from_yaml(args, "language_model") |
| 50 | + else: |
| 51 | + config = core_transformer_config_from_args(args) |
| 52 | + |
| 53 | + if args.use_legacy_models or args.transformer_impl != "transformer_engine": |
| 54 | + raise NotImplementedError( |
| 55 | + 'Calibration is only supported for models using TransformerEngine.' |
| 56 | + ) |
| 57 | + else: |
| 58 | + if args.spec is not None: |
| 59 | + transformer_layer_spec = import_module(args.spec) |
| 60 | + else: |
| 61 | + transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec( |
| 62 | + args.num_experts, args.moe_grouped_gemm |
| 63 | + ) |
| 64 | + model = GPTModel( |
| 65 | + config=config, |
| 66 | + transformer_layer_spec=transformer_layer_spec, |
| 67 | + vocab_size=args.padded_vocab_size, |
| 68 | + max_sequence_length=args.max_position_embeddings, |
| 69 | + pre_process=pre_process, |
| 70 | + post_process=post_process, |
| 71 | + fp16_lm_cross_entropy=args.fp16_lm_cross_entropy, |
| 72 | + parallel_output=True, |
| 73 | + share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights, |
| 74 | + position_embedding_type=args.position_embedding_type, |
| 75 | + rotary_percent=args.rotary_percent |
| 76 | + ) |
| 77 | + |
| 78 | + return model |
| 79 | + |
| 80 | + |
| 81 | +def forward_step(batch, model, config): |
| 82 | + """Forward step.""" |
| 83 | + |
| 84 | + # Get the batch. |
| 85 | + tokens, labels, attention_mask, position_ids, loss_mask = process_batch(batch) |
| 86 | + |
| 87 | + args = get_args() |
| 88 | + args.micro_batch_size = len(labels) |
| 89 | + |
| 90 | + tensor_shape = (args.seq_length, args.micro_batch_size, args.hidden_size) |
| 91 | + input_tensor = recv_forward(tensor_shape, config) |
| 92 | + |
| 93 | + # Forward pass through the model. |
| 94 | + unwrapped_model = unwrap_model(model) |
| 95 | + unwrapped_model.set_input_tensor(input_tensor) |
| 96 | + output = model(tokens, position_ids, attention_mask) |
| 97 | + |
| 98 | + send_forward(output, config) |
| 99 | + |
| 100 | + if parallel_state.is_pipeline_last_stage(): |
| 101 | + losses = tensor_parallel.vocab_parallel_cross_entropy( |
| 102 | + output.contiguous().float(), labels.contiguous() |
| 103 | + ) |
| 104 | + loss = torch.sum(losses.view(-1) * loss_mask.contiguous().view(-1).float()) |
| 105 | + return loss |
| 106 | + |
| 107 | + return None |
| 108 | + |
| 109 | + |
| 110 | +def calibrate(data_loader, model): |
| 111 | + args = get_args() |
| 112 | + config = core_transformer_config_from_args(args) |
| 113 | + |
| 114 | + # Turn on evaluation mode which disables dropout. |
| 115 | + model.eval() |
| 116 | + |
| 117 | + total_output = 0.0 |
| 118 | + num_examples = min(len(data_loader), args.calib_size) |
| 119 | + data_loader = iter(data_loader) |
| 120 | + |
| 121 | + with torch.no_grad(): |
| 122 | + iteration = 0 |
| 123 | + while iteration < num_examples - 1: |
| 124 | + batch = next(data_loader) |
| 125 | + if iteration % args.log_interval == 0: |
| 126 | + print_rank_0('> working on iteration: {}'.format(iteration)) |
| 127 | + with te.fp8_autocast(enabled=False, calibrating=True), torch.autocast( |
| 128 | + device_type='cuda', dtype=torch.bfloat16 |
| 129 | + ): |
| 130 | + output = forward_step(batch, model, config) |
| 131 | + |
| 132 | + # Reduce across processes. |
| 133 | + if parallel_state.is_pipeline_last_stage(): |
| 134 | + torch.distributed.all_reduce( |
| 135 | + output, group=parallel_state.get_data_parallel_group() |
| 136 | + ) |
| 137 | + |
| 138 | + total_output += output |
| 139 | + iteration += 1 |
| 140 | + |
| 141 | + print_rank_0(f"Compute scaling factors with FP8 autocast ...") |
| 142 | + with te.fp8_autocast(enabled=True), torch.autocast( |
| 143 | + device_type='cuda', dtype=torch.bfloat16 |
| 144 | + ): |
| 145 | + forward_step(batch, model, config) |
| 146 | + |
| 147 | + if parallel_state.is_pipeline_last_stage(): |
| 148 | + torch.distributed.all_reduce(output, group=parallel_state.get_data_parallel_group()) |
| 149 | + |
| 150 | + total_output += output |
| 151 | + |
| 152 | + print_rank_0(f"Saving calibrated checkpoint ...") |
| 153 | + save_checkpoint_and_time( |
| 154 | + iteration, |
| 155 | + [model], |
| 156 | + optimizer=None, |
| 157 | + opt_param_scheduler=None, |
| 158 | + num_floating_point_operations_so_far=0, |
| 159 | + checkpointing_context=None, |
| 160 | + ) |
| 161 | + |
| 162 | + return total_output |
| 163 | + |
| 164 | + |
| 165 | +def calibrate_and_print_results(task, data_loader, model): |
| 166 | + """Calibrate and print results on screen.""" |
| 167 | + |
| 168 | + # Calibrate and save scaling factors |
| 169 | + output = calibrate(data_loader, model) |
| 170 | + |
| 171 | + string = ' validation results on {} | '.format(task) |
| 172 | + if is_last_rank(): |
| 173 | + num_tokenized_tokens = data_loader.dataset.num_tokenized_tokens |
| 174 | + num_original_tokens = data_loader.dataset.num_original_tokens |
| 175 | + val_loss = output / (num_tokenized_tokens - 1) |
| 176 | + ppl = math.exp(min(20, val_loss)) |
| 177 | + token_ratio = (num_tokenized_tokens - 1) / (num_original_tokens - 1) |
| 178 | + adjusted_ppl = math.exp(min(20, val_loss * token_ratio)) |
| 179 | + string += 'avg loss: {:.4E} | '.format(val_loss) |
| 180 | + string += 'ppl: {:.4E} | '.format(ppl) |
| 181 | + string += 'adjusted ppl: {:.4E} | '.format(adjusted_ppl) |
| 182 | + string += 'token ratio: {} |'.format(token_ratio) |
| 183 | + |
| 184 | + length = len(string) + 1 |
| 185 | + print('-' * length) |
| 186 | + print(string) |
| 187 | + print('-' * length) |
| 188 | + |
| 189 | + |
| 190 | +def add_calib_args(parser): |
| 191 | + group = parser.add_argument_group(title='calibration') |
| 192 | + group.add_argument("--task", type=str, help="Calibration task to run. Defaults to WIKITEXT103.") |
| 193 | + group.add_argument('--valid-data', nargs='*', default=None, help='Calibration dataset') |
| 194 | + group.add_argument( |
| 195 | + '--overlapping-eval', |
| 196 | + type=int, |
| 197 | + default=32, # Required for reusing _build_wikitext103_dataset() |
| 198 | + help='Sliding window for overlapping evaluation.', |
| 199 | + ) |
| 200 | + group.add_argument( |
| 201 | + "--calib-size", type=int, default=512, help="Number of samples to use for calibration." |
| 202 | + ) |
| 203 | + return parser |
| 204 | + |
| 205 | + |
| 206 | +if __name__ == "__main__": |
| 207 | + initialize_megatron( |
| 208 | + extra_args_provider=add_calib_args, |
| 209 | + args_defaults={ |
| 210 | + 'tokenizer_type': 'GPT2BPETokenizer', |
| 211 | + 'no_load_rng': True, |
| 212 | + 'no_load_optim': True, |
| 213 | + }, |
| 214 | + ) |
| 215 | + |
| 216 | + args = get_args() |
| 217 | + |
| 218 | + if args.num_layers_per_virtual_pipeline_stage is not None: |
| 219 | + print("Interleaved pipeline schedule is not yet supported for calibration.") |
| 220 | + exit() |
| 221 | + |
| 222 | + # Set up model and load checkpoint. |
| 223 | + model = get_model(model_provider, wrap_with_ddp=False) |
| 224 | + if args.load is not None: |
| 225 | + _ = load_checkpoint(model, None, None) |
| 226 | + |
| 227 | + assert len(model) == 1, "Above condition should have caught this" |
| 228 | + model = model[0] |
| 229 | + |
| 230 | + # Setup data loader. |
| 231 | + dataset = build_dataset(args.task) |
| 232 | + dataloader = build_data_loader( |
| 233 | + dataset, args.micro_batch_size, args.num_workers, drop_last=False |
| 234 | + ) |
| 235 | + |
| 236 | + # Run calibration. |
| 237 | + calibrate_and_print_results(args.task, dataloader, model) |
| 238 | + |
| 239 | + print_rank_0('Calibration successfully completed.') |
0 commit comments