diff --git a/llama_cpp/llama_chat_format.py b/llama_cpp/llama_chat_format.py index dfb0af65e..6479d7fdc 100644 --- a/llama_cpp/llama_chat_format.py +++ b/llama_cpp/llama_chat_format.py @@ -2707,7 +2707,7 @@ def last_image_embed_free(): def load_image(self, image_url: str) -> bytes: return self._load_image(image_url) - def _embed_image_bytes(self, image_bytes: bytes, n_threads_batch: int = 1): + def _embed_image_bytes(self, image_bytes: bytes, n_threads: int = 1): if ( self._last_image_embed is not None and self._last_image_hash is not None @@ -2722,7 +2722,7 @@ def _embed_image_bytes(self, image_bytes: bytes, n_threads_batch: int = 1): self._last_image_hash = None embed = self._llava_cpp.llava_image_embed_make_with_bytes( self.clip_ctx, - n_threads_batch, + n_threads, (ctypes.c_uint8 * len(image_bytes)).from_buffer( bytearray(image_bytes) ), @@ -2813,7 +2813,7 @@ def __call__( llama.eval(tokens) else: image_bytes = self.load_image(value) - embed = self._embed_image_bytes(image_bytes, llama.context_params.n_threads_batch) + embed = self._embed_image_bytes(image_bytes, llama.context_params.n_threads) if llama.n_tokens + embed.contents.n_image_pos > llama.n_ctx(): raise ValueError( f"Prompt exceeds n_ctx: {llama.n_tokens + embed.contents.n_image_pos} > {llama.n_ctx()}"