Skip to content

Commit

Permalink
added json mode for lmm, upgraded gpt-4-turbo
Browse files Browse the repository at this point in the history
  • Loading branch information
dillonalaird committed Apr 15, 2024
1 parent dae1031 commit 3af591d
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 14 deletions.
24 changes: 15 additions & 9 deletions vision_agent/agent/vision_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ def self_reflect(
) -> str:
prompt = VISION_AGENT_REFLECTION.format(
question=question,
tools=format_tools(tools),
tools=format_tools({k: v["description"] for k, v in tools.items()}),
tool_results=str(tool_result),
final_answer=final_answer,
)
Expand All @@ -268,11 +268,16 @@ def self_reflect(
return reflect_model(prompt)


def parse_reflect(reflect: str) -> bool:
# GPT-4V has a hard time following directions, so make the criteria less strict
return (
def parse_reflect(reflect: str) -> Dict[str, Any]:
try:
return parse_json(reflect)
except Exception:
_LOGGER.error(f"Failed parse json reflection: {reflect}")
# LMMs have a hard time following directions, so make the criteria less strict
finish = (
"finish" in reflect.lower() and len(reflect) < 100
) or "finish" in reflect.lower()[-10:]
return {"Finish": finish, "Reflection": reflect}


def visualize_result(all_tool_results: List[Dict]) -> List[str]:
Expand Down Expand Up @@ -389,7 +394,7 @@ def __init__(
OpenAILLM(temperature=0.1) if answer_model is None else answer_model
)
self.reflect_model = (
OpenAILMM(temperature=0.1) if reflect_model is None else reflect_model
OpenAILMM(json_mode=True, temperature=0.1) if reflect_model is None else reflect_model
)
self.max_retries = max_retries
self.tools = TOOLS
Expand Down Expand Up @@ -485,13 +490,14 @@ def chat_with_workflow(
visualized_output[0] if len(visualized_output) > 0 else image,
)
self.log_progress(f"Reflection: {reflection}")
if parse_reflect(reflection):
parsed_reflection = parse_reflect(reflection)
if parsed_reflection["Finish"]:
break
else:
reflections += "\n" + reflection
# '<END>' is a symbol to indicate the end of the chat, which is useful for streaming logs.
reflections += "\n" + parsed_reflection["Reflection"]
# '<ANSWER>' is a symbol to indicate the end of the chat, which is useful for streaming logs.
self.log_progress(
f"The Vision Agent has concluded this chat. <ANSWER>{final_answer}</<ANSWER>"
f"The Vision Agent has concluded this chat. <ANSWER>{final_answer}</ANSWER>"
)

if visualize_output:
Expand Down
2 changes: 1 addition & 1 deletion vision_agent/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class OpenAILLM(LLM):

def __init__(
self,
model_name: str = "gpt-4-turbo-preview",
model_name: str = "gpt-4-turbo",
api_key: Optional[str] = None,
json_mode: bool = False,
**kwargs: Any
Expand Down
12 changes: 8 additions & 4 deletions vision_agent/lmm/lmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,10 @@ class OpenAILMM(LMM):

def __init__(
self,
model_name: str = "gpt-4-vision-preview",
model_name: str = "gpt-4-turbo",
api_key: Optional[str] = None,
max_tokens: int = 1024,
json_mode: bool = False,
**kwargs: Any,
):
if not api_key:
Expand All @@ -111,7 +112,10 @@ def __init__(

self.client = OpenAI(api_key=api_key)
self.model_name = model_name
self.max_tokens = max_tokens
if "max_tokens" not in kwargs:
kwargs["max_tokens"] = max_tokens
if json_mode:
kwargs["response_format"] = {"type": "json_object"}
self.kwargs = kwargs

def __call__(
Expand Down Expand Up @@ -153,7 +157,7 @@ def chat(
)

response = self.client.chat.completions.create(
model=self.model_name, messages=fixed_chat, max_tokens=self.max_tokens, **self.kwargs # type: ignore
model=self.model_name, messages=fixed_chat, **self.kwargs # type: ignore
)

return cast(str, response.choices[0].message.content)
Expand Down Expand Up @@ -181,7 +185,7 @@ def generate(self, prompt: str, image: Optional[Union[str, Path]] = None) -> str
)

response = self.client.chat.completions.create(
model=self.model_name, messages=message, max_tokens=self.max_tokens, **self.kwargs # type: ignore
model=self.model_name, messages=message, **self.kwargs # type: ignore
)
return cast(str, response.choices[0].message.content)

Expand Down

0 comments on commit 3af591d

Please sign in to comment.