Skip to content

[WIP][V1/0][P/D] XpYd based on p2p communication without cache store #15806

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 74 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
74 commits
Select commit Hold shift + click to select a range
0a60364
runnable
Abatom Mar 31, 2025
016d004
format
Abatom Mar 31, 2025
448bed9
format
Abatom Mar 31, 2025
825fe06
format
Abatom Mar 31, 2025
477fe2b
format
Abatom Mar 31, 2025
dd6dcf9
pass
Abatom Mar 31, 2025
7eb1575
format
Abatom Mar 31, 2025
a0d37bb
format
Abatom Mar 31, 2025
da335ea
move some args to kv_connector_extra_config
Abatom Apr 1, 2025
178ca2f
format
Abatom Apr 1, 2025
f03ac47
remove some code comments
Abatom Apr 6, 2025
603a355
Replace pickle with msgpack
Abatom Apr 6, 2025
2acb321
fix bug
Abatom Apr 7, 2025
b957dd7
Out Of Memory
Abatom Apr 7, 2025
f6407b3
_send_sync
Abatom Apr 7, 2025
5c165d9
add p2p_nccl_connector.py based on V1
Abatom Apr 9, 2025
1e9fab6
add shape and size log for recv_tensor
Abatom Apr 9, 2025
33601e2
fix hang & oom
Abatom Apr 11, 2025
5a888fc
format
Abatom Apr 11, 2025
d3e7194
Merge branch 'main' into xpyd
Abatom Apr 11, 2025
bddc4e1
ping thread
Abatom Apr 12, 2025
e516a70
add code comments.
Abatom Apr 12, 2025
8dadfb4
GET
Abatom Apr 14, 2025
659ead7
format
Abatom Apr 14, 2025
fd596dc
send_queue
Abatom Apr 14, 2025
fe81aae
modify log
Abatom Apr 15, 2025
bb5d23f
Merge branch 'main' into xpyd
Abatom Apr 15, 2025
21818fe
fix bug for PUT_ASYNC
Abatom Apr 15, 2025
dcb637b
modify log
Abatom Apr 15, 2025
5cbb299
format
Abatom Apr 15, 2025
976f51b
format
Abatom Apr 15, 2025
a2470ac
fix bug
Abatom Apr 15, 2025
b0facb9
format
Abatom Apr 15, 2025
20f2e7a
fix bug
Abatom Apr 16, 2025
8fd3eca
Merge branch 'main' into xpyd
Abatom Apr 17, 2025
f1a5183
fix bug
Abatom Apr 17, 2025
1c74857
format
Abatom Apr 17, 2025
41b0ae6
rm popitem
Abatom Apr 17, 2025
ec364e5
rm disagg_prefill_xpyd.sh
Abatom Apr 17, 2025
06281db
Merge branch 'main' into xpyd
Abatom Apr 18, 2025
ed2fbb6
V1
Abatom Apr 18, 2025
616ce48
bugfix and format
Abatom Apr 18, 2025
49336a2
bugfix
Abatom Apr 18, 2025
3d8d7b6
bugfix
Abatom Apr 18, 2025
0b9a2ac
format
Abatom Apr 18, 2025
e3f858f
add rank and local_rank
Abatom Apr 19, 2025
6e088e6
Merge branch 'main' into xpyd
Abatom Apr 22, 2025
e8b8f36
bugfix
Abatom Apr 22, 2025
8623e3c
runnable for V1
Abatom Apr 22, 2025
d11388a
format
Abatom Apr 22, 2025
17e9905
rm valid_num_tokens
Abatom Apr 23, 2025
e13094b
wait_for_save
Abatom Apr 23, 2025
d96ecc3
inject_kv_into_layer
Abatom Apr 23, 2025
eaaf50c
get_num_new_matched_tokens
Abatom Apr 24, 2025
6a2af6c
make_meta
Abatom Apr 24, 2025
7d0f562
format
Abatom Apr 24, 2025
ca9724a
add send_stream and recv_stream
Abatom Apr 27, 2025
738c14c
Merge branch 'main' into xpyd
Abatom Apr 27, 2025
8d41359
Each NCCL connects to a stream.
Abatom Apr 28, 2025
1ad9579
bugfix for GET and revert Each NCCL connects to a stream.
Abatom Apr 28, 2025
13fa8b6
add mem pool
Abatom Apr 30, 2025
d715d6b
improve mempool
Abatom May 6, 2025
3259540
bugfix
Abatom May 7, 2025
f525001
torch.cuda.Event
Abatom May 7, 2025
1c98a49
Merge branch 'main' into xpyd
Abatom May 8, 2025
ac810f3
Merge branch 'xpyd' into xpyd-mempool
Abatom May 8, 2025
28ef7cc
load_stream.synchronize and store_stream.synchronize
Abatom May 10, 2025
fab1d33
stream.synchronize
Abatom May 10, 2025
2400d0b
build_connector_meta
Abatom May 11, 2025
6eab9df
bugfix
Abatom May 12, 2025
acfa2ac
bugfix
Abatom May 12, 2025
bb596f7
proxy add round robin
Abatom May 12, 2025
7b710f4
mem_pool_size
Abatom May 14, 2025
942e1d5
add tensor_mem_pool
Abatom May 14, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -308,3 +308,5 @@ std::tuple<int64_t, torch::Tensor> allocate_shared_buffer_and_handle(
int64_t size);
int64_t open_mem_handle(torch::Tensor& mem_handle);
void free_shared_buffer(int64_t buffer);
void store_tensor(torch::Tensor device_tensor, torch::Tensor host_tensor);
void load_tensor(torch::Tensor host_tensor, torch::Tensor device_tensor);
100 changes: 100 additions & 0 deletions csrc/tensor_store_load_mem.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>

