Skip to content

Commit

Permalink
Add persistence for hierarchical teams
Browse files Browse the repository at this point in the history
  • Loading branch information
StreetLamb committed Jun 1, 2024
1 parent 4c01bdd commit 24259f9
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 9 deletions.
28 changes: 21 additions & 7 deletions backend/app/core/graph/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def convert_hierarchical_team_to_dict(
provider=member.provider,
model=member.model,
temperature=member.temperature,
interrupt=member.interrupt,
)
elif member.type == "leader":
teams[leader_name].members[member_name] = GraphLeader(
Expand Down Expand Up @@ -192,7 +193,6 @@ def enter_chain(state: TeamState, team: GraphTeam) -> dict[str, Any]:
This makes it so that the states of each graph don't get intermixed.
"""
task = state["task"]

results = {
"messages": task,
"team_name": team.name,
Expand Down Expand Up @@ -240,7 +240,9 @@ def create_tools_condition(


def create_hierarchical_graph(
teams: dict[str, GraphTeam], leader_name: str
teams: dict[str, GraphTeam],
leader_name: str,
memory: BaseCheckpointSaver | None = None,
) -> CompiledGraph:
"""Create the team's graph.
Expand All @@ -256,6 +258,8 @@ def create_hierarchical_graph(
dict: A dictionary representing the graph of teams.
"""
build = StateGraph(TeamState)
# Create a list to store member names that require human intervention before tool calling
interrupt_member_names = []
# Add the start and end node
build.add_node(
leader_name,
Expand Down Expand Up @@ -294,22 +298,32 @@ def create_hierarchical_graph(
# if member can call tools, then add tool node
if len(member.tools) >= 1:
build.add_node(
f"{member.name}_tools",
f"{name}_tools",
ToolNode([all_skills[tool].tool for tool in member.tools]),
)
# After tools node is called, agent node is called next.
build.add_edge(f"{name}_tools", name)
# Check if member requires human intervention before tool calling
if member.interrupt:
interrupt_member_names.append(f"{name}_tools")
elif isinstance(member, GraphLeader):
subgraph = create_hierarchical_graph(teams, leader_name=name)
# subgraphs do not require memory
subgraph = create_hierarchical_graph(teams, leader_name=name, memory=None)
enter = partial(enter_chain, team=teams[name])
build.add_node(name, enter | subgraph | exit_chain)
build.add_node(
name,
enter | subgraph | exit_chain,
)
else:
continue
# If member has tools, we create conditional edge to either tool node or back to leader.
if isinstance(member, GraphMember) and len(member.tools) >= 1:
build.add_conditional_edges(
name, should_continue, create_tools_condition(name, leader_name)
)
# Check if member requires human intervention before tool calling
if member.interrupt:
interrupt_member_names.append(f"{member.name}_tools")
else:
build.add_edge(name, leader_name)
conditional_mapping = {v: v for v in members}
Expand All @@ -318,7 +332,7 @@ def create_hierarchical_graph(

build.set_entry_point(leader_name)
build.set_finish_point("FinalAnswer")
graph = build.compile()
graph = build.compile(checkpointer=memory, interrupt_before=interrupt_member_names)
return graph


Expand Down Expand Up @@ -419,7 +433,7 @@ async def generator(
if team.workflow == "hierarchical":
teams = convert_hierarchical_team_to_dict(team, members)
team_leader = list(teams.keys())[0]
root = create_hierarchical_graph(teams, leader_name=team_leader)
root = create_hierarchical_graph(teams, leader_name=team_leader, memory=memory)
state = {
"messages": formatted_messages,
"team_name": teams[team_leader].name,
Expand Down
10 changes: 8 additions & 2 deletions backend/app/core/graph/members.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,7 @@ async def work(self, state: TeamState) -> ReturnTeamState:
class LeaderNode(BaseNode):
leader_prompt = ChatPromptTemplate.from_messages(
[
MessagesPlaceholder(variable_name="messages"),
(
"system",
(
Expand All @@ -230,7 +231,6 @@ class LeaderNode(BaseNode):
"{team_members_info}"
),
),
MessagesPlaceholder(variable_name="messages"),
(
"system",
"Given the conversation above, who should act next? Or should we FINISH? Select one of: {options}.",
Expand Down Expand Up @@ -342,7 +342,13 @@ async def summarise(self, state: TeamState) -> dict[str, list[BaseMessage]]:
team_members_name = self.get_team_members_name(state["team_members"])
team_name = state["team_name"]
team_responses = self.get_team_responses(state["messages"])
team_task = state["messages"][0].content
# TODO: optimise looking for task
# The most recent human message is the team's most recent task
team_task = ""
for message in state["messages"][::-1]:
if isinstance(message, HumanMessage) and isinstance(message.content, str):
team_task = message.content
break

summarise_chain: RunnableSerializable[Any, Any] = (
self.summariser_prompt.partial(
Expand Down

0 comments on commit 24259f9

Please sign in to comment.