Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fy/sllm checkpoint #2

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions examples/load_sllm_state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
"""
Saves each worker's model state dict directly to a checkpoint, which enables a
fast load path for large tensor-parallel models where each worker only needs to
read its own shard rather than the entire checkpoint.

Example usage:

python save_sharded_state.py \
--model /path/to/load \
--quantization deepspeedfp \
--tensor-parallel-size 8 \
--output /path/to/save

Then, the model can be loaded with

llm = LLM(
model="/path/to/save",
load_format="sharded_state",
quantization="deepspeedfp",
tensor_parallel_size=8,
)
"""
import argparse
import dataclasses
import os
import shutil
from pathlib import Path

from vllm import LLM, EngineArgs

parser = argparse.ArgumentParser()
EngineArgs.add_cli_args(parser)
parser.add_argument("--output",
"-o",
required=True,
type=str,
help="path to output checkpoint")

if __name__ == "__main__":
args = parser.parse_args()
# main(args)

llm = LLM(
model=args.output,
load_format="serverless_llm",
# load_format="sharded_state",
gpu_memory_utilization=0.9,
distributed_executor_backend="mp",
max_model_len = 512,
tensor_parallel_size=args.tensor_parallel_size,
# num_gpu_blocks_override=128,
)

input_text = "Explain thread and process in python."

print(llm.generate(input_text))
9 changes: 9 additions & 0 deletions examples/save_load_sllm_state.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
CUDA_VISIBLE_DEVICES=0,1 python save_sllm_state.py \
--model /mnt/raid0sata1/huggingface/hub/models--facebook--opt-125m/snapshots/27dcfa74d334bc871f3234de431e71c6eeba5dd6 \
--tensor-parallel-size 4 \
--output /mnt/raid0nvme1/xly/test_data/vllm/opt-125m

CUDA_VISIBLE_DEVICES=0,1 python load_sllm_state.py \
--model /home/fuji/.cache/huggingface/hub/models--facebook--opt-1.3b/snapshots/3f5c25d0bc631cb57ac65913f76e22c2dfb61d62 \
--tensor-parallel-size 2 \
--output /home/fuji/sllm_models/opt-1.3b
92 changes: 92 additions & 0 deletions examples/save_sllm_state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
"""
Saves each worker's model state dict directly to a checkpoint, which enables a
fast load path for large tensor-parallel models where each worker only needs to
read its own shard rather than the entire checkpoint.

Example usage:

python save_sharded_state.py \
--model /path/to/load \
--quantization deepspeedfp \
--tensor-parallel-size 8 \
--output /path/to/save

Then, the model can be loaded with

llm = LLM(
model="/path/to/save",
load_format="sharded_state",
quantization="deepspeedfp",
tensor_parallel_size=8,
)
"""
import argparse
import dataclasses
import os
import shutil
from pathlib import Path

from vllm import LLM, EngineArgs

parser = argparse.ArgumentParser()
EngineArgs.add_cli_args(parser)
parser.add_argument("--output",
"-o",
required=True,
type=str,
help="path to output checkpoint")
parser.add_argument("--file-pattern",
type=str,
help="string pattern of saved filenames")
parser.add_argument("--max-file-size",
type=str,
default=5 * 1024**3,
help="max size (in bytes) of each safetensors file")


def main(args):
engine_args = EngineArgs.from_cli_args(args)
engine_args.distributed_executor_backend = "mp"
engine_args.gpu_memory_utilization = 0.4
engine_args.max_seq_len_to_capture = 512
engine_args.max_model_len = 512
engine_args.max_num_seqs = 1
engine_args.num_gpu_blocks_override = 128
if engine_args.enable_lora:
raise ValueError("Saving with enable_lora=True is not supported!")
model_path = engine_args.model
if not Path(model_path).is_dir():
raise ValueError("model path must be a local directory")
# Create LLM instance from arguments
print(dataclasses.asdict(engine_args))
llm = LLM(**dataclasses.asdict(engine_args))
# Prepare output directory
Path(args.output).mkdir(exist_ok=True)
# Dump worker states to output directory
model_executor = llm.llm_engine.model_executor
model_executor.save_serverless_llm_state(path=args.output,
pattern=args.file_pattern,
max_size=args.max_file_size)
# Copy metadata files to output directory
for file in os.listdir(model_path):
if os.path.splitext(file)[1] not in (".bin", ".pt", ".safetensors"):
if os.path.isdir(os.path.join(model_path, file)):
shutil.copytree(os.path.join(model_path, file),
os.path.join(args.output, file))
else:
shutil.copy(os.path.join(model_path, file), args.output)

