Skip to content

Commit

Permalink
synced code with new code interpreter arg
Browse files Browse the repository at this point in the history
  • Loading branch information
dillonalaird committed Oct 11, 2024
1 parent 969c420 commit ad6edf1
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 45 deletions.
38 changes: 26 additions & 12 deletions vision_agent/agent/vision_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,9 +195,8 @@ def __init__(
agent: Optional[LMM] = None,
verbosity: int = 0,
local_artifacts_path: Optional[Union[str, Path]] = None,
code_sandbox_runtime: Optional[str] = None,
callback_message: Optional[Callable[[Dict[str, Any]], None]] = None,
code_interpreter: Optional[CodeInterpreter] = None,
code_interpreter: Optional[Union[str, CodeInterpreter]] = None,
) -> None:
"""Initialize the VisionAgent.
Expand All @@ -207,14 +206,17 @@ def __init__(
verbosity (int): The verbosity level of the agent.
local_artifacts_path (Optional[Union[str, Path]]): The path to the local
artifacts file.
code_sandbox_runtime (Optional[str]): The code sandbox runtime to use.
code_interpreter (Optional[CodeInterpreter]): if not None, use this CodeInterpreter
callback_message (Optional[Callable[[Dict[str, Any]], None]]): Callback
function to send intermediate update messages.
code_interpreter (Optional[Union[str, CodeInterpreter]]): For string values
it can be one of: None, "local" or "e2b". If None, it will read from
the environment variable "CODE_SANDBOX_RUNTIME". If a CodeInterpreter
object is provided it will use that.
"""

self.agent = AnthropicLMM(temperature=0.0) if agent is None else agent
self.max_iterations = 12
self.verbosity = verbosity
self.code_sandbox_runtime = code_sandbox_runtime
self.code_interpreter = code_interpreter
self.callback_message = callback_message
if self.verbosity >= 1:
Expand Down Expand Up @@ -305,11 +307,13 @@ def chat_with_artifacts(
# this is setting remote artifacts path
artifacts = Artifacts(WORKSPACE / "artifacts.pkl")

# NOTE: each chat should have a dedicated code interpreter instance to avoid concurrency issues
code_interpreter = (
self.code_interpreter
if self.code_interpreter is not None
and not isinstance(self.code_interpreter, str)
else CodeInterpreterFactory.new_instance(
code_sandbox_runtime=self.code_sandbox_runtime,
code_sandbox_runtime=self.code_interpreter,
)
)
with code_interpreter:
Expand Down Expand Up @@ -498,8 +502,8 @@ def __init__(
agent: Optional[LMM] = None,
verbosity: int = 0,
local_artifacts_path: Optional[Union[str, Path]] = None,
code_sandbox_runtime: Optional[str] = None,
callback_message: Optional[Callable[[Dict[str, Any]], None]] = None,
code_interpreter: Optional[Union[str, CodeInterpreter]] = None,
) -> None:
"""Initialize the VisionAgent using OpenAI LMMs.
Expand All @@ -509,16 +513,21 @@ def __init__(
verbosity (int): The verbosity level of the agent.
local_artifacts_path (Optional[Union[str, Path]]): The path to the local
artifacts file.
code_sandbox_runtime (Optional[str]): The code sandbox runtime to use.
callback_message (Optional[Callable[[Dict[str, Any]], None]]): Callback
function to send intermediate update messages.
code_interpreter (Optional[Union[str, CodeInterpreter]]): For string values
it can be one of: None, "local" or "e2b". If None, it will read from
the environment variable "CODE_SANDBOX_RUNTIME". If a CodeInterpreter
object is provided it will use that.
"""

agent = OpenAILMM(temperature=0.0, json_mode=True) if agent is None else agent
super().__init__(
agent,
verbosity,
local_artifacts_path,
code_sandbox_runtime,
callback_message,
code_interpreter,
)


Expand All @@ -528,8 +537,8 @@ def __init__(
agent: Optional[LMM] = None,
verbosity: int = 0,
local_artifacts_path: Optional[Union[str, Path]] = None,
code_sandbox_runtime: Optional[str] = None,
callback_message: Optional[Callable[[Dict[str, Any]], None]] = None,
code_interpreter: Optional[Union[str, CodeInterpreter]] = None,
) -> None:
"""Initialize the VisionAgent using Anthropic LMMs.
Expand All @@ -539,14 +548,19 @@ def __init__(
verbosity (int): The verbosity level of the agent.
local_artifacts_path (Optional[Union[str, Path]]): The path to the local
artifacts file.
code_sandbox_runtime (Optional[str]): The code sandbox runtime to use.
callback_message (Optional[Callable[[Dict[str, Any]], None]]): Callback
function to send intermediate update messages.
code_interpreter (Optional[Union[str, CodeInterpreter]]): For string values
it can be one of: None, "local" or "e2b". If None, it will read from
the environment variable "CODE_SANDBOX_RUNTIME". If a CodeInterpreter
object is provided it will use that.
"""

agent = AnthropicLMM(temperature=0.0) if agent is None else agent
super().__init__(
agent,
verbosity,
local_artifacts_path,
code_sandbox_runtime,
callback_message,
code_interpreter,
)
50 changes: 31 additions & 19 deletions vision_agent/agent/vision_agent_coder.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ def __init__(
debugger: Optional[LMM] = None,
verbosity: int = 0,
report_progress_callback: Optional[Callable[[Dict[str, Any]], None]] = None,
code_sandbox_runtime: Optional[str] = None,
code_interpreter: Optional[Union[str, CodeInterpreter]] = None,
) -> None:
"""Initialize the Vision Agent Coder.
Expand All @@ -355,11 +355,10 @@ def __init__(
in a web application where multiple VisionAgentCoder instances are
running in parallel. This callback ensures that the progress are not
mixed up.
code_sandbox_runtime (Optional[str]): the code sandbox runtime to use. A
code sandbox is used to run the generated code. It can be one of the
following values: None, "local" or "e2b". If None, VisionAgentCoder
will read the value from the environment variable CODE_SANDBOX_RUNTIME.
If it's also None, the local python runtime environment will be used.
code_interpreter (Optional[Union[str, CodeInterpreter]]): For string values
it can be one of: None, "local" or "e2b". If None, it will read from
the environment variable "CODE_SANDBOX_RUNTIME". If a CodeInterpreter
object is provided it will use that.
"""

self.planner = (
Expand All @@ -375,7 +374,7 @@ def __init__(
_LOGGER.setLevel(logging.INFO)

self.report_progress_callback = report_progress_callback
self.code_sandbox_runtime = code_sandbox_runtime
self.code_interpreter = code_interpreter

def __call__(
self,
Expand Down Expand Up @@ -441,13 +440,15 @@ def generate_code_from_plan(
raise ValueError("Chat cannot be empty.")

# NOTE: each chat should have a dedicated code interpreter instance to avoid concurrency issues
with (
code_interpreter
if code_interpreter is not None
code_interpreter = (
self.code_interpreter
if self.code_interpreter is not None
and not isinstance(self.code_interpreter, str)
else CodeInterpreterFactory.new_instance(
code_sandbox_runtime=self.code_sandbox_runtime
code_sandbox_runtime=self.code_interpreter,
)
) as code_interpreter:
)
with code_interpreter:
chat = copy.deepcopy(chat)
media_list = []
for chat_i in chat:
Expand Down Expand Up @@ -556,9 +557,16 @@ def generate_code(
if not chat:
raise ValueError("Chat cannot be empty.")

with CodeInterpreterFactory.new_instance(
code_sandbox_runtime=self.code_sandbox_runtime
) as code_interpreter:
# NOTE: each chat should have a dedicated code interpreter instance to avoid concurrency issues
code_interpreter = (
self.code_interpreter
if self.code_interpreter is not None
and not isinstance(self.code_interpreter, str)
else CodeInterpreterFactory.new_instance(
code_sandbox_runtime=self.code_interpreter,
)
)
with code_interpreter:
plan_context = self.planner.generate_plan( # type: ignore
chat,
test_multi_plan=test_multi_plan,
Expand Down Expand Up @@ -595,7 +603,7 @@ def __init__(
debugger: Optional[LMM] = None,
verbosity: int = 0,
report_progress_callback: Optional[Callable[[Dict[str, Any]], None]] = None,
code_sandbox_runtime: Optional[str] = None,
code_interpreter: Optional[Union[str, CodeInterpreter]] = None,
) -> None:
self.planner = (
OpenAIVisionAgentPlanner(verbosity=verbosity)
Expand All @@ -610,7 +618,7 @@ def __init__(
_LOGGER.setLevel(logging.INFO)

self.report_progress_callback = report_progress_callback
self.code_sandbox_runtime = code_sandbox_runtime
self.code_interpreter = code_interpreter


class AnthropicVisionAgentCoder(VisionAgentCoder):
Expand All @@ -624,7 +632,7 @@ def __init__(
debugger: Optional[LMM] = None,
verbosity: int = 0,
report_progress_callback: Optional[Callable[[Dict[str, Any]], None]] = None,
code_sandbox_runtime: Optional[str] = None,
code_interpreter: Optional[Union[str, CodeInterpreter]] = None,
) -> None:
# NOTE: Claude doesn't have an official JSON mode
self.planner = (
Expand All @@ -640,7 +648,7 @@ def __init__(
_LOGGER.setLevel(logging.INFO)

self.report_progress_callback = report_progress_callback
self.code_sandbox_runtime = code_sandbox_runtime
self.code_interpreter = code_interpreter


class OllamaVisionAgentCoder(VisionAgentCoder):
Expand Down Expand Up @@ -668,6 +676,7 @@ def __init__(
debugger: Optional[LMM] = None,
verbosity: int = 0,
report_progress_callback: Optional[Callable[[Dict[str, Any]], None]] = None,
code_interpreter: Optional[Union[str, CodeInterpreter]] = None,
) -> None:
super().__init__(
planner=(
Expand All @@ -692,6 +701,7 @@ def __init__(
),
verbosity=verbosity,
report_progress_callback=report_progress_callback,
code_interpreter=code_interpreter,
)


Expand All @@ -717,6 +727,7 @@ def __init__(
debugger: Optional[LMM] = None,
verbosity: int = 0,
report_progress_callback: Optional[Callable[[Dict[str, Any]], None]] = None,
code_interpreter: Optional[Union[str, CodeInterpreter]] = None,
) -> None:
"""Initialize the Vision Agent Coder.
Expand Down Expand Up @@ -747,4 +758,5 @@ def __init__(
),
verbosity=verbosity,
report_progress_callback=report_progress_callback,
code_interpreter=code_interpreter,
)
32 changes: 18 additions & 14 deletions vision_agent/agent/vision_agent_planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ def __init__(
tool_recommender: Optional[Sim] = None,
verbosity: int = 0,
report_progress_callback: Optional[Callable[[Dict[str, Any]], None]] = None,
code_sandbox_runtime: Optional[str] = None,
code_interpreter: Optional[Union[str, CodeInterpreter]] = None,
) -> None:
self.planner = AnthropicLMM(temperature=0.0) if planner is None else planner
self.verbosity = verbosity
Expand All @@ -331,7 +331,7 @@ def __init__(
else tool_recommender
)
self.report_progress_callback = report_progress_callback
self.code_sandbox_runtime = code_sandbox_runtime
self.code_interpreter = code_interpreter

def __call__(
self, input: Union[str, List[Message]], media: Optional[Union[str, Path]] = None
Expand All @@ -353,13 +353,17 @@ def generate_plan(
if not chat:
raise ValueError("Chat cannot be empty")

with (
code_interpreter = (
code_interpreter
if code_interpreter is not None
else CodeInterpreterFactory.new_instance(
code_sandbox_runtime=self.code_sandbox_runtime
else (
self.code_interpreter
if not isinstance(self.code_interpreter, str)
else CodeInterpreterFactory.new_instance(self.code_interpreter)
)
) as code_interpreter:
)
code_interpreter = cast(CodeInterpreter, code_interpreter)
with code_interpreter:
chat = copy.deepcopy(chat)
media_list = []
for chat_i in chat:
Expand Down Expand Up @@ -464,14 +468,14 @@ def __init__(
tool_recommender: Optional[Sim] = None,
verbosity: int = 0,
report_progress_callback: Optional[Callable[[Dict[str, Any]], None]] = None,
code_sandbox_runtime: Optional[str] = None,
code_interpreter: Optional[Union[str, CodeInterpreter]] = None,
) -> None:
super().__init__(
planner=AnthropicLMM(temperature=0.0) if planner is None else planner,
tool_recommender=tool_recommender,
verbosity=verbosity,
report_progress_callback=report_progress_callback,
code_sandbox_runtime=code_sandbox_runtime,
code_interpreter=code_interpreter,
)


Expand All @@ -482,7 +486,7 @@ def __init__(
tool_recommender: Optional[Sim] = None,
verbosity: int = 0,
report_progress_callback: Optional[Callable[[Dict[str, Any]], None]] = None,
code_sandbox_runtime: Optional[str] = None,
code_interpreter: Optional[Union[str, CodeInterpreter]] = None,
) -> None:
super().__init__(
planner=(
Expand All @@ -493,7 +497,7 @@ def __init__(
tool_recommender=tool_recommender,
verbosity=verbosity,
report_progress_callback=report_progress_callback,
code_sandbox_runtime=code_sandbox_runtime,
code_interpreter=code_interpreter,
)


Expand All @@ -504,7 +508,7 @@ def __init__(
tool_recommender: Optional[Sim] = None,
verbosity: int = 0,
report_progress_callback: Optional[Callable[[Dict[str, Any]], None]] = None,
code_sandbox_runtime: Optional[str] = None,
code_interpreter: Optional[Union[str, CodeInterpreter]] = None,
) -> None:
super().__init__(
planner=(
Expand All @@ -519,7 +523,7 @@ def __init__(
),
verbosity=verbosity,
report_progress_callback=report_progress_callback,
code_sandbox_runtime=code_sandbox_runtime,
code_interpreter=code_interpreter,
)


Expand All @@ -530,7 +534,7 @@ def __init__(
tool_recommender: Optional[Sim] = None,
verbosity: int = 0,
report_progress_callback: Optional[Callable[[Dict[str, Any]], None]] = None,
code_sandbox_runtime: Optional[str] = None,
code_interpreter: Optional[Union[str, CodeInterpreter]] = None,
) -> None:
super().__init__(
planner=(
Expand All @@ -545,5 +549,5 @@ def __init__(
),
verbosity=verbosity,
report_progress_callback=report_progress_callback,
code_sandbox_runtime=code_sandbox_runtime,
code_interpreter=code_interpreter,
)

0 comments on commit ad6edf1

Please sign in to comment.