Skip to content

Commit

Permalink
ADLR/megatron-lm!2392 - Add streaming support for MCore inference API
Browse files Browse the repository at this point in the history
  • Loading branch information
santhnm2 authored and deepakn94 committed Feb 1, 2025
1 parent bc12efb commit db9527f
Show file tree
Hide file tree
Showing 13 changed files with 482 additions and 34 deletions.
4 changes: 2 additions & 2 deletions LICENSE
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ OpenAI). Files from these organizations have notices at the top of each file.
Below are licenses used in those files, as indicated.


--------------------------------------------------------------------------------
-- LICENSE FOR Facebook, huggingface, Google Research, LLaVA, and Mamba code --
--------------------------------------------------------------------------------------
-- LICENSE FOR Facebook, huggingface, Google Research, LLaVA, Mamba, and vLLM code --


Apache License
Expand Down
61 changes: 49 additions & 12 deletions examples/inference/gpt/gpt_batch_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
TextGenerationController,
)
from megatron.core.transformer.module import MegatronModule
from megatron.legacy.model.module import Float16Module

sys.path.append(
os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir, os.path.pardir))
Expand All @@ -32,7 +31,9 @@
from megatron.core import mpu
from megatron.training.initialize import initialize_megatron
from megatron.training import get_model
from typing import List
import asyncio
from typing import AsyncIterator, List



def add_text_generate_args(parser):
Expand Down Expand Up @@ -64,6 +65,7 @@ def add_text_generate_args(parser):
group.add_argument(
"--max-batch-size", type=int, default=1, help='Max number of prompts to process at once'
)
group.add_argument("--stream", action="store_true", default=False, help="Stream output tokens")
return parser


Expand All @@ -90,13 +92,44 @@ def get_inference_engine(args: Namespace, model: MegatronModule) -> AbstractEngi
)

inference_wrapped_model = GPTInferenceWrapper(model, inference_wrapper_config)
text_generation_controller = TextGenerationController(
inference_wrapped_model=inference_wrapped_model, tokenizer=tokenizer
)
return MCoreEngine(
text_generation_controller=text_generation_controller, max_batch_size=args.max_batch_size
)

text_generation_controller = TextGenerationController(inference_wrapped_model=inference_wrapped_model, tokenizer=tokenizer)
return MCoreEngine(text_generation_controller=text_generation_controller, max_batch_size=args.max_batch_size)


async def generate(
inference_engine: MCoreEngine,
sampling_params: SamplingParams,
prompts: List[str],
) -> List[InferenceRequest]:
async def collect_stream(prompt, request_id, stream_generator):
print(f"Request {request_id}: {prompt}", end="", flush=True)
prev_idx = 0
async for output in stream_generator:
print(output.generated_text[prev_idx:], end="", flush=True)
prev_idx = len(output.generated_text)
print()

request_ids: List[str] = [
inference_engine.add_request(
prompt=prompt, inference_parameters=sampling_params, streaming=True
)
for prompt in prompts
]
stream_generators = [inference_engine.get_stream_generator(request_id) for request_id in request_ids]

tasks = [
asyncio.create_task(collect_stream(prompt, request_id, stream_generator))
for (prompt, request_id, stream_generator) in zip(prompts, request_ids, stream_generators)
]

await inference_engine.run_engine_async()
await asyncio.gather(*tasks)

results: List[InferenceRequest] = [
inference_engine.scheduler.completed_request_pool[request_id] for request_id in request_ids
]

return results

def main():
"""Main program."""
Expand Down Expand Up @@ -137,9 +170,12 @@ def main():
)

start_time = time.perf_counter()
results: List[InferenceRequest] = inference_engine.generate(
prompts=args.prompts, sampling_params=sampling_params
)
if args.stream:
results: List[InferenceRequest] = asyncio.run(generate(inference_engine, sampling_params, args.prompts))
else:
results: List[InferenceRequest] = inference_engine.generate(
prompts=args.prompts, sampling_params=sampling_params,
)
end_time = time.perf_counter()
latency = end_time - start_time

Expand All @@ -155,6 +191,7 @@ def main():
}
print(result)

torch.distributed.destroy_process_group()

if __name__ == "__main__":
main()
67 changes: 67 additions & 0 deletions megatron/core/inference/async_stream.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
# Copyright 2025 The vLLM authors.
#
# This code was adopted from https://github.com/vllm-project/vllm/
# This source code is licensed under the Apache license found in the
# LICENSE file in the root directory of this source tree.

import asyncio
from typing import Any, AsyncGenerator, Callable, Optional, Type, Union

