Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve function handling and name resolution in local Python executor #695

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 5 additions & 38 deletions src/smolagents/local_python_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class InterpreterError(ValueError):

def custom_print(*args):
return None

custom_print.__name__ = "print"

BASE_PYTHON_TOOLS = {
"print": custom_print,
Expand Down Expand Up @@ -314,6 +314,7 @@ def new_func(*args: Any, **kwargs: Any) -> Any:

return result

new_func.__name__ = func_def.name
return new_func


Expand Down Expand Up @@ -562,43 +563,9 @@ def evaluate_call(
custom_tools: Dict[str, Callable],
authorized_imports: List[str],
) -> Any:
if not (
isinstance(call.func, ast.Attribute) or isinstance(call.func, ast.Name) or isinstance(call.func, ast.Subscript)
):
raise InterpreterError(f"This is not a correct function: {call.func}).")
if isinstance(call.func, ast.Attribute):
obj = evaluate_ast(call.func.value, state, static_tools, custom_tools, authorized_imports)
func_name = call.func.attr
if not hasattr(obj, func_name):
raise InterpreterError(f"Object {obj} has no attribute {func_name}")
func = getattr(obj, func_name)

elif isinstance(call.func, ast.Name):
func_name = call.func.id
if func_name in state:
func = state[func_name]
elif func_name in static_tools:
func = static_tools[func_name]
elif func_name in custom_tools:
func = custom_tools[func_name]
elif func_name in ERRORS:
func = ERRORS[func_name]
else:
raise InterpreterError(
f"It is not permitted to evaluate other functions than the provided tools or functions defined/imported in previous code (tried to execute {call.func.id})."
)

elif isinstance(call.func, ast.Subscript):
value = evaluate_ast(call.func.value, state, static_tools, custom_tools, authorized_imports)
index = evaluate_ast(call.func.slice, state, static_tools, custom_tools, authorized_imports)
if isinstance(value, (list, tuple)):
func = value[index]
else:
raise InterpreterError(f"Cannot subscript object of type {type(value).__name__}")
func = evaluate_ast(call.func, state, static_tools, custom_tools, authorized_imports)
func_name = func.__name__

if not callable(func):
raise InterpreterError(f"This is not a correct function: {call.func}).")
func_name = None
args = []
for arg in call.args:
if isinstance(arg, ast.Starred):
Expand Down Expand Up @@ -704,7 +671,7 @@ def evaluate_name(
close_matches = difflib.get_close_matches(name.id, list(state.keys()))
if len(close_matches) > 0:
return state[close_matches[0]]
raise InterpreterError(f"The variable `{name.id}` is not defined.")
raise InterpreterError(f"The name `{name.id}` is not defined.")


def evaluate_condition(
Expand Down
35 changes: 27 additions & 8 deletions tests/test_local_python_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,12 +85,12 @@ def test_evaluate_call(self):
state = {"x": 3}
result, _ = evaluate_python_code(code, {"add_two": add_two}, state=state)
assert result == 5
self.assertDictEqualNoPrint(state, {"x": 3, "y": 5, "_operations_count": 3})
self.assertDictEqualNoPrint(state, {"x": 3, "y": 5, "_operations_count": 4})

# Should not work without the tool
with pytest.raises(InterpreterError) as e:
evaluate_python_code(code, {}, state=state)
assert "tried to execute add_two" in str(e.value)
assert "The name `add_two` is not defined" in str(e.value)

def test_evaluate_constant(self):
code = "x = 3"
Expand All @@ -104,7 +104,7 @@ def test_evaluate_dict(self):
state = {"x": 3}
result, _ = evaluate_python_code(code, {"add_two": add_two}, state=state)
self.assertDictEqual(result, {"x": 3, "y": 5})
self.assertDictEqualNoPrint(state, {"x": 3, "test_dict": {"x": 3, "y": 5}, "_operations_count": 7})
self.assertDictEqualNoPrint(state, {"x": 3, "test_dict": {"x": 3, "y": 5}, "_operations_count": 8})

def test_evaluate_expression(self):
code = "x = 3\ny = 5"
Expand Down Expand Up @@ -141,7 +141,7 @@ def test_evaluate_list(self):
state = {"x": 3}
result, _ = evaluate_python_code(code, {"add_two": add_two}, state=state)
self.assertListEqual(result, [3, 5])
self.assertDictEqualNoPrint(state, {"x": 3, "test_list": [3, 5], "_operations_count": 5})
self.assertDictEqualNoPrint(state, {"x": 3, "test_list": [3, 5], "_operations_count": 6})

def test_evaluate_name(self):
code = "y = x"
Expand All @@ -155,13 +155,13 @@ def test_evaluate_subscript(self):
state = {"x": 3}
result, _ = evaluate_python_code(code, {"add_two": add_two}, state=state)
assert result == 5
self.assertDictEqualNoPrint(state, {"x": 3, "test_list": [3, 5], "_operations_count": 9})
self.assertDictEqualNoPrint(state, {"x": 3, "test_list": [3, 5], "_operations_count": 10})

code = "test_dict = {'x': x, 'y': add_two(x)}\ntest_dict['y']"
state = {"x": 3}
result, _ = evaluate_python_code(code, {"add_two": add_two}, state=state)
assert result == 5
self.assertDictEqualNoPrint(state, {"x": 3, "test_dict": {"x": 3, "y": 5}, "_operations_count": 11})
self.assertDictEqualNoPrint(state, {"x": 3, "test_dict": {"x": 3, "y": 5}, "_operations_count": 12})

code = "vendor = {'revenue': 31000, 'rent': 50312}; vendor['ratio'] = round(vendor['revenue'] / vendor['rent'], 2)"
state = {}
Expand All @@ -185,7 +185,7 @@ def test_evaluate_for(self):
state = {}
result, _ = evaluate_python_code(code, {"range": range}, state=state)
assert result == 2
self.assertDictEqualNoPrint(state, {"x": 2, "i": 2, "_operations_count": 11})
self.assertDictEqualNoPrint(state, {"x": 2, "i": 2, "_operations_count": 12})

def test_evaluate_binop(self):
code = "y + x"
Expand Down Expand Up @@ -1108,7 +1108,7 @@ def __{operator_name}__(self, other):
del x
x
"""),
"The variable `x` is not defined",
"The name `x` is not defined",
),
(
dedent("""\
Expand Down Expand Up @@ -1397,3 +1397,22 @@ def test_check_module_authorized(module: str, authorized_imports: list[str], exp
"multiprocessing",
)
assert check_module_authorized(module, authorized_imports, dangerous_patterns) == expected


def test__name__():
code = dedent("""\
def foo(): return 0
foo.__name__
""")
result, _ = evaluate_python_code(code, {}, {})
assert result == "foo"

def test_function_returning_function():
code = dedent("""\
def f():
return lambda x: x + 1
f()(1)
""")
result, _ = evaluate_python_code(code, {}, {})
assert result == 2