Skip to content

Commit

Permalink
Create React Code Agent (#30386)
Browse files Browse the repository at this point in the history
* Add message passing format

Co-authored-by: Aymeric <[email protected]>
Co-authored-by: Cyril Kondratenko <[email protected]>
Co-authored-by: joffrey <[email protected]>
  • Loading branch information
4 people committed Apr 24, 2024
1 parent 90f6949 commit d029168
Show file tree
Hide file tree
Showing 8 changed files with 328 additions and 53 deletions.
14 changes: 5 additions & 9 deletions docs/source/en/agents.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,8 @@ An agent is a system that uses an LLM as its engine, and it has access to functi

These *tools* are functions for performing a task, and they contain all necessary description for the agent to properly use them.

The agent can be programmed to:
- devise a series of actions/tools and run them all at once like the `CodeAgent` for example
- plan and execute actions/tools one by one and wait for the outcome of each action before launching the next one like the `ReactJSONAgent` for example
- devise a series of actions/tool calls and run them all at once, like our `CodeAgent`
- or plan and execute them one by one to wait for the outcome of the each action before launching the next one, thus following a Reflexion ⇒ Action ⇒ Perception cycle. Our `ReactJSONAgent` implements this latter framework.

### Types of agents

Expand Down Expand Up @@ -461,9 +460,7 @@ Before finally generating the image:

<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png">

This also works with its sibling `ReactAgent`:
<<<<<<< HEAD
=======
This also works with its sibling `ReactJSONAgent`:

```python
from transformers import ReactAgent
Expand All @@ -474,12 +471,11 @@ agent.run("Improve this prompt, then generate an image of it.", prompt="A rabbit
```

<Tip warning={true}>
>>>>>>> 2364c3bd3 (Support variable usage in ReactAgent)

```python
from transformers import ReactAgent
from transformers import ReactJSONAgent

agent = ReactAgent(llm_engine, tools=[tool], add_base_tools=True)
agent = ReactJSONAgent(llm_engine, tools=[tool], add_base_tools=True)

agent.run("Improve this prompt, then generate an image of it.", prompt="A rabbit wearing a space suit")
```
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1044,6 +1044,8 @@
],
"tools": [
"Agent",
"ReactJSONAgent",
"ReactCodeAgent",
"ReactAgent",
"CodeAgent",
"PipelineTool",
Expand Down
1 change: 1 addition & 0 deletions src/transformers/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
from .text_to_speech import TextToSpeechTool
from .translation import TranslationTool
from .default_tools import CalculatorTool, PythonEvaluatorTool
from .agents import ReactJSONAgent, CodeAgent, ReactCodeAgent
from .agents import ReactAgent, CodeAgent
else:
import sys
Expand Down
172 changes: 149 additions & 23 deletions src/transformers/tools/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from enum import Enum
from ast import literal_eval
from dataclasses import dataclass
from typing import Dict, Union, List, Optional
from huggingface_hub import hf_hub_download, list_spaces, InferenceClient
from math import sqrt
from typing import Dict, List, Union

Expand All @@ -37,10 +39,10 @@
supports_remote,
)
from .prompts import DEFAULT_CODE_SYSTEM_PROMPT, DEFAULT_REACT_SYSTEM_PROMPT
from .prompts import DEFAULT_REACT_SYSTEM_PROMPT, DEFAULT_CODE_SYSTEM_PROMPT, DEFAULT_REACT_CODE_SYSTEM_PROMPT
from .python_interpreter import evaluate_python_code
from PIL import Image

logging.set_verbosity_info()
logger = logging.get_logger(__name__)

_tools_are_initialized = False
Expand Down Expand Up @@ -86,8 +88,9 @@ class FinalAnswerTool(Tool):
inputs = {"answer": {"type": str, "description": "The final answer to the problem"}}
output_type = str

def __call__(self):
pass
def __call__(self, args):
return args