// Template-based CUDA kernel: Copy from device memory to pinned host memory
template <typename scalar_t>
__global__ void store_kernel(const scalar_t* device_ptr, scalar_t* host_ptr, size_t num_elements) {
const size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < num_elements) {
host_ptr[idx] = device_ptr[idx];
}
}

// Templated CUDA kernel: Copy from pinned host memory to device memory
template <typename scalar_t>
__global__ void load_kernel(const scalar_t* host_ptr, scalar_t* device_ptr, size_t num_elements) {
const size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < num_elements) {
device_ptr[idx] = host_ptr[idx];
}
}

// Templated wrapper function: Store Tensor to pinned memory
template <typename scalar_t>
void store_tensor_impl(torch::Tensor device_tensor, torch::Tensor host_tensor) {
const auto num_elements = device_tensor.numel();
const int threads = 256;
const int blocks = (num_elements + threads - 1) / threads;

auto device_ptr = device_tensor.data_ptr<scalar_t>();
auto host_ptr = host_tensor.data_ptr<scalar_t>();

store_kernel<scalar_t><<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
device_ptr, host_ptr, num_elements);
}

// Templated wrapper function: Load Tensor from pinned memory
template <typename scalar_t>
void load_tensor_impl(torch::Tensor host_tensor, torch::Tensor device_tensor) {
const auto num_elements = host_tensor.numel();
const int threads = 256;
const int blocks = (num_elements + threads - 1) / threads;

auto host_ptr = host_tensor.data_ptr<scalar_t>();
auto device_ptr = device_tensor.data_ptr<scalar_t>();

load_kernel<scalar_t><<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
host_ptr, device_ptr, num_elements);
}

// Type-dispatched wrapper function
void store_tensor(torch::Tensor device_tensor, torch::Tensor host_tensor) {
// Validate arguments
AT_ASSERT(device_tensor.is_cuda(), "Input tensor must be a CUDA tensor");
AT_ASSERT(host_tensor.is_pinned(), "Output tensor must be pinned memory");
AT_ASSERT(device_tensor.numel() == host_tensor.numel(), "Tensors must have same number of elements");
AT_ASSERT(device_tensor.dtype() == host_tensor.dtype(), "Tensors must have same dtype");

// Type-based dispatch to different implementations
switch (device_tensor.scalar_type()) {
case torch::kFloat:
store_tensor_impl<float>(device_tensor, host_tensor);
break;
case torch::kHalf:
store_tensor_impl<at::Half>(device_tensor, host_tensor);
break;
case torch::kBFloat16:
store_tensor_impl<at::BFloat16>(device_tensor, host_tensor);
break;
default:
AT_ERROR("Unsupported data type: ", device_tensor.scalar_type());
}
}

void load_tensor(torch::Tensor host_tensor, torch::Tensor device_tensor) {
// Validate arguments
AT_ASSERT(device_tensor.is_cuda(), "Output tensor must be a CUDA tensor");
AT_ASSERT(host_tensor.is_pinned(), "Input tensor must be pinned memory");
AT_ASSERT(device_tensor.numel() == host_tensor.numel(), "Tensors must have same number of elements");
AT_ASSERT(device_tensor.dtype() == host_tensor.dtype(), "Tensors must have same dtype");

// Type-based dispatch to different implementations
switch (host_tensor.scalar_type()) {
case torch::kFloat:
load_tensor_impl<float>(host_tensor, device_tensor);
break;
case torch::kHalf:
load_tensor_impl<at::Half>(host_tensor, device_tensor);
break;
case torch::kBFloat16:
load_tensor_impl<at::BFloat16>(host_tensor, device_tensor);
break;
default:
AT_ERROR("Unsupported data type: ", host_tensor.scalar_type());
}
}

// PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
// m.def("store_tensor", &store_tensor, "Store CUDA tensor to pinned memory (supports float32, float16, bfloat16)");
// m.def("load_tensor", &load_tensor, "Load CUDA tensor from pinned memory (supports float32, float16, bfloat16)");
// }
5 changes: 5 additions & 0 deletions csrc/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -665,4 +665,9 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) {
custom_ar.def("free_shared_buffer", &free_shared_buffer);
}

TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _mem_pool), mem_pool) {
mem_pool.def("store_tensor", &store_tensor);
mem_pool.def("load_tensor", &load_tensor);
}

REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
157 changes: 157 additions & 0 deletions examples/online_serving/disagg_xpyd/disagg_prefill_proxy_xpyd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
# SPDX-License-Identifier: Apache-2.0

import os
import random
import socket
import threading
import uuid

import aiohttp
import msgpack
import zmq
from quart import Quart, make_response, request

count = 0
prefill_instances: dict[str, str] = {} # http_address: zmq_address
decode_instances: dict[str, str] = {} # http_address: zmq_address

prefill_cv = threading.Condition()
decode_cv = threading.Condition()


def _listen_for_register(poller, router_socket):
while True:
socks = dict(poller.poll())
if router_socket in socks:
remote_address, message = router_socket.recv_multipart()
# data: {"type": "P", "http_address": "ip:port",
# "zmq_address": "ip:port"}
data = msgpack.loads(message)
# print("Received message from %s, data: %s",
# remote_address.decode(), data)
if data["type"] == "P":
global prefill_instances
global prefill_cv
with prefill_cv:
prefill_instances[
data["http_address"]] = data["zmq_address"]
elif data["type"] == "D":
global decode_instances
global decode_cv
with decode_cv:
decode_instances[
data["http_address"]] = data["zmq_address"]
else:
print("Unexpected, Received message from %s, data: %s",
remote_address, data)


def start_service_discovery(hostname, port):
if not hostname:
hostname = socket.gethostname()
if port == 0:
raise ValueError("Port cannot be 0")

context = zmq.Context()
router_socket = context.socket(zmq.ROUTER)
router_socket.bind(f"tcp://{hostname}:{port}")

poller = zmq.Poller()
poller.register(router_socket, zmq.POLLIN)

_listener_thread = threading.Thread(target=_listen_for_register,
args=[poller, router_socket],
daemon=True)
_listener_thread.start()
return _listener_thread


AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)

app = Quart(__name__)


def random_uuid() -> str:
return str(uuid.uuid4().hex)


async def forward_request(url, data, request_id):
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
headers = {
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
"X-Request-Id": request_id
}
async with session.post(url=url, json=data,
headers=headers) as response:
if response.status == 200:
# if response.headers.get('Transfer-Encoding') == 'chunked':
if True:
async for chunk_bytes in response.content.iter_chunked(
1024):
yield chunk_bytes
else:
content = await response.read()
yield content


@app.route('/v1/completions', methods=['POST'])
async def handle_request():
try:
original_request_data = await request.get_json()

prefill_request = original_request_data.copy()
# change max_tokens = 1 to let it only do prefill
prefill_request['max_tokens'] = 1

global count
global prefill_instances
global prefill_cv
with prefill_cv:
# prefill_addr, prefill_zmq_addr = random.choice(
# list(prefill_instances.items()))
prefill_list = list(prefill_instances.items())
prefill_addr, prefill_zmq_addr = prefill_list[count % len(prefill_list)]

global decode_instances
global decode_cv
with decode_cv:
# decode_addr, decode_zmq_addr = random.choice(
# list(decode_instances.items()))
decode_list = list(decode_instances.items())
decode_addr, decode_zmq_addr = decode_list[count % len(decode_list)]

print(f"handle_request count: {count}, [HTTP:{prefill_addr}, "
f"ZMQ:{prefill_zmq_addr}] 👉 [HTTP:{decode_addr}, "
f"ZMQ:{decode_zmq_addr}]")
count += 1

