Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "truss"
version = "0.9.68rc003"
version = "0.9.68dev100"
description = "A seamless bridge from model development to model delivery"
license = "MIT"
readme = "README.md"
Expand Down
54 changes: 37 additions & 17 deletions truss/templates/server/model_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from functools import cached_property
from multiprocessing import Lock
from pathlib import Path
from threading import Thread
from typing import Any, Awaitable, Callable, Dict, Optional, Tuple, Union, cast

import opentelemetry.sdk.trace as sdk_trace
Expand All @@ -32,6 +31,7 @@
from shared import dynamic_config_resolver, serialization
from shared.lazy_data_resolver import LazyDataResolver
from shared.secrets_resolver import SecretsResolver
from tenacity import AsyncRetrying, stop_after_attempt, wait_fixed

if sys.version_info >= (3, 9):
from typing import AsyncGenerator, Generator
Expand Down Expand Up @@ -416,23 +416,19 @@ def skip_input_parsing(self) -> bool:
def truss_schema(self) -> Optional[TrussSchema]:
return self.model_descriptor.truss_schema

def start_load_thread(self):
# Don't retry failed loads.
if self._status == ModelWrapper.Status.NOT_READY:
thread = Thread(target=self.load)
thread.start()

def load(self):
async def load(self):
if self.ready:
return

# if we are already loading, block on acquiring the lock;
# this worker will return 503 while the worker with the lock is loading
with self._load_lock:
self._status = ModelWrapper.Status.LOADING
self._logger.info("Executing model.load()...")
try:
start_time = time.perf_counter()
self._load_impl()
await self.try_load()

self._status = ModelWrapper.Status.READY
self._logger.info(
f"Completed model.load() execution in {_elapsed_ms(start_time)} ms"
Expand All @@ -441,7 +437,15 @@ def load(self):
self._logger.exception("Exception while loading model")
self._status = ModelWrapper.Status.FAILED

def _load_impl(self):
async def start_load(self):
if self.should_load():
asyncio.create_task(self.load())

def should_load(self) -> bool:
# don't retry failed loads
return not self._status == ModelWrapper.Status.FAILED and not self.ready

def _initialize_model(self):
data_dir = Path("data")
data_dir.mkdir(exist_ok=True)

Expand Down Expand Up @@ -520,17 +524,33 @@ def _load_impl(self):

self._maybe_model_descriptor = ModelDescriptor.from_model(self._model)

async def try_load(self):
await to_thread.run_sync(self._initialize_model)

if self._maybe_model_descriptor.setup_environment:
self._initialize_environment_before_load()

if hasattr(self._model, "load"):
retry(
self._model.load,
NUM_LOAD_RETRIES,
self._logger.warning,
"Failed to load model.",
gap_seconds=1.0,
)
if inspect.iscoroutinefunction(self._model.load):
async for attempt in AsyncRetrying(
stop=stop_after_attempt(NUM_LOAD_RETRIES),
wait=wait_fixed(1),
before_sleep=lambda retry_state: self._logger.info(
f"Model load failed (attempt {retry_state.attempt_number})...retrying"
),
):
with attempt:
(await self._model.load(),)

else:
await to_thread.run_sync(
retry,
self._model.load,
NUM_LOAD_RETRIES,
self._logger.warn,
"Failed to load model.",
1.0,
)

def setup_polling_for_environment_updates(self):
self._poll_for_environment_updates_task = asyncio.create_task(
Expand Down
1 change: 1 addition & 0 deletions truss/templates/server/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ psutil==5.9.4
python-json-logger==2.0.2
pyyaml==6.0.0
requests==2.31.0
tenacity==9.0.0
uvicorn==0.24.0
uvloop==0.19.0
websockets<=14.0
5 changes: 3 additions & 2 deletions truss/templates/server/truss_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,13 +323,14 @@ def cleanup(self):
if INFERENCE_SERVER_FAILED_FILE.exists():
INFERENCE_SERVER_FAILED_FILE.unlink()

def on_startup(self):
async def on_startup(self):
"""
This method will be started inside the main process, so here is where
we want to setup our logging and model.
"""
self.cleanup()
self._model.start_load_thread()
await self._model.start_load()

asyncio.create_task(self._shutdown_if_load_fails())
self._model.setup_polling_for_environment_updates()

Expand Down
17 changes: 11 additions & 6 deletions truss/tests/templates/server/test_model_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,15 +54,15 @@
config = yaml.safe_load((app_path / "config.yaml").read_text())
os.chdir(app_path)
model_wrapper = model_wrapper_class(config, sdk_trace.NoOpTracer())
model_wrapper.load()
await model_wrapper.load()
# Allow load thread to execute
time.sleep(1)
output = await model_wrapper.predict({}, MagicMock(spec=Request))
assert output == {}
assert model_wrapper._model.load_count == 2


def test_model_wrapper_load_error_more_than_allowed(app_path, helpers):
async def test_model_wrapper_load_error_more_than_allowed(app_path, helpers):
with helpers.env_var("NUM_LOAD_RETRIES_TRUSS", "0"):
if "model_wrapper" in sys.modules:
model_wrapper_module = sys.modules["model_wrapper"]
Expand All @@ -73,7 +73,7 @@
config = yaml.safe_load((app_path / "config.yaml").read_text())
os.chdir(app_path)
model_wrapper = model_wrapper_class(config, sdk_trace.NoOpTracer())
model_wrapper.load()
await model_wrapper.load()
# Allow load thread to execute
time.sleep(1)
assert model_wrapper.load_failed
Expand Down Expand Up @@ -111,7 +111,8 @@
model_wrapper_module, "_init_extension", return_value=mock_extension
) as mock_init_extension:
model_wrapper = model_wrapper_class(config, sdk_trace.NoOpTracer())
model_wrapper.load()
await model_wrapper.load()

called_with_specific_extension = any(
call_args[0][0] == "trt_llm"
for call_args in mock_init_extension.call_args_list
Expand Down Expand Up @@ -150,8 +151,10 @@
model_wrapper_module, "_init_extension", return_value=mock_extension
):
model_wrapper = model_wrapper_class(config, sdk_trace.NoOpTracer())
model_wrapper.load()
await model_wrapper.load()

resp = await model_wrapper.predict({}, MagicMock(spec=Request))

mock_extension.load.assert_called()
mock_extension.model_args.assert_called()
assert mock_predict_called
Expand Down Expand Up @@ -189,8 +192,10 @@
model_wrapper_module, "_init_extension", return_value=mock_extension
):
model_wrapper = model_wrapper_class(config, sdk_trace.NoOpTracer())
model_wrapper.load()
await model_wrapper.load()

