Skip to content

Commit

Permalink
Give Vision to Anthropic models in Khoj (#948)
Browse files Browse the repository at this point in the history
### Major
- Give Vision to Anthropic models in Khoj

### Minor
- Reuse logic to format messages for chat with anthropic models
- Make the get image from url function more versatile and reusable
- Encourage output mode chat actor to output only json and nothing else
  • Loading branch information
debanjum authored Oct 25, 2024
2 parents 37317e3 + 01d740d commit adee5a3
Show file tree
Hide file tree
Showing 7 changed files with 123 additions and 41 deletions.
42 changes: 21 additions & 21 deletions src/khoj/processor/conversation/anthropic/anthropic_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,12 @@
from khoj.processor.conversation.anthropic.utils import (
anthropic_chat_completion_with_backoff,
anthropic_completion_with_backoff,
format_messages_for_anthropic,
)
from khoj.processor.conversation.utils import (
construct_structured_message,
generate_chatml_messages_with_context,
)
from khoj.processor.conversation.utils import generate_chatml_messages_with_context
from khoj.utils.helpers import ConversationCommand, is_none_or_empty
from khoj.utils.rawconfig import LocationData

Expand All @@ -27,6 +31,8 @@ def extract_questions_anthropic(
temperature=0.7,
location_data: LocationData = None,
user: KhojUser = None,
query_images: Optional[list[str]] = None,
vision_enabled: bool = False,
personality_context: Optional[str] = None,
):
"""
Expand Down Expand Up @@ -68,6 +74,13 @@ def extract_questions_anthropic(
text=text,
)

prompt = construct_structured_message(
message=prompt,
images=query_images,
model_type=ChatModelOptions.ModelType.ANTHROPIC,
vision_enabled=vision_enabled,
)

messages = [ChatMessage(content=prompt, role="user")]

response = anthropic_completion_with_backoff(
Expand Down Expand Up @@ -101,17 +114,7 @@ def anthropic_send_message_to_model(messages, api_key, model):
"""
Send message to model
"""
# Anthropic requires the first message to be a 'user' message, and the system prompt is not to be sent in the messages parameter
system_prompt = None

if len(messages) == 1:
messages[0].role = "user"
else:
system_prompt = ""
for message in messages.copy():
if message.role == "system":
system_prompt += message.content
messages.remove(message)
messages, system_prompt = format_messages_for_anthropic(messages)

# Get Response from GPT. Don't use response_type because Anthropic doesn't support it.
return anthropic_completion_with_backoff(
Expand All @@ -127,7 +130,7 @@ def converse_anthropic(
user_query,
online_results: Optional[Dict[str, Dict]] = None,
conversation_log={},
model: Optional[str] = "claude-instant-1.2",
model: Optional[str] = "claude-3-5-sonnet-20241022",
api_key: Optional[str] = None,
completion_func=None,
conversation_commands=[ConversationCommand.Default],
Expand All @@ -136,6 +139,8 @@ def converse_anthropic(
location_data: LocationData = None,
user_name: str = None,
agent: Agent = None,
query_images: Optional[list[str]] = None,
vision_available: bool = False,
):
"""
Converse with user using Anthropic's Claude
Expand Down Expand Up @@ -189,17 +194,12 @@ def converse_anthropic(
model_name=model,
max_prompt_size=max_prompt_size,
tokenizer_name=tokenizer_name,
query_images=query_images,
vision_enabled=vision_available,
model_type=ChatModelOptions.ModelType.ANTHROPIC,
)

if len(messages) > 1:
if messages[0].role == "assistant":
messages = messages[1:]

for message in messages.copy():
if message.role == "system":
system_prompt += message.content
messages.remove(message)
messages, system_prompt = format_messages_for_anthropic(messages, system_prompt)

truncated_messages = "\n".join({f"{message.content[:40]}..." for message in messages})
logger.debug(f"Conversation Context for Claude: {truncated_messages}")
Expand Down
52 changes: 51 additions & 1 deletion src/khoj/processor/conversation/anthropic/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Dict, List

import anthropic
from langchain.schema import ChatMessage
from tenacity import (
before_sleep_log,
retry,
Expand All @@ -11,7 +12,8 @@
wait_random_exponential,
)

from khoj.processor.conversation.utils import ThreadedGenerator
from khoj.processor.conversation.utils import ThreadedGenerator, get_image_from_url
from khoj.utils.helpers import is_none_or_empty

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -115,3 +117,51 @@ def anthropic_llm_thread(
logger.error(f"Error in anthropic_llm_thread: {e}", exc_info=True)
finally:
g.close()


def format_messages_for_anthropic(messages: list[ChatMessage], system_prompt=None):
"""
Format messages for Anthropic
"""
# Extract system prompt
system_prompt = system_prompt or ""
for message in messages.copy():
if message.role == "system":
system_prompt += message.content
messages.remove(message)
system_prompt = None if is_none_or_empty(system_prompt) else system_prompt

# Anthropic requires the first message to be a 'user' message
if len(messages) == 1:
messages[0].role = "user"
elif len(messages) > 1 and messages[0].role == "assistant":
messages = messages[1:]

# Convert image urls to base64 encoded images in Anthropic message format
for message in messages:
if isinstance(message.content, list):
content = []
# Sort the content. Anthropic models prefer that text comes after images.
message.content.sort(key=lambda x: 0 if x["type"] == "image_url" else 1)
for idx, part in enumerate(message.content):
if part["type"] == "text":
content.append({"type": "text", "text": part["text"]})
elif part["type"] == "image_url":
image = get_image_from_url(part["image_url"]["url"], type="b64")
# Prefix each image with text block enumerating the image number
# This helps the model reference the image in its response. Recommended by Anthropic
content.extend(
[
{
"type": "text",
"text": f"Image {idx + 1}:",
},
{
"type": "image",
"source": {"type": "base64", "media_type": image.type, "data": image.content},
},
]
)
message.content = content

return messages, system_prompt
17 changes: 2 additions & 15 deletions src/khoj/processor/conversation/google/utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
import logging
import random
from io import BytesIO
from threading import Thread

import google.generativeai as genai
import PIL.Image
import requests
from google.generativeai.types.answer_types import FinishReason
from google.generativeai.types.generation_types import StopCandidateException
from google.generativeai.types.safety_types import (
Expand All @@ -22,7 +19,7 @@
wait_random_exponential,
)

from khoj.processor.conversation.utils import ThreadedGenerator
from khoj.processor.conversation.utils import ThreadedGenerator, get_image_from_url
from khoj.utils.helpers import is_none_or_empty

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -207,7 +204,7 @@ def format_messages_for_gemini(messages: list[ChatMessage], system_prompt: str =
if isinstance(message.content, list):
# Convert image_urls to PIL.Image and place them at beginning of list (better for Gemini)
message.content = [
get_image_from_url(item["image_url"]["url"]) if item["type"] == "image_url" else item["text"]
get_image_from_url(item["image_url"]["url"]).content if item["type"] == "image_url" else item["text"]
for item in sorted(message.content, key=lambda x: 0 if x["type"] == "image_url" else 1)
]
elif isinstance(message.content, str):
Expand All @@ -220,13 +217,3 @@ def format_messages_for_gemini(messages: list[ChatMessage], system_prompt: str =
messages[0].role = "user"

return messages, system_prompt


def get_image_from_url(image_url: str) -> PIL.Image:
try:
response = requests.get(image_url)
response.raise_for_status() # Check if the request was successful
return PIL.Image.open(BytesIO(response.content))
except requests.exceptions.RequestException as e:
logger.error(f"Failed to get image from URL {image_url}: {e}")
return None
2 changes: 1 addition & 1 deletion src/khoj/processor/conversation/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,7 +619,7 @@
Q: Share a painting using the weather for Bali every morning.
Khoj: {{"output": "automation"}}
Now it's your turn to pick the mode you would like to use to answer the user's question. Provide your response as a JSON.
Now it's your turn to pick the mode you would like to use to answer the user's question. Provide your response as a JSON. Do not say anything else.
Chat History:
{chat_history}
Expand Down
40 changes: 39 additions & 1 deletion src/khoj/processor/conversation/utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
import base64
import logging
import math
import mimetypes
import queue
from dataclasses import dataclass
from datetime import datetime
from io import BytesIO
from time import perf_counter
from typing import Any, Dict, List, Optional

import PIL.Image
import requests
import tiktoken
from langchain.schema import ChatMessage
from llama_cpp.llama import Llama
Expand Down Expand Up @@ -152,7 +158,11 @@ def construct_structured_message(message: str, images: list[str], model_type: st
if not images or not vision_enabled:
return message

if model_type in [ChatModelOptions.ModelType.OPENAI, ChatModelOptions.ModelType.GOOGLE]:
if model_type in [
ChatModelOptions.ModelType.OPENAI,
ChatModelOptions.ModelType.GOOGLE,
ChatModelOptions.ModelType.ANTHROPIC,
]:
return [
{"type": "text", "text": message},
*[{"type": "image_url", "image_url": {"url": image}} for image in images],
Expand Down Expand Up @@ -306,3 +316,31 @@ def reciprocal_conversation_to_chatml(message_pair):
def remove_json_codeblock(response: str):
"""Remove any markdown json codeblock formatting if present. Useful for non schema enforceable models"""
return response.removeprefix("```json").removesuffix("```")


@dataclass
class ImageWithType:
content: Any
type: str


def get_image_from_url(image_url: str, type="pil"):
try:
response = requests.get(image_url)
response.raise_for_status() # Check if the request was successful

# Get content type from response or infer from URL
content_type = response.headers.get("content-type") or mimetypes.guess_type(image_url)[0] or "image/webp"

# Convert image to desired format
if type == "b64":
image_data = base64.b64encode(response.content).decode("utf-8")
elif type == "pil":
image_data = PIL.Image.open(BytesIO(response.content))
else:
raise ValueError(f"Invalid image type: {type}")

return ImageWithType(content=image_data, type=content_type)
except requests.exceptions.RequestException as e:
logger.error(f"Failed to get image from URL {image_url}: {e}")
return ImageWithType(content=None, type=None)
2 changes: 2 additions & 0 deletions src/khoj/routers/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,11 +448,13 @@ async def extract_references_and_questions(
chat_model = conversation_config.chat_model
inferred_queries = extract_questions_anthropic(
defiltered_query,
query_images=query_images,
model=chat_model,
api_key=api_key,
conversation_log=meta_log,
location_data=location_data,
user=user,
vision_enabled=vision_enabled,
personality_context=personality_context,
)
elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE:
Expand Down
9 changes: 7 additions & 2 deletions src/khoj/routers/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -820,10 +820,13 @@ async def send_message_to_model_wrapper(
conversation_config: ChatModelOptions = await ConversationAdapters.aget_default_conversation_config(user)
vision_available = conversation_config.vision_enabled
if not vision_available and query_images:
logger.warning(f"Vision is not enabled for default model: {conversation_config.chat_model}.")
vision_enabled_config = await ConversationAdapters.aget_vision_enabled_config()
if vision_enabled_config:
conversation_config = vision_enabled_config
vision_available = True
if vision_available and query_images:
logger.info(f"Using {conversation_config.chat_model} model to understand {len(query_images)} images.")

subscribed = await ais_user_subscribed(user)
chat_model = conversation_config.chat_model
Expand Down Expand Up @@ -1104,8 +1107,9 @@ def generate_chat_response(
chat_response = converse_anthropic(
compiled_references,
q,
online_results,
meta_log,
query_images=query_images,
online_results=online_results,
conversation_log=meta_log,
model=conversation_config.chat_model,
api_key=api_key,
completion_func=partial_completion,
Expand All @@ -1115,6 +1119,7 @@ def generate_chat_response(
location_data=location_data,
user_name=user_name,
agent=agent,
vision_available=vision_available,
)
elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE:
api_key = conversation_config.openai_config.api_key
Expand Down

0 comments on commit adee5a3

Please sign in to comment.