diff --git a/backend/chainlit/__init__.py b/backend/chainlit/__init__.py index 635a2d2cd9..279762d3eb 100644 --- a/backend/chainlit/__init__.py +++ b/backend/chainlit/__init__.py @@ -14,6 +14,9 @@ import asyncio from typing import TYPE_CHECKING, Any, Dict +from literalai import ChatGeneration, CompletionGeneration, GenerationMessage +from pydantic.dataclasses import dataclass + import chainlit.input_widget as input_widget from chainlit.action import Action from chainlit.cache import cache @@ -44,22 +47,21 @@ ) from chainlit.step import Step, step from chainlit.sync import make_async, run_sync -from chainlit.types import InputAudioChunk, OutputAudioChunk, ChatProfile, Starter +from chainlit.types import ChatProfile, InputAudioChunk, OutputAudioChunk, Starter from chainlit.user import PersistedUser, User from chainlit.user_session import user_session from chainlit.utils import make_module_getattr from chainlit.version import __version__ -from literalai import ChatGeneration, CompletionGeneration, GenerationMessage -from pydantic.dataclasses import dataclass from .callbacks import ( action_callback, author_rename, + data_layer, header_auth_callback, oauth_callback, - on_audio_start, on_audio_chunk, on_audio_end, + on_audio_start, on_chat_end, on_chat_resume, on_chat_start, @@ -186,6 +188,7 @@ def acall(self): "on_stop", "action_callback", "on_settings_update", + "data_layer", ] diff --git a/backend/chainlit/callbacks.py b/backend/chainlit/callbacks.py index 02c6feb124..f03625da64 100644 --- a/backend/chainlit/callbacks.py +++ b/backend/chainlit/callbacks.py @@ -1,8 +1,12 @@ import inspect from typing import Any, Awaitable, Callable, Dict, List, Optional +from fastapi import Request, Response +from starlette.datastructures import Headers + from chainlit.action import Action from chainlit.config import config +from chainlit.data.base import BaseDataLayer from chainlit.message import Message from chainlit.oauth_providers import get_configured_oauth_providers from chainlit.step import Step, step @@ -10,13 +14,11 @@ from chainlit.types import ChatProfile, Starter, ThreadDict from chainlit.user import User from chainlit.utils import wrap_user_function -from fastapi import Request, Response -from starlette.datastructures import Headers @trace def password_auth_callback( - func: Callable[[str, str], Awaitable[Optional[User]]] + func: Callable[[str, str], Awaitable[Optional[User]]], ) -> Callable: """ Framework agnostic decorator to authenticate the user. @@ -38,7 +40,7 @@ async def password_auth_callback(username: str, password: str) -> Optional[User] @trace def header_auth_callback( - func: Callable[[Headers], Awaitable[Optional[User]]] + func: Callable[[Headers], Awaitable[Optional[User]]], ) -> Callable: """ Framework agnostic decorator to authenticate the user via a header @@ -177,7 +179,7 @@ def set_chat_profiles( @trace def set_starters( - func: Callable[[Optional["User"]], Awaitable[List["Starter"]]] + func: Callable[[Optional["User"]], Awaitable[List["Starter"]]], ) -> Callable: """ Programmatic declaration of the available starter (can depend on the User from the session if authentication is setup). @@ -221,6 +223,7 @@ def on_audio_start(func: Callable) -> Callable: config.code.on_audio_start = wrap_user_function(func, with_task=False) return func + @trace def on_audio_chunk(func: Callable) -> Callable: """ @@ -254,7 +257,7 @@ def on_audio_end(func: Callable) -> Callable: @trace def author_rename( - func: Callable[[str], Awaitable[str]] + func: Callable[[str], Awaitable[str]], ) -> Callable[[str], Awaitable[str]]: """ Useful to rename the author of message to display more friendly author names in the UI. @@ -315,3 +318,17 @@ def on_settings_update( config.code.on_settings_update = wrap_user_function(func, with_task=True) return func + + +def data_layer( + func: Callable[[], BaseDataLayer], +) -> Callable[[], BaseDataLayer]: + """ + Hook to configure custom data layer. + """ + + # We don't use wrap_user_function here because: + # 1. We don't need to support async here and; + # 2. We don't want to change the API for get_data_layer() to be async, everywhere (at this point). + config.code.data_layer = func + return func diff --git a/backend/chainlit/config.py b/backend/chainlit/config.py index ad9c1aaf88..71fa8fca1f 100644 --- a/backend/chainlit/config.py +++ b/backend/chainlit/config.py @@ -17,21 +17,24 @@ ) import tomli -from chainlit.logger import logger -from chainlit.translations import lint_translation_json -from chainlit.version import __version__ from dataclasses_json import DataClassJsonMixin from pydantic.dataclasses import Field, dataclass from starlette.datastructures import Headers +from chainlit.data.base import BaseDataLayer +from chainlit.logger import logger +from chainlit.translations import lint_translation_json +from chainlit.version import __version__ + from ._utils import is_path_inside if TYPE_CHECKING: + from fastapi import Request, Response + from chainlit.action import Action from chainlit.message import Message - from chainlit.types import InputAudioChunk, ChatProfile, Starter, ThreadDict + from chainlit.types import ChatProfile, InputAudioChunk, Starter, ThreadDict from chainlit.user import User - from fastapi import Request, Response BACKEND_ROOT = os.path.dirname(__file__) @@ -272,9 +275,9 @@ class CodeSettings: password_auth_callback: Optional[ Callable[[str, str], Awaitable[Optional["User"]]] ] = None - header_auth_callback: Optional[ - Callable[[Headers], Awaitable[Optional["User"]]] - ] = None + header_auth_callback: Optional[Callable[[Headers], Awaitable[Optional["User"]]]] = ( + None + ) oauth_callback: Optional[ Callable[[str, str, Dict[str, str], "User"], Awaitable[Optional["User"]]] ] = None @@ -293,9 +296,10 @@ class CodeSettings: set_chat_profiles: Optional[ Callable[[Optional["User"]], Awaitable[List["ChatProfile"]]] ] = None - set_starters: Optional[ - Callable[[Optional["User"]], Awaitable[List["Starter"]]] - ] = None + set_starters: Optional[Callable[[Optional["User"]], Awaitable[List["Starter"]]]] = ( + None + ) + data_layer: Optional[Callable[[], BaseDataLayer]] = None @dataclass() diff --git a/backend/chainlit/data/__init__.py b/backend/chainlit/data/__init__.py index 3082f91744..dfc359c24b 100644 --- a/backend/chainlit/data/__init__.py +++ b/backend/chainlit/data/__init__.py @@ -11,9 +11,16 @@ def get_data_layer(): global _data_layer + print("Getting data layer", _data_layer) if not _data_layer: - if api_key := os.environ.get("LITERAL_API_KEY"): + from chainlit.config import config + + if config.code.data_layer: + # When @data_layer is configured, call it to get data layer. + _data_layer = config.code.data_layer() + elif api_key := os.environ.get("LITERAL_API_KEY"): + # When LITERAL_API_KEY is defined, use LiteralAI data layer from .literalai import LiteralDataLayer # support legacy LITERAL_SERVER variable as fallback diff --git a/backend/tests/conftest.py b/backend/tests/conftest.py index 7cc4266371..5eda37063b 100644 --- a/backend/tests/conftest.py +++ b/backend/tests/conftest.py @@ -1,11 +1,16 @@ import datetime from contextlib import asynccontextmanager +from pathlib import Path from typing import Callable from unittest.mock import AsyncMock, Mock import pytest import pytest_asyncio + +from chainlit import config +from chainlit.callbacks import data_layer from chainlit.context import ChainlitContext, context_var +from chainlit.data.base import BaseDataLayer from chainlit.session import HTTPSession, WebsocketSession from chainlit.user import PersistedUser from chainlit.user_session import UserSession @@ -79,3 +84,32 @@ def mock_websocket_session(): @pytest.fixture def mock_http_session(): return Mock(spec=HTTPSession) + + +@pytest.fixture +def mock_data_layer(monkeypatch: pytest.MonkeyPatch) -> AsyncMock: + mock_data_layer = AsyncMock(spec=BaseDataLayer) + + return mock_data_layer + + +@pytest.fixture +def mock_get_data_layer(mock_data_layer: AsyncMock, test_config: config.ChainlitConfig): + # Instantiate mock data layer + mock_get_data_layer = Mock(return_value=mock_data_layer) + + # Configure it using @data_layer decorator + return data_layer(mock_get_data_layer) + + +@pytest.fixture +def test_config(monkeypatch: pytest.MonkeyPatch, tmp_path: Path): + monkeypatch.setenv("CHAINLIT_ROOT_PATH", str(tmp_path)) + + test_config = config.load_config() + + monkeypatch.setattr("chainlit.callbacks.config", test_config) + monkeypatch.setattr("chainlit.server.config", test_config) + monkeypatch.setattr("chainlit.config.config", test_config) + + return test_config diff --git a/backend/tests/data/test_get_data_layer.py b/backend/tests/data/test_get_data_layer.py new file mode 100644 index 0000000000..403e86ae9a --- /dev/null +++ b/backend/tests/data/test_get_data_layer.py @@ -0,0 +1,18 @@ +from unittest.mock import AsyncMock, Mock + +from chainlit.data import get_data_layer + + +async def test_get_data_layer( + mock_data_layer: AsyncMock, + mock_get_data_layer: Mock, +): + # Check whether the data layer is properly set + assert mock_data_layer == get_data_layer() + + mock_get_data_layer.assert_called_once() + + # Getting the data layer again, should not result in additional call + assert mock_data_layer == get_data_layer() + + mock_get_data_layer.assert_called_once() diff --git a/backend/tests/test_callbacks.py b/backend/tests/test_callbacks.py index 3c41000137..c00674ba9b 100644 --- a/backend/tests/test_callbacks.py +++ b/backend/tests/test_callbacks.py @@ -1,22 +1,17 @@ from __future__ import annotations +from unittest.mock import AsyncMock, Mock + import pytest -from chainlit.callbacks import password_auth_callback -from chainlit.user import User from chainlit import config +from chainlit.callbacks import data_layer, password_auth_callback +from chainlit.data import get_data_layer +from chainlit.data.base import BaseDataLayer +from chainlit.user import User -@pytest.fixture -def test_config(monkeypatch: pytest.MonkeyPatch): - test_config = config.load_config() - - monkeypatch.setattr("chainlit.callbacks.config", test_config) - - return test_config - - -async def test_password_auth_callback(test_config): +async def test_password_auth_callback(test_config: config.ChainlitConfig): @password_auth_callback async def auth_func(username: str, password: str) -> User | None: if username == "testuser" and password == "testpass": # nosec B105 @@ -36,10 +31,11 @@ async def auth_func(username: str, password: str) -> User | None: assert result is None -async def test_header_auth_callback(test_config): - from chainlit.callbacks import header_auth_callback +async def test_header_auth_callback(test_config: config.ChainlitConfig): from starlette.datastructures import Headers + from chainlit.callbacks import header_auth_callback + @header_auth_callback async def auth_func(headers: Headers) -> User | None: if headers.get("Authorization") == "Bearer valid_token": @@ -66,7 +62,7 @@ async def auth_func(headers: Headers) -> User | None: assert result is None -async def test_oauth_callback(test_config): +async def test_oauth_callback(test_config: config.ChainlitConfig): from unittest.mock import patch from chainlit.callbacks import oauth_callback @@ -107,7 +103,7 @@ async def auth_func( assert result is None -async def test_on_message(mock_chainlit_context, test_config): +async def test_on_message(mock_chainlit_context, test_config: config.ChainlitConfig): from chainlit.callbacks import on_message from chainlit.message import Message @@ -137,7 +133,7 @@ async def handle_message(message: Message): context.session.emit.assert_called() -async def test_on_stop(mock_chainlit_context, test_config): +async def test_on_stop(mock_chainlit_context, test_config: config.ChainlitConfig): from chainlit.callbacks import on_stop from chainlit.config import config @@ -159,7 +155,9 @@ async def handle_stop(): assert stop_called -async def test_action_callback(mock_chainlit_context, test_config): +async def test_action_callback( + mock_chainlit_context, test_config: config.ChainlitConfig +): from chainlit.action import Action from chainlit.callbacks import action_callback from chainlit.config import config @@ -184,7 +182,9 @@ async def handle_action(action: Action): assert action_handled -async def test_on_settings_update(mock_chainlit_context, test_config): +async def test_on_settings_update( + mock_chainlit_context, test_config: config.ChainlitConfig +): from chainlit.callbacks import on_settings_update from chainlit.config import config @@ -207,7 +207,7 @@ async def handle_settings_update(settings: dict): assert settings_updated -async def test_author_rename(test_config): +async def test_author_rename(test_config: config.ChainlitConfig): from chainlit.callbacks import author_rename from chainlit.config import config @@ -238,7 +238,7 @@ async def rename_author(author: str) -> str: assert result == "Human" -async def test_on_chat_start(mock_chainlit_context, test_config): +async def test_on_chat_start(mock_chainlit_context, test_config: config.ChainlitConfig): from chainlit.callbacks import on_chat_start from chainlit.config import config @@ -263,7 +263,9 @@ async def handle_chat_start(): context.session.emit.assert_called() -async def test_on_chat_resume(mock_chainlit_context, test_config): +async def test_on_chat_resume( + mock_chainlit_context, test_config: config.ChainlitConfig +): from chainlit.callbacks import on_chat_resume from chainlit.config import config from chainlit.types import ThreadDict @@ -299,7 +301,9 @@ async def handle_chat_resume(thread: ThreadDict): assert chat_resumed -async def test_set_chat_profiles(mock_chainlit_context, test_config): +async def test_set_chat_profiles( + mock_chainlit_context, test_config: config.ChainlitConfig +): from chainlit.callbacks import set_chat_profiles from chainlit.config import config from chainlit.types import ChatProfile @@ -327,7 +331,7 @@ async def get_chat_profiles(user): assert result[0].markdown_description == "A test profile" -async def test_set_starters(mock_chainlit_context, test_config): +async def test_set_starters(mock_chainlit_context, test_config: config.ChainlitConfig): from chainlit.callbacks import set_starters from chainlit.config import config from chainlit.types import Starter @@ -358,7 +362,7 @@ async def get_starters(user): assert result[0].message == "Test Message" -async def test_on_chat_end(mock_chainlit_context, test_config): +async def test_on_chat_end(mock_chainlit_context, test_config: config.ChainlitConfig): from chainlit.callbacks import on_chat_end from chainlit.config import config @@ -381,3 +385,22 @@ async def handle_chat_end(): # Check that the emit method was called context.session.emit.assert_called() + + +async def test_data_layer_config( + mock_data_layer: AsyncMock, + test_config: config.ChainlitConfig, + mock_get_data_layer: Mock, +): + """Test whether we can properly configure a data layer.""" + + # Test that the callback is properly registered + assert test_config.code.data_layer is not None + + # Call the registered callback + result = test_config.code.data_layer() + + # Check that the result is an instance of MockDataLayer + assert isinstance(result, BaseDataLayer) + + mock_get_data_layer.assert_called_once() diff --git a/backend/tests/test_server.py b/backend/tests/test_server.py index 36c65124d6..c61827c8a3 100644 --- a/backend/tests/test_server.py +++ b/backend/tests/test_server.py @@ -1,17 +1,18 @@ +import datetime # Added import for datetime import os -from pathlib import Path import pathlib +import tempfile +from pathlib import Path from typing import Callable from unittest.mock import AsyncMock, Mock, create_autospec, mock_open -import datetime # Added import for datetime import pytest -import tempfile -from chainlit.session import WebsocketSession +from fastapi.testclient import TestClient + from chainlit.auth import get_current_user from chainlit.config import APP_ROOT, ChainlitConfig, load_config from chainlit.server import app -from fastapi.testclient import TestClient +from chainlit.session import WebsocketSession from chainlit.types import FileReference from chainlit.user import PersistedUser # Added import for PersistedUser @@ -21,17 +22,6 @@ def test_client(): return TestClient(app) -@pytest.fixture -def test_config(monkeypatch: pytest.MonkeyPatch, tmp_path: Path): - monkeypatch.setenv("CHAINLIT_ROOT_PATH", str(tmp_path)) - - config = load_config() - - monkeypatch.setattr("chainlit.server.config", config) - - return config - - @pytest.fixture def mock_load_translation(test_config: ChainlitConfig, monkeypatch: pytest.MonkeyPatch): monkeypatch.setattr( diff --git a/cypress/e2e/custom_data_layer/sql_alchemy.py b/cypress/e2e/custom_data_layer/sql_alchemy.py index 430b02a58e..809a109e86 100644 --- a/cypress/e2e/custom_data_layer/sql_alchemy.py +++ b/cypress/e2e/custom_data_layer/sql_alchemy.py @@ -1,6 +1,5 @@ from typing import Optional -import chainlit.data as cl_data from chainlit.data.sql_alchemy import SQLAlchemyDataLayer from chainlit.data.storage_clients.azure import AzureStorageClient @@ -10,9 +9,12 @@ account_url="", container="" ) -cl_data._data_layer = SQLAlchemyDataLayer( - conninfo="", storage_provider=storage_client -) + +@cl.data_layer +def data_layer(): + return SQLAlchemyDataLayer( + conninfo="", storage_provider=storage_client + ) @cl.on_chat_start diff --git a/cypress/e2e/data_layer/main.py b/cypress/e2e/data_layer/main.py index 2fc6181ffa..e03adea92c 100644 --- a/cypress/e2e/data_layer/main.py +++ b/cypress/e2e/data_layer/main.py @@ -194,7 +194,9 @@ async def build_debug_url(self) -> str: return "" -cl_data._data_layer = TestDataLayer() +@cl.data_layer +def data_layer(): + return TestDataLayer() async def send_count(): diff --git a/cypress/e2e/data_layer/spec.cy.ts b/cypress/e2e/data_layer/spec.cy.ts index 0320e89e98..394bb6e745 100644 --- a/cypress/e2e/data_layer/spec.cy.ts +++ b/cypress/e2e/data_layer/spec.cy.ts @@ -116,7 +116,7 @@ describe('Data Layer', () => { afterEach(() => { cy.get('@threadHistoryFile').then((threadHistoryFile) => { // Clean up the thread history file - cy.exec(`rm ${threadHistoryFile}`); + cy.exec(`rm -f ${threadHistoryFile}`); }); });