resp = await model_wrapper.predict({}, MagicMock(spec=Request))

mock_extension.load.assert_called()
mock_extension.model_override.assert_called()
assert mock_predict_called
Expand All @@ -213,7 +218,7 @@
model_wrapper.load()

mock_req = MagicMock(spec=Request)
predict_resp = await model_wrapper.predict({}, mock_req)

Check failure on line 221 in truss/tests/templates/server/test_model_wrapper.py

View workflow job for this annotation

GitHub Actions / JUnit Test Report

test_model_wrapper.test_open_ai_completion_endpoints

common.errors.ModelNotReady: Model with name model is not ready.
Raw output
open_ai_container_fs = PosixPath('/tmp/pytest-of-runner/pytest-0/test_open_ai_completion_endpoi0/truss_fs')
helpers = <truss.tests.conftest.Helpers object at 0x7fdb1dff6640>

    @pytest.mark.anyio
    async def test_open_ai_completion_endpoints(open_ai_container_fs, helpers):
        app_path = open_ai_container_fs / "app"
        with (
            _clear_model_load_modules(),
            helpers.sys_paths(app_path),
            _change_directory(app_path),
        ):
            model_wrapper_module = importlib.import_module("model_wrapper")
            model_wrapper_class = getattr(model_wrapper_module, "ModelWrapper")
            config = yaml.safe_load((app_path / "config.yaml").read_text())
    
            model_wrapper = model_wrapper_class(config, sdk_trace.NoOpTracer())
            model_wrapper.load()
    
            mock_req = MagicMock(spec=Request)
>           predict_resp = await model_wrapper.predict({}, mock_req)

/home/runner/work/truss/truss/truss/tests/templates/server/test_model_wrapper.py:221: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
/tmp/pytest-of-runner/pytest-0/test_open_ai_completion_endpoi0/truss_fs/app/model_wrapper.py:858: in predict
    if self.model_descriptor.preprocess:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = <model_wrapper.ModelWrapper object at 0x7fdb1dfe89d0>

    @property
    def model_descriptor(self) -> ModelDescriptor:
        if self._maybe_model_descriptor:
            return self._maybe_model_descriptor
        else:
>           raise errors.ModelNotReady(self.name)
E           common.errors.ModelNotReady: Model with name model is not ready.

/tmp/pytest-of-runner/pytest-0/test_open_ai_completion_endpoi0/truss_fs/app/model_wrapper.py:397: ModelNotReady
assert predict_resp == "predict"

completions_resp = await model_wrapper.completions({}, mock_req)
Expand Down
66 changes: 66 additions & 0 deletions truss/tests/test_model_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -735,6 +735,72 @@ def test_truss_with_error_stacktrace(test_data_path):
)


@pytest.mark.integration
def test_async_load_truss():
model = """
import asyncio

class Model:
async def load(self):
await asyncio.sleep(5)

def predict(self, request):
return {"a": "b"}
"""

config = "model_name: async-load-truss"

with ensure_kill_all(), tempfile.TemporaryDirectory(dir=".") as tmp_work_dir:
truss_dir = Path(tmp_work_dir, "truss")

create_truss(truss_dir, config, textwrap.dedent(model))

tr = TrussHandle(truss_dir)
_ = tr.docker_run(local_port=8090, detach=True, wait_for_server_ready=False)

truss_server_addr = "http://localhost:8090"

def _test_liveness_probe(expected_code):
live = requests.get(f"{truss_server_addr}/", timeout=1)
assert live.status_code == expected_code

def _test_readiness_probe(expected_code):
ready = requests.get(f"{truss_server_addr}/v1/models/model", timeout=1)
assert ready.status_code == expected_code

def _test_ping(expected_code):
ping = requests.get(f"{truss_server_addr}/ping", timeout=1)
assert ping.status_code == expected_code

def _test_predict(expected_code):
invocations = requests.post(
f"{truss_server_addr}/v1/models/model:predict", json={}, timeout=1
)
assert invocations.status_code == expected_code

SERVER_WARMUP_TIME = 3
LOAD_TEST_TIME = 2
LOAD_BUFFER_TIME = 5

# Sleep a few seconds to get the server some time to wake up
time.sleep(SERVER_WARMUP_TIME)

# The truss takes about 5 seconds to load.
# We want to make sure that it's not ready for that time.
for _ in range(LOAD_TEST_TIME):
_test_liveness_probe(200)
_test_readiness_probe(503)
_test_ping(503)
_test_predict(503)
time.sleep(1)

time.sleep(LOAD_BUFFER_TIME)
_test_liveness_probe(200)
_test_readiness_probe(200)
_test_ping(200)
_test_predict(200)


@pytest.mark.integration
def test_slow_truss(test_data_path):
with ensure_kill_all():
Expand Down
Loading