Skip to content

Commit

Permalink
[single-grpc-server] introduce a heartbeat to proxy server
Browse files Browse the repository at this point in the history
  • Loading branch information
dpeng817 committed Jan 2, 2025
1 parent 942e7a0 commit 55bc33f
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 16 deletions.
11 changes: 11 additions & 0 deletions python_modules/dagster/dagster/_cli/code_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from dagster._utils.interrupts import setup_interrupt_handlers
from dagster._utils.log import configure_loggers

DEFAULT_SERVER_HEARTBEAT_TIMEOUT = 30


@click.group(name="code-server")
def code_server_cli():
Expand Down Expand Up @@ -144,6 +146,13 @@ def code_server_cli():
help="How long to wait for code to load or reload before timing out. Defaults to no timeout.",
envvar="DAGSTER_CODE_SERVER_STARTUP_TIMEOUT",
)
@click.option(
"--server-heartbeat-timeout",
type=click.INT,
required=False,
default=DEFAULT_SERVER_HEARTBEAT_TIMEOUT,
help="How long to wait for a heartbeat from the caller before timing out. Defaults to 30 seconds.",
)
@click.option(
"--instance-ref",
type=click.STRING,
Expand All @@ -165,6 +174,7 @@ def start_command(
location_name: Optional[str] = None,
inject_env_vars_from_instance: bool = False,
startup_timeout: int = 0,
server_heartbeat_timeout: int = DEFAULT_SERVER_HEARTBEAT_TIMEOUT,
instance_ref=None,
**kwargs,
):
Expand Down Expand Up @@ -231,6 +241,7 @@ def start_command(
instance_ref=deserialize_value(instance_ref, InstanceRef) if instance_ref else None,
server_termination_event=server_termination_event,
logger=logger,
server_heartbeat_timeout=server_heartbeat_timeout,
)
server = DagsterGrpcServer(
server_termination_event=server_termination_event,
Expand Down
53 changes: 38 additions & 15 deletions python_modules/dagster/dagster/_grpc/proxy_server.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
import sys
import threading
import time
from contextlib import ExitStack
from typing import TYPE_CHECKING, Dict, Optional

Expand Down Expand Up @@ -48,9 +49,9 @@ def __init__(
server_termination_event: threading.Event,
instance_ref: Optional[InstanceRef],
logger: logging.Logger,
server_heartbeat_timeout: int,
):
super(DagsterProxyApiServicer, self).__init__()

self._loadable_target_origin = loadable_target_origin
self._fixed_server_id = fixed_server_id
self._container_image = container_image
Expand All @@ -63,8 +64,8 @@ def __init__(

self._client = None
self._load_error = None
self._heartbeat_shutdown_event = None
self._heartbeat_thread = None
self._client_heartbeat_shutdown_event = None
self._client_heartbeat_thread = None

self._exit_stack = ExitStack()

Expand Down Expand Up @@ -100,6 +101,15 @@ def __init__(
daemon=True,
)

self.__last_heartbeat_time = time.time()
self.__server_heartbeat_thread = threading.Thread(
target=self._server_heartbeat_thread,
args=(server_heartbeat_timeout,),
name="grpc-server-heartbeat",
daemon=True,
)
self.__server_heartbeat_thread.start()

self.__cleanup_thread.start()

# Map runs to the client that launched them, so that we can route
Expand All @@ -121,22 +131,22 @@ def _reload_location(self):
self._logger.exception("Failure while loading code")

if self._client:
self._heartbeat_shutdown_event = threading.Event()
self._heartbeat_thread = threading.Thread(
self._client_heartbeat_shutdown_event = threading.Event()
self._client_heartbeat_thread = threading.Thread(
target=client_heartbeat_thread,
args=(
self._client,
self._heartbeat_shutdown_event,
self._client_heartbeat_shutdown_event,
),
name="grpc-client-heartbeat",
daemon=True,
)
self._heartbeat_thread.start()
self._client_heartbeat_thread.start()

def ReloadCode(self, request, context):
with self._reload_lock: # can only call this method once at a time
old_heartbeat_shutdown_event = self._heartbeat_shutdown_event
old_heartbeat_thread = self._heartbeat_thread
old_heartbeat_shutdown_event = self._client_heartbeat_shutdown_event
old_heartbeat_thread = self._client_heartbeat_thread
old_client = self._client

self._reload_location() # Creates and starts a new heartbeat thread
Expand All @@ -156,13 +166,13 @@ def cleanup(self):
# In case ShutdownServer was not called
self._shutdown_once_executions_finish_event.set()

if self._heartbeat_shutdown_event:
self._heartbeat_shutdown_event.set()
self._heartbeat_shutdown_event = None
if self._client_heartbeat_shutdown_event:
self._client_heartbeat_shutdown_event.set()
self._client_heartbeat_shutdown_event = None

if self._heartbeat_thread:
self._heartbeat_thread.join()
self._heartbeat_thread = None
if self._client_heartbeat_thread:
self._client_heartbeat_thread.join()
self._client_heartbeat_thread = None

self._exit_stack.close()

Expand All @@ -186,6 +196,18 @@ def _query(self, api_name: str, request, _context, timeout: int = DEFAULT_GRPC_T
raise Exception("No available client to code serer")
return check.not_none(self._client)._get_response(api_name, request, timeout) # noqa

def _server_heartbeat_thread(self, heartbeat_timeout: int) -> None:
while True:
if self._server_termination_event.is_set():
break

self._shutdown_once_executions_finish_event.wait(heartbeat_timeout)
if self._shutdown_once_executions_finish_event.is_set():
break

if self.__last_heartbeat_time < time.time() - heartbeat_timeout:
self._shutdown_once_executions_finish_event.set()

def _streaming_query(
self, api_name: str, request, _context, timeout: int = DEFAULT_GRPC_TIMEOUT
):
Expand Down Expand Up @@ -216,6 +238,7 @@ def StreamingExternalRepository(self, request, context):
return self._streaming_query("StreamingExternalRepository", request, context)

def Heartbeat(self, request, context):
self.__last_heartbeat_time = time.time()
return self._query("Heartbeat", request, context)

def StreamingPing(self, request, context):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import re
import subprocess
import sys
import time

import pytest
from dagster import _seven
Expand Down Expand Up @@ -548,7 +549,45 @@ def test_load_timeout():
assert "StatusCode.UNAVAILABLE" in str(timeout_exception)


def test_load_timeout_code_server_cli():
def test_server_heartbeat_timeout_code_server_cli() -> None:
"""Test that without a heartbeat from the calling process, the server will eventually time out."""
port = find_free_port()
python_file = file_relative_path(__file__, "grpc_repo.py")

subprocess_args = [
"dagster",
"code-server",
"start",
"--port",
str(port),
"--python-file",
python_file,
"--server-heartbeat-timeout",
"1",
]

process = subprocess.Popen(subprocess_args)

try:
client = DagsterGrpcClient(port=port, host="localhost")
wait_for_grpc_server(
process,
DagsterGrpcClient(port=port, host="localhost"),
subprocess_args,
)
# Send out an initial heartbeat, ensure server is alive to begin with.
client.ping("foobar")
client.shutdown_server()
assert process.poll() is None
time.sleep(2)
assert process.poll() == 0

finally:
process.terminate()
process.wait()


def test_load_timeout_code_server_cli() -> None:
port = find_free_port()
python_file = file_relative_path(__file__, "grpc_repo_that_times_out.py")

Expand All @@ -562,6 +601,8 @@ def test_load_timeout_code_server_cli():
python_file,
"--startup-timeout",
"1",
"--server-heartbeat-timeout",
"600",
]

process = subprocess.Popen(subprocess_args)
Expand Down

0 comments on commit 55bc33f

Please sign in to comment.