Skip to content

Commit 3af591d

Browse files
committed
added json mode for lmm, upgraded gpt-4-turbo
1 parent dae1031 commit 3af591d

File tree

3 files changed

+24
-14
lines changed

3 files changed

+24
-14
lines changed

vision_agent/agent/vision_agent.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@ def self_reflect(
255255
) -> str:
256256
prompt = VISION_AGENT_REFLECTION.format(
257257
question=question,
258-
tools=format_tools(tools),
258+
tools=format_tools({k: v["description"] for k, v in tools.items()}),
259259
tool_results=str(tool_result),
260260
final_answer=final_answer,
261261
)
@@ -268,11 +268,16 @@ def self_reflect(
268268
return reflect_model(prompt)
269269

270270

271-
def parse_reflect(reflect: str) -> bool:
272-
# GPT-4V has a hard time following directions, so make the criteria less strict
273-
return (
271+
def parse_reflect(reflect: str) -> Dict[str, Any]:
272+
try:
273+
return parse_json(reflect)
274+
except Exception:
275+
_LOGGER.error(f"Failed parse json reflection: {reflect}")
276+
# LMMs have a hard time following directions, so make the criteria less strict
277+
finish = (
274278
"finish" in reflect.lower() and len(reflect) < 100
275279
) or "finish" in reflect.lower()[-10:]
280+
return {"Finish": finish, "Reflection": reflect}
276281

277282

278283
def visualize_result(all_tool_results: List[Dict]) -> List[str]:
@@ -389,7 +394,7 @@ def __init__(
389394
OpenAILLM(temperature=0.1) if answer_model is None else answer_model
390395
)
391396
self.reflect_model = (
392-
OpenAILMM(temperature=0.1) if reflect_model is None else reflect_model
397+
OpenAILMM(json_mode=True, temperature=0.1) if reflect_model is None else reflect_model
393398
)
394399
self.max_retries = max_retries
395400
self.tools = TOOLS
@@ -485,13 +490,14 @@ def chat_with_workflow(
485490
visualized_output[0] if len(visualized_output) > 0 else image,
486491
)
487492
self.log_progress(f"Reflection: {reflection}")
488-
if parse_reflect(reflection):
493+
parsed_reflection = parse_reflect(reflection)
494+
if parsed_reflection["Finish"]:
489495
break
490496
else:
491-
reflections += "\n" + reflection
492-
# '<END>' is a symbol to indicate the end of the chat, which is useful for streaming logs.
497+
reflections += "\n" + parsed_reflection["Reflection"]
498+
# '<ANSWER>' is a symbol to indicate the end of the chat, which is useful for streaming logs.
493499
self.log_progress(
494-
f"The Vision Agent has concluded this chat. <ANSWER>{final_answer}</<ANSWER>"
500+
f"The Vision Agent has concluded this chat. <ANSWER>{final_answer}</ANSWER>"
495501
)
496502

497503
if visualize_output:

vision_agent/llm/llm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ class OpenAILLM(LLM):
3333

3434
def __init__(
3535
self,
36-
model_name: str = "gpt-4-turbo-preview",
36+
model_name: str = "gpt-4-turbo",
3737
api_key: Optional[str] = None,
3838
json_mode: bool = False,
3939
**kwargs: Any

vision_agent/lmm/lmm.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -99,9 +99,10 @@ class OpenAILMM(LMM):
9999

100100
def __init__(
101101
self,
102-
model_name: str = "gpt-4-vision-preview",
102+
model_name: str = "gpt-4-turbo",
103103
api_key: Optional[str] = None,
104104
max_tokens: int = 1024,
105+
json_mode: bool = False,
105106
**kwargs: Any,
106107
):
107108
if not api_key:
@@ -111,7 +112,10 @@ def __init__(
111112

112113
self.client = OpenAI(api_key=api_key)
113114
self.model_name = model_name
114-
self.max_tokens = max_tokens
115+
if "max_tokens" not in kwargs:
116+
kwargs["max_tokens"] = max_tokens
117+
if json_mode:
118+
kwargs["response_format"] = {"type": "json_object"}
115119
self.kwargs = kwargs
116120

117121
def __call__(
@@ -153,7 +157,7 @@ def chat(
153157
)
154158

155159
response = self.client.chat.completions.create(
156-
model=self.model_name, messages=fixed_chat, max_tokens=self.max_tokens, **self.kwargs # type: ignore
160+
model=self.model_name, messages=fixed_chat, **self.kwargs # type: ignore
157161
)
158162

159163
return cast(str, response.choices[0].message.content)
@@ -181,7 +185,7 @@ def generate(self, prompt: str, image: Optional[Union[str, Path]] = None) -> str
181185
)
182186

183187
response = self.client.chat.completions.create(
184-
model=self.model_name, messages=message, max_tokens=self.max_tokens, **self.kwargs # type: ignore
188+
model=self.model_name, messages=message, **self.kwargs # type: ignore
185189
)
186190
return cast(str, response.choices[0].message.content)
187191

0 commit comments

Comments
 (0)