Skip to content

Commit

Permalink
Merge branch 'tpoon/pp_llava_evaluation' into 'main'
Browse files Browse the repository at this point in the history
pp > 1 online evaluation

See merge request ADLR/megatron-lm!2289
  • Loading branch information
ko3n1g committed Nov 23, 2024
2 parents d392f9c + 938e5c8 commit c10721e
Show file tree
Hide file tree
Showing 6 changed files with 118 additions and 37 deletions.
58 changes: 48 additions & 10 deletions examples/multimodal/run_text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
from megatron.core.models.vision.clip_vit_model import get_num_image_embeddings
from megatron.inference.text_generation.api import generate_and_post_process
from megatron.inference.text_generation.forward_step import ForwardStep
from megatron.training import get_args, get_model
from megatron.inference.text_generation.communication import broadcast_int_list
from megatron.training import get_args, get_model, get_tokenizer, print_rank_0
from megatron.training.checkpointing import load_checkpoint
from megatron.training.initialize import initialize_megatron

Expand Down Expand Up @@ -156,7 +157,7 @@ def generate_samples(model, config: EvaluationConfig, print_output):

conv = get_conversation(config.task, question)

forward_step = partial(VLMForwardStep, num_img_embeddings_per_tile, imgs, num_tiles)
forward_step = partial(VLMForwardStep, num_img_embeddings_per_tile, imgs, num_tiles, args.decoder_seq_length)

