Skip to content

Commit

Permalink
Refactor code to use new AsyncPostgresSaver. Delete old PostgresSaver…
Browse files Browse the repository at this point in the history
… class.
  • Loading branch information
StreetLamb committed Aug 13, 2024
1 parent 96500b4 commit 3a6a429
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 585 deletions.
8 changes: 8 additions & 0 deletions backend/app/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import warnings
from typing import Annotated, Any, Literal

from psycopg.rows import dict_row
from pydantic import (
AnyUrl,
BeforeValidator,
Expand Down Expand Up @@ -66,6 +67,13 @@ def SQLALCHEMY_DATABASE_URI(self) -> PostgresDsn:
path=self.POSTGRES_DB,
)

# For checkpointer
SQLALCHEMY_CONNECTION_KWARGS: dict[str, Any] = {
"autocommit": True,
"prepare_threshold": 0,
"row_factory": dict_row,
}

@computed_field # type: ignore[misc]
@property
def PG_DATABASE_URI(self) -> str:
Expand Down
11 changes: 7 additions & 4 deletions backend/app/core/graph/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
)
from langchain_core.runnables import RunnableLambda
from langchain_core.runnables.config import RunnableConfig
from langgraph.checkpoint import BaseCheckpointSaver
from langgraph.checkpoint.base import BaseCheckpointSaver
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
from langgraph.graph import END, StateGraph
from langgraph.graph.graph import CompiledGraph
from langgraph.prebuilt import (
Expand All @@ -22,7 +23,6 @@
from psycopg import AsyncConnection

from app.core.config import settings
from app.core.graph.checkpoint.postgres import PostgresSaver
from app.core.graph.members import (
GraphLeader,
GraphMember,
Expand Down Expand Up @@ -471,8 +471,11 @@ async def generator(
]

try:
async with await AsyncConnection.connect(settings.PG_DATABASE_URI) as conn:
checkpointer = PostgresSaver(async_connection=conn)
async with await AsyncConnection.connect(
settings.PG_DATABASE_URI,
**settings.SQLALCHEMY_CONNECTION_KWARGS,
) as conn:
checkpointer = AsyncPostgresSaver(conn=conn)
if team.workflow == "hierarchical":
teams = convert_hierarchical_team_to_dict(team, members)
team_leader = list(teams.keys())[0]
Expand Down
Loading

0 comments on commit 3a6a429

Please sign in to comment.