diff --git a/examples/frontend/app.py b/examples/frontend/app.py index 8654cc60..2d4ce496 100644 --- a/examples/frontend/app.py +++ b/examples/frontend/app.py @@ -2,9 +2,10 @@ import sys import streamlit as st -from codeinterpreterapi import File from utils import get_images # type: ignore +from codeinterpreterapi import File + # Page configuration st.set_page_config(layout="wide") diff --git a/examples/frontend/chainlitui.py b/examples/frontend/chainlitui.py index 268dbe59..36b66f6c 100644 --- a/examples/frontend/chainlitui.py +++ b/examples/frontend/chainlitui.py @@ -1,4 +1,5 @@ import chainlit as cl # type: ignore + from codeinterpreterapi import CodeInterpreterSession from codeinterpreterapi import File as CIFile diff --git a/examples/frontend/utils.py b/examples/frontend/utils.py index b33792df..e6c18c7c 100644 --- a/examples/frontend/utils.py +++ b/examples/frontend/utils.py @@ -4,6 +4,7 @@ from typing import Optional import streamlit as st + from codeinterpreterapi import CodeInterpreterSession diff --git a/examples/use_additional_tools.py b/examples/use_additional_tools.py index 96cf0915..d8c48bb7 100644 --- a/examples/use_additional_tools.py +++ b/examples/use_additional_tools.py @@ -4,13 +4,15 @@ so it can download the bitcoin chart from yahoo finance and plot it for you """ + import csv import io from typing import Any -from codeinterpreterapi import CodeInterpreterSession from langchain_core.tools import BaseTool +from codeinterpreterapi import CodeInterpreterSession + class ExampleKnowledgeBaseTool(BaseTool): name: str = "salary_database" diff --git a/src/codeinterpreterapi/__init__.py b/src/codeinterpreterapi/__init__.py index 2afd77d1..6c6061c3 100644 --- a/src/codeinterpreterapi/__init__.py +++ b/src/codeinterpreterapi/__init__.py @@ -1,9 +1,8 @@ -from . import _patch_parser # noqa - from codeinterpreterapi.config import settings from codeinterpreterapi.schema import File from codeinterpreterapi.session import CodeInterpreterSession +from . import _patch_parser # noqa __all__ = [ "CodeInterpreterSession", diff --git a/src/codeinterpreterapi/session.py b/src/codeinterpreterapi/session.py index 6cf742a2..81f96a10 100644 --- a/src/codeinterpreterapi/session.py +++ b/src/codeinterpreterapi/session.py @@ -177,16 +177,18 @@ def _choose_agent(self) -> BaseSingleActionAgent: ], ) if isinstance(self.llm, ChatOpenAI) or isinstance(self.llm, AzureChatOpenAI) - else ConversationalChatAgent.from_llm_and_tools( - llm=self.llm, - tools=self.tools, - system_message=settings.SYSTEM_MESSAGE.content.__str__(), - ) - if isinstance(self.llm, BaseChatModel) - else ConversationalAgent.from_llm_and_tools( - llm=self.llm, - tools=self.tools, - prefix=settings.SYSTEM_MESSAGE.content.__str__(), + else ( + ConversationalChatAgent.from_llm_and_tools( + llm=self.llm, + tools=self.tools, + system_message=settings.SYSTEM_MESSAGE.content.__str__(), + ) + if isinstance(self.llm, BaseChatModel) + else ConversationalAgent.from_llm_and_tools( + llm=self.llm, + tools=self.tools, + prefix=settings.SYSTEM_MESSAGE.content.__str__(), + ) ) ) @@ -194,17 +196,21 @@ def _history_backend(self) -> BaseChatMessageHistory: return ( CodeBoxChatMessageHistory(codebox=self.codebox) if settings.HISTORY_BACKEND == "codebox" - else RedisChatMessageHistory( - session_id=str(self.session_id), - url=settings.REDIS_URL, - ) - if settings.HISTORY_BACKEND == "redis" - else PostgresChatMessageHistory( - session_id=str(self.session_id), - connection_string=settings.POSTGRES_URL, + else ( + RedisChatMessageHistory( + session_id=str(self.session_id), + url=settings.REDIS_URL, + ) + if settings.HISTORY_BACKEND == "redis" + else ( + PostgresChatMessageHistory( + session_id=str(self.session_id), + connection_string=settings.POSTGRES_URL, + ) + if settings.HISTORY_BACKEND == "postgres" + else ChatMessageHistory() + ) ) - if settings.HISTORY_BACKEND == "postgres" - else ChatMessageHistory() ) def _agent_executor(self) -> AgentExecutor: diff --git a/tests/chain_test.py b/tests/chain_test.py index bf301562..e655dd8d 100644 --- a/tests/chain_test.py +++ b/tests/chain_test.py @@ -1,12 +1,13 @@ from asyncio import run as _await +from langchain_openai import ChatOpenAI + from codeinterpreterapi.chains import ( aget_file_modifications, aremove_download_link, get_file_modifications, remove_download_link, ) -from langchain_openai import ChatOpenAI llm = ChatOpenAI(model="gpt-3.5-turbo")