Skip to content

Commit

Permalink
updated params for llm/lmm (#34)
Browse files Browse the repository at this point in the history
* updated params for llm/lmm

* type checking fixes
  • Loading branch information
dillonalaird authored Apr 1, 2024
1 parent 456add4 commit 075aa2f
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 20 deletions.
25 changes: 15 additions & 10 deletions vision_agent/agent/vision_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,6 @@ def retrieval(
)
if tool_id is None:
return {}, ""
_LOGGER.info(f"\t(Tool ID, name): ({tool_id}, {tools[tool_id]['name']})")

tool_instructions = tools[tool_id]
tool_usage = tool_instructions["usage"]
Expand All @@ -265,7 +264,6 @@ def retrieval(
parameters = choose_parameter(
model, question, tool_usage, previous_log, reflections
)
_LOGGER.info(f"\tParameters: {parameters} for {tool_name}")
if parameters is None:
return {}, ""
tool_results = {"task": question, "tool_name": tool_name, "parameters": parameters}
Expand All @@ -290,7 +288,7 @@ def parse_tool_results(result: Dict[str, Union[Dict, List]]) -> Any:
tool_results["call_results"] = call_results

call_results_str = str(call_results)
_LOGGER.info(f"\tCall Results: {call_results_str}")
# _LOGGER.info(f"\tCall Results: {call_results_str}")
return tool_results, call_results_str


Expand Down Expand Up @@ -344,7 +342,9 @@ def self_reflect(

def parse_reflect(reflect: str) -> bool:
# GPT-4V has a hard time following directions, so make the criteria less strict
return "finish" in reflect.lower() and len(reflect) < 100
return (
"finish" in reflect.lower() and len(reflect) < 100
) or "finish" in reflect.lower()[-10:]


def visualize_result(all_tool_results: List[Dict]) -> List[str]:
Expand Down Expand Up @@ -423,10 +423,16 @@ def __init__(
verbose: bool = False,
):
self.task_model = (
OpenAILLM(json_mode=True) if task_model is None else task_model
OpenAILLM(json_mode=True, temperature=0.1)
if task_model is None
else task_model
)
self.answer_model = (
OpenAILLM(temperature=0.1) if answer_model is None else answer_model
)
self.reflect_model = (
OpenAILMM(temperature=0.1) if reflect_model is None else reflect_model
)
self.answer_model = OpenAILLM() if answer_model is None else answer_model
self.reflect_model = OpenAILMM() if reflect_model is None else reflect_model
self.max_retries = max_retries

self.tools = TOOLS
Expand Down Expand Up @@ -466,7 +472,6 @@ def chat_with_workflow(
for _ in range(self.max_retries):
task_list = create_tasks(self.task_model, question, self.tools, reflections)

_LOGGER.info(f"Task Dependency: {task_list}")
task_depend = {"Original Quesiton": question}
previous_log = ""
answers = []
Expand All @@ -477,7 +482,6 @@ def chat_with_workflow(
for task in task_list:
task_str = task["task"]
previous_log = str(task_depend)
_LOGGER.info(f"\tSubtask: {task_str}")
tool_results, call_results = retrieval(
self.task_model,
task_str,
Expand All @@ -492,6 +496,7 @@ def chat_with_workflow(
tool_results["answer"] = answer
all_tool_results.append(tool_results)

_LOGGER.info(f"\tCall Result: {call_results}")
_LOGGER.info(f"\tAnswer: {answer}")
answers.append({"task": task_str, "answer": answer})
task_depend[task["id"]]["answer"] = answer # type: ignore
Expand All @@ -510,7 +515,7 @@ def chat_with_workflow(
final_answer,
visualized_images[0] if len(visualized_images) > 0 else image,
)
_LOGGER.info(f"\tReflection: {reflection}")
_LOGGER.info(f"Reflection: {reflection}")
if parse_reflect(reflection):
break
else:
Expand Down
17 changes: 10 additions & 7 deletions vision_agent/llm/llm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json
from abc import ABC, abstractmethod
from typing import Callable, Dict, List, Mapping, Union, cast
from typing import Any, Callable, Dict, List, Mapping, Union, cast

from openai import OpenAI

Expand Down Expand Up @@ -31,30 +31,33 @@ class OpenAILLM(LLM):
r"""An LLM class for any OpenAI LLM model."""

def __init__(
self, model_name: str = "gpt-4-turbo-preview", json_mode: bool = False
self,
model_name: str = "gpt-4-turbo-preview",
json_mode: bool = False,
**kwargs: Any
):
self.model_name = model_name
self.client = OpenAI()
self.json_mode = json_mode
self.kwargs = kwargs
if json_mode:
self.kwargs["response_format"] = {"type": "json_object"}

def generate(self, prompt: str) -> str:
kwargs = {"response_format": {"type": "json_object"}} if self.json_mode else {}
response = self.client.chat.completions.create(
model=self.model_name,
messages=[
{"role": "user", "content": prompt},
],
**kwargs, # type: ignore
**self.kwargs,
)

return cast(str, response.choices[0].message.content)

def chat(self, chat: List[Dict[str, str]]) -> str:
kwargs = {"response_format": {"type": "json_object"}} if self.json_mode else {}
response = self.client.chat.completions.create(
model=self.model_name,
messages=chat, # type: ignore
**kwargs,
**self.kwargs,
)

return cast(str, response.choices[0].message.content)
Expand Down
17 changes: 14 additions & 3 deletions vision_agent/lmm/lmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,11 +97,15 @@ class OpenAILMM(LMM):
r"""An LMM class for the OpenAI GPT-4 Vision model."""

def __init__(
self, model_name: str = "gpt-4-vision-preview", max_tokens: int = 1024
self,
model_name: str = "gpt-4-vision-preview",
max_tokens: int = 1024,
**kwargs: Any,
):
self.model_name = model_name
self.max_tokens = max_tokens
self.client = OpenAI()
self.kwargs = kwargs

def __call__(
self,
Expand All @@ -123,6 +127,13 @@ def chat(

if image:
extension = Path(image).suffix
if extension.lower() == ".jpeg" or extension.lower() == ".jpg":
extension = "jpg"
elif extension.lower() == ".png":
extension = "png"
else:
raise ValueError(f"Unsupported image extension: {extension}")

encoded_image = encode_image(image)
fixed_chat[0]["content"].append( # type: ignore
{
Expand All @@ -135,7 +146,7 @@ def chat(
)

response = self.client.chat.completions.create(
model=self.model_name, messages=fixed_chat, max_tokens=self.max_tokens # type: ignore
model=self.model_name, messages=fixed_chat, max_tokens=self.max_tokens, **self.kwargs # type: ignore
)

return cast(str, response.choices[0].message.content)
Expand Down Expand Up @@ -163,7 +174,7 @@ def generate(self, prompt: str, image: Optional[Union[str, Path]] = None) -> str
)

response = self.client.chat.completions.create(
model=self.model_name, messages=message, max_tokens=self.max_tokens # type: ignore
model=self.model_name, messages=message, max_tokens=self.max_tokens, **self.kwargs # type: ignore
)
return cast(str, response.choices[0].message.content)

Expand Down

0 comments on commit 075aa2f

Please sign in to comment.