from vllm.distributed import get_tensor_model_parallel_rank
if __name__ == "__main__":
args = parser.parse_args()
main(args)

# llm = LLM(
# model=args.output,
# load_format="serverless_llm",
# tensor_parallel_size=2,
# )

# input_text = "Hello, world!"

# print(llm.generate(input_text))
1 change: 1 addition & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,6 +498,7 @@ class LoadFormat(str, enum.Enum):
TENSORIZER = "tensorizer"
SHARDED_STATE = "sharded_state"
BITSANDBYTES = "bitsandbytes"
SERVERLESS_LLM = "serverless_llm"


@dataclass
Expand Down
11 changes: 11 additions & 0 deletions vllm/executor/distributed_gpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,17 @@ def save_sharded_state(
path=path,
pattern=pattern,
max_size=max_size)

def save_serverless_llm_state(
self,
path: str,
pattern: Optional[str] = None,
max_size: Optional[int] = None,
) -> None:
self._run_workers("save_serverless_llm_state",
path=path,
pattern=pattern,
max_size=max_size)

@abstractmethod
def _driver_execute_model(
Expand Down
10 changes: 10 additions & 0 deletions vllm/executor/gpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,16 @@ def check_health(self) -> None:
# GPUExecutor will always be healthy as long as
# it's running.
return

def save_serverless_llm_state(
self,
path: str,
pattern: Optional[str] = None,
max_size: Optional[int] = None,
) -> None:
self.driver_worker.save_serverless_llm_state(
path=path, pattern=pattern, max_size=max_size
)


class GPUExecutorAsync(GPUExecutor, ExecutorAsyncBase):
Expand Down
144 changes: 143 additions & 1 deletion vllm/model_executor/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import torch
from huggingface_hub import HfApi, hf_hub_download
from torch import nn
import gc

from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoadFormat,
LoRAConfig, ModelConfig, ParallelConfig,
Expand Down Expand Up @@ -418,7 +419,6 @@ def save_model(
tensorizer_config=tensorizer_config,
)


class ShardedStateLoader(BaseModelLoader):
"""
Model loader that directly loads each worker's model state dict, which
Expand Down Expand Up @@ -577,6 +577,145 @@ def save_model(
)


class ServerlessLLMLoader(BaseModelLoader):
# DEFAULT_PATTERN = "model-rank-{rank}-part-{part}.safetensors"

def __init__(self, load_config: LoadConfig):
super().__init__(load_config)
extra_config = ({} if load_config.model_loader_extra_config is None
else load_config.model_loader_extra_config.copy())
# self.pattern = extra_config.pop("pattern", self.DEFAULT_PATTERN)
if extra_config:
raise ValueError(f"Unexpected extra config keys for load format "
f"{load_config.load_format}: "
f"{load_config.model_loader_extra_config.keys()}")

@staticmethod
def _filter_subtensors(
tensors: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""
Filter out all tensors that share the same memory or a subset of the
memory of another tensor.
"""
same_storage_groups = collections.defaultdict(list)
for key, tensor in tensors.items():
if tensor.numel():
ptr = tensor.untyped_storage().data_ptr()
same_storage_groups[tensor.device, ptr].append((key, tensor))

def get_end_ptr(tensor: torch.Tensor) -> int:
return tensor.view(-1)[-1].data_ptr() + tensor.element_size()

result = {}
for group in same_storage_groups.values():
for k, t in group:
a, b = t.data_ptr(), get_end_ptr(t)
for k2, t2 in group:
if not t2.is_contiguous():
continue
a2, b2 = t2.data_ptr(), get_end_ptr(t2)
if a < a2 or b2 < b:
continue
if a2 < a or b < b2 or not t.is_contiguous():
break # t2 covers strictly more memory than t.
if k2 < k:
# Same tensors, keep the one with the smaller key.
break
else:
result[k] = t
return result

def load_model(self, *, model_config: ModelConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig],
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
cache_config: CacheConfig) -> nn.Module:
from serverless_llm_store.torch import load_dict
from vllm.distributed import get_tensor_model_parallel_rank

