From 4890c105439ab9d30eba0df61c5c3f18f456851b Mon Sep 17 00:00:00 2001 From: hpbtql <86307756+hpbtql@users.noreply.github.com> Date: Thu, 20 Jun 2024 07:02:31 +0800 Subject: [PATCH] adding lmm support for ollama (#144) --- vision_agent/lmm/lmm.py | 87 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 87 insertions(+) diff --git a/vision_agent/lmm/lmm.py b/vision_agent/lmm/lmm.py index 693851fa..9ca5c581 100644 --- a/vision_agent/lmm/lmm.py +++ b/vision_agent/lmm/lmm.py @@ -1,5 +1,6 @@ import base64 import json +import requests import logging import os from abc import ABC, abstractmethod @@ -267,3 +268,89 @@ def __init__( if json_mode: kwargs["response_format"] = {"type": "json_object"} self.kwargs = kwargs + + +class OllamaLMM(LMM): + r"""An LMM class for the ollama.""" + + def __init__( + self, + model_name: str = "llava", + base_url: Optional[str] = "http://localhost:11434/api", + json_mode: bool = False, + **kwargs: Any, + ): + self.url = base_url + self.model_name = model_name + self.json_mode = json_mode + self.stream = False + + def __call__( + self, + input: Union[str, List[Message]], + ) -> str: + if isinstance(input, str): + return self.generate(input) + return self.chat(input) + + def chat( + self, + chat: List[Message], + ) -> str: + """Chat with the LMM model. + + Parameters: + chat (List[Dict[str, str]]): A list of dictionaries containing the chat + messages. The messages can be in the format: + [{"role": "user", "content": "Hello!"}, ...] + or if it contains media, it should be in the format: + [{"role": "user", "content": "Hello!", "media": ["image1.jpg", ...]}, ...] + """ + fixed_chat = [] + for message in chat: + if "media" in message: + message["images"] = [encode_image(m) for m in message["media"]] + del message["media"] + fixed_chat.append(message) + url = f"{self.url}/chat" + model = self.model_name + messages = fixed_chat + data = { + "model": model, + "messages": messages, + "stream": self.stream + } + json_data = json.dumps(data) + response = requests.post(url, data=json_data) + if response.status_code != 200: + raise ValueError(f"Request failed with status code {response.status_code}") + response = response.json() + return response["message"]["content"] + + def generate( + self, + prompt: str, + media: Optional[List[Union[str, Path]]] = None, + ) -> str: + + url = f"{self.url}/generate" + data = { + "model": self.model_name, + "prompt": prompt, + "images": [], + "stream": self.stream + } + + json_data = json.dumps(data) + if media and len(media) > 0: + for m in media: + data["images"].append(encode_image(m)) + + response = requests.post(url, data=json_data) + + if response.status_code != 200: + raise ValueError(f"Request failed with status code {response.status_code}") + + response = response.json() + return response["response"] + \ No newline at end of file