from megatron.core.inference.inference_request import InferenceRequest

STOP_ITERATION = Exception()


class AsyncStream:
"""
Class for encapsulating an asynchronous stream of InferenceRequest outputs.
Adopted from https://github.com/vllm-project/vllm/blob/eb881ed006ca458b052905e33f0d16dbb428063a/vllm/v1/engine/async_stream.py # pylint: disable=line-too-long
"""

def __init__(self, request_id: str, cancel: Callable[[str], None]) -> None:
self._request_id = request_id
self._cancel = cancel
self._queue: asyncio.Queue = asyncio.Queue()
self._finished = False
self._loop = asyncio.get_running_loop()

def put(self, item: Union[InferenceRequest, Exception]) -> None:
"""Adds a new value to the stream"""
if not self._finished:
self._loop.call_soon_threadsafe(self._queue.put_nowait, item)

def finish(self, exception: Optional[Union[BaseException, Type[BaseException]]] = None) -> None:
"""Completes the stream by adding a sentinel value"""
if not self._finished:
self._finished = True
self._loop.call_soon_threadsafe(
self._queue.put_nowait,
exception if self._is_raisable(exception) else STOP_ITERATION,
)

@property
def finished(self) -> bool:
"""Whether the stream has finished"""
return self._finished

async def generator(self) -> AsyncGenerator[InferenceRequest, None]:
"""Creates an AsyncGenerator over the stream queue"""
try:
while True:
result = await self._queue.get()
if self._is_raisable(result):
if result == STOP_ITERATION:
return
raise result
yield result
except GeneratorExit:
self._cancel()
raise asyncio.CancelledError from None

@staticmethod
def _is_raisable(value: Any):
return isinstance(value, BaseException) or (
isinstance(value, type) and issubclass(value, BaseException)
)
77 changes: 70 additions & 7 deletions megatron/core/inference/engines/mcore_engine.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from typing import Dict, List, Optional
import asyncio
from collections import OrderedDict
from typing import AsyncGenerator, Dict, List, Optional, Union

import torch

