Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Switch to async for notebook client #26

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,10 @@ to communicate between user clients and the ai agent.
"username": {
"title": "Unique identifier of the user making the prompt.",
"type": "string"
},
"timestamp": {
"title": "Number of milliseconds elapsed since the epoch; i.e. January 1st, 1970 at midnight UTC.",
"type": "integer"
}
},
"required": ["id", "prompt"]
Expand All @@ -204,6 +208,10 @@ to communicate between user clients and the ai agent.
"type": {
"title": "Type message",
"enum": [0, 1, 2]
},
"timestamp": {
"title": "Number of milliseconds elapsed since the epoch; i.e. January 1st, 1970 at midnight UTC.",
"type": "integer"
}
},
"required": ["id", "prompt"]
Expand Down
3 changes: 2 additions & 1 deletion jupyter_nbmodel_client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,14 @@

from nbformat import NotebookNode

from .agent import BaseNbAgent
from .agent import AIMessageType, BaseNbAgent
from .client import NbModelClient, get_jupyter_notebook_websocket_url
from .model import KernelClient, NotebookModel

__version__ = "0.6.0"

__all__ = [
"AIMessageType",
"BaseNbAgent",
"KernelClient",
"NbModelClient",
Expand Down
206 changes: 173 additions & 33 deletions jupyter_nbmodel_client/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,28 @@

from __future__ import annotations

import asyncio
import os
from datetime import datetime, timezone
from enum import IntEnum
from logging import Logger
from typing import Any, Literal, cast

from pycrdt import ArrayEvent, Map, MapEvent

from .client import NbModelClient
from .client import REQUEST_TIMEOUT, NbModelClient


def timestamp() -> int:
"""Return the current timestamp in milliseconds since epoch."""
return int(datetime.now(timezone.utc).timestamp() * 1000.0)


class AIMessageType(IntEnum):
"""Type of AI agent message."""

ERROR = -1
"""Error message."""
ACKNOWLEDGE = 0
"""Prompt is being processed."""
SUGGESTION = 1
Expand Down Expand Up @@ -53,7 +64,7 @@ class BaseNbAgent(NbModelClient):
- method:`_on_cell_source_changes(self, cell_id: str, new_source: str, old_source: str, username: str | None = None):
Callback on cell source changes
- The agent can leverage the helper functions to send a reply to the user:
- method:`update_document(self, prompt: dict, message_type: AIMessageType, message: str, cell_id: str = "")`:
- method:`save_ai_message(self, message_type: AIMessageType, message: str, cell_id: str = "", parent_id: str | None = None)`:
Attach a message to the given cell (or to the notebook if no ``cell_id`` is provided).

Args:
Expand All @@ -79,9 +90,44 @@ class BaseNbAgent(NbModelClient):
"""

# FIXME implement username retrieval
def __init__(
self,
websocket_url: str,
path: str | None = None,
username: str = os.environ.get("USER", "username"),
timeout: float = REQUEST_TIMEOUT,
log: Logger | None = None,
) -> None:
super().__init__(websocket_url, path, username, timeout, log)
self._doc_events: asyncio.Queue[dict] = asyncio.Queue()
self._events_worker: asyncio.Task | None = None

async def start(self) -> None:
await super().start()
self._events_worker = asyncio.create_task(self._process_doc_events())

async def stop(self) -> None:
await super().stop()
if self._events_worker:
self._events_worker.cancel()
self._events_worker = None

while not self._doc_events.empty():
self._doc_events.get_nowait()

async def _process_doc_events(self) -> None:
while True:
event = await self._doc_events.get()
event_type = event.pop("type")
if event_type == "user":
self._on_user_prompt(**event)
if event_type == "source":
self._on_cell_source_changes(**event)

def _on_notebook_changes(
self, part: Literal["state"] | Literal["meta"] | Literal["cells"] | str, all_changes: Any
self,
part: Literal["state"] | Literal["meta"] | Literal["cells"] | str,
all_changes: Any,
) -> None:
# _debug_print_changes(part, all_changes)

Expand All @@ -96,7 +142,9 @@ def _on_notebook_changes(
for cell in delta["insert"]:
if "metadata" in cell:
new_metadata = cell["metadata"]
datalayer_ia = new_metadata.get("datalayer", {}).get("ai", {})
datalayer_ia = new_metadata.get(
"datalayer", {}
).get("ai", {})
prompts = datalayer_ia.get("prompts", [])
prompt_ids = {prompt["id"] for prompt in prompts}
new_prompts = prompt_ids.difference(
Expand All @@ -105,40 +153,86 @@ def _on_notebook_changes(
)
if new_prompts:
for prompt in filter(
lambda p: p.get("id") in new_prompts, prompts
lambda p: p.get("id") in new_prompts,
prompts,
):
self._on_user_prompt(cell["id"], prompt["prompt"])
self._doc_events.put_nowait(
{
"type": "user",
"cell_id": cell["id"],
"prompt_id": prompt["id"],
"prompt": prompt["prompt"],
"username": prompt.get("user"),
"timestamp": prompt.get(
"timestamp"
),
}
)
if "source" in cell:
self._on_cell_source_changes(cell["id"], cell["source"], "")
self._doc_events.put_nowait(
{
"type": "source",
"cell_id": cell["id"],
"new_source": cell["source"].to_py(),
"old_source": "",
}
)
elif path_length == 1:
# Change is on one cell
for key, change in changes.keys.items():
if key == "source":
if change["action"] == "add":
self._on_cell_source_changes(
changes.target["id"],
change["newValue"],
change.get("oldValue", ""),
self._doc_events.put_nowait(
{
"type": "source",
"cell_id": changes.target["id"],
"new_source": change["newValue"],
"old_source": change.get("oldValue", ""),
}
)
elif change["action"] == "update":
self._on_cell_source_changes(
changes.target["id"], change["newValue"], change["oldValue"]
self._doc_events.put_nowait(
{
"type": "source",
"cell_id": changes.target["id"],
"new_source": change["newValue"],
"old_source": change["oldValue"],
}
)
elif change["action"] == "delete":
self._on_cell_source_changes(
changes.target["id"], change.get("newValue"), change["oldValue"]
self._doc_events.put_nowait(
{
"type": "source",
"cell_id": changes.target["id"],
"new_source": change.get("newValue", ""),
"old_source": change["oldValue"],
}
)
elif key == "metadata":
new_metadata = change.get("newValue", {})
datalayer_ia = new_metadata.get("datalayer", {}).get("ai", {})
datalayer_ia = new_metadata.get("datalayer", {}).get(
"ai", {}
)
prompts = datalayer_ia.get("prompts", [])
prompt_ids = {prompt["id"] for prompt in prompts}
new_prompts = prompt_ids.difference(
message["parent_id"] for message in datalayer_ia.get("messages", [])
message["parent_id"]
for message in datalayer_ia.get("messages", [])
)
if new_prompts and change["action"] in {"add", "update"}:
for prompt in filter(lambda p: p.get("id") in new_prompts, prompts):
self._on_user_prompt(changes.target["id"], prompt["prompt"])
for prompt in filter(
lambda p: p.get("id") in new_prompts, prompts
):
self._doc_events.put_nowait(
{
"type": "user",
"cell_id": changes.target["id"],
"prompt_id": prompt["id"],
"prompt": prompt["prompt"],
"username": prompt.get("user"),
"timestamp": prompt.get("timestamp"),
}
)
# elif change["action"] == "delete":
# ...
# elif key == "outputs":
Expand All @@ -157,12 +251,24 @@ def _on_notebook_changes(
prompts = datalayer_ia.get("prompts")
prompt_ids = {prompt["id"] for prompt in prompts}
new_prompts = prompt_ids.difference(
message["parent_id"] for message in datalayer_ia.get("messages", [])
message["parent_id"]
for message in datalayer_ia.get("messages", [])
)
if new_prompts and change["action"] in {"add", "update"}:
for prompt in filter(lambda p: p.get("id") in new_prompts, prompts):
self._on_user_prompt(
self._doc.ycells[changes.path[0]]["id"], prompt["prompt"]
for prompt in filter(
lambda p: p.get("id") in new_prompts, prompts
):
self._doc_events.put_nowait(
{
"type": "user",
"cell_id": self._doc.ycells[
changes.path[0]
]["id"],
"prompt_id": prompt["id"],
"prompt": prompt["prompt"],
"username": prompt.get("user"),
"timestamp": prompt.get("timestamp"),
}
)
# elif change["action"] == "delete":
# ...
Expand All @@ -179,15 +285,28 @@ def _reset_y_model(self) -> None:
super()._reset_y_model()
self._doc.observe(self._on_notebook_changes)

def _on_user_prompt(self, cell_id: str, prompt: str, username: str | None = None) -> None:
def _on_user_prompt(
self,
cell_id: str,
prompt_id: str,
prompt: str,
username: str | None = None,
timestamp: int | None = None,
) -> None:
username = username or self._username
self._log.debug("New AI prompt sets by user [%s] in [%s]: [%s].", username, cell_id, prompt)
self._log.debug(
"New AI prompt sets by user [%s] in [%s]: [%s].", username, cell_id, prompt
)

def _on_cell_source_changes(
self, cell_id: str, new_source: str, old_source: str, username: str | None = None
self,
cell_id: str,
new_source: str,
old_source: str,
username: str | None = None,
) -> None:
username = username or self._username
self._log.debug("New cell source sets by user [%s] in [%s].", username, cell_id, new_source)
self._log.debug("New cell source sets by user [%s] in [%s].", username, cell_id)

# def _on_cell_outputs_changes(self, *args) -> None:
# print(args)
Expand Down Expand Up @@ -224,18 +343,30 @@ def get_cell_index(self, cell_id: str) -> int:

return -1

def update_document(
self, prompt: dict, message_type: AIMessageType, message: str, cell_id: str = ""
def save_ai_message(
self,
message_type: AIMessageType,
message: str,
cell_id: str = "",
parent_id: str | None = None,
) -> None:
"""Update the document.

If a message with the same ``parent_id`` already exists, it will be
overwritten.

Args:
prompt: User prompt
message_type: Type of message to insert in the document
message: Message to insert
cell_id: Cell targeted by the update; if empty, the notebook is the target
parent_id: Parent message id
"""
message_dict = {"parent_id": prompt["id"], "message": message, "type": message_type}
message_dict = {
"parent_id": parent_id,
"message": message,
"type": message_type,
"timestamp": timestamp(),
}

def set_message(metadata: Map, message: dict):
if "datalayer" not in metadata:
Expand All @@ -245,7 +376,14 @@ def set_message(metadata: Map, message: dict):
elif "messages" not in metadata["datalayer"]["ai"]:
metadata["datalayer"]["ai"] = {"messages": []}

metadata["datalayer"]["ai"]["messages"].append(message)
messages = [
filter(
lambda m: not m.get("parent_id") or m["parent_id"] != parent_id,
metadata["datalayer"]["ai"]["messages"],
)
]
messages.append(message)
metadata["datalayer"]["ai"]["messages"] = messages

metadata["datalayer"] = metadata["datalayer"].copy()

Expand All @@ -254,7 +392,9 @@ def set_message(metadata: Map, message: dict):
if not cell:
raise ValueError(f"Cell [{cell_id}] not found.")
if "metadata" not in cell:
cell["metadata"] = Map({"datalayer": {"ai": {"prompts": [], "messages": []}}})
cell["metadata"] = Map(
{"datalayer": {"ai": {"prompts": [], "messages": []}}}
)
set_message(cell["metadata"], message_dict)

else:
Expand Down
Loading
Loading