class MessageRole(str, Enum):
USER = "user"
Expand All @@ -99,7 +102,7 @@ class MessageRole(str, Enum):
@classmethod
def roles(cls):
return [r.value for r in cls]


def get_remote_tools(organization="huggingface-tools"):
if is_offline_mode():
Expand Down Expand Up @@ -178,11 +181,19 @@ def parse_json_blob(json_blob: str):
try:
first_accolade_index = json_blob.find("{")
last_accolade_index = [a.start() for a in list(re.finditer('}', json_blob))][-1]
json_blob = json_blob[first_accolade_index:last_accolade_index+1].replace("\\", "")
json_blob = json_blob[first_accolade_index:last_accolade_index+1]
return json.loads(json_blob)
except Exception as e:
raise ValueError(f"The JSON blob you used is invalid: due to the following error: {e}. Make sure to correct its formatting.")

def parse_code_blob(code_blob: str):
try:
pattern = r'```(?:py)?\n(.*?)```'
match = re.search(pattern, code_blob, re.DOTALL)
return match.group(1)
except Exception as e:
raise ValueError(f"The code blob you used is invalid: due to the following error: {e}. This means that the regex pattern {pattern} was not respected. Make sure to correct its formatting.")


def parse_json_tool_call(json_blob: str):
json_blob = json_blob.replace("```json", "").replace("```", "")
Expand Down Expand Up @@ -325,7 +336,7 @@ def get_clean_message_list(message_list: List[Dict[str, str]], role_conversions:
message["role"] = role_conversions[role]

if len(final_message_list) > 0 and message["role"] == final_message_list[-1]["role"]:
final_message_list[-1]["content"] += "\n" + message["content"]
final_message_list[-1]["content"] += "\n============\n" + message["content"]
else:
final_message_list.append(message)
return final_message_list
Expand All @@ -336,11 +347,12 @@ def get_clean_message_list(message_list: List[Dict[str, str]], role_conversions:
MessageRole.TOOL_RESPONSE: MessageRole.USER,
}

class LLMEngine:
def __init__(self, client):
self.client = client

def call(self, messages: List[Dict[str, str]], stop=["Output:"]) -> str:
class HfEngine:
def __init__(self, repo_id: str = "meta-llama/Meta-Llama-3-70B-Instruct"):
self.client = InferenceClient(model=repo_id, timeout=120)

def call(self, messages: List[Dict[str, str]], stop=["Output:", "assistant"]) -> str:
# Get clean message list
messages = get_clean_message_list(messages, role_conversions=llama_role_conversions)

Expand Down Expand Up @@ -381,7 +393,6 @@ def __init__(
self._toolbox = Toolbox(tools, add_base_tools=add_base_tools)

self.system_prompt = format_prompt(self._toolbox, self.system_prompt_template, self.tool_description_template)
self.messages = []
self.prompt = None
self.logs = []

Expand All @@ -395,9 +406,10 @@ def toolbox(self) -> Dict[str, Tool]:
return self._toolbox


def get_inner_memory_from_logs(self) -> str:
def write_inner_memory_from_logs(self) -> List[Dict[str, str]]:
"""
Reads past llm_outputs, actions, and observations or errors from the logs.
Reads past llm_outputs, actions, and observations or errors from the logs into a series of messages
that can be used as input to the LLM.
"""
prompt_message = {
"role": MessageRole.SYSTEM,
Expand All @@ -417,7 +429,11 @@ def get_inner_memory_from_logs(self) -> str:
memory.append(thought_message)

if 'error' in step_log:
message_content = "Error: " + str(step_log["error"]) + "\nNow let's retry: take care not to repeat previous errors! Try to adopt different approaches if you can.\n"
message_content = (
"Error: "
+ str(step_log["error"])
+ "\nNow let's retry: take care not to repeat previous errors! Try to adopt different approaches if you can.\n"
)
else:
message_content = f"Observation: {step_log['observation']}"
tool_response_message = {
Expand All @@ -432,7 +448,6 @@ def show_message_history(self):
self.log.info('\n'.join(self.messages))



def extract_action(self, llm_output: str, split_token: str) -> str:
"""
Parse action from the LLM output
Expand Down Expand Up @@ -478,7 +493,7 @@ def execute(self, tool_name: str, arguments: Dict[str, str]) -> None:

except Exception as e:
raise AgentExecutionError(
f"Error in tool call execution: {e}.\nYour input was probably incorrect.\n"
f"Error in tool call execution: {e}.\nYou provided an incorrect input to the tool.\n"
f"As a reminder, this tool's description is the following:\n{get_tool_description_with_args(self.toolbox.tools[tool_name])}"
)

Expand Down Expand Up @@ -512,7 +527,7 @@ def default_tool_description_template(self)-> str:
"""
This template is taking can desbribe a tool as it is expected by the model
"""
logger.warning_once(
logger.info_once(
"\nNo tool description template is defined for this tokenizer - using a default tool description template "
"that implements the ChatML format (without BOS/EOS tokens!). If the default is not appropriate for "
"your model, please set `tokenizer.tool_description_template` to an appropriate template. "
Expand Down Expand Up @@ -559,7 +574,7 @@ def run(self, task, return_generated_code=False, **kwargs):

self.logs.append({"task": task_message, "system_prompt": self.system_prompt})

memory = self.get_inner_memory_from_logs()
memory = self.write_inner_memory_from_logs()

self.log.info("====Executing with these messages====")
self.log.info(memory)
Expand Down Expand Up @@ -598,9 +613,9 @@ def run(self, task, return_generated_code=False, **kwargs):

class ReactAgent(Agent):
"""
A class for an agent that solves the given task step by step, using the ReAct framework.
This agent that solves the given task step by step, using the ReAct framework.
While the objective is not reached, the agent will perform a cycle of thinking and acting.
The action will be parsed from the LLM output, it will be the call of a tool from the toolbox, with arguments provided by the LLM.
The action will be parsed from the LLM output, it consists in calls to tools from the toolbox, with arguments chosen by the LLM engine.
"""
def __init__(
self,
Expand Down Expand Up @@ -662,6 +677,8 @@ def run(self, task, **kwargs):
'<<additional_args>>',
f"You have been provided with these initial arguments, that you should absolutely use if needed rather than hallucinating arguments: {str(self.state)}."
)
else:
self.system_prompt = self.system_prompt.replace('<<additional_args>>', '')

self.log.info("=====New task=====")
self.log.debug("System prompt is as follows:")
Expand All @@ -688,12 +705,30 @@ def run(self, task, **kwargs):

return final_answer

class ReactJSONAgent(ReactAgent):
def __init__(
self,
llm_engine,
system_prompt=DEFAULT_REACT_SYSTEM_PROMPT,
tool_description_template=None,
max_iterations=5,
llm_engine_grammar=None,
**kwargs
):
super().__init__(
llm_engine,
system_prompt=system_prompt,
tool_description_template=tool_description_template if tool_description_template else self.default_tool_description_template,
max_iterations=max_iterations,
llm_engine_grammar=llm_engine_grammar,
**kwargs
)

def step(self):
"""
Runs agent step with the current prompt (task + state).
"""
agent_memory = self.get_inner_memory_from_logs()
agent_memory = self.write_inner_memory_from_logs()
self.logs[-1]["agent_memory"] = agent_memory.copy()

self.prompt = agent_memory
Expand All @@ -707,9 +742,9 @@ def step(self):
self.log.info(agent_memory)

if self.llm_engine_grammar:
llm_output = self.llm_engine(self.prompt, stop=["Observation:"], grammar=self.llm_engine_grammar)
llm_output = self.llm_engine(self.prompt, stop=["Observation:", "assistant"], grammar=self.llm_engine_grammar)
else:
llm_output = self.llm_engine(self.prompt, stop=["Observation:"])
llm_output = self.llm_engine(self.prompt, stop=["Observation:", "assistant"])
self.log.debug("=====Output message of the LLM:=====")
self.log.debug(llm_output)
self.logs[-1]["llm_output"] = llm_output
Expand Down Expand Up @@ -760,3 +795,94 @@ def step(self):
self.log.info(updated_information)
self.logs[-1]["observation"] = updated_information
return None


class ReactCodeAgent(ReactAgent):
"""
This agent that solves the given task step by step, using the ReAct framework:
while the objective is not reached, the agent will perform a cycle of thinking and acting.
To run its actions, this agent can execute a whole blob of code, thus performing many actions at a time.
"""
def __init__(
self,
llm_engine,
system_prompt=DEFAULT_REACT_CODE_SYSTEM_PROMPT,
tool_description_template=None,
max_iterations=5,
llm_engine_grammar=None,
**kwargs
):

super().__init__(
llm_engine,
system_prompt=system_prompt,
tool_description_template=tool_description_template if tool_description_template else self.default_tool_description_template,
max_iterations=max_iterations,
llm_engine_grammar = llm_engine_grammar,
**kwargs
)


def step(self):
"""
Runs agent step with the current prompt (task + state).
"""
agent_memory = self.write_inner_memory_from_logs()
self.logs[-1]["agent_memory"] = agent_memory.copy()

self.prompt = agent_memory

self.log.debug("=====New step=====")

# Add new step in logs
self.logs.append({})

self.log.info("=====Calling LLM with these messages:=====")
self.log.info(agent_memory)


if self.llm_engine_grammar:
llm_output = self.llm_engine(self.prompt, stop=["Observation:", "assistant", "<end_code>"], grammar=self.llm_engine_grammar)
else:
llm_output = self.llm_engine(self.prompt, stop=["Observation:", "assistant", "<end_code>"])
self.log.debug("=====Output message of the LLM:=====")
self.log.debug(llm_output)
self.logs[-1]["llm_output"] = llm_output

# Parse
self.log.debug("=====Extracting action=====")
rationale, code_action = self.extract_action(
llm_output=llm_output,
split_token="Code:"
)

self.logs[-1]["rationale"] = rationale
self.logs[-1]["tool_call"] = {
"tool_name": 'code interpreter',
"tool_arguments": code_action
}

# Execute
try:
code_action = parse_code_blob(code_action)
except Exception as e:
error_msg = f"Error in code parsing: {e}. Be sure to provide correct code"
self.log.error(error_msg)
raise AgentParsingError(error_msg)

# Execute
try:
self.log.info("\n\n==Executing the code below:==")
self.log.info(code_action)
available_tools = {**BASE_PYTHON_TOOLS.copy(), **self.toolbox.tools}
result = evaluate_python_code(code_action, available_tools, state=self.state)
self.logs[-1]["observation"] = result
except Exception as e:
error_msg = f"Error in execution: {e}. Be sure to provide correct code."
self.log.error(error_msg, exc_info=1)
raise AgentExecutionError(error_msg)
for line in code_action.split('\n'):
if line[:len('final_answer')] == 'final_answer':
return result
else:
return None
3 changes: 1 addition & 2 deletions src/transformers/tools/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,10 +480,9 @@ def get_tool_description_with_args(
tool: Tool, description_template: str = DEFAULT_TOOL_DESCRIPTION_TEMPLATE
) -> str:
compiled_template = compile_jinja_template(description_template)
rendered = compiled_template.render(
return compiled_template.render(
tool=tool, # **self.special_tokens_map
)
return rendered


@lru_cache
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/tools/default_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ def __init__(self):
self.numexpr = numexpr

def __call__(self, expression):
if type(expression) != str:
expression = expression['expression']
local_dict = {"pi": math.pi, "e": math.e}
output = str(
self.numexpr.evaluate(
Expand Down

0 comments on commit d029168

Please sign in to comment.