Skip to content

Commit

Permalink
@data_layer API hook to configure data layer. (#1463)
Browse files Browse the repository at this point in the history
* @data_layer API hook to configure data layer.
* Use @data_layer in e2e test.
* Unit tests for get_data_layer() and @data_layer.
* Test cleanup should never fail.
  • Loading branch information
dokterbob authored Nov 7, 2024
1 parent fc5acde commit 82f08bf
Show file tree
Hide file tree
Showing 11 changed files with 169 additions and 69 deletions.
11 changes: 7 additions & 4 deletions backend/chainlit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
import asyncio
from typing import TYPE_CHECKING, Any, Dict

from literalai import ChatGeneration, CompletionGeneration, GenerationMessage
from pydantic.dataclasses import dataclass

import chainlit.input_widget as input_widget
from chainlit.action import Action
from chainlit.cache import cache
Expand Down Expand Up @@ -44,22 +47,21 @@
)
from chainlit.step import Step, step
from chainlit.sync import make_async, run_sync
from chainlit.types import InputAudioChunk, OutputAudioChunk, ChatProfile, Starter
from chainlit.types import ChatProfile, InputAudioChunk, OutputAudioChunk, Starter
from chainlit.user import PersistedUser, User
from chainlit.user_session import user_session
from chainlit.utils import make_module_getattr
from chainlit.version import __version__
from literalai import ChatGeneration, CompletionGeneration, GenerationMessage
from pydantic.dataclasses import dataclass

from .callbacks import (
action_callback,
author_rename,
data_layer,
header_auth_callback,
oauth_callback,
on_audio_start,
on_audio_chunk,
on_audio_end,
on_audio_start,
on_chat_end,
on_chat_resume,
on_chat_start,
Expand Down Expand Up @@ -186,6 +188,7 @@ def acall(self):
"on_stop",
"action_callback",
"on_settings_update",
"data_layer",
]


Expand Down
29 changes: 23 additions & 6 deletions backend/chainlit/callbacks.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,24 @@
import inspect
from typing import Any, Awaitable, Callable, Dict, List, Optional

from fastapi import Request, Response
from starlette.datastructures import Headers

from chainlit.action import Action
from chainlit.config import config
from chainlit.data.base import BaseDataLayer
from chainlit.message import Message
from chainlit.oauth_providers import get_configured_oauth_providers
from chainlit.step import Step, step
from chainlit.telemetry import trace
from chainlit.types import ChatProfile, Starter, ThreadDict
from chainlit.user import User
from chainlit.utils import wrap_user_function
from fastapi import Request, Response
from starlette.datastructures import Headers