from megatron.core.inference.async_stream import AsyncStream
from megatron.core.inference.engines.abstract_engine import AbstractEngine
from megatron.core.inference.inference_request import InferenceRequest
from megatron.core.inference.sampling_params import SamplingParams
Expand Down Expand Up @@ -37,6 +40,47 @@ def __init__(
self.random_seed = random_seed
self.scheduler = Scheduler(max_batch_size=max_batch_size)

def add_request(
self,
prompt: str,
add_BOS: bool = False,
encoder_prompt: Optional[str] = None,
inference_parameters: Optional[SamplingParams] = None,
streaming: bool = False,
) -> str:
"""
Adds a request to the scheduler and returns the request ID.
Args:
prompt (str): A prompt string
add_BOS (bool): Whether to add BOS token to beginning of the prompt
encoder_prompt (str): The encoder prompt string
inference_parameters (SamplingParams): The inference parameters
streaming (bool): Whether to stream incremental outputs for this request
Returns:
The newly created request ID.
"""

prompt_tokens = self.text_generation_controller.tokenize_prompt(prompt, add_BOS)

return self.scheduler.add_request(
prompt=prompt,
prompt_tokens=prompt_tokens,
encoder_prompt=encoder_prompt,
inference_parameters=inference_parameters,
streaming=streaming,
)

def get_stream_generator(
self, request_id: str
) -> Union[AsyncGenerator[InferenceRequest, None], None]:
"""Returns the stream generator for the given request ID if it exists."""
stream = self.scheduler.streams.get(request_id, None)
if stream is not None:
return stream.generator()
return None

def generate(
self,
prompts: List[str],
Expand Down Expand Up @@ -73,15 +117,13 @@ def generate(
if self.random_seed:
torch.random.manual_seed(self.random_seed)

request_ids = []
request_ids: List[str] = []
for i in range(len(prompts)):
prompt = prompts[i]
encoder_prompt = encoder_prompts[i] if encoder_prompts is not None else None
prompt_tokens = self.text_generation_controller.tokenize_prompt(prompt, add_BOS)

request_id = self.scheduler.add_request(
request_id = self.add_request(
prompt=prompt,
prompt_tokens=prompt_tokens,
add_BOS=add_BOS,
encoder_prompt=encoder_prompt,
inference_parameters=sampling_params,
)
Expand All @@ -106,9 +148,14 @@ def run_engine(self):
"""
while self.scheduler.have_requests_pending():
active_requests: Dict[str, InferenceRequest] = self.scheduler.active_request_pool.copy()
active_streams: Dict[str, AsyncStream] = OrderedDict()
for request_id in active_requests:
if (stream := self.scheduler.streams.get(request_id, None)) is not None:
assert isinstance(stream, AsyncStream), stream
active_streams[request_id] = stream
result_dict: Dict[str, InferenceRequest] = (
self.text_generation_controller.generate_all_output_tokens_static_batch(
active_requests
active_requests, active_streams
)
)

Expand All @@ -124,3 +171,19 @@ def run_engine(self):
)
self.scheduler.update_requests_pools(result_dict=result_dict)
"""

def _wrapped_run_engine(self, cuda_device):
"""
Explicitly sets the CUDA device before running the engine.
This is to ensure that the CUDA device is correctly propagated when running
in a new thread context.
"""
torch.cuda.set_device(cuda_device)
self.run_engine()

async def run_engine_async(self):
"""Runs the engine asynchronously using asyncio"""
loop = asyncio.get_running_loop()

await loop.run_in_executor(None, self._wrapped_run_engine, torch.cuda.current_device())
5 changes: 4 additions & 1 deletion megatron/core/inference/inference_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,11 @@ class InferenceRequest:
prompt_tokens: List[int]
arrival_time: float
status: Status
prompt_log_probs: Optional[float] = None
encoder_prompt: Optional[str] = None
generated_text: Optional[str] = None
generated_segments: Optional[List[List[str]]] = None
generated_sequence_lengths: Optional[List[int]] = None
generated_tokens: Optional[torch.Tensor] = None
generated_log_probs: Optional[torch.Tensor] = None
generated_log_probs: Optional[float] = None
generated_length: int = 0
27 changes: 25 additions & 2 deletions megatron/core/inference/scheduler.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import functools
import time
import typing
from collections import OrderedDict
from typing import Dict, Optional
from typing import Dict, Optional, Type, Union

import torch

from megatron.core.inference.async_stream import AsyncStream
from megatron.core.inference.inference_request import InferenceRequest, Status
from megatron.core.inference.sampling_params import SamplingParams
from megatron.core.inference.utils import Counter
Expand All @@ -23,6 +25,8 @@ class Scheduler:

def __init__(self, max_batch_size: int):
self.max_batch_size = max_batch_size
self.requests: Dict[str, InferenceRequest] = OrderedDict()
self.streams: Dict[str, AsyncStream] = OrderedDict()
self.active_request_pool: Dict[str, InferenceRequest] = OrderedDict()
self.waiting_request_pool: Dict[str, InferenceRequest] = OrderedDict()
self.completed_request_pool: Dict[str, InferenceRequest] = OrderedDict()
Expand All @@ -35,7 +39,8 @@ def add_request(
encoder_prompt: Optional[str] = None,
inference_parameters: Optional[SamplingParams] = None,
arrival_time: Optional[float] = None,
):
streaming: bool = False,
) -> str:
"""Add an incoming request
This method will add the request to either the active pool or the waiting pool
Expand All @@ -47,6 +52,7 @@ def add_request(
encoder_prompt (str): Encoder input string
inference_parameters (SamplingParams): The inference parameters
arrival_time (float, optional): The incoming request time. Defaults to None.
streaming (bool, optional): Whether to asynchronously stream tokens for this request.
Returns:
The request_id for the new request.
Expand All @@ -62,6 +68,10 @@ def add_request(
else Status.WAITING_IN_QUEUE
)

if streaming:
abort_request = functools.partial(self.abort_request, request_id=request_id)
self.streams[request_id] = AsyncStream(request_id, abort_request)

if inference_parameters is None:
inference_parameters = SamplingParams()

Expand All @@ -75,6 +85,8 @@ def add_request(
encoder_prompt=encoder_prompt,
)

self.requests[request_id] = inference_request

if status == status.ACTIVE_BUT_NOT_GENERATING_TOKENS:
self.active_request_pool[request_id] = inference_request
else:
Expand Down Expand Up @@ -135,3 +147,14 @@ def update_requests_pools(
and len(self.waiting_request_pool) > 0
):
self.add_earliest_waiting_request_to_active_pool()

def abort_request(
self,
request_id: str,
*,
exception: Optional[Union[BaseException, Type[BaseException]]] = None
):
"""Cancels the given request"""
stream = self.streams.get(request_id, None)
if stream is not None:
stream.finish(exception=exception)
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.

from megatron.core.inference.text_generation_controllers.text_generation_controller import ( # noqa: F401 # pylint: disable=unused-import
TextGenerationController as SimpleTextGenerationController,
Expand Down
Loading

0 comments on commit db9527f

Please sign in to comment.