Skip to content

Commit

Permalink
fix: test and avoid memory leaks by checking references to virtual ke…
Browse files Browse the repository at this point in the history
…rnel
  • Loading branch information
maartenbreddels committed Nov 16, 2023
1 parent a991ed8 commit 945fb87
Show file tree
Hide file tree
Showing 11 changed files with 233 additions and 40 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/integration.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ jobs:

steps:
- uses: actions/checkout@v2
- name: Setup Graphviz
uses: ts-graphviz/setup-graphviz@v1
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ dev = [
"dask[dataframe]; python_version < '3.7'",
"playwright; python_version > '3.6'",
"pytest-playwright; python_version > '3.6'",
"objgraph",
]
assets = [
"solara-assets==1.22.0"
Expand Down
18 changes: 13 additions & 5 deletions solara/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import sys
import threading
import traceback
import weakref
from enum import Enum
from pathlib import Path
from typing import Any, Dict, List, Optional, cast
Expand Down Expand Up @@ -125,7 +126,7 @@ def add_path():
else:
# the module itself will be added by reloader
# automatically
with reload.reloader.watch():
with kernel_context.without_context(), reload.reloader.watch():
self.type = AppType.MODULE
try:
spec = importlib.util.find_spec(self.name)
Expand Down Expand Up @@ -345,6 +346,9 @@ def solara_comm_target(comm, msg_first):

def on_msg(msg):
nonlocal app
comm = comm_ref()
assert comm is not None
context = kernel_context.get_current_context()
data = msg["content"]["data"]
method = data["method"]
if method == "run":
Expand Down Expand Up @@ -378,18 +382,22 @@ def on_msg(msg):
else:
logger.error("Unknown comm method called on solara.control comm: %s", method)

comm.on_msg(on_msg)

def reload():
comm = comm_ref()
assert comm is not None
context = kernel_context.get_current_context()
# we don't reload the app ourself, we send a message to the client
# this ensures that we don't run code of any client that for some reason is connected
# but not working anymore. And it indirectly passes a message from the current thread
# (which is that of the Reloader/watchdog), to the thread of the client
logger.debug(f"Send reload to client: {context.id}")
comm.send({"method": "reload"})

context = kernel_context.get_current_context()
context.reload = reload
comm.on_msg(on_msg)
comm_ref = weakref.ref(comm)
del comm

kernel_context.get_current_context().reload = reload


def register_solara_comm_target(kernel: Kernel):
Expand Down
39 changes: 38 additions & 1 deletion solara/server/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,10 @@ def send_websockets(websockets: Set[websocket.WebsocketWrapper], binary_msg):
ws.send(binary_msg)
except: # noqa
# in case of any issue, we simply remove it from the list
websockets.remove(ws)
try:
websockets.remove(ws)
except KeyError:
pass # already removed


class SessionWebsocket(session.Session):
Expand All @@ -233,6 +236,8 @@ def close(self):
pass

def send(self, stream, msg_or_type, content=None, parent=None, ident=None, buffers=None, track=False, header=None, metadata=None):
if stream is None:
return # can happen when the kernel is closed but someone was still trying to send a message
try:
if isinstance(msg_or_type, dict):
msg = msg_or_type
Expand Down Expand Up @@ -290,6 +295,38 @@ def __init__(self):
self.shell.display_pub.session = self.session
self.shell.display_pub.pub_socket = self.iopub_socket

def close(self):
if self.comm_manager is None:
raise RuntimeError("Kernel already closed")
self.session.close()
self._cleanup_references()

def _cleanup_references(self):
try:
# all of these reduce the circular references
# making it easier for the garbage collector to clean up
self.shell_handlers.clear()
self.control_handlers.clear()
for comm_object in list(self.comm_manager.comms.values()): # type: ignore
comm_object.close()
self.comm_manager.targets.clear() # type: ignore
# self.comm_manager.kernel points to us, but we cannot set it to None
# so we remove the circular reference by setting the comm_manager to None
self.comm_manager = None # type: ignore
self.session.parent = None # type: ignore

self.shell.display_pub.session = None # type: ignore
self.shell.display_pub.pub_socket = None # type: ignore
self.shell = None # type: ignore
self.session.websockets.clear()
self.session.stream = None # type: ignore
self.session = None # type: ignore
self.stream.session = None # type: ignore
self.stream = None # type: ignore
self.iopub_socket = None # type: ignore
except Exception:
logger.exception("Error cleaning up references from kernel, not fatal")

async def _flush_control_queue(self):
pass

Expand Down
94 changes: 70 additions & 24 deletions solara/server/kernel_context.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import contextlib
import dataclasses
import enum
import logging
Expand Down Expand Up @@ -63,6 +64,7 @@ class VirtualKernelContext:
# only used for testing
_last_kernel_cull_task: "Optional[asyncio.Future[None]]" = None
closed_event: threading.Event = dataclasses.field(default_factory=threading.Event)
lock: threading.RLock = dataclasses.field(default_factory=threading.RLock)

def display(self, *args):
print(args) # noqa
Expand All @@ -81,6 +83,9 @@ def __exit__(self, *args):

def close(self):
logger.info("Shut down virtual kernel: %s", self.id)
with self.lock:
for key in self.page_status:
self.page_status[key] = PageStatus.CLOSED
with self:
if self.app_object is not None:
if isinstance(self.app_object, reacton.core._RenderContext):
Expand All @@ -93,9 +98,11 @@ def close(self):
# what if we reference each other
# import gc
# gc.collect()
self.kernel.session.close()
self.kernel.close()
self.kernel = None # type: ignore
if self.id in contexts:
del contexts[self.id]
del current_context[get_current_thread_key()]
self.closed_event.set()

def _state_reset(self):
Expand Down Expand Up @@ -123,10 +130,12 @@ def state_save(self, state_directory: os.PathLike):

def page_connect(self, page_id: str):
logger.info("Connect page %s for kernel %s", page_id, self.id)
assert self.page_status.get(page_id) != PageStatus.CLOSED, "cannot connect with the same page_id after a close"
self.page_status[page_id] = PageStatus.CONNECTED
if self._last_kernel_cull_task:
self._last_kernel_cull_task.cancel()
with self.lock:
if page_id in self.page_status and self.page_status.get(page_id) == PageStatus.CLOSED:
raise RuntimeError("Cannot connect a page that is already closed")
self.page_status[page_id] = PageStatus.CONNECTED
if self._last_kernel_cull_task:
self._last_kernel_cull_task.cancel()

def page_disconnect(self, page_id: str) -> "asyncio.Future[None]":
"""Signal that a page has disconnected, and schedule a kernel cull if needed.
Expand All @@ -139,23 +148,36 @@ def page_disconnect(self, page_id: str) -> "asyncio.Future[None]":
"""
logger.info("Disconnect page %s for kernel %s", page_id, self.id)
future: "asyncio.Future[None]" = asyncio.Future()
self.page_status[page_id] = PageStatus.DISCONNECTED
with self.lock:
if self.page_status[page_id] == PageStatus.CLOSED:
logger.info("Page %s already closed for kernel %s", page_id, self.id)
future.set_result(None)
return future
assert self.page_status[page_id] == PageStatus.CONNECTED, "cannot disconnect a page that is in state: %r" % self.page_status[page_id]
self.page_status[page_id] = PageStatus.DISCONNECTED
current_event_loop = asyncio.get_event_loop()

async def kernel_cull():
try:
cull_timeout_sleep_seconds = solara.util.parse_timedelta(solara.server.settings.kernel.cull_timeout)
logger.info("Scheduling kernel cull, will wait for max %s before shutting down the virtual kernel %s", cull_timeout_sleep_seconds, self.id)
await asyncio.sleep(cull_timeout_sleep_seconds)
has_connected_pages = PageStatus.CONNECTED in self.page_status.values()
if has_connected_pages:
logger.info("We have (re)connected pages, keeping the virtual kernel %s alive", self.id)
else:
logger.info("No connected pages, and timeout reached, shutting down virtual kernel %s", self.id)
self.close()
current_event_loop.call_soon_threadsafe(future.set_result, None)
with self.lock:
has_connected_pages = PageStatus.CONNECTED in self.page_status.values()
if has_connected_pages:
logger.info("We have (re)connected pages, keeping the virtual kernel %s alive", self.id)
else:
logger.info("No connected pages, and timeout reached, shutting down virtual kernel %s", self.id)
self.close()
try:
current_event_loop.call_soon_threadsafe(future.set_result, None)
except RuntimeError:
pass # event loop already closed, happens during testing
except asyncio.CancelledError:
current_event_loop.call_soon_threadsafe(future.cancel, "cancelled because a new cull task was scheduled")
try:
current_event_loop.call_soon_threadsafe(future.cancel, "cancelled because a new cull task was scheduled")
except RuntimeError:
pass # event loop already closed, happens during testing
raise

has_connected_pages = PageStatus.CONNECTED in self.page_status.values()
Expand All @@ -168,7 +190,10 @@ async def create_task():
task = asyncio.create_task(kernel_cull())
# create a reference to the task so we can cancel it later
self._last_kernel_cull_task = task
await task
try:
await task
except RuntimeError:
pass # event loop already closed, happens during testing

asyncio.run_coroutine_threadsafe(create_task(), keep_alive_event_loop)
else:
Expand All @@ -182,15 +207,21 @@ def page_close(self, page_id: str):
different from a websocket/page disconnect, which we might want to recover from.
"""
self.page_status[page_id] = PageStatus.CLOSED
logger.info("Close page %s for kernel %s", page_id, self.id)
has_connected_pages = PageStatus.CONNECTED in self.page_status.values()
has_disconnected_pages = PageStatus.DISCONNECTED in self.page_status.values()
if not (has_connected_pages or has_disconnected_pages):
logger.info("No connected or disconnected pages, shutting down virtual kernel %s", self.id)
if self._last_kernel_cull_task:
self._last_kernel_cull_task.cancel()
self.close()

logger.info("page status: %s", self.page_status)
with self.lock:
if self.page_status[page_id] == PageStatus.CLOSED:
logger.info("Page %s already closed for kernel %s", page_id, self.id)
return
self.page_status[page_id] = PageStatus.CLOSED
logger.info("Close page %s for kernel %s", page_id, self.id)
has_connected_pages = PageStatus.CONNECTED in self.page_status.values()
has_disconnected_pages = PageStatus.DISCONNECTED in self.page_status.values()
if not (has_connected_pages or has_disconnected_pages):
logger.info("No connected or disconnected pages, shutting down virtual kernel %s", self.id)
if self._last_kernel_cull_task:
self._last_kernel_cull_task.cancel()
self.close()


try:
Expand Down Expand Up @@ -267,6 +298,21 @@ def set_current_context(context: Optional[VirtualKernelContext]):
current_context[thread_key] = context


@contextlib.contextmanager
def without_context():
context = None
try:
context = get_current_context()
except RuntimeError:
pass
thread_key = get_current_thread_key()
current_context[thread_key] = None
try:
yield
finally:
current_context[thread_key] = context


def initialize_virtual_kernel(session_id: str, kernel_id: str, websocket: websocket.WebsocketWrapper):
from solara.server import app as appmodule

Expand Down
10 changes: 6 additions & 4 deletions solara/server/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
import ipywidgets.widgets.widget_output
from IPython.core.interactiveshell import InteractiveShell

import solara.util

from . import app, kernel_context, reload, settings
from .utils import pdb_guard

Expand Down Expand Up @@ -235,7 +237,8 @@ def auto_watch_get_template(get_template):

def wrapper(abs_path):
template = get_template(abs_path)
reload.reloader.watcher.add_file(abs_path)
with kernel_context.without_context():
reload.reloader.watcher.add_file(abs_path)
return template

return wrapper
Expand All @@ -255,9 +258,8 @@ def WidgetContextAwareThread__init__(self, *args, **kwargs):


def Thread_debug_run(self):
if self.current_context:
kernel_context.set_context_for_thread(self.current_context, self)
with pdb_guard():
context = self.current_context or solara.util.nullcontext()
with pdb_guard(), context:
Thread__run(self)


Expand Down
3 changes: 2 additions & 1 deletion solara/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,8 @@ async def app_loop(ws: websocket.WebsocketWrapper, session_id: str, kernel_id: s
message = await ws.receive()
except websocket.WebSocketDisconnect:
try:
context.kernel.session.websockets.remove(ws)
if context.kernel is not None and context.kernel.session is not None:
context.kernel.session.websockets.remove(ws)
except KeyError:
pass
logger.debug("Disconnected")
Expand Down
11 changes: 11 additions & 0 deletions solara/server/shell.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import atexit
import sys
from threading import local
from unittest.mock import Mock
Expand Down Expand Up @@ -175,10 +176,20 @@ class SolaraInteractiveShell(InteractiveShell):
display_pub_class = Type(SolaraDisplayPublisher)
history_manager = Any() # type: ignore

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
atexit.unregister(self.atexit_operations)

magic = self.magics_manager.registry["ScriptMagics"]
atexit.unregister(magic.kill_bg_processes)

def set_parent(self, parent):
"""Tell the children about the parent message."""
self.display_pub.set_parent(parent)

def init_sys_modules(self):
pass # don't create a __main__, it will cause a mem leak

def init_history(self):
self.history_manager = Mock() # type: ignore

Expand Down
Loading

0 comments on commit 945fb87

Please sign in to comment.