@trace
def password_auth_callback(
func: Callable[[str, str], Awaitable[Optional[User]]]
func: Callable[[str, str], Awaitable[Optional[User]]],
) -> Callable:
"""
Framework agnostic decorator to authenticate the user.
Expand All @@ -38,7 +40,7 @@ async def password_auth_callback(username: str, password: str) -> Optional[User]

@trace
def header_auth_callback(
func: Callable[[Headers], Awaitable[Optional[User]]]
func: Callable[[Headers], Awaitable[Optional[User]]],
) -> Callable:
"""
Framework agnostic decorator to authenticate the user via a header
Expand Down Expand Up @@ -177,7 +179,7 @@ def set_chat_profiles(

@trace
def set_starters(
func: Callable[[Optional["User"]], Awaitable[List["Starter"]]]
func: Callable[[Optional["User"]], Awaitable[List["Starter"]]],
) -> Callable:
"""
Programmatic declaration of the available starter (can depend on the User from the session if authentication is setup).
Expand Down Expand Up @@ -221,6 +223,7 @@ def on_audio_start(func: Callable) -> Callable:
config.code.on_audio_start = wrap_user_function(func, with_task=False)
return func


@trace
def on_audio_chunk(func: Callable) -> Callable:
"""
Expand Down Expand Up @@ -254,7 +257,7 @@ def on_audio_end(func: Callable) -> Callable:

@trace
def author_rename(
func: Callable[[str], Awaitable[str]]
func: Callable[[str], Awaitable[str]],
) -> Callable[[str], Awaitable[str]]:
"""
Useful to rename the author of message to display more friendly author names in the UI.
Expand Down Expand Up @@ -315,3 +318,17 @@ def on_settings_update(

config.code.on_settings_update = wrap_user_function(func, with_task=True)
return func


def data_layer(
func: Callable[[], BaseDataLayer],
) -> Callable[[], BaseDataLayer]:
"""
Hook to configure custom data layer.
"""

# We don't use wrap_user_function here because:
# 1. We don't need to support async here and;
# 2. We don't want to change the API for get_data_layer() to be async, everywhere (at this point).
config.code.data_layer = func
return func
26 changes: 15 additions & 11 deletions backend/chainlit/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,24 @@
)

import tomli
from chainlit.logger import logger
from chainlit.translations import lint_translation_json
from chainlit.version import __version__
from dataclasses_json import DataClassJsonMixin
from pydantic.dataclasses import Field, dataclass
from starlette.datastructures import Headers

from chainlit.data.base import BaseDataLayer
from chainlit.logger import logger
from chainlit.translations import lint_translation_json
from chainlit.version import __version__

from ._utils import is_path_inside

if TYPE_CHECKING:
from fastapi import Request, Response

from chainlit.action import Action
from chainlit.message import Message
from chainlit.types import InputAudioChunk, ChatProfile, Starter, ThreadDict
from chainlit.types import ChatProfile, InputAudioChunk, Starter, ThreadDict
from chainlit.user import User
from fastapi import Request, Response


BACKEND_ROOT = os.path.dirname(__file__)
Expand Down Expand Up @@ -272,9 +275,9 @@ class CodeSettings:
password_auth_callback: Optional[
Callable[[str, str], Awaitable[Optional["User"]]]
] = None
header_auth_callback: Optional[
Callable[[Headers], Awaitable[Optional["User"]]]
] = None
header_auth_callback: Optional[Callable[[Headers], Awaitable[Optional["User"]]]] = (
None
)
oauth_callback: Optional[
Callable[[str, str, Dict[str, str], "User"], Awaitable[Optional["User"]]]
] = None
Expand All @@ -293,9 +296,10 @@ class CodeSettings:
set_chat_profiles: Optional[
Callable[[Optional["User"]], Awaitable[List["ChatProfile"]]]
] = None
set_starters: Optional[
Callable[[Optional["User"]], Awaitable[List["Starter"]]]
] = None
set_starters: Optional[Callable[[Optional["User"]], Awaitable[List["Starter"]]]] = (
None
)
data_layer: Optional[Callable[[], BaseDataLayer]] = None


@dataclass()
Expand Down
9 changes: 8 additions & 1 deletion backend/chainlit/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,16 @@

def get_data_layer():
global _data_layer
print("Getting data layer", _data_layer)

if not _data_layer:
if api_key := os.environ.get("LITERAL_API_KEY"):
from chainlit.config import config

if config.code.data_layer:
# When @data_layer is configured, call it to get data layer.
_data_layer = config.code.data_layer()
elif api_key := os.environ.get("LITERAL_API_KEY"):
# When LITERAL_API_KEY is defined, use LiteralAI data layer
from .literalai import LiteralDataLayer

# support legacy LITERAL_SERVER variable as fallback
Expand Down
34 changes: 34 additions & 0 deletions backend/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
import datetime
from contextlib import asynccontextmanager
from pathlib import Path
from typing import Callable
from unittest.mock import AsyncMock, Mock

import pytest
import pytest_asyncio

from chainlit import config
from chainlit.callbacks import data_layer
from chainlit.context import ChainlitContext, context_var
from chainlit.data.base import BaseDataLayer
from chainlit.session import HTTPSession, WebsocketSession
from chainlit.user import PersistedUser
from chainlit.user_session import UserSession
Expand Down Expand Up @@ -79,3 +84,32 @@ def mock_websocket_session():
@pytest.fixture
def mock_http_session():
return Mock(spec=HTTPSession)


@pytest.fixture
def mock_data_layer(monkeypatch: pytest.MonkeyPatch) -> AsyncMock:
mock_data_layer = AsyncMock(spec=BaseDataLayer)

return mock_data_layer


@pytest.fixture
def mock_get_data_layer(mock_data_layer: AsyncMock, test_config: config.ChainlitConfig):
# Instantiate mock data layer
mock_get_data_layer = Mock(return_value=mock_data_layer)

# Configure it using @data_layer decorator
return data_layer(mock_get_data_layer)


@pytest.fixture
def test_config(monkeypatch: pytest.MonkeyPatch, tmp_path: Path):
monkeypatch.setenv("CHAINLIT_ROOT_PATH", str(tmp_path))

test_config = config.load_config()

monkeypatch.setattr("chainlit.callbacks.config", test_config)
monkeypatch.setattr("chainlit.server.config", test_config)
monkeypatch.setattr("chainlit.config.config", test_config)

return test_config
18 changes: 18 additions & 0 deletions backend/tests/data/test_get_data_layer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from unittest.mock import AsyncMock, Mock

from chainlit.data import get_data_layer


async def test_get_data_layer(
mock_data_layer: AsyncMock,
mock_get_data_layer: Mock,
):
# Check whether the data layer is properly set
assert mock_data_layer == get_data_layer()

mock_get_data_layer.assert_called_once()

# Getting the data layer again, should not result in additional call
assert mock_data_layer == get_data_layer()

mock_get_data_layer.assert_called_once()
Loading

0 comments on commit 82f08bf

Please sign in to comment.