From 4aa17e7c8ec771f9d2a72f4d756627f931754955 Mon Sep 17 00:00:00 2001 From: shankar-landing-ai Date: Wed, 28 Feb 2024 12:59:15 -0800 Subject: [PATCH] rename notebook and fix mypy errors --- examples/{lmm_example.ipynb => va_example.ipynb} | 2 +- vision_agent/lmm/lmm.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) rename examples/{lmm_example.ipynb => va_example.ipynb} (99%) diff --git a/examples/lmm_example.ipynb b/examples/va_example.ipynb similarity index 99% rename from examples/lmm_example.ipynb rename to examples/va_example.ipynb index d0b2547d..d43173f3 100644 --- a/examples/lmm_example.ipynb +++ b/examples/va_example.ipynb @@ -82,7 +82,7 @@ "prompt = \"Here is the image of a page from a document. Parse this document, if you can find a table and its columns in the page, print Table Title in json format and if not found Table title will be 'N/A'\"\n", "# The llava model can take two additional parameters\n", "# temperature (float): The temperature parameter for text generation. Higher values (e.g., 1.0) make the output more random, while lower values (e.g., 0.1) make it more deterministic. Default is 0.2.\n", - "# max_new_tokens (int): The maximum number of tokens to generate. Default is 256.\n", + "# max_new_tokens (int): The maximum number of tokens to generate. Default is 1500.\n", "resp = model.generate(prompt, image_path, temperature=0.1, max_new_tokens=1500)\n", "print(textwrap.fill(resp, 80))" ] diff --git a/vision_agent/lmm/lmm.py b/vision_agent/lmm/lmm.py index de4d255d..bdd4cc52 100644 --- a/vision_agent/lmm/lmm.py +++ b/vision_agent/lmm/lmm.py @@ -35,14 +35,14 @@ def generate( self, prompt: str, image: Optional[Union[str, Path]] = None, - temperature: float = 0.2, - max_new_tokens: int = 256, + temperature: float = 0.1, + max_new_tokens: int = 1500, ) -> str: data = {"prompt": prompt} if image: data["image"] = encode_image(image) - data["temperature"] = temperature - data["max_new_tokens"] = max_new_tokens + data["temperature"] = temperature # type: ignore + data["max_new_tokens"] = max_new_tokens # type: ignore res = requests.post( _LLAVA_ENDPOINT, headers={"Content-Type": "application/json"},