if is_first_rank():
resp_sentences, _, _, _ = generate_and_post_process(
Expand Down Expand Up @@ -316,6 +317,7 @@ def __init__(
num_img_embeddings_per_tile,
images,
num_tiles,
decoder_seq_length,
model,
max_batch_size,
max_sequence_length,
Expand All @@ -327,6 +329,18 @@ def __init__(
super().__init__(model, max_batch_size, max_sequence_length + num_img_embeddings)
self._images = images
self._num_tiles = num_tiles
self._num_img_embeddings = num_img_embeddings
self.decoder_seq_length = decoder_seq_length

self._recv_only_vision_embeds = False
pp_rank = parallel_state.get_pipeline_model_parallel_rank()
# Checks if the previous stage only has a vision encoder, and that the current stage has part of the LM decoder.
# In this case, the current stage should only receive vision embeddings.
if pp_rank > 0:
self._recv_only_vision_embeds = parallel_state.is_inside_encoder(pp_rank - 1) and (not parallel_state.is_inside_decoder(pp_rank - 1)) and parallel_state.is_inside_decoder()

# Checks if the current stage only has a vision encoder
self._encoder_only = parallel_state.is_inside_encoder() and not parallel_state.is_inside_decoder()

def _forward(self, tokens, position_ids, attention_mask):
return self.model(
Expand All @@ -340,20 +354,44 @@ def _forward(self, tokens, position_ids, attention_mask):
)

def __call__(self, tokens, position_ids, attention_mask):
output = super().__call__(tokens, position_ids, attention_mask)
num_image_tokens = (tokens == self.model.image_token_index).sum().item()
num_tokens = tokens.size(1)
recv_buffer_seq_length = None
if num_image_tokens > 0:
# When there are image tokens and this stage only receives vision embeddings, adjust the recv buffer seq length to match the image embeddings sequence length.
# If there are image tokens and this stage receives full embeddings, make sure we compensate for expansion of image tokens.
# Note that this will set a recv_buffer_seq_length for the encoder stage, this length is irrelevant since that recv buffer is never allocated.
if self._recv_only_vision_embeds:
recv_buffer_seq_length = self._num_img_embeddings
else:
recv_buffer_seq_length = min(self._num_img_embeddings + num_tokens - num_image_tokens, self.decoder_seq_length)
elif self._recv_only_vision_embeds:
# If this stage only receives vision embeddings and there are no image tokens we won't run the encoder and therefore shouldn't try to recv.
recv_buffer_seq_length = 0

# If the pipeline stage only has a vision encoder, then it only needs to run when there are image tokens
if not (self._encoder_only and num_image_tokens == 0):
output = super().__call__(tokens, position_ids, attention_mask, recv_buffer_seq_length=recv_buffer_seq_length)
else:
output = None
if isinstance(output, tuple):
logits = output[0]
logits, _ = output
else:
logits = output

# On the first inference iteration, we compute image tokens.
# Update the sequence length offset by the number of image tokens.
num_image_tokens = (tokens == self.model.module.image_token_index).sum().item()
num_tokens = tokens.size(1)
# On every PP stage(although inference params should only matter for decoder),
# update the sequence length offset by the number of image tokens.
if num_tokens > 1 and num_image_tokens > 0:
self.inference_params.sequence_len_offset += (
self.inference_params.key_value_memory_dict["image_tokens_count"] - num_image_tokens
)
if "image_tokens_count" not in self.inference_params.key_value_memory_dict:
self.inference_params.key_value_memory_dict["image_tokens_count"] = self._num_img_embeddings

if self._num_img_embeddings + num_tokens - num_image_tokens > self.decoder_seq_length:
self.inference_params.sequence_len_offset += self.decoder_seq_length - num_tokens
else:
self.inference_params.sequence_len_offset += (
self.inference_params.key_value_memory_dict["image_tokens_count"] - num_image_tokens
)

return logits

Expand Down
3 changes: 3 additions & 0 deletions megatron/core/models/multimodal/llava_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,7 @@ def _preprocess_data(
loss_mask,
labels,
use_inference_kv_cache,
inference_params,
image_token_index,
num_image_tiles,
attention_mask,
Expand Down Expand Up @@ -351,6 +352,7 @@ def _preprocess_data(
if (
self._language_is_pipeline_parallel
and max_seq_len < self._language_max_sequence_length
and inference_params is None
):
max_seq_len = self._language_max_sequence_length

Expand Down Expand Up @@ -696,6 +698,7 @@ def forward(
loss_mask,
labels,
use_inference_kv_cache,
inference_params,
image_token_index if image_token_index is not None else self.image_token_index,
num_image_tiles,
attention_mask,
Expand Down
13 changes: 13 additions & 0 deletions megatron/core/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,10 @@
# the first local rank in the tensor model parallel group
_TENSOR_MODEL_PARALLEL_GLOBAL_RANKS = None

# A list of global ranks for each model parallel group to ease calculation of
# the first local rank in the model parallel group
_MODEL_PARALLEL_GLOBAL_RANKS = None

# Context parallel group that the current rank belongs to
_CONTEXT_PARALLEL_GROUP = None
# A list of global ranks for each context parallel group to ease calculation of the
Expand Down Expand Up @@ -762,13 +766,15 @@ def generator_wrapper(group_type, is_expert=False, **kwargs):

# Build the model-parallel groups.
global _MODEL_PARALLEL_GROUP
global _MODEL_PARALLEL_GLOBAL_RANKS
assert _MODEL_PARALLEL_GROUP is None, 'model parallel group is already initialized'
for ranks in generator_wrapper('tp-pp'):
group = torch.distributed.new_group(
ranks, timeout=timeout, pg_options=get_nccl_options('mp', nccl_comm_cfgs)
)
if rank in ranks:
_MODEL_PARALLEL_GROUP = group
_MODEL_PARALLEL_GLOBAL_RANKS = ranks

# Build the tensor model-parallel groups.
global _TENSOR_MODEL_PARALLEL_GROUP
Expand Down Expand Up @@ -1342,6 +1348,13 @@ def get_tensor_model_parallel_src_rank():
return _TENSOR_MODEL_PARALLEL_GLOBAL_RANKS[0]


def get_model_parallel_src_rank():
"""Calculate the global rank corresponding to the first local rank
in the model parallel group."""
assert _MODEL_PARALLEL_GLOBAL_RANKS is not None, "Model parallel group is not initialized"
return _MODEL_PARALLEL_GLOBAL_RANKS[0]


def get_data_parallel_src_rank(with_context_parallel=False):
"""Calculate the global rank corresponding to the first local rank
in the data parallel group."""
Expand Down
45 changes: 29 additions & 16 deletions megatron/inference/text_generation/communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from megatron.core import mpu



# TODO: use functions from megatron/p2p
def recv_from_prev_pipeline_rank_(recv_buffer=None):
"""Receive from previous pipeline stage and update the
Expand All @@ -25,8 +24,6 @@ def recv_from_prev_pipeline_rank_(recv_buffer=None):
# To protect against race condition when using batch_isend_irecv().
torch.cuda.synchronize()



# TODO: use functions from megatron/p2p
def send_to_next_pipeline_rank(tensor=None):
"""Send output to the next pipeline stage."""
Expand Down Expand Up @@ -80,6 +77,29 @@ def broadcast_from_last_pipeline_stage(size, dtype, tensor=None):
return tensor


def _send_and_recv_from_last_to_first_pipeline_stage(tensor=None):
is_last_stage = mpu.is_pipeline_last_stage()
is_first_stage = mpu.is_pipeline_first_stage()

if is_last_stage or is_first_stage:
if is_first_stage:
recv_prev_op = torch.distributed.P2POp(
torch.distributed.irecv, tensor,
mpu.get_pipeline_model_parallel_last_rank())
reqs = torch.distributed.batch_isend_irecv([recv_prev_op])
elif is_last_stage:
send_next_op = torch.distributed.P2POp(
torch.distributed.isend, tensor,
mpu.get_pipeline_model_parallel_first_rank())
reqs = torch.distributed.batch_isend_irecv([send_next_op])

for req in reqs:
req.wait()
# To protect against race condition when using batch_isend_irecv().
torch.cuda.synchronize()

return tensor


def broadcast_from_last_to_first_pipeline_stage(size, dtype, tensor=None):
"""Broadcast tensor values from last stage into the first stage."""
Expand All @@ -98,10 +118,7 @@ def broadcast_from_last_to_first_pipeline_stage(size, dtype, tensor=None):
tensor = torch.empty(size,
dtype=dtype,
device=torch.cuda.current_device())
src = mpu.get_pipeline_model_parallel_last_rank()
group = mpu.get_embedding_group()
# Broadcast from last stage into the first stage.
torch.distributed.broadcast(tensor, src, group)
tensor = _send_and_recv_from_last_to_first_pipeline_stage(tensor)
else:
tensor = None

Expand All @@ -123,8 +140,6 @@ def copy_from_last_to_first_pipeline_stage(size, dtype, tensor=None):
if is_last_stage or is_first_stage:
_is_cuda(tensor)
is_contiguous = tensor.is_contiguous()
src = mpu.get_pipeline_model_parallel_last_rank()
group = mpu.get_embedding_group()
if is_contiguous:
tensor_ = tensor
else:
Expand All @@ -134,8 +149,7 @@ def copy_from_last_to_first_pipeline_stage(size, dtype, tensor=None):
tensor_ = torch.empty(size,
dtype=dtype,
device=torch.cuda.current_device())
# Broadcast from last stage into the first stage.
torch.distributed.broadcast(tensor_, src, group)
tensor_ = _send_and_recv_from_last_to_first_pipeline_stage(tensor_)
# Update the first stage tensor
if is_first_stage and not is_contiguous:
tensor[...] = tensor_
Expand All @@ -150,7 +164,7 @@ def broadcast_tensor(size, dtype, tensor=None, rank=0, data_parallel=False):
data_parallel (bool): Broadcast across a single data parallel model replica.
"""
if data_parallel:
rank = parallel_state.get_tensor_model_parallel_src_rank()
rank = parallel_state.get_model_parallel_src_rank()

if torch.distributed.get_rank() == rank:
_is_cuda_contiguous(tensor)
Expand All @@ -161,7 +175,7 @@ def broadcast_tensor(size, dtype, tensor=None, rank=0, data_parallel=False):

group = None
if data_parallel:
group = parallel_state.get_tensor_model_parallel_group()
group = parallel_state.get_model_parallel_group()

torch.distributed.broadcast(tensor, rank, group=group)

Expand All @@ -179,12 +193,11 @@ def broadcast_list(size, dtype, list_values=None, rank=0, data_parallel=False):
tensor = None

if data_parallel:
src_rank = parallel_state.get_data_parallel_src_rank()
if src_rank == 0:
if parallel_state.get_model_parallel_src_rank() == torch.distributed.get_rank():
tensor = torch.tensor(list_values, dtype=dtype,
device=torch.cuda.current_device())

rank = parallel_state.get_tensor_model_parallel_src_rank()
rank = parallel_state.get_model_parallel_src_rank()
else:
if torch.distributed.get_rank() == rank:
tensor = torch.tensor(list_values, dtype=dtype,
Expand Down
34 changes: 23 additions & 11 deletions megatron/inference/text_generation/forward_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,42 +39,54 @@ def __init__(self, model, max_batch_size, max_sequence_length):
def _forward(self, tokens, position_ids, attention_mask):
return self.model(tokens, position_ids, attention_mask, inference_params=self.inference_params)

def __call__(self, tokens, position_ids, attention_mask):
def __call__(self, tokens, position_ids, attention_mask, recv_buffer_seq_length=None):
"""Invocation of the forward methods. Note that self.inference_params
is being modified by the forward step."""
# Pipelining case.
# This runs only if current_batch_x_seqlen > args.inference_batch_times_seqlen_threshold
# and requires setting args.pipeline_model_parallel > 1. The batch will be split into
# smaller microbatches to be pipelined through the stages.
if self.pipeline_size_larger_than_one:
current_batch_x_seqlen = tokens.size(0) * tokens.size(1)
seq_len = tokens.size(1) if recv_buffer_seq_length is None else recv_buffer_seq_length
current_batch_x_seqlen = tokens.size(0) * seq_len
if current_batch_x_seqlen >= self.pipelining_batch_x_seqlen:
micro_batch_size = \
max(1, self.pipelining_batch_x_seqlen // tokens.size(1))
max(1, self.pipelining_batch_x_seqlen // seq_len)
return self._with_pipelining_forward_step(tokens,
position_ids,
attention_mask,
micro_batch_size)
# Do not pipeline the batch; the entire batch will be passed through all at once.
micro_batch_size,
recv_buffer_seq_length=recv_buffer_seq_length)

recv_buffer = None
if recv_buffer_seq_length is not None:
recv_buffer = _allocate_recv_buffer(tokens.size(0), recv_buffer_seq_length)

return self._no_pipelining_forward_step(tokens,
position_ids,
attention_mask)
attention_mask,
recv_buffer=recv_buffer)


def _forward_step_helper(self, tokens, position_ids, attention_mask, recv_buffer=None):
"""Single forward step. Update the allocate memory flag so
only the first time the memory is allocated."""
batch_size = tokens.size(0)
sequence_length = tokens.size(1)

if recv_buffer is None:
recv_buffer = _allocate_recv_buffer(batch_size, sequence_length)

# Receive from previous stage.
recv_from_prev_pipeline_rank_(recv_buffer)
if recv_buffer is not None and torch.numel(recv_buffer) > 0:
recv_from_prev_pipeline_rank_(recv_buffer)

# Forward pass through the model.
self.model.set_input_tensor(recv_buffer)
if not mpu.is_pipeline_first_stage():
self.model.set_input_tensor(recv_buffer)
output_tensor = self._forward(tokens, position_ids, attention_mask)
if isinstance(output_tensor, tuple):
output_tensor = output_tensor[0]

# Send output to the next stage.
send_to_next_pipeline_rank(output_tensor)
Expand All @@ -99,10 +111,10 @@ def _no_pipelining_forward_step(self, tokens, position_ids, attention_mask,
return logits


def _with_pipelining_forward_step(self, tokens, position_ids, attention_mask, micro_batch_size):
def _with_pipelining_forward_step(self, tokens, position_ids, attention_mask, micro_batch_size, recv_buffer_seq_length=None):
"""No interleaving is supported."""
sequence_length = tokens.size(1)
batch_size = tokens.size(0)
sequence_length = tokens.size(1) if recv_buffer_seq_length is None else recv_buffer_seq_length

# Divide the batch dimension into micro batches.
num_micro_batches, last_chunk = divmod(batch_size,
Expand Down Expand Up @@ -143,7 +155,7 @@ def _with_pipelining_forward_step(self, tokens, position_ids, attention_mask, mi

# Once we are done with all the micro-batches, we can
# adjust the sequence length offset.
self.inference_params.sequence_len_offset += sequence_length
self.inference_params.sequence_len_offset += tokens.size(1)
# and reset the batch size offset
self.inference_params.batch_size_offset = 0

Expand Down
2 changes: 2 additions & 0 deletions tests/unit_tests/models/test_llava_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ def test_preprocess_data(self):

use_inference_kv_cache = False
attention_mask = None
inference_params = None

embeddings, labels, loss_mask, attention_mask = self.model._preprocess_data(
image_embeddings,
Expand All @@ -134,6 +135,7 @@ def test_preprocess_data(self):
loss_mask,
labels,
use_inference_kv_cache,
inference_params,
image_token_index,
num_image_tiles,
attention_mask,
Expand Down

0 comments on commit c10721e

Please sign in to comment.