Skip to content

Commit

Permalink
Enable cudagraphs for Peft and Pretraining (#11290)
Browse files Browse the repository at this point in the history
* Rebase

Signed-off-by: Jimmy Zhang <[email protected]>

* packed seqlen padding

Signed-off-by: Jimmy Zhang <[email protected]>

* Apply isort and black reformatting

Signed-off-by: JimmyZhang12 <[email protected]>

* Update sequence_packing_utils.py

Signed-off-by: JimmyZhang12 <[email protected]>

* fix ci

Signed-off-by: Jimmy Zhang <[email protected]>

* assert pad_to_max_length

Signed-off-by: Jimmy Zhang <[email protected]>

* Apply isort and black reformatting

Signed-off-by: malay-nagda <[email protected]>

* Apply suggestions from code review

addressing comments

Co-authored-by: Chen Cui <[email protected]>
Signed-off-by: malay-nagda <[email protected]>

* remove pad_cu_seqlens param

Signed-off-by: Malay Nagda <[email protected]>

* fix chat ci test

Signed-off-by: Jimmy Zhang <[email protected]>

* Apply isort and black reformatting

Signed-off-by: jiemingz <[email protected]>

---------

Signed-off-by: Jimmy Zhang <[email protected]>
Signed-off-by: JimmyZhang12 <[email protected]>
Signed-off-by: JimmyZhang12 <[email protected]>
Signed-off-by: malay-nagda <[email protected]>
Signed-off-by: malay-nagda <[email protected]>
Signed-off-by: Malay Nagda <[email protected]>
Signed-off-by: jiemingz <[email protected]>
Co-authored-by: root <[email protected]>
Co-authored-by: Jimmy Zhang <[email protected]>
Co-authored-by: JimmyZhang12 <[email protected]>
Co-authored-by: malay-nagda <[email protected]>
Co-authored-by: malay-nagda <[email protected]>
Co-authored-by: Chen Cui <[email protected]>
Co-authored-by: Malay Nagda <[email protected]>
Co-authored-by: Jimmy Zhang <[email protected]>
Co-authored-by: jiemingz <[email protected]>
  • Loading branch information
10 people authored Dec 24, 2024
1 parent 8c787fc commit fe4f39a
Show file tree
Hide file tree
Showing 11 changed files with 167 additions and 34 deletions.
4 changes: 3 additions & 1 deletion nemo/collections/llm/gpt/data/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class ChatDataModule(FineTuningDataModule):
"""

@lru_cache
def _create_dataset(self, path, is_test=False, **kwargs):
def _create_dataset(self, path, pack_metadata_path=None, is_test=False, **kwargs):
# pylint: disable=C0115,C0116
return create_sft_dataset(
path,
Expand All @@ -37,5 +37,7 @@ def _create_dataset(self, path, is_test=False, **kwargs):
seed=self.seed,
chat=True,
is_test=is_test,
pack_metadata_file_path=None, # packing is not supported
pad_cu_seqlens=False,
**kwargs,
)
61 changes: 37 additions & 24 deletions nemo/collections/llm/gpt/data/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,42 +47,55 @@ def create_sft_dataset(
memmap_workers: int = 2,
hf_dataset: bool = False,
global_sample_mapping: bool = False,
pack_metadata_file_path: Path = None,
pad_cu_seqlens: bool = False,
chat: bool = False,
**kwargs,
) -> "GPTSFTDataset":
"""
Create the dataset class (GPTSFTDataset, GPTSFTChatDataset or GPTSFTPackedDataset)
"""

gpt_sft_dataset_kwargs = {
'file_path': str(path),
'tokenizer': tokenizer,
'max_seq_length': seq_length,
'memmap_workers': memmap_workers,
'hf_dataset': hf_dataset,
'global_sample_mapping': global_sample_mapping,
'add_bos': add_bos,
'add_eos': add_eos,
'add_sep': add_sep,
'seed': seed,
'label_key': label_key,
'answer_only_loss': answer_only_loss,
'truncation_field': truncation_field,
'pad_to_max_length': pad_to_max_length,
'index_mapping_dir': index_mapping_dir,
'prompt_template': prompt_template,
'truncation_method': truncation_method,
}

if chat:
from nemo.collections.nlp.data.language_modeling.megatron.gpt_sft_chat_dataset import GPTSFTChatDataset

dataset_cls = GPTSFTChatDataset
return GPTSFTChatDataset(
**gpt_sft_dataset_kwargs,
**kwargs,
)
elif path.suffix == '.npy':
from nemo.collections.nlp.data.language_modeling.megatron.gpt_sft_dataset import GPTSFTPackedDataset

dataset_cls = GPTSFTPackedDataset
return GPTSFTPackedDataset(
pack_metadata_file_path=pack_metadata_file_path,
pad_cu_seqlens=pad_cu_seqlens,
**gpt_sft_dataset_kwargs,
**kwargs,
)
else:
from nemo.collections.nlp.data.language_modeling.megatron.gpt_sft_dataset import GPTSFTDataset

dataset_cls = GPTSFTDataset

return dataset_cls(
file_path=str(path),
tokenizer=tokenizer,
max_seq_length=seq_length,
memmap_workers=memmap_workers,
hf_dataset=hf_dataset,
global_sample_mapping=global_sample_mapping,
add_bos=add_bos,
add_eos=add_eos,
add_sep=add_sep,
seed=seed,
label_key=label_key,
answer_only_loss=answer_only_loss,
truncation_field=truncation_field,
pad_to_max_length=pad_to_max_length,
index_mapping_dir=index_mapping_dir,
prompt_template=prompt_template,
truncation_method=truncation_method,
**kwargs,
)
return GPTSFTDataset(
**gpt_sft_dataset_kwargs,
**kwargs,
)
48 changes: 46 additions & 2 deletions nemo/collections/llm/gpt/data/fine_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def __init__(
self.packed_sequence_size = -1 if not packed_sequence_specs else packed_sequence_specs.packed_sequence_size
self.validate_batch_size_for_packed_sequence()
self.dataset_kwargs = dataset_kwargs or {}
self._pad_cu_seqlens = False if not packed_sequence_specs else packed_sequence_specs.pad_cu_seqlens
self.init_global_step = 0

def validate_batch_size_for_packed_sequence(self):
Expand Down Expand Up @@ -128,6 +129,7 @@ def prepare_data(self) -> None:
tokenizer=self.tokenizer,
max_seq_length=self.seq_length,
seed=self.seed,
output_metadata_path=self.train_pack_metadata,
)

if not self.validation_path_packed.is_file():
Expand All @@ -138,6 +140,7 @@ def prepare_data(self) -> None:
tokenizer=self.tokenizer,
max_seq_length=self.seq_length,
seed=self.seed,
output_metadata_path=self.val_pack_metadata,
)

def setup(self, stage: str):
Expand Down Expand Up @@ -194,6 +197,7 @@ def train_dataloader(self) -> DataLoader:
return self._create_dataloader(
self._create_dataset(
self.train_path if self.packed_sequence_size <= 0 else self.train_path_packed,
pack_metadata_path=None if self.packed_sequence_size <= 0 else self.train_pack_metadata,
max_num_samples=self.max_train_samples,
**self.dataset_kwargs,
),
Expand All @@ -205,6 +209,7 @@ def val_dataloader(self) -> DataLoader:
return self._create_dataloader(
self._create_dataset(
self.validation_path if self.packed_sequence_size <= 0 else self.validation_path_packed,
pack_metadata_path=None if self.packed_sequence_size <= 0 else self.val_pack_metadata,
is_test=True,
**self.dataset_kwargs,
),
Expand All @@ -224,15 +229,18 @@ def test_dataloader(self) -> DataLoader:
)

@lru_cache
def _create_dataset(self, path, is_test=False, **kwargs):
def _create_dataset(self, path, pack_metadata_path=None, is_test=False, **kwargs):
# pylint: disable=C0115,C0116
is_not_packing = is_test or self.packed_sequence_size <= 0
return create_sft_dataset(
path,
tokenizer=self.tokenizer,
seq_length=(self.seq_length if is_test or self.packed_sequence_size <= 0 else self.packed_sequence_size),
seq_length=(self.seq_length if is_not_packing else self.packed_sequence_size),
memmap_workers=self.memmap_workers,
seed=self.seed,
is_test=is_test,
pack_metadata_file_path=None if is_not_packing else pack_metadata_path,
pad_cu_seqlens=False if is_not_packing else self.pad_cu_seqlens,
**kwargs,
)

Expand All @@ -255,6 +263,32 @@ def train_path(self) -> Path:
"""Path to training dataset file"""
return self.dataset_root / "training.jsonl"

@property
def train_pack_metadata(self) -> Path:
"""Path to metadata dataset file for packed sequence."""
if self.packed_sequence_size > 0:
if self.packed_sequence_specs.packed_train_metadata_path is not None:
return self.packed_sequence_specs.packed_train_metadata_path
tokenizer_model_name = self._extract_tokenizer_model_name()
folder_name = self.dataset_root / "packed" / tokenizer_model_name
folder_name.mkdir(parents=True, exist_ok=True)
return folder_name / f"train_{self.packed_sequence_size}_metadata.jsonl"
else:
raise ValueError("`train_pack_metadata invalid since packed sequence size is not specified.")

@property
def val_pack_metadata(self) -> Path:
"""Path to metadata dataset file for packed sequence."""
if self.packed_sequence_size > 0:
if self.packed_sequence_specs.packed_val_metadata_path is not None:
return self.packed_sequence_specs.packed_val_metadata_path
tokenizer_model_name = self._extract_tokenizer_model_name()
folder_name = self.dataset_root / "packed" / tokenizer_model_name
folder_name.mkdir(parents=True, exist_ok=True)
return folder_name / f"val_{self.packed_sequence_size}_metadata.jsonl"
else:
raise ValueError("val_pack_metadata invalid since packed sequence size is not specified.")

@property
def train_path_packed(self) -> Path:
"""Path to training dataset file for packed sequence. The file path contains a reference to the
Expand Down Expand Up @@ -293,6 +327,16 @@ def test_path(self) -> Path:
"""Path to test dataset file"""
return self.dataset_root / "test.jsonl"

@property
def pad_cu_seqlens(self) -> bool:
"""Whether to pad cu_seqlens to a constant shape"""
if self.packed_sequence_size > 0:
if self.packed_sequence_specs.pad_cu_seqlens is not None:
return self.packed_sequence_specs.pad_cu_seqlens
else:
return self._pad_cu_seqlens
return False

def _extract_tokenizer_model_name(self) -> str:
"""Automatically get the model name from model path."""
if self.packed_sequence_specs.tokenizer_model_name is not None:
Expand Down
23 changes: 22 additions & 1 deletion nemo/collections/llm/gpt/data/packed_sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from dataclasses import dataclass
from pathlib import Path
from typing import Optional
Expand Down Expand Up @@ -50,6 +51,7 @@ def tokenize_dataset(path: Path, tokenizer: TokenizerSpec, max_seq_length: int,
def prepare_packed_sequence_data(
input_path: Path,
output_path: Path,
output_metadata_path: Path,
packed_sequence_size: int,
tokenizer: TokenizerSpec,
max_seq_length: int,
Expand Down Expand Up @@ -77,11 +79,15 @@ def prepare_packed_sequence_data(
dataset = tokenize_dataset(input_path, tokenizer, max_seq_length, seed)
sequences, histogram = create_hist(dataset, max_seq_length)

assignments = create_packing_strategy(histogram, packed_sequence_size, packing_algorithm)
assignments, packing_metadata = create_packing_strategy(histogram, packed_sequence_size, packing_algorithm)
output_data = fill_packing_strategy(assignments, sequences, packed_sequence_size, tokenizer.eos_id)

# save output data
np.save(output_path, output_data)
# save packing metadata
if output_metadata_path is not None:
with open(output_metadata_path, "w") as f:
json.dump(packing_metadata, f)
logging.info(f"Packed sequence is prepared and saved to {output_path}")


Expand Down Expand Up @@ -111,6 +117,21 @@ class PackedSequenceSpecs:
If specified, use this file for the packed validation dataset instead of the default path.
"""

packed_train_metadata_path: str = None
"""
If specified, use this file for the train packing metadata instead of the default path.
"""

packed_val_metadata_path: str = None
"""
If specified, use this file for the val packing metadata instead of the default path.
"""

pad_cu_seqlens: bool = False
"""
If True, pad cu_seqlens to a constant size, which is required for use with cudagraphs.
"""

def __post_init__(self):
if self.packed_train_data_path is not None:
self.packed_train_data_path = Path(self.packed_train_data_path)
Expand Down
7 changes: 7 additions & 0 deletions nemo/collections/llm/gpt/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,13 @@ class GPTConfig(TransformerConfig, io.IOMixin):
data_step_fn: Callable = gpt_data_step

def configure_model(self, tokenizer, pre_process=None, post_process=None) -> "MCoreGPTModel":
if self.enable_cuda_graph:
assert HAVE_TE, "Transformer Engine is required for cudagraphs."
assert getattr(self, 'use_te_rng_tracker', False), (
"Transformer engine's RNG tracker is required for cudagraphs, it can be "
"enabled with use_te_rng_tracker=True'."
)

vp_size = self.virtual_pipeline_model_parallel_size
if vp_size:
p_size = self.pipeline_model_parallel_size
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import math
import re
from typing import List, Mapping, Optional
Expand Down Expand Up @@ -524,7 +525,15 @@ def collate_fn(self, batch):


class GPTSFTPackedDataset(GPTSFTDataset):
def __init__(self, file_path: str, tokenizer: TokenizerSpec, return_cu_seqlen: bool = True, **kwargs):
def __init__(
self,
file_path: str,
tokenizer: TokenizerSpec,
return_cu_seqlen: bool = True,
pad_cu_seqlens: bool = False,
pack_metadata_file_path: Optional[str] = None,
**kwargs,
):
"""
file_path: See `file_path` in the parent class.
tokenizer: See `tokenizer` in the parent class.
Expand All @@ -537,6 +546,20 @@ def __init__(self, file_path: str, tokenizer: TokenizerSpec, return_cu_seqlen: b
assert self.virtual_tokens == 0, "P-Tuning with packed sequence is not supported."
self.return_cu_seqlen = return_cu_seqlen

self.pad_cu_seqlens = pad_cu_seqlens
if self.pad_cu_seqlens:
assert (
pack_metadata_file_path is not None
), "a metadata json file is required when pad_cu_seqlens is enabled"
assert (
self.pad_to_max_length is True
), "'pad_to_max_length=True' is required when pad_cu_seqlens is enabled"

self.pack_metadata = None
if pack_metadata_file_path is not None:
with open(pack_metadata_file_path) as f:
self.pack_metadata = json.load(f)

def __getitem__(self, idx):
if self.samples_mapping is not None:
# assert idx < len(self.samples_mapping)
Expand Down Expand Up @@ -665,6 +688,11 @@ def collate_fn(self, batch):
if len(cu_seqlens[-1]) > len(cu_seqlens_unpadded[-1]):
cu_seqlens_unpadded[-1].append(cu_seqlens_unpadded[-1][-1])

if self.pad_cu_seqlens:
# pad cu_seqlens with zero length sequences
pad_num = self.pack_metadata['max_samples_per_bin'] - len(cu_seqlens[-1])
cu_seqlens[-1].extend([max_length] * pad_num)

assert len(input_ids[0]) == len(
position_ids[0]
), "Dataset problem: input_ids and position_ids lengths don't match"
Expand Down Expand Up @@ -695,6 +723,15 @@ def collate_fn(self, batch):
cu_seqlens_unpadded = torch.IntTensor(cu_seqlens_unpadded)
cu_seqlens_unpadded_argmin = torch.argmin(cu_seqlens_unpadded, dim=1, keepdim=True)

if self.pad_cu_seqlens:
# Use the global max seqlen, as 'pad_cu_seqlens' is used mainly
# to support cudagraphs, and 'max_seqlen' is a cpu tensor, which means should
# be the same across all batches.
max_seqlen = torch.IntTensor([self.pack_metadata['dataset_max_seqlen']] * len(cu_seqlens))
else:
seqlens = cu_seqlens[:, 1:] - cu_seqlens[:, :-1]
max_seqlen, _ = seqlens.max(dim=1, keepdim=True)

processed_batch.update(
{
'attention_mask': torch.LongTensor(
Expand Down
1 change: 1 addition & 0 deletions nemo/lightning/_strategy_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def init_parallel_ranks(
use_fp8=fp8,
init_mpi_proc_group=getattr(parallel_config, "tp_comm_overlap", False)
and getattr(parallel_config, "tp_comm_bootstrap_backend", None) == 'mpi',
use_te_rng_tracker=getattr(parallel_config, "use_te_rng_tracker", False),
# apex_transformer_log_level=self.cfg.get('apex_transformer_log_level', 30),
)

Expand Down
4 changes: 4 additions & 0 deletions nemo/lightning/pytorch/strategies/megatron_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ class ParallelismConfig:
pipeline_dtype: torch.dtype
encoder_tensor_model_parallel_size: int = 0
encoder_pipeline_model_parallel_size: int = 0
use_te_rng_tracker: bool = False


class MegatronStrategy(DDPStrategy, io.IOMixin):
Expand Down Expand Up @@ -199,6 +200,7 @@ def __init__(
ddp: Union[DDPLiteral, DistributedDataParallelConfig] = "megatron",
lazy_init: bool = False,
pipeline_dtype: Optional[torch.dtype] = None,
use_te_rng_tracker: bool = False,
save_ckpt_format: str = "torch_dist",
ckpt_async_save: bool = True,
ckpt_torch_dist_multiproc: int = None, ## TODO(ashors): put elsewhere?
Expand Down Expand Up @@ -244,6 +246,7 @@ def __init__(
self.ckpt_load_optimizer = ckpt_load_optimizer
self.ckpt_save_optimizer = ckpt_save_optimizer
self.ckpt_load_strictness = ckpt_load_strictness
self.use_te_rng_tracker = use_te_rng_tracker
self._pipeline_dtype = pipeline_dtype
self._setup_optimizers = setup_optimizers
self._init_model_parallel = init_model_parallel
Expand Down Expand Up @@ -900,6 +903,7 @@ def parallelism(self) -> ParallelismConfig:
encoder_tensor_model_parallel_size=self.encoder_tensor_model_parallel_size,
encoder_pipeline_model_parallel_size=self.encoder_pipeline_model_parallel_size,
pipeline_dtype=self.pipeline_dtype,
use_te_rng_tracker=self.use_te_rng_tracker,
)

@contextmanager
Expand Down
Loading

0 comments on commit fe4f39a

Please sign in to comment.