Skip to content

Commit

Permalink
adding lmm support for ollama (#144)
Browse files Browse the repository at this point in the history
  • Loading branch information
hpbtql authored Jun 19, 2024
1 parent 8cdbf38 commit 4890c10
Showing 1 changed file with 87 additions and 0 deletions.
87 changes: 87 additions & 0 deletions vision_agent/lmm/lmm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import base64
import json
import requests
import logging
import os
from abc import ABC, abstractmethod
Expand Down Expand Up @@ -267,3 +268,89 @@ def __init__(
if json_mode:
kwargs["response_format"] = {"type": "json_object"}
self.kwargs = kwargs


class OllamaLMM(LMM):
r"""An LMM class for the ollama."""

def __init__(
self,
model_name: str = "llava",
base_url: Optional[str] = "http://localhost:11434/api",
json_mode: bool = False,
**kwargs: Any,
):
self.url = base_url
self.model_name = model_name
self.json_mode = json_mode
self.stream = False

def __call__(
self,
input: Union[str, List[Message]],
) -> str:
if isinstance(input, str):
return self.generate(input)
return self.chat(input)

def chat(
self,
chat: List[Message],
) -> str:
"""Chat with the LMM model.
Parameters:
chat (List[Dict[str, str]]): A list of dictionaries containing the chat
messages. The messages can be in the format:
[{"role": "user", "content": "Hello!"}, ...]
or if it contains media, it should be in the format:
[{"role": "user", "content": "Hello!", "media": ["image1.jpg", ...]}, ...]
"""
fixed_chat = []
for message in chat:
if "media" in message:
message["images"] = [encode_image(m) for m in message["media"]]
del message["media"]
fixed_chat.append(message)
url = f"{self.url}/chat"
model = self.model_name
messages = fixed_chat
data = {
"model": model,
"messages": messages,
"stream": self.stream
}
json_data = json.dumps(data)
response = requests.post(url, data=json_data)
if response.status_code != 200:
raise ValueError(f"Request failed with status code {response.status_code}")
response = response.json()
return response["message"]["content"]

def generate(
self,
prompt: str,
media: Optional[List[Union[str, Path]]] = None,
) -> str:

url = f"{self.url}/generate"
data = {
"model": self.model_name,
"prompt": prompt,
"images": [],
"stream": self.stream
}

json_data = json.dumps(data)
if media and len(media) > 0:
for m in media:
data["images"].append(encode_image(m))

response = requests.post(url, data=json_data)

if response.status_code != 200:
raise ValueError(f"Request failed with status code {response.status_code}")

response = response.json()
return response["response"]

0 comments on commit 4890c10

Please sign in to comment.