Skip to content

Draft #3

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 35 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions cpp/include/tensorrt_llm/batch_manager/llmRequest.h
Original file line number Diff line number Diff line change
Expand Up @@ -2328,6 +2328,11 @@ class LlmRequest : public GenericLlmRequest<runtime::ITensor::SharedPtr>
/// @return An optional Response
std::optional<executor::Response> createResponse(bool useFastLogits = false, int32_t mpiWorldRank = 0);

std::optional<executor::Result> createResult(bool useFastLogits = false, int32_t mpiWorldRank = 0);

void createSerializedResult(
std::vector<char>& serializedResult, bool& isFinal, bool useFastLogits = false, int32_t mpiWorldRank = 0);

void validate(SizeType32 maxInputLen, SizeType32 maxSequenceLen, SizeType32 maxDraftLen, SizeType32 vocabSizePadded,
std::optional<SizeType32> maxEncoderInputLen = std::nullopt, bool enableKVCacheReuse = false);

Expand Down
35 changes: 29 additions & 6 deletions cpp/tensorrt_llm/batch_manager/llmRequest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
*/

#include "tensorrt_llm/batch_manager/llmRequest.h"
#include "tensorrt_llm/executor/serializeUtils.h"
#include "tensorrt_llm/kernels/beamSearchKernels.h"

