Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use a single code server #26818

Open
wants to merge 1 commit into
base: dpeng817/use_code_server_start
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
218 changes: 120 additions & 98 deletions python_modules/dagster/dagster/_cli/dev.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@
import os
import subprocess
import sys
import tempfile
import time
from contextlib import contextmanager
from pathlib import Path
from typing import Optional
from typing import Iterator, Optional

import click
import yaml

import dagster._check as check
from dagster._annotations import deprecated
Expand All @@ -21,6 +24,12 @@
working_directory_option,
workspace_option,
)
from dagster._core.remote_representation.origin import (
GrpcServerCodeLocationOrigin,
ManagedGrpcPythonEnvCodeLocationOrigin,
)
from dagster._core.workspace.context import WorkspaceProcessContext
from dagster._grpc.server import GrpcServerCommand
from dagster._serdes import serialize_value
from dagster._serdes.ipc import interrupt_ipc_subprocess, open_ipc_subprocess
from dagster._utils.log import configure_loggers
Expand Down Expand Up @@ -122,8 +131,7 @@ def dev_command(
configure_loggers(formatter=log_format, log_level=log_level.upper())
logger = logging.getLogger("dagster")

# Sanity check workspace args
get_workspace_load_target(kwargs)
workspace_target = get_workspace_load_target(kwargs)

dagster_home_path = os.getenv("DAGSTER_HOME")

Expand All @@ -140,103 +148,117 @@ def dev_command(
)

with get_possibly_temporary_instance_for_cli("dagster dev", logger=logger) as instance:
logger.info("Launching Dagster services...")

args = [
"--instance-ref",
serialize_value(instance.get_ref()),
"--code-server-log-level",
code_server_log_level,
]

if kwargs.get("workspace"):
for workspace in check.tuple_elem(kwargs, "workspace"):
args.extend(["--workspace", workspace])

if kwargs.get("python_file"):
for python_file in check.tuple_elem(kwargs, "python_file"):
args.extend(["--python-file", python_file])

if kwargs.get("module_name"):
for module_name in check.tuple_elem(kwargs, "module_name"):
args.extend(["--module-name", module_name])

if kwargs.get("working_directory"):
args.extend(["--working-directory", check.str_elem(kwargs, "working_directory")])

if kwargs.get("grpc_port"):
args.extend(["--grpc-port", str(kwargs["grpc_port"])])

if kwargs.get("grpc_host"):
args.extend(["--grpc-host", str(kwargs["grpc_host"])])

if kwargs.get("grpc_socket"):
args.extend(["--grpc-socket", str(kwargs["grpc_socket"])])

if kwargs.get("use_ssl"):
args.extend(["--use-ssl"])

webserver_process = open_ipc_subprocess(
[sys.executable, "-m", "dagster_webserver"]
+ (["--port", port] if port else [])
+ (["--host", host] if host else [])
+ (["--dagster-log-level", log_level])
+ (["--log-format", log_format])
+ (["--live-data-poll-rate", live_data_poll_rate] if live_data_poll_rate else [])
+ args
)
daemon_process = open_ipc_subprocess(
[
sys.executable,
"-m",
"dagster._daemon",
"run",
"--log-level",
log_level,
"--log-format",
log_format,
]
+ args
)
try:
while True:
time.sleep(_CHECK_SUBPROCESS_INTERVAL)

if webserver_process.poll() is not None:
raise Exception(
"dagster-webserver process shut down unexpectedly with return code"
f" {webserver_process.returncode}"
)
with WorkspaceProcessContext(
instance,
workspace_target,
code_server_log_level=code_server_log_level,
server_command=GrpcServerCommand.CODE_SERVER_START,
) as context:
with _temp_grpc_socket_workspace_file(context) as workspace_file:
logger.info("Launching Dagster services...")

if daemon_process.poll() is not None:
raise Exception(
"dagster-daemon process shut down unexpectedly with return code"
f" {daemon_process.returncode}"
)
args = [
"--instance-ref",
serialize_value(instance.get_ref()),
"--workspace",
workspace_file,
"--code-server-log-level",
code_server_log_level,
]

except KeyboardInterrupt:
logger.info("KeyboardInterrupt received")
except:
logger.exception("An unexpected exception has occurred")
finally:
logger.info("Shutting down Dagster services...")
interrupt_ipc_subprocess(daemon_process)
interrupt_ipc_subprocess(webserver_process)

try:
webserver_process.wait(timeout=_SUBPROCESS_WAIT_TIMEOUT)
except subprocess.TimeoutExpired:
logger.warning(
"dagster-webserver process did not terminate cleanly, killing the process"
)
webserver_process.kill()
if kwargs.get("use_ssl"):
args.extend(["--use-ssl"])

try:
daemon_process.wait(timeout=_SUBPROCESS_WAIT_TIMEOUT)
except subprocess.TimeoutExpired:
logger.warning(
"dagster-daemon process did not terminate cleanly, killing the process"
webserver_process = open_ipc_subprocess(
[sys.executable, "-m", "dagster_webserver"]
+ (["--port", port] if port else [])
+ (["--host", host] if host else [])
+ (["--dagster-log-level", log_level])
+ (["--log-format", log_format])
+ (
["--live-data-poll-rate", live_data_poll_rate]
if live_data_poll_rate
else []
)
+ args
)
daemon_process = open_ipc_subprocess(
[
sys.executable,
"-m",
"dagster._daemon",
"run",
"--log-level",
log_level,
"--log-format",
log_format,
]
+ args
)
daemon_process.kill()
try:
while True:
time.sleep(_CHECK_SUBPROCESS_INTERVAL)

if webserver_process.poll() is not None:
raise Exception(
"dagster-webserver process shut down unexpectedly with return code"
f" {webserver_process.returncode}"
)

if daemon_process.poll() is not None:
raise Exception(
"dagster-daemon process shut down unexpectedly with return code"
f" {daemon_process.returncode}"
)

except KeyboardInterrupt:
logger.info("KeyboardInterrupt received")
except:
logger.exception("An unexpected exception has occurred")
finally:
logger.info("Shutting down Dagster services...")
interrupt_ipc_subprocess(daemon_process)
interrupt_ipc_subprocess(webserver_process)

try:
webserver_process.wait(timeout=_SUBPROCESS_WAIT_TIMEOUT)
except subprocess.TimeoutExpired:
logger.warning(
"dagster-webserver process did not terminate cleanly, killing the process"
)
webserver_process.kill()

try:
daemon_process.wait(timeout=_SUBPROCESS_WAIT_TIMEOUT)
except subprocess.TimeoutExpired:
logger.warning(
"dagster-daemon process did not terminate cleanly, killing the process"
)
daemon_process.kill()

logger.info("Dagster services shut down.")


logger.info("Dagster services shut down.")
@contextmanager
def _temp_grpc_socket_workspace_file(context: WorkspaceProcessContext) -> Iterator[str]:
location_specs = []
with tempfile.NamedTemporaryFile(mode="w+") as temp_file:
for origin in context._origins: # noqa: SLF001
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we add this as a real code_location_origins property on WorkspaceProcessContext?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe a method on WorkspaceProcessContext that returns a list of server specs and handles the origin switching there?

if isinstance(origin, ManagedGrpcPythonEnvCodeLocationOrigin):
grpc_endpoint = context._grpc_server_registry.get_grpc_endpoint(origin) # noqa: SLF001
server_spec = {
"location_name": origin.location_name,
"socket": grpc_endpoint.socket,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

local grpc servers can also use ports (notably on windows)

}
elif isinstance(origin, GrpcServerCodeLocationOrigin):
server_spec = {
"location_name": origin.location_name,
"host": origin.host,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can include host either way (it will just be localhost in the managed case)

"port": origin.port,
}
else:
check.failed(f"Unexpected origin type {origin}")
location_specs.append({"grpc_server": server_spec})
temp_file.write(yaml.dump({"load_from": location_specs}))
temp_file.flush()
yield temp_file.name
3 changes: 2 additions & 1 deletion python_modules/dagster/dagster/_core/workspace/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -624,6 +624,7 @@ def __init__(
read_only: bool = False,
grpc_server_registry: Optional[GrpcServerRegistry] = None,
code_server_log_level: str = "INFO",
server_command: GrpcServerCommand = GrpcServerCommand.API_GRPC,
):
self._stack = ExitStack()

Expand Down Expand Up @@ -657,7 +658,7 @@ def __init__(
self._grpc_server_registry = self._stack.enter_context(
GrpcServerRegistry(
instance_ref=self._instance.get_ref(),
server_command=GrpcServerCommand.API_GRPC,
server_command=server_command,
heartbeat_ttl=WEBSERVER_GRPC_SERVER_HEARTBEAT_TTL,
startup_timeout=instance.code_server_process_startup_timeout,
log_level=code_server_log_level,
Expand Down
8 changes: 6 additions & 2 deletions python_modules/dagster/dagster/_grpc/proxy_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,8 +240,12 @@ def ListRepositories(self, request, context):
def Ping(self, request, context):
return self._query("Ping", request, context)

def GetServerId(self, request, context):
return self._fixed_server_id or self._query("GetServerId", request, context)
def GetServerId(self, request, context) -> api_pb2.GetServerIdReply:
return (
api_pb2.GetServerIdReply(server_id=self._fixed_server_id)
if self._fixed_server_id
else self._query("GetServerId", request, context)
)

def GetCurrentImage(self, request, context):
return self._query("GetCurrentImage", request, context)
Expand Down
Loading