Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Misc] Separate hf dataset sampling function from benchmark_serving.py #12447

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
133 changes: 8 additions & 125 deletions benchmarks/benchmark_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,7 @@
"""
import argparse
import asyncio
import base64
import gc
import io
import json
import os
import random
Expand All @@ -39,8 +37,7 @@
import numpy as np
from backend_request_func import (ASYNC_REQUEST_FUNCS, RequestFuncInput,
RequestFuncOutput)
from datasets import load_dataset
from PIL.Image import Image
from dataset_sample_func import get_hf_dataset_sampler
from tqdm.asyncio import tqdm
from transformers import PreTrainedTokenizerBase

Expand Down Expand Up @@ -200,50 +197,6 @@ def sample_sonnet_requests(
return sampled_requests


def sample_vision_arena_requests(
dataset,
num_requests: int,
tokenizer: PreTrainedTokenizerBase,
fixed_output_len: Optional[int] = None,
) -> List[Tuple[str, str, int, Optional[Dict[str, Collection[str]]]]]:
sampled_requests: List[Tuple[str, int, int, Dict[str,
Collection[str]]]] = []
for data in dataset:
if len(sampled_requests) == num_requests:
break

prompt = data["turns"][0][0]['content']

prompt_token_ids = tokenizer(prompt).input_ids
if fixed_output_len is None:
# Default max output len is set to 128
print("--hf-output-len is not provided. Using default value 128.")
fixed_output_len = 128

prompt_len = len(prompt_token_ids)
output_len = fixed_output_len

assert isinstance(
data["images"][0],
Image), ("Input image format must be `PIL.Image.Image`, "
f"given {type(data['image'])}.")
image: Image = data["images"][0]
image = image.convert("RGB")
image_data = io.BytesIO()
image.save(image_data, format='JPEG')
image_base64 = base64.b64encode(image_data.getvalue()).decode("utf-8")
mm_content = {
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{image_base64}"
},
}

sampled_requests.append((prompt, prompt_len, output_len, mm_content))

return sampled_requests


def sample_hf_requests(
dataset_path: str,
dataset_subset: Optional[str],
Expand All @@ -252,82 +205,12 @@ def sample_hf_requests(
tokenizer: PreTrainedTokenizerBase,
random_seed: int,
fixed_output_len: Optional[int] = None,
) -> List[Tuple[str, str, int, Optional[Dict[str, Collection[str]]]]]:

# Special case for vision_arena dataset
if dataset_path == 'lmarena-ai/vision-arena-bench-v0.1' \
and dataset_subset is None:
assert dataset_split == "train"
dataset = load_dataset(dataset_path,
name=dataset_subset,
split=dataset_split,
streaming=True)
dataset = dataset.shuffle(seed=random_seed)
return sample_vision_arena_requests(dataset, num_requests, tokenizer,
fixed_output_len)

dataset = load_dataset(dataset_path,
name=dataset_subset,
split=dataset_split,
streaming=True)
assert "conversations" in dataset.features, (
"HF Dataset must have 'conversations' column.")
filter_func = lambda x: len(x["conversations"]) >= 2
filtered_dataset = dataset.shuffle(seed=random_seed).filter(filter_func)
sampled_requests: List[Tuple[str, int, int, Dict[str,
Collection[str]]]] = []
for data in filtered_dataset:
if len(sampled_requests) == num_requests:
break

# Tokenize the prompts and completions.
prompt = data["conversations"][0]["value"]
prompt_token_ids = tokenizer(prompt).input_ids
completion = data["conversations"][1]["value"]
completion_token_ids = tokenizer(completion).input_ids
prompt_len = len(prompt_token_ids)
output_len = len(completion_token_ids
) if fixed_output_len is None else fixed_output_len
if fixed_output_len is None and (prompt_len < 4 or output_len < 4):
# Prune too short sequences.
continue
if fixed_output_len is None and \
(prompt_len > 1024 or prompt_len + output_len > 2048):
# Prune too long sequences.
continue

if "image" in data and isinstance(data["image"], Image):
image: Image = data["image"]
image = image.convert("RGB")
image_data = io.BytesIO()
image.save(image_data, format='JPEG')
image_base64 = base64.b64encode(
image_data.getvalue()).decode("utf-8")
mm_content = {
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{image_base64}"
},
}
elif "image" in data and isinstance(data["image"], str):
if (data["image"].startswith("http://") or \
data["image"].startswith("file://")):
image_url = data["image"]
else:
image_url = f"file://{data['image']}"

mm_content = {
"type": "image_url",
"image_url": {
"url": image_url
},
}
else:
mm_content = None

sampled_requests.append((prompt, prompt_len, output_len, mm_content))

return sampled_requests
) -> list[Tuple[str, str, int, Optional[Dict[str, Collection[str]]]]]:
hf_dataset_sampler = get_hf_dataset_sampler(dataset_path,
dataset_subset,
dataset_split,
seed=random_seed)
return hf_dataset_sampler.sample(num_requests, tokenizer, fixed_output_len)


def sample_random_requests(
Expand Down Expand Up @@ -1214,7 +1097,7 @@ def main(args: argparse.Namespace):
hf_group.add_argument(
"--hf-output-len",
type=int,
default=None,
default=128,
help="Output length for each request. Overrides the output lengths "
"from the sampled HF dataset.",
)
Expand Down
178 changes: 178 additions & 0 deletions benchmarks/dataset_sample_func.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
import base64
import io
from abc import ABC, abstractmethod
from typing import Collection, Optional

from datasets import IterableDataset, load_dataset, load_dataset_builder
from PIL import Image
from transformers import PreTrainedTokenizerBase


def pil_image_to_mm_content(image: Image.Image):
image = image.convert("RGB")
image_data = io.BytesIO()
image.save(image_data, format='JPEG')
image_base64 = base64.b64encode(image_data.getvalue()).decode("utf-8")
mm_content = {
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{image_base64}"
},
}
return mm_content


def image_url_to_mm_content(image_url: str):
if (image_url.startswith("http://") or \
image_url.startswith("file://")):
image_url = image_url
else:
image_url = f"file://{image_url}"

mm_content = {
"type": "image_url",
"image_url": {
"url": image_url
},
}
return mm_content


class HFDatasetSampler(ABC):

def __init__(self, dataset: IterableDataset, seed: Optional[int] = None):
self.dataset = dataset.shuffle(seed=seed).filter(self.filter_func)

@abstractmethod
def filter_func(self, data: dict) -> bool:
"""Filter function to filter out unsatisfied rows from dataset."""
raise NotImplementedError

@abstractmethod
def sample(
self,
num_samples: int,
tokenizer: PreTrainedTokenizerBase,
fixed_output_len: Optional[int] = None
) -> list[tuple[str, int, int, dict[str, Collection[str]]]]:
"""Function to sample requests from the dataset."""
raise NotImplementedError


class ShareGPTSampler(HFDatasetSampler):
"""
Dataset sampler for ShareGPT-style datasets.
- Text-only dataset like: 'RyokoAI/ShareGPT52K' etc.
- Vision dataset like: 'lmms-lab/LLaVA-OneVision-Data' etc.
"""

def __init__(self, dataset: IterableDataset, seed: Optional[int] = None):
assert "conversations" in dataset.features, (
"Sonnet-style Dataset must have 'conversations' column.")
super().__init__(dataset, seed=seed)

def filter_func(self, data: dict) -> bool:
return len(data["conversations"]) >= 2

def _get_mm_content(self,
data: dict) -> Optional[dict[str, Collection[str]]]:
if "image" in data and isinstance(data["image"], Image.Image):
return pil_image_to_mm_content(data["image"])
elif "image" in data and isinstance(data["image"], str):
return image_url_to_mm_content(data["image"])
return None

def sample(
self,
num_samples: int,
tokenizer: PreTrainedTokenizerBase,
fixed_output_len: Optional[int] = 128,
) -> list[tuple[str, int, int, dict[str, Collection[str]]]]:
sampled_requests: list[tuple[str, int, int,
dict[str, Collection[str]]]] = []
for data in self.dataset:
if len(sampled_requests) == num_samples:
break

# Tokenize the prompts and completions.
prompt = data["conversations"][0]["value"]
prompt_token_ids = tokenizer(prompt).input_ids
completion = data["conversations"][1]["value"]
completion_token_ids = tokenizer(completion).input_ids
prompt_len = len(prompt_token_ids)
output_len = len(
completion_token_ids
) if fixed_output_len is None else fixed_output_len
if fixed_output_len is None and (prompt_len < 4 or output_len < 4):
# Prune too short sequences.
continue
if fixed_output_len is None and \
(prompt_len > 1024 or prompt_len + output_len > 2048):
# Prune too long sequences.
continue

mm_content = self._get_mm_content(data)

sampled_requests.append(
(prompt, prompt_len, output_len, mm_content))
return sampled_requests


class VisionArenaBenchSampler(HFDatasetSampler):
"""Dataset sampler for 'lmarena-ai/vision-arena-bench-v0.1' dataset."""

def filter_func(self, data: dict) -> bool:
# vision-arena-bench always has an image and one turn conversation.
return True

def sample(
self,
num_samples: int,
tokenizer: PreTrainedTokenizerBase,
fixed_output_len: Optional[int] = 128,
):
sampled_requests: list[tuple[str, int, int,
dict[str, Collection[str]]]] = []
for data in self.dataset:
if len(sampled_requests) == num_samples:
break

prompt = data["turns"][0][0]['content']
prompt_token_ids = tokenizer(prompt).input_ids

prompt_len = len(prompt_token_ids)
output_len = fixed_output_len

# lmarena-ai/vision-arena-bench-v0.1 always has an image.
mm_content = pil_image_to_mm_content(data["images"][0])

sampled_requests.append(
(prompt, prompt_len, output_len, mm_content))

return sampled_requests


DATASET_SAMPLE_FUNC: dict[str, HFDatasetSampler] = {
"lmarena-ai/vision-arena-bench-v0.1": VisionArenaBenchSampler,
}


def get_hf_dataset_sampler(
dataset_path: str,
dataset_subset: Optional[str],
dataset_split: str,
seed: Optional[int] = None,
) -> HFDatasetSampler:
ds_builder = load_dataset_builder(dataset_path, name=dataset_subset)
ds_info = ds_builder.info
assert dataset_split in ds_info.splits, (
f"Split '{dataset_split}' not found in dataset '{dataset_path}'")
dataset = load_dataset(dataset_path,
name=dataset_subset,
split=dataset_split,
streaming=True)

if dataset_path in DATASET_SAMPLE_FUNC:
return DATASET_SAMPLE_FUNC[dataset_path](dataset, seed=seed)
else:
return ShareGPTSampler(dataset, seed=seed)
Loading