Skip to content

Commit

Permalink
Fix qwen vl eval (#2892)
Browse files Browse the repository at this point in the history
  • Loading branch information
Jintao-Huang authored Jan 9, 2025
1 parent a0d0351 commit 005674f
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 41 deletions.
23 changes: 18 additions & 5 deletions swift/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
LazyLLMDataset, ConstantLengthDataset, standard_keys, load_dataset, DATASET_TYPE,
sample_dataset, RowPreprocessor, DatasetMeta)
from .utils import (deep_getattr, to_device, History, Messages, history_to_messages, messages_to_history, Processor,
save_checkpoint, ProcessorMixin)
save_checkpoint, ProcessorMixin, get_temporary_cache_files_directory)
from .base import SwiftPipeline
else:
_extra_objects = {k: v for k, v in globals().items() if not k.startswith('_')}
Expand Down Expand Up @@ -57,13 +57,26 @@
'load_by_unsloth', 'git_clone_github', 'get_matched_model_meta'
],
'dataset': [
'AlpacaPreprocessor', 'MessagesPreprocessor', 'DATASET_MAPPING', 'MediaResource', 'register_dataset',
'register_dataset_info', 'EncodePreprocessor', 'LazyLLMDataset', 'ConstantLengthDataset', 'standard_keys',
'load_dataset', 'DATASET_TYPE', 'sample_dataset', 'RowPreprocessor', 'ResponsePreprocessor', 'DatasetMeta'
'AlpacaPreprocessor',
'MessagesPreprocessor',
'DATASET_MAPPING',
'MediaResource',
'register_dataset',
'register_dataset_info',
'EncodePreprocessor',
'LazyLLMDataset',
'ConstantLengthDataset',
'standard_keys',
'load_dataset',
'DATASET_TYPE',
'sample_dataset',
'RowPreprocessor',
'ResponsePreprocessor',
'DatasetMeta',
],
'utils': [
'deep_getattr', 'to_device', 'History', 'Messages', 'history_to_messages', 'messages_to_history',
'Processor', 'save_checkpoint', 'ProcessorMixin'
'Processor', 'save_checkpoint', 'ProcessorMixin', 'get_temporary_cache_files_directory'
],
'base': ['SwiftPipeline'],
}
Expand Down
22 changes: 3 additions & 19 deletions swift/llm/dataset/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,9 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

import os
import tempfile

import datasets.config
import datasets.fingerprint
from datasets import disable_caching
from modelscope.hub.utils.utils import get_cache_dir

from swift.utils.torch_utils import _find_local_mac
from ..utils import get_temporary_cache_files_directory
from . import dataset
from .loader import DATASET_TYPE, load_dataset
from .media import MediaResource
Expand All @@ -32,20 +27,9 @@ def _update_fingerprint_mac(*args, **kwargs):
return fp


def _new_get_temporary_cache_files_directory(*args, **kwargs):
global DATASET_TEMP_DIR
if DATASET_TEMP_DIR is None:
tmp_dir = os.path.join(get_cache_dir(), 'tmp')
os.makedirs(tmp_dir, exist_ok=True)
DATASET_TEMP_DIR = tempfile.TemporaryDirectory(prefix=datasets.config.TEMP_CACHE_DIR_PREFIX, dir=tmp_dir)

return DATASET_TEMP_DIR.name


