Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feat] Add Sequential and AsyncSequential agents #270

Merged
merged 4 commits into from
Nov 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions lagent/agents/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from .agent import Agent, AgentDict, AgentList, AsyncAgent
from .agent import Agent, AgentDict, AgentList, AsyncAgent, AsyncSequential, Sequential
from .react import AsyncReAct, ReAct
from .stream import AgentForInternLM, AsyncAgentForInternLM, AsyncMathCoder, MathCoder

__all__ = [
'Agent', 'AgentDict', 'AgentList', 'AsyncAgent', 'AgentForInternLM',
'AsyncAgentForInternLM', 'MathCoder', 'AsyncMathCoder', 'ReAct',
'AsyncReAct'
'AsyncReAct', 'Sequential', 'AsyncSequential'
]
96 changes: 93 additions & 3 deletions lagent/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import warnings
from collections import OrderedDict, UserDict, UserList, abc
from functools import wraps
from itertools import chain, repeat
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Union

from lagent.agents.aggregator import DefaultAggregator
Expand Down Expand Up @@ -169,7 +170,22 @@ def reset(self, session_id=0):
self.memory.reset(session_id=session_id)

def __repr__(self):
return f"{self.__class__.__name__}(name='{self.name}', description='{self.description or ''}')"

def _rcsv_repr(agent, n_indent=1):
res = agent.__class__.__name__ + (f"(name='{agent.name}')"
if agent.name else '')
modules = [
f"{n_indent * ' '}({name}): {_rcsv_repr(agent, n_indent + 1)}"
for name, agent in getattr(agent, '_agents', {}).items()
]
if modules:
res += '(\n' + '\n'.join(
modules) + f'\n{(n_indent - 1) * " "})'
elif not res.endswith(')'):
res += '()'
return res

return _rcsv_repr(self)


class AsyncAgent(Agent):
Expand Down Expand Up @@ -225,6 +241,78 @@ async def forward(self,
return llm_response


class Sequential(Agent):
"""Sequential is an agent container that forwards messages to each agent
in the order they are added."""

def __init__(self, *agents: Union[Agent, AsyncAgent, Iterable], **kwargs):
super().__init__(**kwargs)
self._agents = OrderedDict()
if not agents:
raise ValueError('At least one agent should be provided')
if isinstance(agents[0],
Iterable) and not isinstance(agents[0], Agent):
if not agents[0]:
raise ValueError('At least one agent should be provided')
agents = agents[0]
for key, agent in enumerate(agents):
if isinstance(agents, Mapping):
key, agent = agent, agents[agent]
elif isinstance(agent, tuple):
key, agent = agent
self.add_agent(key, agent)

def add_agent(self, name: str, agent: Union[Agent, AsyncAgent]):
assert isinstance(
agent, (Agent, AsyncAgent
)), f'{type(agent)} is not an Agent or AsyncAgent subclass'
self._agents[str(name)] = agent

def forward(self,
*message: AgentMessage,
session_id=0,
exit_at: Optional[int] = None,
**kwargs) -> AgentMessage:
assert exit_at is None or exit_at >= 0, 'exit_at should be greater than or equal to 0'
if exit_at is None:
exit_at = len(self) - 1
iterator = chain.from_iterable(repeat(self._agents.values()))
for _ in range(exit_at + 1):
agent = next(iterator)
if isinstance(message, AgentMessage):
message = (message, )
message = agent(*message, session_id=session_id, **kwargs)
return message

def __getitem__(self, key):
if isinstance(key, int) and key < 0:
assert key >= -len(self), 'index out of range'
key = len(self) + key
return self._agents[str(key)]

def __len__(self):
return len(self._agents)


class AsyncSequential(Sequential, AsyncAgent):

async def forward(self,
*message: AgentMessage,
session_id=0,
exit_at: Optional[int] = None,
**kwargs) -> AgentMessage:
assert exit_at is None or exit_at >= 0, 'exit_at should be greater than or equal to 0'
if exit_at is None:
exit_at = len(self) - 1
iterator = chain.from_iterable(repeat(self._agents.values()))
for _ in range(exit_at + 1):
agent = next(iterator)
if isinstance(message, AgentMessage):
message = (message, )
message = await agent(*message, session_id=session_id, **kwargs)
return message


class AgentContainerMixin:

def __init_subclass__(cls):
Expand Down Expand Up @@ -276,18 +364,20 @@ def _backup(d):
setattr(cls, method, wrap_api(getattr(cls, method)))


class AgentList(UserList, Agent, AgentContainerMixin):
class AgentList(Agent, UserList, AgentContainerMixin):

def __init__(self,
agents: Optional[Iterable[Union[Agent, AsyncAgent]]] = None):
Agent.__init__(self, memory=None)
UserList.__init__(self, agents)
self.name = None


class AgentDict(UserDict, Agent, AgentContainerMixin):
class AgentDict(Agent, UserDict, AgentContainerMixin):

def __init__(self,
agents: Optional[Mapping[str, Union[Agent,
AsyncAgent]]] = None):
Agent.__init__(self, memory=None)
UserDict.__init__(self, agents)
self.name = None