Skip to content

Commit 3695717

Browse files
Frédéric Collonvalfcollonvalgithub-actions[bot]
authored
Add reacting model and test it (#22)
* Add reacting model and test it * Automatic application of license header * Add get_cell_index * Add pytest-asynico * Fixing unit tests * Undo test_client changes --------- Co-authored-by: Frédéric Collonval <[email protected]> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent 97b0320 commit 3695717

File tree

9 files changed

+565
-8
lines changed

9 files changed

+565
-8
lines changed

README.md

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,68 @@ pip uninstall jupyter_nbmodel_client
154154

155155
## Contributing
156156

157+
### Data models
158+
159+
The following json schema describe the data model used in cells and notebook metadata
160+
to communicate between user clients and the ai agent.
161+
162+
```json
163+
{
164+
"datalayer": {
165+
"type": "object",
166+
"properties": {
167+
"ai": {
168+
"type": "object",
169+
"properties": {
170+
"prompts": {
171+
"type": "array",
172+
"items": {
173+
"type": "object",
174+
"properties": {
175+
"id": {
176+
"title": "Prompt unique identifier",
177+
"type": "string"
178+
},
179+
"prompt": {
180+
"title": "User prompt",
181+
"type": "string"
182+
},
183+
"username": {
184+
"title": "Unique identifier of the user making the prompt.",
185+
"type": "string"
186+
}
187+
},
188+
"required": ["id", "prompt"]
189+
}
190+
},
191+
"messages": {
192+
"type": "array",
193+
"items": {
194+
"type": "object",
195+
"properties": {
196+
"parent_id": {
197+
"title": "Prompt unique identifier",
198+
"type": "string"
199+
},
200+
"message": {
201+
"title": "AI reply",
202+
"type": "string"
203+
},
204+
"type": {
205+
"title": "Type message",
206+
"enum": [0, 1, 2]
207+
}
208+
},
209+
"required": ["id", "prompt"]
210+
}
211+
}
212+
}
213+
}
214+
}
215+
}
216+
}
217+
```
218+
157219
### Development install
158220

159221
```bash

jupyter_nbmodel_client/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,14 @@
66

77
from nbformat import NotebookNode
88

9+
from .agent import BaseNbAgent
910
from .client import NbModelClient, get_jupyter_notebook_websocket_url
1011
from .model import KernelClient, NotebookModel
1112

1213
__version__ = "0.6.0"
1314

1415
__all__ = [
16+
"BaseNbAgent",
1517
"KernelClient",
1618
"NbModelClient",
1719
"NotebookModel",

jupyter_nbmodel_client/agent.py

Lines changed: 270 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,270 @@
1+
# Copyright (c) 2023-2024 Datalayer, Inc.
2+
#
3+
# BSD 3-Clause License
4+
5+
"""This module provides a base class agent to interact with collaborative Jupyter notebook."""
6+
7+
from __future__ import annotations
8+
9+
from enum import IntEnum
10+
from typing import Any, Literal, cast
11+
12+
from pycrdt import ArrayEvent, Map, MapEvent
13+
14+
from .client import NbModelClient
15+
16+
17+
class AIMessageType(IntEnum):
18+
"""Type of AI agent message."""
19+
20+
ACKNOWLEDGE = 0
21+
"""Prompt is being processed."""
22+
SUGGESTION = 1
23+
"""Message suggesting a new cell content."""
24+
EXPLANATION = 2
25+
"""Message explaining a content."""
26+
27+
28+
# def _debug_print_changes(part: str, changes: Any) -> None:
29+
# print(f"{part}")
30+
31+
# def print_change(changes):
32+
# if isinstance(changes, MapEvent):
33+
# print(f"{type(changes.target)} {changes.target} {changes.keys} {changes.path}")
34+
# elif isinstance(changes, ArrayEvent):
35+
# print(f"{type(changes.target)} {changes.target} {changes.delta} {changes.path}")
36+
# else:
37+
# print(changes)
38+
39+
# if isinstance(changes, list):
40+
# for c in changes:
41+
# print_change(c)
42+
# else:
43+
# print_change(changes)
44+
45+
46+
class BaseNbAgent(NbModelClient):
47+
"""Base class to react to user prompt and notebook changes based on CRDT changes.
48+
49+
Notes:
50+
- Agents are expected to extend this base class and override either
51+
- method:`_on_user_prompt(self, cell_id: str, prompt: str, username: str | None = None)`:
52+
Callback on user prompt
53+
- method:`_on_cell_source_changes(self, cell_id: str, new_source: str, old_source: str, username: str | None = None):
54+
Callback on cell source changes
55+
- The agent can leverage the helper functions to send a reply to the user:
56+
- method:`update_document(self, prompt: dict, message_type: AIMessageType, message: str, cell_id: str = "")`:
57+
Attach a message to the given cell (or to the notebook if no ``cell_id`` is provided).
58+
59+
Args:
60+
ws_url: Endpoint to connect to the collaborative Jupyter notebook.
61+
path: [optional] Notebook path relative to the server root directory; default None
62+
username: [optional] Client user name; default to environment variable USER
63+
timeout: [optional] Request timeout in seconds; default to environment variable REQUEST_TIMEOUT
64+
log: [optional] Custom logger; default local logger
65+
66+
Examples:
67+
68+
When connection to a Jupyter notebook server, you can leverage the get_jupyter_notebook_websocket_url
69+
helper:
70+
71+
>>> from jupyter_nbmodel_client import NbModelClient, get_jupyter_notebook_websocket_url
72+
>>> client = NbModelClient(
73+
>>> get_jupyter_notebook_websocket_url(
74+
>>> "http://localhost:8888",
75+
>>> "path/to/notebook.ipynb",
76+
>>> "your-server-token"
77+
>>> )
78+
>>> )
79+
"""
80+
81+
# FIXME implement username retrieval
82+
83+
def _on_notebook_changes(
84+
self, part: Literal["state"] | Literal["meta"] | Literal["cells"] | str, all_changes: Any
85+
) -> None:
86+
# _debug_print_changes(part, all_changes)
87+
88+
if part == "cells":
89+
for changes in all_changes:
90+
path_length = len(changes.path)
91+
if path_length == 0:
92+
# Change is on the cell list
93+
for delta in changes.delta:
94+
if "insert" in delta:
95+
# New cells got added
96+
for cell in delta["insert"]:
97+
if "metadata" in cell:
98+
new_metadata = cell["metadata"]
99+
datalayer_ia = new_metadata.get("datalayer", {}).get("ai", {})
100+
prompts = datalayer_ia.get("prompts", [])
101+
prompt_ids = {prompt["id"] for prompt in prompts}
102+
new_prompts = prompt_ids.difference(
103+
message["parent_id"]
104+
for message in datalayer_ia.get("messages", [])
105+
)
106+
if new_prompts:
107+
for prompt in filter(
108+
lambda p: p.get("id") in new_prompts, prompts
109+
):
110+
self._on_user_prompt(cell["id"], prompt["prompt"])
111+
if "source" in cell:
112+
self._on_cell_source_changes(cell["id"], cell["source"], "")
113+
elif path_length == 1:
114+
# Change is on one cell
115+
for key, change in changes.keys.items():
116+
if key == "source":
117+
if change["action"] == "add":
118+
self._on_cell_source_changes(
119+
changes.target["id"],
120+
change["newValue"],
121+
change.get("oldValue", ""),
122+
)
123+
elif change["action"] == "update":
124+
self._on_cell_source_changes(
125+
changes.target["id"], change["newValue"], change["oldValue"]
126+
)
127+
elif change["action"] == "delete":
128+
self._on_cell_source_changes(
129+
changes.target["id"], change.get("newValue"), change["oldValue"]
130+
)
131+
elif key == "metadata":
132+
new_metadata = change.get("newValue", {})
133+
datalayer_ia = new_metadata.get("datalayer", {}).get("ai", {})
134+
prompts = datalayer_ia.get("prompts", [])
135+
prompt_ids = {prompt["id"] for prompt in prompts}
136+
new_prompts = prompt_ids.difference(
137+
message["parent_id"] for message in datalayer_ia.get("messages", [])
138+
)
139+
if new_prompts and change["action"] in {"add", "update"}:
140+
for prompt in filter(lambda p: p.get("id") in new_prompts, prompts):
141+
self._on_user_prompt(changes.target["id"], prompt["prompt"])
142+
# elif change["action"] == "delete":
143+
# ...
144+
# elif key == "outputs":
145+
# # TODO
146+
# ...
147+
elif (
148+
path_length == 2
149+
and isinstance(changes.path[0], int)
150+
and changes.path[1] == "metadata"
151+
):
152+
# Change in cell metadata
153+
for key, change in changes.keys.items():
154+
if key == "datalayer":
155+
new_metadata = change.get("newValue", {})
156+
datalayer_ia = new_metadata.get("ai", {})
157+
prompts = datalayer_ia.get("prompts")
158+
prompt_ids = {prompt["id"] for prompt in prompts}
159+
new_prompts = prompt_ids.difference(
160+
message["parent_id"] for message in datalayer_ia.get("messages", [])
161+
)
162+
if new_prompts and change["action"] in {"add", "update"}:
163+
for prompt in filter(lambda p: p.get("id") in new_prompts, prompts):
164+
self._on_user_prompt(
165+
self._doc.ycells[changes.path[0]]["id"], prompt["prompt"]
166+
)
167+
# elif change["action"] == "delete":
168+
# ...
169+
170+
# elif part == "meta":
171+
# # FIXME handle notebook metadata
172+
173+
def _reset_y_model(self) -> None:
174+
try:
175+
self._doc.unobserve()
176+
except AttributeError:
177+
pass
178+
finally:
179+
super()._reset_y_model()
180+
self._doc.observe(self._on_notebook_changes)
181+
182+
def _on_user_prompt(self, cell_id: str, prompt: str, username: str | None = None) -> None:
183+
username = username or self._username
184+
self._log.debug("New AI prompt sets by user [%s] in [%s]: [%s].", username, cell_id, prompt)
185+
186+
def _on_cell_source_changes(
187+
self, cell_id: str, new_source: str, old_source: str, username: str | None = None
188+
) -> None:
189+
username = username or self._username
190+
self._log.debug("New cell source sets by user [%s] in [%s].", username, cell_id, new_source)
191+
192+
# def _on_cell_outputs_changes(self, *args) -> None:
193+
# print(args)
194+
195+
def get_cell(self, cell_id: str) -> Map | None:
196+
"""Find the cell with the given ID.
197+
198+
If the cell cannot be found it will return ``None``.
199+
200+
Args:
201+
cell_id: str
202+
Returns:
203+
Cell or None
204+
"""
205+
for cell in self._doc.ycells:
206+
if cell["id"] == cell_id:
207+
return cast(Map, cell)
208+
209+
return None
210+
211+
def get_cell_index(self, cell_id: str) -> int:
212+
"""Find the cell with the given ID.
213+
214+
If the cell cannot be found it will return ``-1``.
215+
216+
Args:
217+
cell_id: str
218+
Returns:
219+
Cell index or -1
220+
"""
221+
for index, cell in enumerate(self._doc.ycells):
222+
if cell["id"] == cell_id:
223+
return index
224+
225+
return -1
226+
227+
def update_document(
228+
self, prompt: dict, message_type: AIMessageType, message: str, cell_id: str = ""
229+
) -> None:
230+
"""Update the document.
231+
232+
Args:
233+
prompt: User prompt
234+
message_type: Type of message to insert in the document
235+
message: Message to insert
236+
cell_id: Cell targeted by the update; if empty, the notebook is the target
237+
"""
238+
message_dict = {"parent_id": prompt["id"], "message": message, "type": message_type}
239+
240+
def set_message(metadata: Map, message: dict):
241+
if "datalayer" not in metadata:
242+
metadata["datalayer"] = {"ai": {"prompts": [], "messages": []}}
243+
elif "ai" not in metadata["datalayer"]:
244+
metadata["datalayer"] = {"ai": {"prompts": [], "messages": []}}
245+
elif "messages" not in metadata["datalayer"]["ai"]:
246+
metadata["datalayer"]["ai"] = {"messages": []}
247+
248+
metadata["datalayer"]["ai"]["messages"].append(message)
249+
250+
metadata["datalayer"] = metadata["datalayer"].copy()
251+
252+
if cell_id:
253+
cell = self.get_cell(cell_id)
254+
if not cell:
255+
raise ValueError(f"Cell [{cell_id}] not found.")
256+
if "metadata" not in cell:
257+
cell["metadata"] = Map({"datalayer": {"ai": {"prompts": [], "messages": []}}})
258+
set_message(cell["metadata"], message_dict)
259+
260+
else:
261+
notebook_metadata = self._doc._ymeta["metadata"]
262+
set_message(notebook_metadata, message_dict)
263+
264+
# def notify(self, message: str, cell_id: str = "") -> None:
265+
# """Send a transient message to users.
266+
267+
# Args:
268+
# message: Notification message
269+
# cell_id: Cell targeted by the notification; if empty the notebook is the target
270+
# """

0 commit comments

Comments
 (0)