Skip to content

Commit

Permalink
added streaming
Browse files Browse the repository at this point in the history
  • Loading branch information
dillonalaird committed Aug 9, 2024
1 parent 4ba464c commit f0a2b95
Showing 1 changed file with 122 additions and 48 deletions.
170 changes: 122 additions & 48 deletions vision_agent/lmm/lmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import os
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Union, cast
from typing import Any, Callable, Dict, Generator, Iterator, List, Optional, Union, cast

import anthropic
import requests
Expand Down Expand Up @@ -104,15 +104,17 @@ def __init__(
def __call__(
self,
input: Union[str, List[Message]],
) -> str:
**kwargs: Any,
) -> Union[str, Iterator[Optional[str]]]:
if isinstance(input, str):
return self.generate(input)
return self.chat(input)
return self.generate(input, **kwargs)
return self.chat(input, **kwargs)

def chat(
self,
chat: List[Message],
) -> str:
**kwargs: Any,
) -> Union[str, Iterator[Optional[str]]]:
"""Chat with the LMM model.
Parameters:
Expand Down Expand Up @@ -141,17 +143,24 @@ def chat(
)
fixed_chat.append(fixed_c)

# prefers kwargs from second dictionary over first
tmp_kwargs = self.kwargs | kwargs
response = self.client.chat.completions.create(
model=self.model_name, messages=fixed_chat, **self.kwargs # type: ignore
model=self.model_name, messages=fixed_chat, **tmp_kwargs # type: ignore
)

return cast(str, response.choices[0].message.content)
if "stream" in tmp_kwargs and tmp_kwargs["stream"]:
for chunk in response:
chunk_message = chunk.choices[0].delta.content
yield chunk_message
else:
return cast(str, response.choices[0].message.content)

def generate(
self,
prompt: str,
media: Optional[List[Union[str, Path]]] = None,
) -> str:
**kwargs: Any,
) -> Union[str, Iterator[Optional[str]]]:
message: List[Dict[str, Any]] = [
{
"role": "user",
Expand All @@ -173,10 +182,17 @@ def generate(
},
)

# prefers kwargs from second dictionary over first
tmp_kwargs = self.kwargs | kwargs
response = self.client.chat.completions.create(
model=self.model_name, messages=message, **self.kwargs # type: ignore
model=self.model_name, messages=message, **tmp_kwargs # type: ignore
)
return cast(str, response.choices[0].message.content)
if "stream" in tmp_kwargs and tmp_kwargs["stream"]:
for chunk in response:
chunk_message = chunk.choices[0].delta.content
yield chunk_message # type: ignore
else:
return cast(str, response.choices[0].message.content)

def generate_classifier(self, question: str) -> Callable:
api_doc = T.get_tool_documentation([T.clip])
Expand Down Expand Up @@ -309,20 +325,22 @@ def __init__(
self.url = base_url
self.model_name = model_name
self.json_mode = json_mode
self.stream = False
self.kwargs = kwargs

def __call__(
self,
input: Union[str, List[Message]],
) -> str:
**kwargs: Any,
) -> Union[str, Iterator[Optional[str]]]:
if isinstance(input, str):
return self.generate(input)
return self.chat(input)
return self.generate(input, **kwargs)
return self.chat(input, **kwargs)

def chat(
self,
chat: List[Message],
) -> str:
**kwargs: Any,
) -> Union[str, Iterator[Optional[str]]]:
"""Chat with the LMM model.
Parameters:
Expand All @@ -341,40 +359,77 @@ def chat(
url = f"{self.url}/chat"
model = self.model_name
messages = fixed_chat
data = {"model": model, "messages": messages, "stream": self.stream}
data = {"model": model, "messages": messages}

tmp_kwargs = self.kwargs | kwargs
data.update(tmp_kwargs)
json_data = json.dumps(data)
response = requests.post(url, data=json_data)
if response.status_code != 200:
raise ValueError(f"Request failed with status code {response.status_code}")
response = response.json()
return response["message"]["content"] # type: ignore
if "stream" in tmp_kwargs and tmp_kwargs["stream"]:
with requests.post(url, data=json_data, stream=True) as stream:
if stream.status_code != 200:
raise ValueError(
f"Request failed with status code {stream.status_code}"
)

for chunk in stream.iter_content(chunk_size=None):
chunk_data = json.loads(chunk)
if chunk_data["done"]:
yield None
else:
yield chunk_data["message"]["content"]
else:
stream = requests.post(url, data=json_data)
if stream.status_code != 200:
raise ValueError(
f"Request failed with status code {stream.status_code}"
)
stream = stream.json()
return stream["message"]["content"] # type: ignore

def generate(
self,
prompt: str,
media: Optional[List[Union[str, Path]]] = None,
) -> str:
**kwargs: Any,
) -> Union[str, Iterator[Optional[str]]]:

