Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

updated params for llm/lmm #34

Merged
merged 2 commits into from
Apr 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 15 additions & 10 deletions vision_agent/agent/vision_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,6 @@ def retrieval(
)
if tool_id is None:
return {}, ""
_LOGGER.info(f"\t(Tool ID, name): ({tool_id}, {tools[tool_id]['name']})")

tool_instructions = tools[tool_id]
tool_usage = tool_instructions["usage"]
Expand All @@ -265,7 +264,6 @@ def retrieval(
parameters = choose_parameter(
model, question, tool_usage, previous_log, reflections
)
_LOGGER.info(f"\tParameters: {parameters} for {tool_name}")
if parameters is None:
return {}, ""
tool_results = {"task": question, "tool_name": tool_name, "parameters": parameters}
Expand All @@ -290,7 +288,7 @@ def parse_tool_results(result: Dict[str, Union[Dict, List]]) -> Any:
tool_results["call_results"] = call_results

call_results_str = str(call_results)
_LOGGER.info(f"\tCall Results: {call_results_str}")
# _LOGGER.info(f"\tCall Results: {call_results_str}")
return tool_results, call_results_str


Expand Down Expand Up @@ -344,7 +342,9 @@ def self_reflect(

def parse_reflect(reflect: str) -> bool:
# GPT-4V has a hard time following directions, so make the criteria less strict
return "finish" in reflect.lower() and len(reflect) < 100
return (
"finish" in reflect.lower() and len(reflect) < 100
) or "finish" in reflect.lower()[-10:]


def visualize_result(all_tool_results: List[Dict]) -> List[str]:
Expand Down Expand Up @@ -423,10 +423,16 @@ def __init__(
verbose: bool = False,
):
self.task_model = (
OpenAILLM(json_mode=True) if task_model is None else task_model
OpenAILLM(json_mode=True, temperature=0.1)
if task_model is None
else task_model
)
self.answer_model = (
OpenAILLM(temperature=0.1) if answer_model is None else answer_model
)
self.reflect_model = (
OpenAILMM(temperature=0.1) if reflect_model is None else reflect_model
)
self.answer_model = OpenAILLM() if answer_model is None else answer_model
self.reflect_model = OpenAILMM() if reflect_model is None else reflect_model
self.max_retries = max_retries

self.tools = TOOLS
Expand Down Expand Up @@ -466,7 +472,6 @@ def chat_with_workflow(
for _ in range(self.max_retries):
task_list = create_tasks(self.task_model, question, self.tools, reflections)

_LOGGER.info(f"Task Dependency: {task_list}")
task_depend = {"Original Quesiton": question}
previous_log = ""
answers = []
Expand All @@ -477,7 +482,6 @@ def chat_with_workflow(
for task in task_list:
task_str = task["task"]
previous_log = str(task_depend)
_LOGGER.info(f"\tSubtask: {task_str}")
tool_results, call_results = retrieval(
self.task_model,
task_str,
Expand All @@ -492,6 +496,7 @@ def chat_with_workflow(
tool_results["answer"] = answer
all_tool_results.append(tool_results)

_LOGGER.info(f"\tCall Result: {call_results}")
_LOGGER.info(f"\tAnswer: {answer}")
answers.append({"task": task_str, "answer": answer})
task_depend[task["id"]]["answer"] = answer # type: ignore
Expand All @@ -510,7 +515,7 @@ def chat_with_workflow(
final_answer,
visualized_images[0] if len(visualized_images) > 0 else image,
)
_LOGGER.info(f"\tReflection: {reflection}")
_LOGGER.info(f"Reflection: {reflection}")
if parse_reflect(reflection):
break
else:
Expand Down
17 changes: 10 additions & 7 deletions vision_agent/llm/llm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json
from abc import ABC, abstractmethod
from typing import Callable, Dict, List, Mapping, Union, cast
from typing import Any, Callable, Dict, List, Mapping, Union, cast

from openai import OpenAI

Expand Down Expand Up @@ -31,30 +31,33 @@ class OpenAILLM(LLM):
r"""An LLM class for any OpenAI LLM model."""

def __init__(
self, model_name: str = "gpt-4-turbo-preview", json_mode: bool = False
self,
model_name: str = "gpt-4-turbo-preview",
json_mode: bool = False,
**kwargs: Any
):
self.model_name = model_name
self.client = OpenAI()
self.json_mode = json_mode
self.kwargs = kwargs
if json_mode:
self.kwargs["response_format"] = {"type": "json_object"}

def generate(self, prompt: str) -> str:
kwargs = {"response_format": {"type": "json_object"}} if self.json_mode else {}
response = self.client.chat.completions.create(
model=self.model_name,
messages=[
{"role": "user", "content": prompt},
],
**kwargs, # type: ignore
**self.kwargs,
)

return cast(str, response.choices[0].message.content)

def chat(self, chat: List[Dict[str, str]]) -> str:
kwargs = {"response_format": {"type": "json_object"}} if self.json_mode else {}
response = self.client.chat.completions.create(
model=self.model_name,
messages=chat, # type: ignore
**kwargs,
**self.kwargs,
)

return cast(str, response.choices[0].message.content)
Expand Down
17 changes: 14 additions & 3 deletions vision_agent/lmm/lmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,11 +97,15 @@ class OpenAILMM(LMM):
r"""An LMM class for the OpenAI GPT-4 Vision model."""

def __init__(
self, model_name: str = "gpt-4-vision-preview", max_tokens: int = 1024
self,
model_name: str = "gpt-4-vision-preview",
max_tokens: int = 1024,
**kwargs: Any,
):
self.model_name = model_name
self.max_tokens = max_tokens
self.client = OpenAI()
self.kwargs = kwargs

def __call__(
self,
Expand All @@ -123,6 +127,13 @@ def chat(

if image:
extension = Path(image).suffix
if extension.lower() == ".jpeg" or extension.lower() == ".jpg":
extension = "jpg"
elif extension.lower() == ".png":
extension = "png"
else:
raise ValueError(f"Unsupported image extension: {extension}")

encoded_image = encode_image(image)
fixed_chat[0]["content"].append( # type: ignore
{
Expand All @@ -135,7 +146,7 @@ def chat(
)

response = self.client.chat.completions.create(
model=self.model_name, messages=fixed_chat, max_tokens=self.max_tokens # type: ignore
model=self.model_name, messages=fixed_chat, max_tokens=self.max_tokens, **self.kwargs # type: ignore
)

return cast(str, response.choices[0].message.content)
Expand Down Expand Up @@ -163,7 +174,7 @@ def generate(self, prompt: str, image: Optional[Union[str, Path]] = None) -> str
)

response = self.client.chat.completions.create(
model=self.model_name, messages=message, max_tokens=self.max_tokens # type: ignore
model=self.model_name, messages=message, max_tokens=self.max_tokens, **self.kwargs # type: ignore
)
return cast(str, response.choices[0].message.content)

Expand Down
Loading