Skip to content

Commit

Permalink
Basic distillation running
Browse files Browse the repository at this point in the history
  • Loading branch information
AAnoosheh committed Jan 11, 2025
1 parent bba88bb commit 80cbfd2
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 71 deletions.
14 changes: 7 additions & 7 deletions nemo/lightning/megatron_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def __init__(
convert_module_fn: Optional[Callable[[ModelT], nn.Module]] = None,
) -> None:
from megatron.core import parallel_state
from megatron.core.transformer.module import Float16Module as MCoreFloat16Module
from megatron.core.transformer.module import Float16Module as McoreFloat16Module

from nemo.collections.nlp.modules.common.megatron.module import Float16Module
from nemo.utils.model_utils import unwrap_model
Expand Down Expand Up @@ -216,9 +216,9 @@ def __init__(
self.convert_module_fn = convert_module_fn

# [ModelOpt]: Detect Pipeline-parallel Distillation mode.
self._unwrapped_model = [unwrap_model(self[0].module, (Float16Module, MCoreFloat16Module))]
self._unwrapped_model = unwrap_model(self.module, (DDP, Float16Module, McoreFloat16Module))
if (
hasattr(self._unwrapped_model[0], "teacher_model")
hasattr(self._unwrapped_model, "teacher_model")
and parallel_state.get_pipeline_model_parallel_world_size() > 1
):
self._kd_teacher_in_pp = True
Expand Down Expand Up @@ -317,10 +317,10 @@ def _dummy_reduction(output_tensor, *args, **kwargs):

self.callbacks.event("on_megatron_microbatches_start", step=step)
if self._kd_teacher_in_pp:
with self._unwrapped_model[0].only_teacher_forward():
with self._unwrapped_model[0].swap_teacher_config(self[0].module):
with self._unwrapped_model.only_teacher_forward():
with self._unwrapped_model.swap_teacher_config(self.module):
teacher_step()

Check failure

Code scanning / CodeQL

Potentially uninitialized local variable Error

Local variable 'teacher_step' may be used before it is initialized.
with self._unwrapped_model[0].only_student_forward():
with self._unwrapped_model.only_student_forward():
microbatch_outputs = step()
else:
microbatch_outputs = step()
Expand All @@ -332,7 +332,7 @@ def _dummy_reduction(output_tensor, *args, **kwargs):
)

if isinstance(_loss_reduction, _ModuleStepFunction):
_loss_reduction = _loss_reduction(self[0])
_loss_reduction = _loss_reduction(self.module)

