Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix lint errors #9

Merged
merged 1 commit into from
May 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions backend/app/api/routes/members.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

def check_duplicate_names_on_create(
session: SessionDep, team_id: int, member_in: MemberCreate
):
) -> None:
"""Check if (name, team_id) is unique"""
statement = select(Member).where(
Member.name == member_in.name,
Expand All @@ -35,7 +35,7 @@ def check_duplicate_names_on_create(

def check_duplicate_names_on_update(
session: SessionDep, team_id: int, member_in: MemberUpdate, id: int
):
) -> None:
"""Check if (name, team_id) is unique"""
statement = select(Member).where(
Member.name == member_in.name,
Expand Down Expand Up @@ -133,6 +133,8 @@ def create_member(
"""
if not current_user.is_superuser:
team = session.get(Team, team_id)
if not team:
raise HTTPException(status_code=404, detail="Team not found.")
if team.owner_id != current_user.id:
raise HTTPException(status_code=400, detail="Not enough permissions")
member = Member.model_validate(member_in, update={"belongs_to": team_id})
Expand Down Expand Up @@ -181,7 +183,7 @@ def update_member(
if member_in.skills is not None:
skill_ids = [skill.id for skill in member_in.skills]
skills = session.exec(select(Skill).where(col(Skill.id).in_(skill_ids))).all()
member.skills = skills
member.skills = list(skills)

update_dict = member_in.model_dump(exclude_unset=True)
member.sqlmodel_update(update_dict)
Expand Down
8 changes: 5 additions & 3 deletions backend/app/api/routes/teams.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,9 @@
router = APIRouter()


async def check_duplicate_name_on_create(session: SessionDep, team_in: TeamCreate):
async def check_duplicate_name_on_create(
session: SessionDep, team_in: TeamCreate
) -> None:
"""Validate that team name is unique"""
statement = select(Team).where(Team.name == team_in.name)
team = session.exec(statement).first()
Expand All @@ -72,7 +74,7 @@ async def check_duplicate_name_on_create(session: SessionDep, team_in: TeamCreat

async def check_duplicate_name_on_update(
session: SessionDep, team_in: TeamUpdate, id: int
):
) -> None:
"""Validate that team name is unique"""
statement = select(Team).where(Team.name == team_in.name, Team.id != id)
team = session.exec(statement).first()
Expand Down Expand Up @@ -199,7 +201,7 @@ def delete_team(session: SessionDep, current_user: CurrentUser, id: int) -> Any:
@router.post("/{id}/stream")
async def stream(
session: SessionDep, current_user: CurrentUser, id: int, team_chat: TeamChat
):
) -> StreamingResponse:
"""
Stream a response to a user's input.
"""
Expand Down
41 changes: 23 additions & 18 deletions backend/app/core/graph/build.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import json
from collections import defaultdict, deque
from collections.abc import AsyncGenerator
from functools import partial
from typing import Any

from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
from langchain_core.runnables import RunnableLambda
from langgraph.graph import StateGraph
from langgraph.graph.graph import CompiledGraph

from app.core.graph.members import (
Leader,
Expand Down Expand Up @@ -41,19 +44,20 @@ def convert_team_to_dict(
"""
teams: dict[str, Team] = {}

in_counts = defaultdict(int)
out_counts = defaultdict(list[int])
in_counts: defaultdict[int, int] = defaultdict(int)
out_counts: defaultdict[int, list[int]] = defaultdict(list[int])
members_lookup: dict[int, MemberModel] = {}

for member in members:
assert member.id is not None, "member.id is unexpectedly None"
if member.source:
in_counts[member.id] += 1
out_counts[member.source].append(member.id)
else:
in_counts[member.id] = 0
members_lookup[member.id] = member

queue = deque()
queue: deque[int] = deque()

for member_id in in_counts:
if in_counts[member_id] == 0:
Expand All @@ -73,12 +77,11 @@ def convert_team_to_dict(
temperature=member.temperature,
)
# If member is not root team leader, add as a member
if member.type != "root":
if member.type != "root" and member.source:
member_name = member.name
leader = members_lookup[member.source]
leader_name = leader.name
teams[leader_name].members[member_name] = Member(
type=member.type,
name=member_name,
backstory=member.backstory or "",
role=member.role,
Expand All @@ -96,7 +99,7 @@ def convert_team_to_dict(
return teams


def format_teams(teams: dict[str, any]) -> dict[str, Team]:
def format_teams(teams: dict[str, dict[str, Any]]) -> dict[str, Team]:
"""
FOR TESTING PURPOSES ONLY!

Expand All @@ -113,7 +116,7 @@ def format_teams(teams: dict[str, any]) -> dict[str, Team]:
for team_name, team in teams.items():
if not isinstance(team, dict):
raise ValueError(f"Invalid team {team_name}. Teams must be dictionaries.")
members = team.get("members", {})
members: dict[str, dict[str, Any]] = team.get("members", {})
for k, v in members.items():
if v["type"] == "leader":
teams[team_name]["members"][k] = Leader(**v)
Expand All @@ -122,11 +125,13 @@ def format_teams(teams: dict[str, any]) -> dict[str, Team]:
return {team_name: Team(**team) for team_name, team in teams.items()}


def router(state: TeamState):
def router(state: TeamState) -> str:
return state["next"]


def enter_chain(state: TeamState, team: dict[str, str | list[Member | Leader]]):
def enter_chain(
state: TeamState, team: dict[str, str | list[Member | Leader]]
) -> dict[str, Any]:
"""
Initialise the sub-graph state.
This makes it so that the states of each graph don't get intermixed.
Expand All @@ -143,15 +148,15 @@ def enter_chain(state: TeamState, team: dict[str, str | list[Member | Leader]]):
return results


def exit_chain(state: TeamState):
def exit_chain(state: TeamState) -> dict[str, list[BaseMessage]]:
"""
Pass the final response back to the top-level graph's state.
"""
answer = state["messages"][-1]
return {"messages": [answer]}


def create_graph(teams: dict[str, Team], leader_name: str):
def create_graph(teams: dict[str, Team], leader_name: str) -> CompiledGraph:
"""Create the team's graph.

This function creates a graph representation of the given teams. The graph is represented as a dictionary where each key is a team name,
Expand Down Expand Up @@ -220,23 +225,23 @@ def create_graph(teams: dict[str, Team], leader_name: str):


async def generator(
team: TeamModel, members: list[Member], messages: list[ChatMessage]
):
team: TeamModel, members: list[MemberModel], messages: list[ChatMessage]
) -> AsyncGenerator[Any, Any]:
"""Create the graph and stream responses as JSON."""
teams = convert_team_to_dict(team, members)
team_leader = list(teams.keys())[0]
root = create_graph(teams, leader_name=team_leader)
messages = [
HumanMessage(message.content)
formatted_messages = [
HumanMessage(content=message.content)
if message.type == "human"
else AIMessage(message.content)
else AIMessage(content=message.content)
for message in messages
]

# TODO: Figure out how to use async_stream to stream responses from subgraphs
async for output in root.astream(
{
"messages": messages,
"messages": formatted_messages,
"team_name": teams[team_leader].name,
"team_members": teams[team_leader].members,
}
Expand Down
61 changes: 34 additions & 27 deletions backend/app/core/graph/members.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import operator
from typing import Annotated, TypedDict
from collections.abc import Sequence
from typing import Annotated, Any, TypedDict

from langchain.agents import (
AgentExecutor,
Expand All @@ -9,7 +10,8 @@
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
from langchain_core.output_parsers.openai_tools import JsonOutputKeyToolsParser
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.runnables import RunnableLambda
from langchain_core.runnables import RunnableLambda, RunnableSerializable
from langchain_core.tools import BaseTool
from pydantic import BaseModel, Field

from app.core.graph.models import all_models
Expand Down Expand Up @@ -50,7 +52,7 @@ class Team(BaseModel):
)


def update_name(name: str, new_name: str):
def update_name(name: str, new_name: str) -> str:
"""Update name at the onset."""
if not name:
return new_name
Expand All @@ -59,7 +61,7 @@ def update_name(name: str, new_name: str):

def update_members(
members: dict[str, Member | Leader] | None, new_members: dict[str, Member | Leader]
):
) -> dict[str, Member | Leader]:
"""Update members at the onset"""
if not members:
members = {}
Expand All @@ -79,15 +81,15 @@ class TeamState(TypedDict):

class BaseNode:
def __init__(self, provider: str, model: str, temperature: float):
self.model = all_models[provider](model=model, temperature=temperature)
self.final_answer_model = all_models[provider](model=model, temperature=0)
self.model = all_models[provider](model=model, temperature=temperature) # type: ignore[call-arg]
self.final_answer_model = all_models[provider](model=model, temperature=0) # type: ignore[call-arg]

def tag_with_name(self, ai_message: AIMessage, name: str):
def tag_with_name(self, ai_message: AIMessage, name: str) -> AIMessage:
"""Tag a name to the AI message"""
ai_message.name = name
return ai_message

def get_team_members_name(self, team_members: dict[str, Person]):
def get_team_members_name(self, team_members: dict[str, Member | Leader]) -> str:
"""Get the names of all team members as a string"""
return ",".join(list(team_members))

Expand All @@ -111,23 +113,24 @@ class WorkerNode(BaseNode):
]
)

def convert_output_to_ai_message(self, state: TeamState):
def convert_output_to_ai_message(self, agent_output: dict[str, str]) -> AIMessage:
"""Convert agent executor output to ai message"""
output = state["output"]
output = agent_output["output"]
return AIMessage(content=output)

def create_agent(
self, llm: BaseChatModel, prompt: ChatPromptTemplate, tools: list[str]
):
) -> AgentExecutor:
"""Create the agent executor. Tools must non-empty."""
tools = [all_skills[tool].tool for tool in tools]
agent = create_tool_calling_agent(llm, tools, prompt)
executor = AgentExecutor(agent=agent, tools=tools)
formatted_tools: Sequence[BaseTool] = [all_skills[tool].tool for tool in tools]
agent = create_tool_calling_agent(llm, formatted_tools, prompt)
executor = AgentExecutor(agent=agent, tools=formatted_tools) # type: ignore[arg-type]
return executor

async def work(self, state: TeamState):
async def work(self, state: TeamState) -> dict[str, list[BaseMessage]]:
name = state["next"]
member = state["team_members"][name]
assert isinstance(member, Member), "member is unexpectedly not a Member"
tools = member.tools
team_members_name = self.get_team_members_name(state["team_members"])
prompt = self.worker_prompt.partial(
Expand All @@ -139,9 +142,13 @@ async def work(self, state: TeamState):
agent = self.create_agent(self.model, prompt, tools)
chain = agent | RunnableLambda(self.convert_output_to_ai_message)
else:
chain = prompt.partial(agent_scratchpad=[]) | self.model
work_chain = chain | RunnableLambda(self.tag_with_name).bind(name=member.name)
result = await work_chain.ainvoke(state)
chain: RunnableSerializable[dict[str, Any], BaseMessage] = ( # type: ignore[no-redef]
prompt.partial(agent_scratchpad=[]) | self.model
)
work_chain: RunnableSerializable[dict[str, Any], Any] = chain | RunnableLambda(
self.tag_with_name # type: ignore[arg-type]
).bind(name=member.name)
result = await work_chain.ainvoke(state) # type: ignore[arg-type]
return {"messages": [result]}


Expand All @@ -165,14 +172,14 @@ class LeaderNode(BaseNode):
]
)

def get_team_members_info(self, team_members: list[Member]):
def get_team_members_info(self, team_members: dict[str, Member | Leader]) -> str:
"""Create a string containing team members name and role."""
result = ""
for member in team_members.values():
result += f"name: {member.name}\nrole: {member.role}\n\n"
return result

def get_tool_definition(self, options: list[str]):
def get_tool_definition(self, options: list[str]) -> dict[str, Any]:
"""Return the tool definition to choose next team member and provide the task."""
return {
"type": "function",
Expand All @@ -199,14 +206,14 @@ def get_tool_definition(self, options: list[str]):
},
}

async def delegate(self, state: TeamState):
async def delegate(self, state: TeamState) -> dict[str, Any]:
team_members_name = self.get_team_members_name(state["team_members"])
team_name = state["team_name"]
team_members_info = self.get_team_members_info(state["team_members"])
options = list(state["team_members"]) + ["FINISH"]
tools = [self.get_tool_definition(options)]

delegate_chain = (
delegate_chain: RunnableSerializable[Any, Any] = (
self.leader_prompt.partial(
team_name=team_name,
team_members_name=team_members_name,
Expand All @@ -216,7 +223,7 @@ async def delegate(self, state: TeamState):
| self.model.bind_tools(tools=tools)
| JsonOutputKeyToolsParser(key_name="route", first_tool_only=True)
)
result = await delegate_chain.ainvoke(state)
result: dict[str, Any] = await delegate_chain.ainvoke(state)
if not result:
return {
"task": [HumanMessage(content="No further tasks.", name=team_name)],
Expand Down Expand Up @@ -254,28 +261,28 @@ class SummariserNode(BaseNode):
]
)

def get_team_responses(self, messages: list[BaseMessage]):
def get_team_responses(self, messages: list[BaseMessage]) -> str:
"""Create a string containing the team's responses."""
result = ""
for message in messages:
result += f"{message.name}: {message.content}\n"
return result

async def summarise(self, state: TeamState):
async def summarise(self, state: TeamState) -> dict[str, list[BaseMessage]]:
team_members_name = self.get_team_members_name(state["team_members"])
team_name = state["team_name"]
team_responses = self.get_team_responses(state["messages"])
team_task = state["messages"][0].content

summarise_chain = (
summarise_chain: RunnableSerializable[Any, Any] = (
self.summariser_prompt.partial(
team_name=team_name,
team_members_name=team_members_name,
team_task=team_task,
team_responses=team_responses,
)
| self.final_answer_model
| RunnableLambda(self.tag_with_name).bind(name="FinalAnswer")
| RunnableLambda(self.tag_with_name).bind(name="FinalAnswer") # type: ignore[arg-type]
)
result = await summarise_chain.ainvoke(state)
return {"messages": [result]}
Loading
Loading