assert os.path.isdir(model_config.model)

rank = get_tensor_model_parallel_rank()

local_model_path = model_config.model
local_model_path = os.path.join(local_model_path, f"rank_{rank}")

def remove_prefix(path, prefix):
# Normalize the paths to ensure consistency across different platforms
path = os.path.normpath(path)
prefix = os.path.normpath(prefix)

# Check if the path starts with the prefix
if path.startswith(prefix):
# Return the path without the prefix
return path[len(prefix):].lstrip(os.sep)

# Return the original path if the prefix doesn't exist
return path

# vLLM needs a local model path to read model config but
# ServerlessLLM Store requires a global model path as the model ID
storage_path = os.getenv("STORAGE_PATH", "./models")
model_path = remove_prefix(local_model_path, storage_path)

with set_default_torch_dtype(model_config.dtype):
# with torch.device(device_config.device):
with torch.device("cpu"):
model = _initialize_model(model_config, self.load_config,
lora_config, vision_language_config,
cache_config)
model = model.eval()
# set all parameters to meta device
state_dict = self._filter_subtensors(model.state_dict())
key_list = list(state_dict.keys())

for key, param in model.named_parameters(recurse=True):
if key in key_list:
param.data = torch.empty(1, device="cuda")
gc.collect()

device_id = torch.cuda.current_device()
device_map = {"": device_id}
# Note: storage path is already included in the local model path
sllm_state_dict = load_dict(model_path, device_map)

for key, param in model.named_parameters(recurse=True):
if key in key_list:
tensor = sllm_state_dict[key]
param.data = tensor
state_dict.pop(key)
if state_dict:
raise ValueError(
f"Missing keys {tuple(state_dict)} in loaded state!")

return model

@staticmethod
def save_model(
model: torch.nn.Module,
path: str,
pattern: Optional[str] = None,
max_size: Optional[int] = None,
) -> None:
from vllm.distributed import get_tensor_model_parallel_rank
from serverless_llm_store.torch import save_dict

rank = get_tensor_model_parallel_rank()
state_dict = ServerlessLLMLoader._filter_subtensors(model.state_dict())

# move all tensors to CPU
for key, tensor in state_dict.items():
state_dict[key] = tensor.cpu().contiguous()

save_path = os.path.join(path, f"rank_{rank}")
if not os.path.exists(save_path):
os.makedirs(save_path)

save_dict(state_dict, save_path)


class BitsAndBytesModelLoader(BaseModelLoader):
"""Model loader to load model weights with BitAndBytes quantization."""

Expand Down Expand Up @@ -826,6 +965,9 @@ def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:

if load_config.load_format == LoadFormat.SHARDED_STATE:
return ShardedStateLoader(load_config)

if load_config.load_format == LoadFormat.SERVERLESS_LLM:
return ServerlessLLMLoader(load_config)

if load_config.load_format == LoadFormat.BITSANDBYTES:
return BitsAndBytesModelLoader(load_config)
Expand Down
14 changes: 14 additions & 0 deletions vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,20 @@ def save_sharded_state(
pattern=pattern,
max_size=max_size,
)

def save_serverless_llm_state(
self,
path: str,
pattern: Optional[str] = None,
max_size: Optional[int] = None,
) -> None:
from vllm.model_executor.model_loader.loader import ServerlessLLMLoader
ServerlessLLMLoader.save_model(
self.model,
path,
pattern=pattern,
max_size=max_size,
)

def save_tensorized_model(
self,
Expand Down
Loading