namespace tensorrt_llm::batch_manager
Expand All @@ -39,8 +40,34 @@ runtime::SizeType32 GenericLlmRequest<TTensor, TStream>::getBeamWidthByIter(bool

template class GenericLlmRequest<runtime::ITensor::SharedPtr>;

/// Note that there is some dependency on the order of operations in this method. Modify with care!
std::optional<executor::Response> LlmRequest::createResponse(bool useFastLogits, int32_t mpiWorldRank)
{
auto requestId = isChild() ? mParentRequestId : mRequestId;
auto result = createResult(useFastLogits, mpiWorldRank);
if (result.has_value())
{
return executor::Response(requestId, result.value(), mClientId);
}
return std::nullopt;
}

void LlmRequest::createSerializedResult(
std::vector<char>& serializedResult, bool& isFinal, bool useFastLogits, int32_t mpiWorldRank)
{
auto result = createResult(useFastLogits, mpiWorldRank);
if (result.has_value())
{
std::ostringstream oStream;
executor::serialize_utils::serialize(result.value(), oStream);
auto str = oStream.str();
serializedResult.resize(str.size());
std::copy(str.begin(), str.end(), serializedResult.begin());
isFinal = result.value().isFinal;
}
}

/// Note that there is some dependency on the order of operations in this method. Modify with care!
std::optional<executor::Result> LlmRequest::createResult(bool useFastLogits, int32_t mpiWorldRank)
{
TLLM_CHECK(!isDisaggContextCompleteState());
if (!(isFinished() || (mIsStreaming && mState == LlmRequestState::kGENERATION_IN_PROGRESS)))
Expand Down Expand Up @@ -192,11 +219,7 @@ std::optional<executor::Response> LlmRequest::createResponse(bool useFastLogits,

// Update position of last sent response
setMaxSentTokenLen(maxNbTokens);

auto requestId = isChild() ? mParentRequestId : mRequestId;
auto response = executor::Response(requestId, std::move(result), mClientId);

return response;
return result;
}

void LlmRequest::validate(SizeType32 maxInputLen, SizeType32 maxSequenceLen, SizeType32 maxDraftLen,
Expand Down
11 changes: 11 additions & 0 deletions cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
#include <pybind11/stl.h>
#include <pybind11/stl_bind.h>
#include <torch/extension.h>
#include <tuple>

namespace py = pybind11;
namespace tb = tensorrt_llm::batch_manager;
Expand Down Expand Up @@ -360,6 +361,16 @@ void initBindings(pybind11::module_& m)
py::arg("enable_kv_cache_reuse") = false)
.def("create_response", &tb::LlmRequest::createResponse, py::arg("use_fast_logits") = false,
py::arg("mpi_world_rank") = 0)
.def("create_result", &tb::LlmRequest::createResult, py::arg("use_fast_logits") = false,
py::arg("mpi_world_rank") = 0)
.def("create_serialized_result",
[](tb::LlmRequest& self, bool use_fast_logits = false, int mpi_world_rank = 0)
{
std::vector<char> serialized_result;
bool is_final = false;
self.createSerializedResult(serialized_result, is_final, use_fast_logits, mpi_world_rank);
return std::make_tuple(py::bytes(serialized_result.data(), serialized_result.size()), is_final);
})
.def("move_prompt_embedding_table_to_gpu", &tb::LlmRequest::movePromptEmbeddingTableToGpu, py::arg("manager"))
.def("move_lora_weights_to_gpu", &tb::LlmRequest::moveLoraWeightsToGpu, py::arg("manager"))
.def("finish_by_reason", &tb::LlmRequest::finishByReason, py::arg("finish_reason"));
Expand Down
9 changes: 9 additions & 0 deletions cpp/tensorrt_llm/pybind/executor/request.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/common/logger.h"
#include "tensorrt_llm/executor/executor.h"
#include "tensorrt_llm/executor/serializeUtils.h"
#include "tensorrt_llm/executor/tensor.h"
#include "tensorrt_llm/executor/types.h"
#include "tensorrt_llm/runtime/cudaStream.h"
Expand All @@ -29,6 +30,7 @@
#include <pybind11/operators.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <sstream>

#include <optional>
#include <vector>
Expand Down Expand Up @@ -775,6 +777,13 @@ void initRequestBindings(pybind11::module_& m)
.def_readwrite("context_phase_params", &tle::Result::contextPhaseParams)
.def(py::pickle(resultGetstate, resultSetstate));

m.def("deserialize_result",
[](std::string& x)
{
std::istringstream is(x);
return tle::serialize_utils::deserialize<tle::Result>(is);
});

auto responseGetstate = [](tle::Response const& self)
{ return py::make_tuple(self.getRequestId(), self.getResult(), self.getClientId()); };

Expand Down
67 changes: 43 additions & 24 deletions tensorrt_llm/_torch/pyexecutor/llm_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,39 +205,38 @@ class LlmResult:
py_result_properties = frozenset(
('context_logits', 'generation_logits', 'log_probs', 'cum_log_probs'))

def __init__(self, result: tensorrt_llm.bindings.executor.Result,
py_result: PyResult):
def __init__(self,
result: bytes,
py_result: PyResult,
is_final: bool = False):
self._result = result
self._py_result = py_result
self.is_final = is_final

def __getattr__(self, item):
if item in self.py_result_properties:
return getattr(self._py_result, item)
return getattr(self._result, item)
if item == 'is_final':
return object.__getattribute__(self, 'is_final')
result = object.__getattribute__(self, '_result')
return getattr(result, item)


class LlmResponse:
"""LlmResponse wraps `bindings.executor.Response` but detour some features to Python implementation"""

def __init__(self, response: tensorrt_llm.bindings.executor.Response,
py_result: PyResult):
self._response = response
self._py_result = py_result

def __getstate__(self):
return self._response, self._py_result

def __setstate__(self, state):
self._response, self._py_result = state
def __init__(self,
request_id: int,
error_msg: str = None,
result: LlmResult = None,
client_id: int = None):
self.request_id = request_id
self.error_msg = error_msg
self.result = result
self.client_id = client_id

@property
def result(self) -> tensorrt_llm.bindings.executor.Result:
return LlmResult(
self._response.result,
self._py_result) # LlmResult masquerades bindings.executor.Result

def __getattr__(self, item):
return getattr(self._response, item)
def has_error(self):
return self.error_msg is not None


class LlmRequest(tensorrt_llm.bindings.internal.batch_manager.LlmRequest):
Expand Down Expand Up @@ -269,6 +268,7 @@ def __init__(
**kwargs)
self.py_client_id = client_id
self.py_request_id = self.request_id
self.py_llm_request_type = self.llm_request_type
self.py_end_id = self.end_id
self.py_prompt_len = self.prompt_len
self.py_orig_prompt_len = self.orig_prompt_len
Expand All @@ -282,6 +282,8 @@ def __init__(
self.is_cuda_graph_dummy = False
self.py_lora_task_layer_module_configs = None

self.py_tokens = super().get_tokens()

self.py_return_log_probs = return_log_probs
self.py_return_context_logits = return_context_logits
self.py_return_generation_logits = return_generation_logits
Expand All @@ -297,13 +299,30 @@ def __init__(
return_generation_logits,
exclude_last_generation_logits)

def is_generation_only_request(self):
return self.py_llm_request_type == LlmRequestType.LLMREQUEST_TYPE_GENERATION_ONLY

def get_tokens(self, beam: int) -> int:
return self.py_tokens[beam]

def get_last_tokens(self, beam: int) -> int:
return self.py_tokens[beam][-1]

def add_new_token(self, token: int, beam: int) -> int:
self.py_tokens[beam].append(token)
# sync to C++ side
return super().add_new_token(token, beam)

def create_response(
self,
use_fast_logits=False,
mpi_world_rank=0) -> tensorrt_llm.bindings.executor.Response | None:
response = super().create_response(use_fast_logits, mpi_world_rank)
return LlmResponse(response,
self.py_result) if response is not None else None
result, is_final = super().create_serialized_result(
use_fast_logits, mpi_world_rank)
return LlmResponse(
request_id=self.py_request_id,
result=LlmResult(result, self.py_result, is_final),
client_id=self.py_client_id) if len(result) > 0 else None

@property
def is_dummy(self):
Expand Down
2 changes: 1 addition & 1 deletion tensorrt_llm/_torch/pyexecutor/model_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1184,7 +1184,7 @@ def _prepare_tp_inputs(
gather_ids.append(len(input_ids) - 1)
sequence_lengths.append(len(prompt_tokens))
prompt_lengths.append(len(prompt_tokens))
past_seen_token_num = request.context_current_position
past_seen_token_num = begin_compute
num_cached_tokens_per_seq.append(past_seen_token_num)
multimodal_embedding = request.multimodal_embedding
if multimodal_embedding is not None:
Expand Down
24 changes: 13 additions & 11 deletions tensorrt_llm/_torch/pyexecutor/py_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@

from ..distributed import Distributed
from .kv_cache_transceiver import KvCacheTransceiver
from .llm_request import (ExecutorRequest, ExecutorResponse, LlmRequest,
LlmRequestState, executor_request_to_llm_request)
from .llm_request import (ExecutorRequest, LlmRequest, LlmRequestState,
LlmResponse, executor_request_to_llm_request)
from .model_engine import ModelEngine
from .sampler import Sampler, SampleState, SampleStateTensors, TorchSampler
from .scheduler import ScheduledRequests
Expand Down Expand Up @@ -323,14 +323,14 @@ def await_responses(
self,
id: Optional[Union[List[int], int]] = None,
timeout: Optional[datetime.timedelta] = None,
) -> Union[List[List[ExecutorResponse]], List[ExecutorResponse]]:
) -> Union[List[List[LlmResponse]], List[LlmResponse]]:
"""
Await for ready responses
Args:
id (Optional[Union[List[int], int]]): Request id
timeout (Optional[datetime.timedelta]): The maximum time to wait for new responses
Returns:
Union[List[tensorrt_llm.bindings.executor.Response], List[List[tensorrt_llm.bindings.executor.Response]]]: Responses
Union[List[LlmResponse], List[List[LlmResponse]]]: Responses
"""
timeout = timeout.total_seconds() if timeout is not None else None
if id is None:
Expand Down Expand Up @@ -1934,8 +1934,10 @@ def _handle_errors(self, error_msg: Optional[str] = None):
req_id = request.py_request_id
request.state = LlmRequestState.GENERATION_COMPLETE
self._terminate_request(request)
error_responses[req_id] = ExecutorResponse(
req_id, error_msg, client_id=request.py_client_id)
error_responses[req_id] = LlmResponse(
request_id=req_id,
error_msg=error_msg,
client_id=request.py_client_id)
self.active_requests.clear()
self._enqueue_responses(error_responses)

Expand Down Expand Up @@ -1979,7 +1981,7 @@ def _handle_cancelled_requests(self):
self._enqueue_responses(cancelled_responses)

@nvtx_range("_enqueue_responses")
def _enqueue_responses(self, responses: Dict[int, ExecutorResponse]):
def _enqueue_responses(self, responses: Dict[int, LlmResponse]):
if 0 not in self.dist.mapping.tp_group and not self.gather_all_responses:
return

Expand Down Expand Up @@ -2036,7 +2038,7 @@ def _handle_responses(self):
requests_to_terminate.append(request)
continue

if request.is_generation_only_request:
if request.is_generation_only_request():
# If request is in transmission, so we don't need to emit a response
# Also, for the first iteration with overlap, we should skip since first
# token has already been emitted previously
Expand All @@ -2048,7 +2050,7 @@ def _handle_responses(self):

request.draft_tokens = request.py_draft_tokens
request.decoding_iter = request.py_decoding_iter
response: Response = request.create_response(False, self.dist.rank)
response = request.create_response(False, self.dist.rank)
request_done = False
if response:
request_done = response.result.is_final
Expand All @@ -2075,7 +2077,7 @@ def _terminate_ctx_finished_requests(self):

def _await_any_response(self,
timeout: Optional[float] = None
) -> List[ExecutorResponse]:
) -> List[LlmResponse]:

def any_responses_ready():
return len(self.responses) > 0 or self.is_shutdown
Expand All @@ -2092,7 +2094,7 @@ def any_responses_ready():
def _await_single_response(
self,
id: int,
timeout: Optional[float] = None) -> List[ExecutorResponse]:
timeout: Optional[float] = None) -> List[LlmResponse]:
with self.response_cv:

def key_has_response():
Expand Down
4 changes: 4 additions & 0 deletions tensorrt_llm/executor/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,10 @@ def _handle_response(self,
handler(response.error_msg)

response_result = response.result
if hasattr(response_result, "_result"):
response_result = tllm.deserialize_result(
response_result._result)

self._done = response_result.is_final
context_phase_params = response_result.context_phase_params
self.decoding_iter = response_result.decoding_iter
Expand Down
5 changes: 3 additions & 2 deletions tensorrt_llm/executor/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# pickle is not secure, but but this whole file is a wrapper to make it
# possible to mitigate the primary risk of code injection via pickle.
import pickle # nosec B403
from functools import partial

# These are the base classes that are generally serialized by the ZeroMQ IPC.
# If a class is needed by ZMQ routinely it should be added here. If
Expand Down Expand Up @@ -155,8 +156,8 @@ def find_class(self, module, name):
# dump and dumps are just aliases because the serucity controls are on the deserialization
# side. However they are included here so that in the future if a more secure serialization
# soliton is identified, it can be added with less impact to the rest of the application.
dump = pickle.dump # nosec B301
dumps = pickle.dumps # nosec B301
dump = partial(pickle.dump, protocol=pickle.HIGHEST_PROTOCOL) # nosec B301
dumps = partial(pickle.dumps, protocol=pickle.HIGHEST_PROTOCOL) # nosec B301


def load(file,
Expand Down
7 changes: 1 addition & 6 deletions tensorrt_llm/executor/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from strenum import StrEnum

from tensorrt_llm._utils import mpi_rank
from tensorrt_llm.bindings.executor import Response
from tensorrt_llm.llmapi.utils import print_colored_debug

from ..llmapi.mpi_session import (MpiCommSession, MpiPoolSession, MpiSession,
Expand Down Expand Up @@ -144,8 +143,4 @@ class WorkerCommIpcAddrs(NamedTuple):


def is_llm_response(instance):
from tensorrt_llm._torch.pyexecutor.llm_request import \
LlmResponse as PyLlmResponse

from .result import ResponseWrapper
return isinstance(instance, (Response, PyLlmResponse, ResponseWrapper))
return hasattr(instance, "result")