Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "truss"
version = "0.11.26"
version = "0.11.27rc500"
description = "A seamless bridge from model development to model delivery"
authors = [
{ name = "Pankaj Gupta", email = "no-reply@baseten.co" },
Expand Down
2 changes: 1 addition & 1 deletion truss-transfer/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion truss-transfer/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "truss_transfer"
version = "0.0.38"
version = "0.0.39"
edition = "2021"

[lib]
Expand Down
4 changes: 2 additions & 2 deletions truss-transfer/src/bindings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,8 @@ impl PyModelRepo {
#[new]
#[pyo3(signature = (
repo_id,
revision,
volume_folder,
revision = "".to_string(),
volume_folder = "".to_string(),
kind = "hf".to_string(),
allow_patterns = None,
ignore_patterns = None,
Expand Down
4 changes: 2 additions & 2 deletions truss-transfer/src/create/basetenpointer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ mod tests {
},
ModelRepo {
repo_id: "gs://test-bucket/model-path".to_string(),
revision: "main".to_string(), // Ignored for GCS
revision: "".to_string(), // Not needed for GCS
allow_patterns: Some(vec!["*.safetensors".to_string()]),
ignore_patterns: Some(vec!["*.md".to_string()]),
kind: ResolutionType::Gcs,
Expand Down Expand Up @@ -249,7 +249,7 @@ mod tests {
// Test Azure support with a mock repository
let model_repos = vec![ModelRepo {
repo_id: "azure://testaccount/testcontainer/model.bin".to_string(),
revision: "main".to_string(),
revision: "".to_string(), // Not needed for Azure
runtime_secret_name: "azure-storage".to_string(),
volume_folder: "test_azure_model".to_string(),
kind: ResolutionType::Azure,
Expand Down
18 changes: 14 additions & 4 deletions truss/base/truss_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ class ModelRepoSourceKind(str, enum.Enum):

class ModelRepo(custom_types.ConfigModel):
repo_id: Annotated[str, pydantic.StringConstraints(min_length=1)]
revision: Optional[Annotated[str, pydantic.StringConstraints(min_length=1)]] = None
revision: str = ""
allow_patterns: Optional[list[str]] = None
ignore_patterns: Optional[list[str]] = None
volume_folder: Optional[
Expand All @@ -151,6 +151,13 @@ class ModelRepo(custom_types.ConfigModel):
kind: ModelRepoSourceKind = ModelRepoSourceKind.HF
runtime_secret_name: str = "hf_access_token"

@pydantic.field_validator("revision")
@classmethod
def _validate_revision(cls, v: str) -> str:
if len(v) == 1:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if revision == HF: `revision must not be empty.

I have a very useful idea here: I would try recommend importing huggingface hub, get the latest revsion, and suggest it to the user.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@michaelfeil We do validation for v2, in runtime_path() below. For v1, if it's empty it gets from HEAD. (hf_hub_download() behavior if revision=None).

raise ValueError("revision must be empty or at least 2 characters")
return v

@property
def runtime_path(self) -> pathlib.Path:
assert self.volume_folder is not None
Expand All @@ -161,11 +168,14 @@ def _check_v2_requirements(cls, v) -> str:
use_volume = v.get("use_volume", False)
if not use_volume:
return v
if v.get("kind") == ModelRepoSourceKind.HF.value and v.get("revision") is None:
revision = v.get("revision") or ""
kind = v.get("kind")
is_hf = kind is None or kind == ModelRepoSourceKind.HF.value
if is_hf and not revision:
logger.warning(
"the key `revision: str` is required for use_volume=True huggingface repos. For S3/GCS/Azure repos, set it to any non-empty string."
"the key `revision: str` is required for use_volume=True huggingface repos."
)
raise_insufficent_revision(v.get("repo_id"), v.get("revision"))
raise_insufficent_revision(v.get("repo_id"), revision)
if v.get("volume_folder") is None or len(v["volume_folder"]) == 0:
raise ValueError(
"the key `volume_folder: str` is required for `use_volume=True` repos."
Expand Down
2 changes: 1 addition & 1 deletion truss/templates/cache.Dockerfile.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,6 @@ COPY ./{{credential}} ${APP_HOME}/{{credential}}

{% for repo, hf_dir in models.items() %}
{% for file in hf_dir.files %}
{{ "RUN --mount=type=secret,id=" + hf_access_token_file_name + ",dst=/etc/secrets/" + hf_access_token_file_name if use_hf_secret else "RUN" }} python3 /cache_warmer.py {{file}} {{repo}} {% if hf_dir.revision != None %}{{hf_dir.revision}}{% endif %}
{{ "RUN --mount=type=secret,id=" + hf_access_token_file_name + ",dst=/etc/secrets/" + hf_access_token_file_name if use_hf_secret else "RUN" }} python3 /cache_warmer.py {{file}} {{repo}} {% if hf_dir.revision %}{{hf_dir.revision}}{% endif %}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch

{% endfor %}
{% endfor %}
6 changes: 3 additions & 3 deletions truss/tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ def test_huggingface_cache_single_model_default_revision(default_config):
new_config["model_cache"] = [{"repo_id": "test/model", "use_volume": False}]

assert new_config == config.to_dict(verbose=False)
assert config.to_dict(verbose=True)["model_cache"][0].get("revision") is None
assert config.to_dict(verbose=True)["model_cache"][0].get("revision") == ""


def test_huggingface_cache_single_model_non_default_revision_v1():
Expand Down Expand Up @@ -354,7 +354,7 @@ def test_huggingface_cache_multiple_models_default_revision(default_config):
"model_cache"
]
assert config.to_dict(verbose=True)["model_cache"][0].get("revision") == "main"
assert config.to_dict(verbose=True)["model_cache"][1].get("revision") is None
assert config.to_dict(verbose=True)["model_cache"][1].get("revision") == ""


def test_huggingface_cache_multiple_models_mixed_revision(default_config):
Expand All @@ -377,7 +377,7 @@ def test_huggingface_cache_multiple_models_mixed_revision(default_config):
]

assert new_config == config.to_dict(verbose=False)
assert config.to_dict(verbose=True)["model_cache"][0].get("revision") is None
assert config.to_dict(verbose=True)["model_cache"][0].get("revision") == ""
assert config.to_dict(verbose=True)["model_cache"][1].get("revision") == "not-main2"


Expand Down
2 changes: 1 addition & 1 deletion uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading