Skip to content

Commit

Permalink
Enable customizing max steps at runtime
Browse files Browse the repository at this point in the history
  • Loading branch information
aymeric-roucher committed Feb 22, 2025
1 parent 14b6008 commit 1831733
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 10 deletions.
24 changes: 14 additions & 10 deletions src/smolagents/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ class MultiStepAgent:
tools (`list[Tool]`): [`Tool`]s that the agent can use.
model (`Callable[[list[dict[str, str]]], ChatMessage]`): Model that will generate the agent's actions.
prompt_templates ([`~agents.PromptTemplates`], *optional*): Prompt templates.
max_steps (`int`, default `6`): Maximum number of steps the agent can take to solve the task.
max_steps (`int`, default `10`): Maximum number of steps the agent can take to solve the task.
tool_parser (`Callable`, *optional*): Function used to parse the tool calls from the LLM output.
add_base_tools (`bool`, default `False`): Whether to add the base tools to the agent's tools.
verbosity_level (`LogLevel`, default `LogLevel.INFO`): Level of verbosity of the agent's logs.
Expand All @@ -193,7 +193,7 @@ def __init__(
tools: List[Tool],
model: Callable[[List[Dict[str, str]]], ChatMessage],
prompt_templates: Optional[PromptTemplates] = None,
max_steps: int = 6,
max_steps: int = 10,
tool_parser: Optional[Callable] = None,
add_base_tools: bool = False,
verbosity_level: LogLevel = LogLevel.INFO,
Expand Down Expand Up @@ -275,6 +275,7 @@ def run(
reset: bool = True,
images: Optional[List[str]] = None,
additional_args: Optional[Dict] = None,
max_steps: Optional[int] = None,
):
"""
Run the agent for the given task.
Expand All @@ -284,7 +285,8 @@ def run(
stream (`bool`): Whether to run in a streaming way.
reset (`bool`): Whether to reset the conversation or keep it going from previous run.
images (`list[str]`, *optional*): Paths to image(s).
additional_args (`dict`): Any other variables that you want to pass to the agent run, for instance images or dataframes. Give them clear names!
additional_args (`dict`, *optional*): Any other variables that you want to pass to the agent run, for instance images or dataframes. Give them clear names!
max_steps (`int`, *optional*): Maximum number of steps the agent can take to solve the task. if not provided, will use the agent's default value.
Example:
```py
Expand All @@ -293,7 +295,7 @@ def run(
agent.run("What is the result of 2 power 3.7384?")
```
"""

max_steps = max_steps or self.max_steps
self.task = task
if additional_args is not None:
self.state.update(additional_args)
Expand All @@ -318,14 +320,16 @@ def run(

if stream:
# The steps are returned as they are executed through a generator to iterate on.
return self._run(task=self.task, images=images)
# Outputs are returned only at the end as a string. We only look at the last step
return deque(self._run(task=self.task, images=images), maxlen=1)[0]
return self._run(task=self.task, max_steps=max_steps, images=images)
# Outputs are returned only at the end. We only look at the last step.
return deque(self._run(task=self.task, max_steps=max_steps, images=images), maxlen=1)[0]

def _run(self, task: str, images: List[str] | None = None) -> Generator[ActionStep | AgentType, None, None]:
def _run(
self, task: str, max_steps: int, images: List[str] | None = None
) -> Generator[ActionStep | AgentType, None, None]:
final_answer = None
self.step_number = 1
while final_answer is None and self.step_number <= self.max_steps:
while final_answer is None and self.step_number <= max_steps:
step_start_time = time.time()
memory_step = self._create_memory_step(step_start_time, images)
try:
Expand All @@ -337,7 +341,7 @@ def _run(self, task: str, images: List[str] | None = None) -> Generator[ActionSt
yield memory_step
self.step_number += 1

if final_answer is None and self.step_number == self.max_steps + 1:
if final_answer is None and self.step_number == max_steps + 1:
final_answer = self._handle_max_steps_reached(task, images, step_start_time)
yield memory_step
yield handle_agent_output_types(final_answer)
Expand Down
10 changes: 10 additions & 0 deletions tests/test_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,16 @@ def test_fails_max_steps(self):
assert type(agent.memory.steps[-1].error) is AgentMaxStepsError
assert isinstance(answer, str)

agent = CodeAgent(
tools=[PythonInterpreterTool()],
model=fake_code_model_no_return, # use this callable because it never ends
max_steps=5,
)
answer = agent.run("What is 2 multiplied by 3.6452?", max_steps=3)
assert len(agent.memory.steps) == 5 # Task step + 3 action steps + Final answer
assert type(agent.memory.steps[-1].error) is AgentMaxStepsError
assert isinstance(answer, str)

def test_tool_descriptions_get_baked_in_system_prompt(self):
tool = PythonInterpreterTool()
tool.name = "fake_tool_name"
Expand Down

0 comments on commit 1831733

Please sign in to comment.