Skip to content

Commit

Permalink
Fix bracket logging in rich
Browse files Browse the repository at this point in the history
  • Loading branch information
aymeric-roucher committed Feb 22, 2025
1 parent 7927bca commit 62b794e
Show file tree
Hide file tree
Showing 6 changed files with 72 additions and 7 deletions.
4 changes: 2 additions & 2 deletions examples/open_deep_research/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,14 @@
FindNextTool,
PageDownTool,
PageUpTool,
SearchInformationTool,
SimpleTextBrowser,
VisitTool,
)
from scripts.visual_qa import visualizer

from smolagents import (
CodeAgent,
GoogleSearchTool,
# HfApiModel,
LiteLLMModel,
ToolCallingAgent,
Expand Down Expand Up @@ -98,7 +98,7 @@ def main():
browser = SimpleTextBrowser(**BROWSER_CONFIG)

WEB_TOOLS = [
SearchInformationTool(browser),
GoogleSearchTool(provider="serper"),
VisitTool(browser),
PageUpTool(browser),
PageDownTool(browser),
Expand Down
43 changes: 40 additions & 3 deletions src/smolagents/local_python_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from collections.abc import Mapping
from importlib import import_module
from types import ModuleType
from typing import Any, Callable, Dict, List, Optional, Tuple
from typing import Any, Callable, Dict, List, Optional, Set, Tuple

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -868,6 +868,41 @@ def inner_evaluate(generators: List[ast.comprehension], index: int, current_stat
return inner_evaluate(listcomp.generators, 0, state)


def evaluate_setcomp(
setcomp: ast.SetComp,
state: Dict[str, Any],
static_tools: Dict[str, Callable],
custom_tools: Dict[str, Callable],
authorized_imports: List[str],
) -> Set[Any]:
result = set()
for gen in setcomp.generators:
iter_value = evaluate_ast(gen.iter, state, static_tools, custom_tools, authorized_imports)
for value in iter_value:
new_state = state.copy()
set_value(
gen.target,
value,
new_state,
static_tools,
custom_tools,
authorized_imports,
)
if all(
evaluate_ast(if_clause, new_state, static_tools, custom_tools, authorized_imports)
for if_clause in gen.ifs
):
element = evaluate_ast(
setcomp.elt,
new_state,
static_tools,
custom_tools,
authorized_imports,
)
result.add(element)
return result


def evaluate_try(
try_node: ast.Try,
state: Dict[str, Any],
Expand Down Expand Up @@ -1196,6 +1231,10 @@ def evaluate_ast(
return tuple((evaluate_ast(elt, *common_params) for elt in expression.elts))
elif isinstance(expression, (ast.ListComp, ast.GeneratorExp)):
return evaluate_listcomp(expression, *common_params)
elif isinstance(expression, ast.DictComp):
return evaluate_dictcomp(expression, *common_params)
elif isinstance(expression, ast.SetComp):
return evaluate_setcomp(expression, *common_params)
elif isinstance(expression, ast.UnaryOp):
return evaluate_unaryop(expression, *common_params)
elif isinstance(expression, ast.Starred):
Expand Down Expand Up @@ -1268,8 +1307,6 @@ def evaluate_ast(
evaluate_ast(expression.upper, *common_params) if expression.upper is not None else None,
evaluate_ast(expression.step, *common_params) if expression.step is not None else None,
)
elif isinstance(expression, ast.DictComp):
return evaluate_dictcomp(expression, *common_params)
elif isinstance(expression, ast.While):
return evaluate_while(expression, *common_params)
elif isinstance(expression, (ast.Import, ast.ImportFrom)):
Expand Down
2 changes: 1 addition & 1 deletion src/smolagents/monitoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ class LogLevel(IntEnum):
class AgentLogger:
def __init__(self, level: LogLevel = LogLevel.INFO):
self.level = level
self.console = Console()
self.console = Console(width=60)

def log(self, *args, level: str | LogLevel = LogLevel.INFO, **kwargs) -> None:
"""Logs a message to the console.
Expand Down
15 changes: 14 additions & 1 deletion src/smolagents/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,13 +65,26 @@ def _is_pillow_available():
]


def escape_code_brackets(text: str) -> str:
"""Escapes square brackets in code segments while preserving Rich styling tags."""

def replace_bracketed_content(match):
content = match.group(1)
cleaned = re.sub(
r"bold|red|green|blue|yellow|magenta|cyan|white|black|italic|dim|\s|#[0-9a-fA-F]{6}", "", content
)
return f"\\[{content}\\]" if cleaned.strip() else f"[{content}]"

return re.sub(r"\[([^\]]*)\]", replace_bracketed_content, text)


class AgentError(Exception):
"""Base class for other agent-related exceptions"""

def __init__(self, message, logger: "AgentLogger"):
super().__init__(message)
self.message = message
logger.log(f"[bold red]{message}[/bold red]", level="ERROR")
logger.log(f"[bold red]{escape_code_brackets(message)}[/bold red]", level="ERROR")

def dict(self) -> Dict[str, str]:
return {"type": self.__class__.__name__, "message": str(self.message)}
Expand Down
10 changes: 10 additions & 0 deletions tests/test_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -780,6 +780,16 @@ def test_call_with_provide_run_summary(self, provide_run_summary):
)
assert result == expected_summary

def test_errors_logging(self):
def fake_code_model(messages, stop_sequences=None, grammar=None) -> str:
return ChatMessage(role="assistant", content="Code:\n```py\nsecret=3;['1', '2'][secret]\n```")

agent = CodeAgent(tools=[], model=fake_code_model, verbosity_level=1)

with agent.logger.console.capture() as capture:
agent.run("Test request")
assert "secret\\\\" in repr(capture.get())


class MultiAgentsTests(unittest.TestCase):
def test_multiagents_save(self):
Expand Down
5 changes: 5 additions & 0 deletions tests/test_local_python_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,11 @@ def test_listcomp(self):
result, _ = evaluate_python_code(code, {"range": range}, state={})
assert result == [0, 1, 2]

def test_setcomp(self):
code = "batman_times = {entry['time'] for entry in [{'time': 10}, {'time': 19}, {'time': 20}]}"
result, _ = evaluate_python_code(code, {}, state={})
assert result == {10, 19, 20}

def test_break_continue(self):
code = "for i in range(10):\n if i == 5:\n break\ni"
result, _ = evaluate_python_code(code, {"range": range}, state={})
Expand Down

0 comments on commit 62b794e

Please sign in to comment.