url = f"{self.url}/generate"
data = {
"model": self.model_name,
"prompt": prompt,
"images": [],
"stream": self.stream,
}

json_data = json.dumps(data)
if media and len(media) > 0:
for m in media:
data["images"].append(encode_media(m)) # type: ignore

response = requests.post(url, data=json_data)
tmp_kwargs = self.kwargs | kwargs
data.update(tmp_kwargs)
json_data = json.dumps(data)
if "stream" in tmp_kwargs and tmp_kwargs["stream"]:
with requests.post(url, data=json_data, stream=True) as stream:
if stream.status_code != 200:
raise ValueError(
f"Request failed with status code {stream.status_code}"
)

for chunk in stream.iter_content(chunk_size=None):
chunk_data = json.loads(chunk)
if chunk_data["done"]:
yield None
else:
yield chunk_data["response"]
else:
stream = requests.post(url, data=json_data)

if response.status_code != 200:
raise ValueError(f"Request failed with status code {response.status_code}")
if stream.status_code != 200:
raise ValueError(
f"Request failed with status code {stream.status_code}"
)

response = response.json()
return response["response"] # type: ignore
stream = stream.json()
return stream["response"] # type: ignore


class ClaudeSonnetLMM(LMM):
Expand All @@ -385,27 +440,28 @@ def __init__(
api_key: Optional[str] = None,
model_name: str = "claude-3-sonnet-20240229",
max_tokens: int = 4096,
temperature: float = 0.7,
**kwargs: Any,
):
self.client = anthropic.Anthropic(api_key=api_key)
self.model_name = model_name
self.max_tokens = max_tokens
self.temperature = temperature
if "max_tokens" not in kwargs:
kwargs["max_tokens"] = max_tokens
self.kwargs = kwargs

def __call__(
self,
input: Union[str, List[Dict[str, Any]]],
) -> str:
**kwargs: Any,
) -> Union[str, Iterator[Optional[str]]]:
if isinstance(input, str):
return self.generate(input)
return self.chat(input)
return self.generate(input, **kwargs)
return self.chat(input, **kwargs)

def chat(
self,
chat: List[Dict[str, Any]],
) -> str:
**kwargs: Any,
) -> Union[str, Iterator[Optional[str]]]:
messages: List[MessageParam] = []
for msg in chat:
content: List[Union[TextBlockParam, ImageBlockParam]] = [
Expand All @@ -426,20 +482,28 @@ def chat(
)
messages.append({"role": msg["role"], "content": content})

# prefers kwargs from second dictionary over first
tmp_kwargs = self.kwargs | kwargs
response = self.client.messages.create(
model=self.model_name,
max_tokens=self.max_tokens,
temperature=self.temperature,
messages=messages,
**self.kwargs,
model=self.model_name, messages=messages, **tmp_kwargs
)
return cast(str, response.content[0].text)
if "stream" in tmp_kwargs and tmp_kwargs["stream"]:
for chunk in response:
if chunk.type == "message_start" or chunk.type == "content_block_start":
continue
elif chunk.type == "content_block_delta":
yield chunk.delta.text
elif chunk.type == "message_stop":
yield None
else:
return cast(str, response.content[0].text)

def generate(
self,
prompt: str,
media: Optional[List[Union[str, Path]]] = None,
) -> str:
**kwargs: Any,
) -> Union[str, Iterator[Optional[str]]]:
content: List[Union[TextBlockParam, ImageBlockParam]] = [
TextBlockParam(type="text", text=prompt)
]
Expand All @@ -456,11 +520,21 @@ def generate(
},
)
)

# prefers kwargs from second dictionary over first
tmp_kwargs = self.kwargs | kwargs
response = self.client.messages.create(
model=self.model_name,
max_tokens=self.max_tokens,
temperature=self.temperature,
messages=[{"role": "user", "content": content}],
**self.kwargs,
**tmp_kwargs,
)
return cast(str, response.content[0].text)
if "stream" in tmp_kwargs and tmp_kwargs["stream"]:
for chunk in response:
if chunk.type == "message_start" or chunk.type == "content_block_start":
continue
elif chunk.type == "content_block_delta":
yield chunk.delta.text
elif chunk.type == "message_stop":
yield None
else:
return cast(str, response.content[0].text)

0 comments on commit f0a2b95

Please sign in to comment.