Skip to content

Commit

Permalink
feat(oryx): add spatial pooling parameters and multi-round generation…
Browse files Browse the repository at this point in the history
… placeholder
  • Loading branch information
pufanyi committed Dec 26, 2024
1 parent 7f9c482 commit f072079
Showing 1 changed file with 9 additions and 1 deletion.
10 changes: 9 additions & 1 deletion lmms_eval/models/oryx.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
tokenizer_image_token,
)
from oryx.model.builder import load_pretrained_model
from oryx.model.language_model.oryx_llama import OryxConfig
except ImportError:
eval_logger.debug("Oryx is not installed. Please install Oryx to use this model.")

Expand Down Expand Up @@ -67,6 +66,9 @@ def __init__(
truncate_context=False,
max_frames_num: int = 32,
mm_resampler_type: str = "spatial_pool",
mm_spatial_pool_stride: int = 2,
mm_spatial_pool_out_channels: int = 1024,
mm_spatial_pool_mode: str = "average",
overwrite: bool = True,
video_decode_backend: str = "decord",
**kwargs,
Expand Down Expand Up @@ -98,6 +100,9 @@ def __init__(
overwrite_config["mm_resampler_type"] = self.mm_resampler_type
overwrite_config["patchify_video_feature"] = False
overwrite_config["attn_implementation"] = attn_implementation
overwrite_config["mm_spatial_pool_stride"] = mm_spatial_pool_stride
overwrite_config["mm_spatial_pool_out_channels"] = mm_spatial_pool_out_channels
overwrite_config["mm_spatial_pool_mode"] = mm_spatial_pool_mode

cfg_pretrained = AutoConfig.from_pretrained(self.pretrained)

Expand Down Expand Up @@ -473,3 +478,6 @@ def generate_until(self, requests) -> List[str]:
pbar.update(1)
continue
return res

def generate_until_multi_round(self, requests) -> List[str]:
raise NotImplementedError("TODO: Implement multi-round generation")

0 comments on commit f072079

Please sign in to comment.