reduced = _loss_reduction.reduce(microbatch_outputs)
self.callbacks.event(
Expand Down
153 changes: 89 additions & 64 deletions scripts/llm/gpt_distillation.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,18 +26,21 @@
from megatron.core.dist_checkpointing.mapping import ShardedStateDict
from megatron.core.models.gpt.gpt_model import GPTModel as MCoreGPTModel
from megatron.core.optimizer import OptimizerConfig
from megatron.core.transformer.module import Float16Module as MCoreFloat16Module
from megatron.core.transformer.transformer_config import TransformerConfig
from torch import Tensor
from torch.nn.modules.loss import _Loss

from nemo import lightning as nl
from nemo.collections import llm
from nemo.collections.llm.gpt.model.base import get_batch_on_this_context_parallel_rank
from nemo.collections.llm.quantization import load_with_modelopt_layer_spec
from nemo.collections.nlp.models.language_modeling.megatron.gpt_layer_modelopt_spec import get_gpt_layer_modelopt_spec
from nemo.collections.nlp.modules.common.megatron.module import Float16Module
from nemo.collections.nlp.modules.common.megatron.utils import average_losses_across_data_parallel_group
from nemo.lightning.megatron_parallel import MaskedTokenLossReduction
from nemo.lightning import io
from nemo.lightning.megatron_parallel import DDP, MaskedTokenLossReduction
from nemo.utils import logging
from nemo.utils.model_utils import unwrap_model


def gpt_distillation_data_step(dataloader_iter, attn_mask_cpu=False) -> Dict[str, torch.Tensor]:
Expand Down Expand Up @@ -115,9 +118,9 @@ def configure_model(self, *args, **kwargs) -> MCoreGPTModel:
class _DistillationLossReduction(MaskedTokenLossReduction):
"""Custom masking and reduction callable used only in training mode."""

def __init__(self, model, *args, **kwargs):
def __init__(self, distillation_loss_fn, *args, **kwargs):
super().__init__(*args, **kwargs)
self._distillation_model: mtd.DistillationModel = model.module
self._distillation_loss_fn = distillation_loss_fn
self._cp_size = parallel_state.get_context_parallel_world_size()

def forward(self, batch: Dict[str, Tensor], forward_out: Tensor) -> Tuple[Tensor, Dict[str, Tensor]]:

Check notice

Code scanning / CodeQL

Unused local variable Note

Variable forward_out is not used.
Expand All @@ -126,8 +129,10 @@ def forward(self, batch: Dict[str, Tensor], forward_out: Tensor) -> Tuple[Tensor
forward_out, batch["loss_mask"] = forward_out

# [ModelOpt]: KD loss calculation.
loss_for_ub = self._distillation_model.compute_kd_loss(
loss_reduction_fn=lambda x: self._masked_token_loss(x, batch["loss_mask"], batch['num_valid_tokens_in_ub'])
loss_for_ub = self._distillation_loss_fn(
loss_reduction_fn=lambda x: self._masked_token_loss(
x, batch["loss_mask"], batch.get("num_valid_tokens_in_ub")
)
)

reduced_loss = average_losses_across_data_parallel_group([loss_for_ub])
Expand Down Expand Up @@ -159,19 +164,6 @@ def _masked_token_loss(self, loss_output: Tensor, mask: Tensor, num_valid_tokens
return loss


class _LoopingCachedDataIterator:
def __init__(self, data):
self.data = data
self.it = iter(self.data)

def __next__(self):
try:
return next(self.it)
except StopIteration:
self.it = iter(self.data)
return next(self.it)


class DistillationGPTModel(llm.GPTModel):
"""Custom GPT subclass for distillation-related modifications."""

Expand All @@ -189,7 +181,10 @@ def data_step(self, dataloader_iter, cache_num_batches: Optional[int] = None) ->
@property
def training_loss_reduction(self) -> _DistillationLossReduction:
if not self._training_loss_reduction:
self._training_loss_reduction = _DistillationLossReduction()
core_module = unwrap_model(self.module, (DDP, Float16Module, MCoreFloat16Module))
self._training_loss_reduction = _DistillationLossReduction(
distillation_loss_fn=core_module.compute_kd_loss
)

return self._training_loss_reduction

Expand Down Expand Up @@ -357,7 +352,7 @@ def all_reduce_autograd(tensor, op=torch.distributed.ReduceOp.SUM, group=torch.d
########################################################


def load_distillation_config(cfg: TransformerConfig) -> Dict[str, Any]:
def load_distillation_config(cfg: DistillationGPTConfig) -> Dict[str, Any]:
"""Create a default distillation config for MCore GPT Models.
Args:
Expand All @@ -370,23 +365,49 @@ def load_distillation_config(cfg: TransformerConfig) -> Dict[str, Any]:
"skip_lm_loss": True,
}
if cfg.pipeline_model_parallel_size == 1 or parallel_state.is_pipeline_last_stage():
distill_cfg["criterion"][tuple(logit_pair)] = LogitsKLLoss(cfg)
distill_cfg["criterion"][logit_pair] = LogitsKLLoss(cfg)

return distill_cfg


def _teacher_provider(cfg: TransformerConfig) -> MCoreGPTModel:
def _teacher_provider(cfg: DistillationGPTConfig) -> MCoreGPTModel:
"""Teacher model factory (must be a non-local function to pickle)."""

logging.info("Distillation: Loading teacher weights...")
model = load_with_modelopt_layer_spec(
cfg.kd_teacher_restore_from_path,
strategy = nl.MegatronStrategy(
tensor_model_parallel_size=cfg.tensor_model_parallel_size,
context_parallel_size=cfg.context_parallel_size,
pipeline_model_parallel_size=cfg.pipeline_model_parallel_size,
inference_only=True,
ckpt_load_optimizer=False,
ckpt_parallel_save_optim=False,
setup_optimizers=False,
ddp="pytorch",
)
trainer = nl.Trainer(
devices=cfg.tensor_model_parallel_size,
num_nodes=cfg.context_parallel_size * cfg.pipeline_model_parallel_size,
strategy=strategy,
plugins=nl.MegatronMixedPrecision(precision="bf16-mixed"),
)

model, _ = io.ModelConnector().nemo_load(cfg.kd_teacher_restore_from_path, trainer, cpu=False)
model = unwrap_model(model.module, (DDP, Float16Module, MCoreFloat16Module))

logging.info("Distillation: ... teacher weights loaded.")
return model.module
return model


class _LoopingCachedDataIterator:
def __init__(self, data):
self.data = data
self.it = iter(self.data)

def __next__(self):
try:
return next(self.it)
except StopIteration:
self.it = iter(self.data)
return next(self.it)


def adjust_distillation_model_for_mcore(
Expand Down Expand Up @@ -458,6 +479,8 @@ def _swap_teacher_config(self, model_wrapper):
if __name__ == "__main__":
logging.info("Distillation enabled.")

TEACHER_PATH = "./test_teacher/"

seq_length = 2048
global_batch_size = 16
tp = 1
Expand All @@ -466,58 +489,68 @@ def _swap_teacher_config(self, model_wrapper):
# TODO: setup the dummy dataset
data = llm.MockDataModule(seq_length=seq_length, global_batch_size=global_batch_size)

TEACHER_PATH = "./test_teacher/"
#
import os
import sys
## initialize the strategy
strategy = nl.MegatronStrategy(
tensor_model_parallel_size=tp,
pipeline_model_parallel_size=pp,
)
trainer = nl.Trainer(
devices=1, ## you can change the number of devices to suit your setup
max_steps=50,
accelerator="gpu",
strategy=strategy,
plugins=nl.MegatronMixedPrecision(precision="bf16-mixed"),
)

from megatron.core import dist_checkpointing
common_model_kwargs = dict(
seq_length=seq_length,
init_method_std=0.023,
hidden_dropout=0.0,
attention_dropout=0.0,
layernorm_epsilon=1e-5,
make_vocab_size_divisible_by=128,
transformer_layer_spec=get_gpt_layer_modelopt_spec(),
)

from nemo.lightning.io.pl import ckpt_to_weights_subdir
############# TEACHER HACK #############
import os
import sys

if not os.path.exists(TEACHER_PATH):
from lightning.pytorch.trainer.states import TrainerFn

gpt_config = llm.GPTConfig(
num_layers=9,
hidden_size=384,
ffn_hidden_size=1536,
num_attention_heads=6,
seq_length=seq_length,
init_method_std=0.023,
hidden_dropout=0.0,
attention_dropout=0.0,
layernorm_epsilon=1e-5,
make_vocab_size_divisible_by=128,
transformer_layer_spec=get_gpt_layer_modelopt_spec(),
**common_model_kwargs,
)
model = llm.GPTModel(gpt_config, tokenizer=data.tokenizer)
dist_checkpointing.save(model.sharded_state_dict(), str(ckpt_to_weights_subdir(TEACHER_PATH, is_saving=True)))

strategy.ckpt_save_optimizer = False # otherwise need to do `model._trainer = trainer`
trainer.state.fn = TrainerFn.FITTING # needed for proper save.
trainer.strategy.connect(model)
trainer.strategy.setup_environment()
with trainer.init_module():
model.configure_model()

io.ModelConnector().nemo_save(TEACHER_PATH, trainer)

sys.exit(0)
#
##########################################

## initialize a small GPT model
gpt_config = DistillationGPTConfig(
num_layers=6,
hidden_size=384,
ffn_hidden_size=1536,
num_attention_heads=6,
seq_length=seq_length,
init_method_std=0.023,
hidden_dropout=0.0,
attention_dropout=0.0,
layernorm_epsilon=1e-5,
make_vocab_size_divisible_by=128,
transformer_layer_spec=get_gpt_layer_modelopt_spec(),
**common_model_kwargs,
kd_teacher_restore_from_path=TEACHER_PATH,
)
model = DistillationGPTModel(gpt_config, tokenizer=data.tokenizer)

## initialize the strategy
strategy = nl.MegatronStrategy(
tensor_model_parallel_size=tp,
pipeline_model_parallel_size=pp,
pipeline_dtype=torch.bfloat16,
)

## setup the optimizer
opt_config = OptimizerConfig(
optimizer='adam',
Expand All @@ -526,14 +559,6 @@ def _swap_teacher_config(self, model_wrapper):
)
opt = nl.MegatronOptimizerModule(config=opt_config)

trainer = nl.Trainer(
devices=1, ## you can change the number of devices to suit your setup
max_steps=50,
accelerator="gpu",
strategy=strategy,
plugins=nl.MegatronMixedPrecision(precision="bf16-mixed"),
)

nemo_logger = nl.NeMoLogger(
log_dir="test_logdir", ## logs and checkpoints will be written here
)
Expand Down

0 comments on commit 80cbfd2

Please sign in to comment.