From 1304459410e0dc1874d8f739072efd1ab95cb7f2 Mon Sep 17 00:00:00 2001 From: Are Meisfjord Date: Tue, 18 Feb 2025 11:05:50 +0100 Subject: [PATCH] Improve function handling and name resolution in local Python executor --- src/smolagents/local_python_executor.py | 43 +++---------------------- tests/test_local_python_executor.py | 35 +++++++++++++++----- 2 files changed, 32 insertions(+), 46 deletions(-) diff --git a/src/smolagents/local_python_executor.py b/src/smolagents/local_python_executor.py index a7b2fb3f6..de42a9fec 100644 --- a/src/smolagents/local_python_executor.py +++ b/src/smolagents/local_python_executor.py @@ -57,7 +57,7 @@ class InterpreterError(ValueError): def custom_print(*args): return None - +custom_print.__name__ = "print" BASE_PYTHON_TOOLS = { "print": custom_print, @@ -314,6 +314,7 @@ def new_func(*args: Any, **kwargs: Any) -> Any: return result + new_func.__name__ = func_def.name return new_func @@ -562,43 +563,9 @@ def evaluate_call( custom_tools: Dict[str, Callable], authorized_imports: List[str], ) -> Any: - if not ( - isinstance(call.func, ast.Attribute) or isinstance(call.func, ast.Name) or isinstance(call.func, ast.Subscript) - ): - raise InterpreterError(f"This is not a correct function: {call.func}).") - if isinstance(call.func, ast.Attribute): - obj = evaluate_ast(call.func.value, state, static_tools, custom_tools, authorized_imports) - func_name = call.func.attr - if not hasattr(obj, func_name): - raise InterpreterError(f"Object {obj} has no attribute {func_name}") - func = getattr(obj, func_name) - - elif isinstance(call.func, ast.Name): - func_name = call.func.id - if func_name in state: - func = state[func_name] - elif func_name in static_tools: - func = static_tools[func_name] - elif func_name in custom_tools: - func = custom_tools[func_name] - elif func_name in ERRORS: - func = ERRORS[func_name] - else: - raise InterpreterError( - f"It is not permitted to evaluate other functions than the provided tools or functions defined/imported in previous code (tried to execute {call.func.id})." - ) - - elif isinstance(call.func, ast.Subscript): - value = evaluate_ast(call.func.value, state, static_tools, custom_tools, authorized_imports) - index = evaluate_ast(call.func.slice, state, static_tools, custom_tools, authorized_imports) - if isinstance(value, (list, tuple)): - func = value[index] - else: - raise InterpreterError(f"Cannot subscript object of type {type(value).__name__}") + func = evaluate_ast(call.func, state, static_tools, custom_tools, authorized_imports) + func_name = func.__name__ - if not callable(func): - raise InterpreterError(f"This is not a correct function: {call.func}).") - func_name = None args = [] for arg in call.args: if isinstance(arg, ast.Starred): @@ -704,7 +671,7 @@ def evaluate_name( close_matches = difflib.get_close_matches(name.id, list(state.keys())) if len(close_matches) > 0: return state[close_matches[0]] - raise InterpreterError(f"The variable `{name.id}` is not defined.") + raise InterpreterError(f"The name `{name.id}` is not defined.") def evaluate_condition( diff --git a/tests/test_local_python_executor.py b/tests/test_local_python_executor.py index f7ecc91ee..b00093306 100644 --- a/tests/test_local_python_executor.py +++ b/tests/test_local_python_executor.py @@ -85,12 +85,12 @@ def test_evaluate_call(self): state = {"x": 3} result, _ = evaluate_python_code(code, {"add_two": add_two}, state=state) assert result == 5 - self.assertDictEqualNoPrint(state, {"x": 3, "y": 5, "_operations_count": 3}) + self.assertDictEqualNoPrint(state, {"x": 3, "y": 5, "_operations_count": 4}) # Should not work without the tool with pytest.raises(InterpreterError) as e: evaluate_python_code(code, {}, state=state) - assert "tried to execute add_two" in str(e.value) + assert "The name `add_two` is not defined" in str(e.value) def test_evaluate_constant(self): code = "x = 3" @@ -104,7 +104,7 @@ def test_evaluate_dict(self): state = {"x": 3} result, _ = evaluate_python_code(code, {"add_two": add_two}, state=state) self.assertDictEqual(result, {"x": 3, "y": 5}) - self.assertDictEqualNoPrint(state, {"x": 3, "test_dict": {"x": 3, "y": 5}, "_operations_count": 7}) + self.assertDictEqualNoPrint(state, {"x": 3, "test_dict": {"x": 3, "y": 5}, "_operations_count": 8}) def test_evaluate_expression(self): code = "x = 3\ny = 5" @@ -141,7 +141,7 @@ def test_evaluate_list(self): state = {"x": 3} result, _ = evaluate_python_code(code, {"add_two": add_two}, state=state) self.assertListEqual(result, [3, 5]) - self.assertDictEqualNoPrint(state, {"x": 3, "test_list": [3, 5], "_operations_count": 5}) + self.assertDictEqualNoPrint(state, {"x": 3, "test_list": [3, 5], "_operations_count": 6}) def test_evaluate_name(self): code = "y = x" @@ -155,13 +155,13 @@ def test_evaluate_subscript(self): state = {"x": 3} result, _ = evaluate_python_code(code, {"add_two": add_two}, state=state) assert result == 5 - self.assertDictEqualNoPrint(state, {"x": 3, "test_list": [3, 5], "_operations_count": 9}) + self.assertDictEqualNoPrint(state, {"x": 3, "test_list": [3, 5], "_operations_count": 10}) code = "test_dict = {'x': x, 'y': add_two(x)}\ntest_dict['y']" state = {"x": 3} result, _ = evaluate_python_code(code, {"add_two": add_two}, state=state) assert result == 5 - self.assertDictEqualNoPrint(state, {"x": 3, "test_dict": {"x": 3, "y": 5}, "_operations_count": 11}) + self.assertDictEqualNoPrint(state, {"x": 3, "test_dict": {"x": 3, "y": 5}, "_operations_count": 12}) code = "vendor = {'revenue': 31000, 'rent': 50312}; vendor['ratio'] = round(vendor['revenue'] / vendor['rent'], 2)" state = {} @@ -185,7 +185,7 @@ def test_evaluate_for(self): state = {} result, _ = evaluate_python_code(code, {"range": range}, state=state) assert result == 2 - self.assertDictEqualNoPrint(state, {"x": 2, "i": 2, "_operations_count": 11}) + self.assertDictEqualNoPrint(state, {"x": 2, "i": 2, "_operations_count": 12}) def test_evaluate_binop(self): code = "y + x" @@ -1108,7 +1108,7 @@ def __{operator_name}__(self, other): del x x """), - "The variable `x` is not defined", + "The name `x` is not defined", ), ( dedent("""\ @@ -1397,3 +1397,22 @@ def test_check_module_authorized(module: str, authorized_imports: list[str], exp "multiprocessing", ) assert check_module_authorized(module, authorized_imports, dangerous_patterns) == expected + + +def test__name__(): + code = dedent("""\ + def foo(): return 0 + foo.__name__ + """) + result, _ = evaluate_python_code(code, {}, {}) + assert result == "foo" + +def test_function_returning_function(): + code = dedent("""\ + def f(): + return lambda x: x + 1 + f()(1) + """) + result, _ = evaluate_python_code(code, {}, {}) + assert result == 2 +