request_id = (
f"___prefill_addr_{prefill_zmq_addr}___decode_addr_{decode_zmq_addr}_{random_uuid()}"
)

# finish prefill
async for _ in forward_request(f'http://{prefill_addr}/v1/completions',
prefill_request, request_id):
continue

# return decode
generator = forward_request(f'http://{decode_addr}/v1/completions',
original_request_data, request_id)
response = await make_response(generator)
response.timeout = None

return response

except Exception as e:
import sys
import traceback
exc_info = sys.exc_info()
print("Error occurred in disagg prefill proxy server")
print(e)
print("".join(traceback.format_exception(*exc_info)))


if __name__ == '__main__':
t = start_service_discovery("0.0.0.0", 30001)
app.run(host='0.0.0.0', port=10001)
t.join()
8 changes: 8 additions & 0 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1493,6 +1493,14 @@ def free_shared_buffer(ptr: int) -> None:
torch.ops._C_custom_ar.free_shared_buffer(ptr)


def store_tensor(device_tensor: torch.Tensor, host_tensor: torch.Tensor):
torch.ops._C_mem_pool.store_tensor(device_tensor, host_tensor)


def load_tensor(host_tensor: torch.Tensor, device_tensor: torch.Tensor):
torch.ops._C_mem_pool.load_tensor(host_tensor, device_tensor)


def get_flash_mla_metadata(
cache_seqlens: torch.Tensor,
num_heads_per_head_k: int,
Expand Down
21 changes: 21 additions & 0 deletions vllm/distributed/device_communicators/pynccl_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,27 @@ def ncclGetUniqueId(self) -> ncclUniqueId:
ctypes.byref(unique_id)))
return unique_id

def unique_id_from_bytes(self, data: bytes) -> ncclUniqueId:
"""
Reconstructs an `ncclUniqueId` object from bytes data.

Args:
data: Must be a 128-byte data block (matching NCCL's unique_id).

Returns:
ncclUniqueId: The reconstructed NCCL Unique ID object.

Raises:
ValueError: If the input data length is not 128 bytes.
"""
if len(data) != 128:
raise ValueError(
f"Expected 128 bytes for ncclUniqueId, got {len(data)} bytes")

unique_id = ncclUniqueId()
ctypes.memmove(ctypes.addressof(unique_id.internal), data, 128)
return unique_id

def ncclCommInitRank(self, world_size: int, unique_id: ncclUniqueId,
rank: int) -> ncclComm_t:
comm = ncclComm_t()
Expand Down
22 changes: 16 additions & 6 deletions vllm/distributed/kv_transfer/kv_connector/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,11 @@ def create_connector_v0(cls, rank: int, local_rank: int,
return connector_cls(rank, local_rank, config)

@classmethod
def create_connector_v1(
cls,
config: "VllmConfig",
role: KVConnectorRole,
) -> KVConnectorBase_V1:
def create_connector_v1(cls,
config: "VllmConfig",
role: KVConnectorRole,
rank: int = 0,
local_rank: int = 0) -> KVConnectorBase_V1:
if not envs.VLLM_USE_V1:
raise ValueError("Attempting to initialize a V1 Connector, "
f"but found {envs.VLLM_USE_V1=}")
Expand All @@ -70,12 +70,13 @@ def create_connector_v1(
# - Co-locate with worker process
# - Should only be used inside the forward context & attention layer
# We build separately to enforce strict separation
return connector_cls(config, role)
return connector_cls(config, role, rank, local_rank)


# Register various connectors here.
# The registration should not be done in each individual file, as we want to
# only load the files corresponding to the current connector.

KVConnectorFactory.register_connector(
"PyNcclConnector",
"vllm.distributed.kv_transfer.kv_connector.simple_connector",
Expand All @@ -96,11 +97,20 @@ def create_connector_v1(
"vllm.distributed.kv_transfer.kv_connector.mooncake_store_connector",
"MooncakeStoreConnector")

KVConnectorFactory.register_connector(
"P2pConnector", "vllm.distributed.kv_transfer.kv_connector.p2p_connector",
"P2pConnector")

KVConnectorFactory.register_connector(
"SharedStorageConnector",
"vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector",
"SharedStorageConnector")

KVConnectorFactory.register_connector(
"P2pNcclConnector",
"vllm.distributed.kv_transfer.kv_connector.v1.p2p_nccl_connector",
"P2pNcclConnector")

KVConnectorFactory.register_connector(
"LMCacheConnectorV1",
"vllm.distributed.kv_transfer.kv_connector.v1.lmcache_connector",
Expand Down
Loading