datasets.fingerprint.update_fingerprint = _update_fingerprint_mac
datasets.arrow_dataset.update_fingerprint = _update_fingerprint_mac
datasets.fingerprint.get_temporary_cache_files_directory = _new_get_temporary_cache_files_directory
datasets.arrow_dataset.get_temporary_cache_files_directory = _new_get_temporary_cache_files_directory
DATASET_TEMP_DIR = None
datasets.fingerprint.get_temporary_cache_files_directory = get_temporary_cache_files_directory
datasets.arrow_dataset.get_temporary_cache_files_directory = get_temporary_cache_files_directory
register_dataset_info()
disable_caching()
38 changes: 21 additions & 17 deletions swift/llm/template/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from modelscope.hub.utils.utils import get_cache_dir
from peft import PeftModel
from PIL import Image
from torch.nn.utils.rnn import pad_sequence
Expand Down Expand Up @@ -116,22 +117,21 @@ def __init__(
self._deepspeed_initialize = None

@staticmethod
def _load_images(images, load_images: bool) -> None:
for i, image in enumerate(images):
if load_images:
if isinstance(image, dict) and 'bytes' in image:
image = image['bytes'] or image['path']
def _load_image(image, load_images: bool):
if load_images:
if isinstance(image, dict) and 'bytes' in image:
image = image['bytes'] or image['path']
image = load_image(image)
else:
if isinstance(image, dict):
path = image['path']
if path and (path.startswith('http') or os.path.exists(path)):
image = path
else:
image = load_image(image['bytes'])
elif not isinstance(image, str):
image = load_image(image)
else:
if isinstance(image, dict):
path = image['path']
if path and (path.startswith('http') or os.path.exists(path)):
image = path
else:
image = load_image(image['bytes'])
elif not isinstance(image, str):
image = load_image(image)
images[i] = image
return image

def _preprocess_inputs(
self,
Expand All @@ -143,7 +143,8 @@ def _preprocess_inputs(
if self.max_pixels is not None or inputs.objects:
load_images = True
if images:
self._load_images(images, load_images)
for i, image in enumerate(images):
images[i] = self._load_image(images[i], load_images)
if self.max_pixels is not None:
assert self.grounding_type != 'real', 'not support' # TODO:check
images = [rescale_image(img, self.max_pixels) for img in images]
Expand Down Expand Up @@ -298,7 +299,10 @@ def prepare_generate_kwargs(self, generate_kwargs: Dict[str, Any], *, model=None
def _save_pil_image(image: Image.Image) -> str:
img_bytes = image.tobytes()
img_hash = hashlib.sha256(img_bytes).hexdigest()
img_path = os.path.join('tmp', f'{img_hash}.png')
tmp_dir = os.path.join(get_cache_dir(), 'tmp', 'images')
logger.info_once(f'create tmp_dir: {tmp_dir}')
os.makedirs(tmp_dir, exist_ok=True)
img_path = os.path.join(tmp_dir, f'{img_hash}.png')
if not os.path.exists(img_path):
image.save(img_path)
return img_path
Expand Down
6 changes: 6 additions & 0 deletions swift/llm/template/template/qwen.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,12 @@ class QwqTemplateMeta(QwenTemplateMeta):
class QwenVLTemplate(Template):
load_images = False

@staticmethod
def _load_image(image, load_images: bool):
if not load_images and isinstance(image, str) and (image.startswith('data:') or len(image) > 200):
load_images = True
return Template._load_image(image, load_images)

def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
inputs: StdTemplateInputs) -> List[Context]:
assert media_type == 'image'
Expand Down
22 changes: 22 additions & 0 deletions swift/llm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@
import inspect
import os
import shutil
import tempfile
from types import MethodType
from typing import Any, Dict, List, Mapping, Optional, Tuple, Union

import torch
import torch.nn as nn
from modelscope.hub.utils.utils import get_cache_dir
from transformers import FeatureExtractionMixin, GenerationConfig, PreTrainedModel, PreTrainedTokenizerBase
from transformers import ProcessorMixin as HfProcessorMixin

Expand Down Expand Up @@ -223,3 +225,23 @@ def save_checkpoint(model: Optional[PreTrainedModel],
elif os.path.isdir(src_path):
shutil.copytree(src_path, tgt_path)
break


TEMP_DIR_POOL = {}


def get_temporary_cache_files_directory(prefix=None):
if prefix is None:
import datasets.config
prefix = datasets.config.TEMP_CACHE_DIR_PREFIX
global TEMP_DIR_POOL
if prefix in TEMP_DIR_POOL:
TEMP_DIR = TEMP_DIR_POOL[prefix]
else:
tmp_dir = os.path.join(get_cache_dir(), 'tmp')
os.makedirs(tmp_dir, exist_ok=True)
TEMP_DIR = tempfile.TemporaryDirectory(prefix=prefix, dir=tmp_dir)
logger.info(f'create tmp_dir: {TEMP_DIR.name}')
TEMP_DIR_POOL[prefix] = TEMP_DIR

return TEMP_DIR.name

0 comments on commit 005674f

Please sign in to comment.