From f07207935061edf6b877e1bbe1b56b8ee38822ba Mon Sep 17 00:00:00 2001 From: Pu Fanyi Date: Thu, 26 Dec 2024 16:44:54 +0800 Subject: [PATCH] feat(oryx): add spatial pooling parameters and multi-round generation placeholder --- lmms_eval/models/oryx.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/lmms_eval/models/oryx.py b/lmms_eval/models/oryx.py index 28923c57..5d4dfcf2 100644 --- a/lmms_eval/models/oryx.py +++ b/lmms_eval/models/oryx.py @@ -38,7 +38,6 @@ tokenizer_image_token, ) from oryx.model.builder import load_pretrained_model - from oryx.model.language_model.oryx_llama import OryxConfig except ImportError: eval_logger.debug("Oryx is not installed. Please install Oryx to use this model.") @@ -67,6 +66,9 @@ def __init__( truncate_context=False, max_frames_num: int = 32, mm_resampler_type: str = "spatial_pool", + mm_spatial_pool_stride: int = 2, + mm_spatial_pool_out_channels: int = 1024, + mm_spatial_pool_mode: str = "average", overwrite: bool = True, video_decode_backend: str = "decord", **kwargs, @@ -98,6 +100,9 @@ def __init__( overwrite_config["mm_resampler_type"] = self.mm_resampler_type overwrite_config["patchify_video_feature"] = False overwrite_config["attn_implementation"] = attn_implementation + overwrite_config["mm_spatial_pool_stride"] = mm_spatial_pool_stride + overwrite_config["mm_spatial_pool_out_channels"] = mm_spatial_pool_out_channels + overwrite_config["mm_spatial_pool_mode"] = mm_spatial_pool_mode cfg_pretrained = AutoConfig.from_pretrained(self.pretrained) @@ -473,3 +478,6 @@ def generate_until(self, requests) -> List[str]: pbar.update(1) continue return res + + def generate_until_multi_round(self, requests) -> List[str]: + raise NotImplementedError("TODO: Implement multi-round generation")