From f358b44e673e4c91c733bcbf7293b2c3ea80fd3a Mon Sep 17 00:00:00 2001 From: andreasjansson Date: Fri, 23 Aug 2024 13:41:42 +0200 Subject: [PATCH] Multi-LoRA inference --- .github/workflows/ci.yaml | 59 ++++---- .github/workflows/push.yaml | 52 +++++++ predict.py | 213 +++++++++++++++------------- requirements-test.txt | 4 + unit-tests/test_weights.py | 180 +++++++++++++++++++++++ weights.py | 276 ++++++++++++++++++++++++------------ 6 files changed, 570 insertions(+), 214 deletions(-) create mode 100644 .github/workflows/push.yaml create mode 100644 requirements-test.txt create mode 100644 unit-tests/test_weights.py diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 414b3ec..b0c8013 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -1,20 +1,14 @@ name: CI on: - workflow_dispatch: - inputs: - test_only: - description: 'Test only, without pushing to prod' - type: boolean - default: true - compare_outputs: - description: 'Compare outputs between existing version and new version' - type: boolean - default: true + push: + branches: [main] + pull_request: + branches: [main] jobs: - cog-safe-push: - runs-on: ubuntu-latest-4-cores + lint: + runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 @@ -24,28 +18,33 @@ jobs: with: python-version: '3.12' - - name: Install Cog + - name: Install dependencies run: | - sudo curl -o /usr/local/bin/cog -L "https://github.com/replicate/cog/releases/latest/download/cog_$(uname -s)_$(uname -m)" - sudo chmod +x /usr/local/bin/cog + pip install -r requirements-test.txt - - name: cog login + - name: Run ruff run: | - echo ${{ secrets.COG_TOKEN }} | cog login --token-stdin + ruff check --exclude=ai-toolkit/ --exclude=LLaVA/ --ignore=E402 - - name: Install cog-safe-push + - name: Run black run: | - pip install git+https://github.com/replicate/cog-safe-push.git + black --check --exclude="ai-toolkit/|LLaVA/" . - - name: Run cog-safe-push - env: - ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }} - REPLICATE_API_TOKEN: ${{ secrets.REPLICATE_API_TOKEN }} + unit-test: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v3 + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: '3.12' + + - name: Install dependencies + run: | + pip install -r requirements-test.txt + + - name: Run pytest run: | - cog-safe-push -vv \ - --test-model=replicate-internal/test-flux-fine-tuner \ - ${{ github.event.inputs.test_only == 'true' && '--test-only' || '' }} \ - ${{ github.event.inputs.compare_outputs == 'false' && '--no-compare-outputs' || '' }} \ - --test-hardware=cpu \ - -i replicate_weights="https://replicate.delivery/yhqm/iWjMZHd2T35kI5jaUkaG3Jb43MeA67PpYjKZQeifvTEf9yTNB/trained_model.tar" \ - ostris/flux-dev-lora-trainer + pytest unit-tests/ diff --git a/.github/workflows/push.yaml b/.github/workflows/push.yaml new file mode 100644 index 0000000..14f334b --- /dev/null +++ b/.github/workflows/push.yaml @@ -0,0 +1,52 @@ +name: Push + +on: + workflow_dispatch: + branches: [main] + inputs: + test_only: + description: 'Test only, without pushing to prod' + type: boolean + default: true + compare_outputs: + description: 'Compare outputs between existing version and new version' + type: boolean + default: true + +jobs: + cog-safe-push: + runs-on: ubuntu-latest-4-cores + + steps: + - uses: actions/checkout@v3 + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: '3.12' + + - name: Install Cog + run: | + sudo curl -o /usr/local/bin/cog -L "https://github.com/replicate/cog/releases/latest/download/cog_$(uname -s)_$(uname -m)" + sudo chmod +x /usr/local/bin/cog + + - name: cog login + run: | + echo ${{ secrets.COG_TOKEN }} | cog login --token-stdin + + - name: Install cog-safe-push + run: | + pip install git+https://github.com/replicate/cog-safe-push.git + + - name: Run cog-safe-push + env: + ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }} + REPLICATE_API_TOKEN: ${{ secrets.REPLICATE_API_TOKEN }} + run: | + cog-safe-push -vv \ + --test-model=replicate-internal/test-flux-fine-tuner \ + ${{ github.event.inputs.test_only == 'true' && '--test-only' || '' }} \ + ${{ github.event.inputs.compare_outputs == 'false' && '--no-compare-outputs' || '' }} \ + --test-hardware=cpu \ + -i replicate_weights="https://replicate.delivery/yhqm/iWjMZHd2T35kI5jaUkaG3Jb43MeA67PpYjKZQeifvTEf9yTNB/trained_model.tar" \ + ostris/flux-dev-lora-trainer diff --git a/predict.py b/predict.py index 3129034..2d5912d 100644 --- a/predict.py +++ b/predict.py @@ -1,22 +1,22 @@ -from cog import BasePredictor, Input, Path +from dataclasses import dataclass import os import time import torch import subprocess from typing import List -import base64 -import tempfile -import tarfile -from io import BytesIO +from cog import BasePredictor, Input, Path import numpy as np from diffusers import FluxPipeline -from weights import WeightsDownloadCache from transformers import CLIPImageProcessor from diffusers.pipelines.stable_diffusion.safety_checker import ( StableDiffusionSafetyChecker, ) +from weights import WeightsDownloadCache + -MODEL_URL_DEV = "https://weights.replicate.delivery/default/black-forest-labs/FLUX.1-dev/files.tar" +MODEL_URL_DEV = ( + "https://weights.replicate.delivery/default/black-forest-labs/FLUX.1-dev/files.tar" +) MODEL_URL_SCHNELL = "https://weights.replicate.delivery/default/black-forest-labs/FLUX.1-schnell/slim.tar" SAFETY_CACHE = "safety-cache" FEATURE_EXTRACTOR = "/src/feature-extractor" @@ -37,55 +37,13 @@ } -def download_weights(url, dest): - start = time.time() - print("downloading url: ", url) - print("downloading to: ", dest) - subprocess.check_call(["pget", "-xf", url, dest], close_fds=False) - print("downloading took: ", time.time() - start) +@dataclass +class LoadedLoRAs: + main: str | None + extra: str | None class Predictor(BasePredictor): - def load_trained_weights( - self, weights: Path | str, pipe: FluxPipeline, lora_scale: float - ): - if isinstance(weights, str) and weights.startswith("data:"): - # Handle data URL - print("Loading LoRA weights from data URL") - - # not caching data URIs, can revisit if this becomes common - pipe.unload_lora_weights() - self.set_loaded_weights_string(pipe, "loading") - _, encoded = weights.split(",", 1) - data = base64.b64decode(encoded) - with tempfile.TemporaryDirectory() as temp_dir: - with tarfile.open(fileobj=BytesIO(data), mode="r:*") as tar: - tar.extractall(path=temp_dir) - lora_path = os.path.join( - temp_dir, "output/flux_train_replicate/lora.safetensors" - ) - pipe.load_lora_weights(lora_path) - pipe.fuse_lora(lora_scale=lora_scale) - self.set_loaded_weights_string(pipe, "data_uri") - else: - # Handle local path - print("Loading LoRA weights") - weights = str(weights) - if weights == self.get_loaded_weights_string(pipe): - print("Weights already loaded") - return - pipe.unload_lora_weights() - - self.set_loaded_weights_string(pipe, "loading") - local_weights_cache = self.weights_cache.ensure(weights) - lora_path = os.path.join( - local_weights_cache, "output/flux_train_replicate/lora.safetensors" - ) - pipe.load_lora_weights(lora_path) - self.set_loaded_weights_string(pipe, weights) - - print("LoRA weights loaded successfully") - def setup(self) -> None: """Load the model into memory to make running multiple predictions efficient""" start = time.time() @@ -96,7 +54,7 @@ def setup(self) -> None: print("Loading safety checker...") if not os.path.exists(SAFETY_CACHE): - download_weights(SAFETY_URL, SAFETY_CACHE) + download_base_weights(SAFETY_URL, SAFETY_CACHE) self.safety_checker = StableDiffusionSafetyChecker.from_pretrained( SAFETY_CACHE, torch_dtype=torch.float16 ).to("cuda") @@ -104,56 +62,34 @@ def setup(self) -> None: print("Loading Flux dev pipeline") if not os.path.exists("FLUX.1-dev"): - download_weights(MODEL_URL_DEV, ".") - self.dev_pipe = FluxPipeline.from_pretrained( + download_base_weights(MODEL_URL_DEV, ".") + dev_pipe = FluxPipeline.from_pretrained( "FLUX.1-dev", torch_dtype=torch.bfloat16, ).to("cuda") - self.dev_weights = "" print("Loading Flux schnell pipeline") if not os.path.exists("FLUX.1-schnell"): - download_weights(MODEL_URL_SCHNELL, "FLUX.1-schnell") - self.schnell_pipe = FluxPipeline.from_pretrained( + download_base_weights(MODEL_URL_SCHNELL, "FLUX.1-schnell") + schnell_pipe = FluxPipeline.from_pretrained( "FLUX.1-schnell", - text_encoder=self.dev_pipe.text_encoder, - text_encoder_2=self.dev_pipe.text_encoder_2, - tokenizer=self.dev_pipe.tokenizer, - tokenizer_2=self.dev_pipe.tokenizer_2, + text_encoder=dev_pipe.text_encoder, + text_encoder_2=dev_pipe.text_encoder_2, + tokenizer=dev_pipe.tokenizer, + tokenizer_2=dev_pipe.tokenizer_2, torch_dtype=torch.bfloat16, ).to("cuda") - self.schnell_weights = "" - - print("setup took: ", time.time() - start) - - def get_loaded_weights_string(self, pipe: FluxPipeline): - return ( - self.dev_weights - if pipe.transformer.config.guidance_embeds - else self.schnell_weights - ) - def set_loaded_weights_string(self, pipe: FluxPipeline, new_weights: str): - if pipe.transformer.config.guidance_embeds: - self.dev_weights = new_weights - else: - self.schnell_weights = new_weights - return - - @torch.amp.autocast("cuda") - def run_safety_checker(self, image): - safety_checker_input = self.feature_extractor(image, return_tensors="pt").to( - "cuda" - ) - np_image = [np.array(val) for val in image] - image, has_nsfw_concept = self.safety_checker( - images=np_image, - clip_input=safety_checker_input.pixel_values.to(torch.float16), - ) - return image, has_nsfw_concept + self.pipes = { + "dev": dev_pipe, + "schnell": schnell_pipe, + } + self.loaded_lora_urls = { + "dev": LoadedLoRAs(main=None, extra=None), + "schnell": LoadedLoRAs(main=None, extra=None), + } - def aspect_ratio_to_width_height(self, aspect_ratio: str) -> tuple[int, int]: - return ASPECT_RATIOS[aspect_ratio] + print("setup took: ", time.time() - start) @torch.inference_mode() def predict( @@ -183,7 +119,7 @@ def predict( default=1, ), lora_scale: float = Input( - description="Determines how strongly the LoRA should be applied. Sane results between 0 and 1.", + description="Determines how strongly the main LoRA should be applied. Sane results between 0 and 1.", default=1.0, le=2.0, ge=-1.0, @@ -208,6 +144,16 @@ def predict( seed: int = Input( description="Random seed. Set for reproducible generation", default=None ), + extra_lora: str = Input( + description="Combine this fine-tune with another LoRA. Supports Replicate models in the format / or //, HuggingFace URLs in the format huggingface.co//, CivitAI URLs in the format civitai.com/models/[/], or arbitrary .safetensors URLs from the Internet. For example, 'fofr/flux-pixar-cars'", + default=None, + ), + extra_lora_scale: float = Input( + description="Determines how strongly the extra LoRA should be applied.", + ge=0, + le=1, + default=0.8, + ), output_format: str = Input( description="Format of the output images", choices=["webp", "jpg", "png"], @@ -254,17 +200,30 @@ def predict( if model == "dev": print("Using dev model") max_sequence_length = 512 - pipe = self.dev_pipe else: print("Using schnell model") max_sequence_length = 256 - pipe = self.schnell_pipe guidance_scale = 0 + pipe = self.pipes[model] + if replicate_weights: - self.load_trained_weights(replicate_weights, pipe, lora_scale) + start_time = time.time() + if extra_lora: + flux_kwargs["joint_attention_kwargs"] = {"scale": 1.0} + print(f"Loading extra LoRA weights from: {extra_lora}") + self.load_multiple_loras(replicate_weights, extra_lora, model) + pipe.set_adapters( + ["main", "extra"], adapter_weights=[lora_scale, extra_lora_scale] + ) + else: + flux_kwargs["joint_attention_kwargs"] = {"scale": lora_scale} + self.load_single_lora(replicate_weights, model) + pipe.set_adapters(["main"], adapter_weights=[lora_scale]) + print(f"Loaded LoRAs in {time.time() - start_time:.2f}s") else: pipe.unload_lora_weights() + self.loaded_lora_urls[model] = LoadedLoRAs(main=None, extra=None) generator = torch.Generator("cuda").manual_seed(seed) @@ -301,6 +260,66 @@ def predict( return output_paths + def load_single_lora(self, lora_url: str, model: str): + # If no change, skip + if lora_url == self.loaded_lora_urls[model].main: + print("Weights already loaded") + return + + pipe = self.pipes[model] + pipe.unload_lora_weights() + lora_path = self.weights_cache.ensure(lora_url) + pipe.load_lora_weights(lora_path, adapter_name="main") + self.loaded_lora_urls[model] = LoadedLoRAs(main=lora_url, extra=None) + + def load_multiple_loras(self, main_lora_url: str, extra_lora_url: str, model: str): + pipe = self.pipes[model] + loaded_lora_urls = self.loaded_lora_urls[model] + + # If no change, skip + if ( + main_lora_url == loaded_lora_urls.main + and extra_lora_url == self.loaded_lora_urls[model].extra + ): + print("Weights already loaded") + return + + # We always need to load both? + pipe.unload_lora_weights() + + main_lora_path = self.weights_cache.ensure(main_lora_url) + pipe.load_lora_weights(main_lora_path, adapter_name="main") + + extra_lora_path = self.weights_cache.ensure(extra_lora_url) + pipe.load_lora_weights(extra_lora_path, adapter_name="extra") + + self.loaded_lora_urls[model] = LoadedLoRAs( + main=main_lora_url, extra=extra_lora_url + ) + + @torch.amp.autocast("cuda") + def run_safety_checker(self, image): + safety_checker_input = self.feature_extractor(image, return_tensors="pt").to( + "cuda" + ) + np_image = [np.array(val) for val in image] + image, has_nsfw_concept = self.safety_checker( + images=np_image, + clip_input=safety_checker_input.pixel_values.to(torch.float16), + ) + return image, has_nsfw_concept + + def aspect_ratio_to_width_height(self, aspect_ratio: str) -> tuple[int, int]: + return ASPECT_RATIOS[aspect_ratio] + + +def download_base_weights(url, dest): + start = time.time() + print("downloading url: ", url) + print("downloading to: ", dest) + subprocess.check_call(["pget", "-xf", url, dest], close_fds=False) + print("downloading took: ", time.time() - start) + def make_multiple_of_16(n): return ((n + 15) // 16) * 16 diff --git a/requirements-test.txt b/requirements-test.txt new file mode 100644 index 0000000..18f5b31 --- /dev/null +++ b/requirements-test.txt @@ -0,0 +1,4 @@ +pytest +requests-mock +ruff +black diff --git a/unit-tests/test_weights.py b/unit-tests/test_weights.py new file mode 100644 index 0000000..ae4a438 --- /dev/null +++ b/unit-tests/test_weights.py @@ -0,0 +1,180 @@ +from pathlib import Path +import sys +import pytest +import requests_mock +from unittest.mock import patch, MagicMock + +sys.path.append(str(Path(__file__).parent.parent)) +from weights import make_download_url, WeightsDownloadCache + + +def test_replicate_model_url(): + assert make_download_url("owner/model") == "https://replicate.com/owner/model/_weights" + assert ( + make_download_url("https://replicate.com/owner/model") + == "https://replicate.com/owner/model/_weights" + ) + + +def test_replicate_version_url(): + assert ( + make_download_url("owner/model/version123") + == "https://replicate.com/owner/model/versions/version123/_weights" + ) + assert ( + make_download_url("owner/model/versions/version123") + == "https://replicate.com/owner/model/versions/version123/_weights" + ) + assert ( + make_download_url("https://replicate.com/owner/model/versions/version123") + == "https://replicate.com/owner/model/versions/version123/_weights" + ) + + +def test_replicate_com_url(): + url = "https://replicate.com/owner/model" + assert make_download_url(url) == "https://replicate.com/owner/model/_weights" + +def test_replicate_com_version_url(): + url = "https://replicate.com/owner/model/versions/123abc" + assert make_download_url(url) == "https://replicate.com/owner/model/versions/123abc/_weights" + +def test_huggingface_url(): + with requests_mock.Mocker() as m: + m.get( + "https://huggingface.co/api/models/owner/model/tree/main", + json=[{"path": "model.safetensors", "type": "file"}], + ) + assert ( + make_download_url("https://huggingface.co/owner/model") + == "https://huggingface.co/owner/model/resolve/main/model.safetensors" + ) + + +def test_civitai_url(): + assert ( + make_download_url("https://civitai.com/models/12345") + == "https://civitai.com/api/download/models/12345?type=Model&format=SafeTensor" + ) + assert ( + make_download_url("civitai.com/models/12345/model-name") + == "https://civitai.com/api/download/models/12345?type=Model&format=SafeTensor" + ) + + +def test_direct_safetensors_url(): + assert ( + make_download_url("https://example.com/model.safetensors") + == "https://example.com/model.safetensors" + ) + assert ( + make_download_url("https://example.com/model.safetensors?download=true") + == "https://example.com/model.safetensors" + ) + + +def test_replicate_delivery_url(): + url = "https://replicate.delivery/pbxt/ABC123/model.tar" + assert make_download_url(url) == url + +def test_data_url(): + data_url = "data:application/x-tar;base64,SGVsbG8gV29ybGQh" + assert make_download_url(data_url) == data_url + + +def test_invalid_huggingface_url(): + with pytest.raises(ValueError, match="Failed to parse HuggingFace URL"): + make_download_url("https://huggingface.co/invalid/url/format") + + +def test_invalid_civitai_url(): + with pytest.raises(ValueError, match="Failed to parse CivitAI URL"): + make_download_url("https://civitai.com/invalid/url/format") + + +def test_unsupported_url(): + with pytest.raises(ValueError, match="Failed to parse URL"): + make_download_url("https://unsupported.com/model") + + +def test_huggingface_no_safetensors(): + with requests_mock.Mocker() as m: + m.get( + "https://huggingface.co/api/models/owner/model/tree/main", + json=[{"path": "model.bin", "type": "file"}], + ) + with pytest.raises(ValueError, match="No .safetensors file found"): + make_download_url("https://huggingface.co/owner/model") + + +def test_huggingface_multiple_safetensors(): + with requests_mock.Mocker() as m: + m.get( + "https://huggingface.co/api/models/owner/model/tree/main", + json=[ + {"path": "model1.safetensors", "type": "file"}, + {"path": "model2.safetensors", "type": "file"}, + ], + ) + with pytest.raises(ValueError, match="Multiple .safetensors files found"): + make_download_url("https://huggingface.co/owner/model") + + +@pytest.fixture +def mock_base_dir(tmp_path): + return tmp_path / "weights-cache" + + +@pytest.fixture +def cache(mock_base_dir): + return WeightsDownloadCache(min_disk_free=1000, base_dir=mock_base_dir) + + +@patch("weights.download_weights") +@patch("shutil.disk_usage") +@patch("pathlib.Path.unlink") +def test_weights_download_cache( + mock_unlink, mock_disk_usage, mock_download_weights, cache, mock_base_dir +): + # Setup + disk_space = [1500, 1000, 500, 1000] # Simulate changing disk space + mock_disk_usage.side_effect = [MagicMock(free=space) for space in disk_space] + + # Test ensure method + url1 = "https://example.com/weights1.tar" + url2 = "https://example.com/weights2.tar" + + # First call should download + path1 = cache.ensure(url1) + mock_download_weights.assert_called_once_with(url1, path1) + assert path1.parent == mock_base_dir + assert cache._hits == 0 + assert cache._misses == 1 + + # Second call to same URL should hit cache + cache.ensure(url1) + assert cache._hits == 1 + assert cache._misses == 1 + + # Call with new URL should download again + _ = cache.ensure(url2) + assert mock_download_weights.call_count == 2 + assert cache._hits == 1 + assert cache._misses == 2 + + # Test LRU behavior + url3 = "https://example.com/weights3.tar" + cache.ensure(url3) + mock_unlink.assert_called_once_with() # Check that unlink was called + + # Test cache_info + info = cache.cache_info() + assert "hits=1" in info + assert "misses=3" in info + assert str(mock_base_dir) in info + + +def test_weights_download_cache_initialization(mock_base_dir): + cache = WeightsDownloadCache(base_dir=mock_base_dir) + assert cache.base_dir == mock_base_dir + assert mock_base_dir.exists() diff --git a/weights.py b/weights.py index a25f71c..aee3cab 100644 --- a/weights.py +++ b/weights.py @@ -1,27 +1,25 @@ +import os +import base64 +from io import BytesIO +import tempfile +import tarfile +from pathlib import Path +import re from collections import deque import hashlib -import os import shutil import subprocess import time +import requests + + +DEFAULT_CACHE_BASE_DIR = Path("/src/weights-cache") class WeightsDownloadCache: def __init__( - self, min_disk_free: int = 10 * (2**30), base_dir: str = "/src/weights-cache" + self, min_disk_free: int = 10 * (2**30), base_dir: Path = DEFAULT_CACHE_BASE_DIR ): - """ - WeightsDownloadCache is meant to track and download weights files as fast - as possible, while ensuring there's enough disk space. - - It tries to keep the most recently used weights files in the cache, so - ensure you call ensure() on the weights each time you use them. - - It will not re-download weights files that are already in the cache. - - :param min_disk_free: Minimum disk space required to start download, in bytes. - :param base_dir: The base directory to store weights files. - """ self.min_disk_free = min_disk_free self.base_dir = base_dir self._hits = 0 @@ -29,98 +27,202 @@ def __init__( # Least Recently Used (LRU) cache for paths self.lru_paths = deque() - if not os.path.exists(base_dir): - os.makedirs(base_dir) + base_dir.mkdir(parents=True, exist_ok=True) - def _remove_least_recent(self) -> None: - """ - Remove the least recently used weights file from the cache and disk. - """ - oldest = self.lru_paths.popleft() - self._rm_disk(oldest) + def ensure(self, url: str) -> Path: + path = self._weights_path(url) - def cache_info(self) -> str: - """ - Get cache information. + if path in self.lru_paths: + # here we remove to re-add to the end of the LRU (marking it as recently used) + self._hits += 1 + self.lru_paths.remove(path) + else: + self._misses += 1 - :return: Cache information. - """ + while not self._has_enough_space() and len(self.lru_paths) > 0: + self._remove_least_recent() + download_weights(url, path) + + self.lru_paths.append(path) # Add file to end of cache + return path + + def cache_info(self) -> str: return f"CacheInfo(hits={self._hits}, misses={self._misses}, base_dir='{self.base_dir}', currsize={len(self.lru_paths)})" - def _rm_disk(self, path: str) -> None: - """ - Remove a weights file or directory from disk. - :param path: Path to remove. - """ - if os.path.isfile(path): - os.remove(path) - elif os.path.isdir(path): - shutil.rmtree(path) + def _remove_least_recent(self) -> None: + oldest = self.lru_paths.popleft() + print("removing oldest", oldest) + oldest.unlink() def _has_enough_space(self) -> bool: - """ - Check if there's enough disk space. - - :return: True if there's more than min_disk_free free, False otherwise. - """ disk_usage = shutil.disk_usage(self.base_dir) - print(f"Free disk space: {disk_usage.free}") - return disk_usage.free >= self.min_disk_free - def ensure(self, url: str) -> str: - """ - Ensure weights file is in the cache and return its path. + free = disk_usage.free + print(f"{free=}") # TODO(andreas): remove debug - This also updates the LRU cache to mark the weights as recently used. + return free >= self.min_disk_free - :param url: URL to download weights file from, if not in cache. - :return: Path to weights. - """ - path = self.weights_path(url) + def _weights_path(self, url: str) -> Path: + hashed_url = hashlib.sha256(url.encode()).hexdigest() + short_hash = hashed_url[:16] # Use the first 16 characters of the hash + return self.base_dir / short_hash - if path in self.lru_paths: - # here we remove to re-add to the end of the LRU (marking it as recently used) - self._hits += 1 - self.lru_paths.remove(path) - else: - self._misses += 1 - self.download_weights(url, path) - self.lru_paths.append(path) # Add file to end of cache - return path +def download_weights(url: str, path: Path): + download_url = make_download_url(url) + download_weights_url(download_url, path) - def weights_path(self, url: str) -> str: - """ - Generate path to store a weights file based hash of the URL. - :param url: URL to download weights file from. - :return: Path to store weights file. - """ - hashed_url = hashlib.sha256(url.encode()).hexdigest() - short_hash = hashed_url[:16] # Use the first 16 characters of the hash - return os.path.join(self.base_dir, short_hash) +def download_weights_url(url: str, path: Path): + path = Path(path) - def download_weights(self, url: str, dest: str) -> None: - """ - Download weights file from a URL, ensuring there's enough disk space. + print("Downloading weights") + start_time = time.time() - :param url: URL to download weights file from. - :param dest: Path to store weights file. - """ - print("Ensuring enough disk space...") - while not self._has_enough_space() and len(self.lru_paths) > 0: - self._remove_least_recent() + if url.startswith("data:"): + download_data_url(url, path) + elif url.endswith(".tar"): + download_safetensors_tarball(url, path) + elif url.endswith(".safetensors"): + download_safetensors(url, path) + elif "://civitai.com/api/download" in url: + download_safetensors(url, path) + elif url.endswith("/_weights"): + download_safetensors_tarball(url, path) + else: + raise ValueError("URL must end with either .tar or .safetensors") - print("Downloading weights") + print(f"Downloaded weights in {time.time() - start_time:.2f}s") + + +def find_safetensors(directory: Path) -> list[Path]: + safetensors_paths = [] + for root, _, files in os.walk(directory): + root = Path(root) + for filename in files: + path = root / filename + if path.suffix == ".safetensors": + safetensors_paths.append(path) + return safetensors_paths + + +def download_safetensors_tarball(url: str, path: Path): + with tempfile.TemporaryDirectory() as temp_dir: + temp_dir = Path(temp_dir) + extract_dir = temp_dir / "weights" - st = time.time() - # maybe retry with the real url if this doesn't work try: - subprocess.check_output(["pget", "--log-level", "warn", "-x", url, dest], close_fds=True) + subprocess.run(["pget", "-x", url, extract_dir], check=True) except subprocess.CalledProcessError as e: - # If download fails, clean up and re-raise exception - print(e.output) - self._rm_disk(dest) - raise e - print(f"Downloaded weights in {time.time() - st:.2f}s") + raise RuntimeError(f"Failed to download tarball: {e}") + + safetensors_paths = find_safetensors(extract_dir) + if not safetensors_paths: + raise ValueError("No .safetensors file found in tarball") + if len(safetensors_paths) > 1: + raise ValueError("Multiple .safetensors files found in tarball") + safetensors_path = safetensors_paths[0] + + shutil.move(safetensors_path, path) + + +def download_data_url(url: str, path: Path): + _, encoded = url.split(",", 1) + data = base64.b64decode(encoded) + + with tempfile.TemporaryDirectory() as temp_dir: + with tarfile.open(fileobj=BytesIO(data), mode="r:*") as tar: + tar.extractall(path=temp_dir) + + safetensors_paths = find_safetensors(Path(temp_dir)) + if not safetensors_paths: + raise ValueError("No .safetensors file found in data URI") + if len(safetensors_paths) > 1: + raise ValueError("Multiple .safetensors files found in data URI") + safetensors_path = safetensors_paths[0] + + shutil.move(safetensors_path, path) + + +def download_safetensors(url: str, path: Path): + try: + subprocess.run(["pget", url, str(path)], check=True) + except subprocess.CalledProcessError as e: + raise RuntimeError(f"Failed to download safetensors file: {e}") + + +def make_download_url(url: str) -> str: + if url.startswith("data:"): + return url + if m := re.match(r"^(?:https?://)?huggingface\.co/([^/]+)/([^/]+)/?$", url): + owner, model_name = m.groups() + return make_huggingface_download_url(owner, model_name) + if m := re.match(r"^(?:https?://)?civitai\.com/models/(\d+)(?:/[^/?]+)?/?$", url): + model_id = m.groups()[0] + return make_civitai_download_url(model_id) + if m := re.match(r"^((?:https?://)?civitai\.com/api/download/models/.*)$", url): + return url + if m := re.match(r"^(https?://.*\.safetensors)(?:\?|$)", url): + safetensors_url = m.groups()[0] + return safetensors_url + if m := re.match(r"^(?:https?://replicate.com/)?([^/]+)/([^/]+)/?$", url): + owner, model_name = m.groups() + return make_replicate_model_download_url(owner, model_name) + if m := re.match( + r"^(?:https?://replicate.com/)?([^/]+)/([^/]+)/(?:versions/)?([^/]+)/?$", url + ): + owner, model_name, version_id = m.groups() + return make_replicate_version_download_url(owner, model_name, version_id) + if m := re.match(r"^(https?://replicate.delivery/.*\.tar)$", url): + replicate_tar_url = m.groups()[0] + return replicate_tar_url + + if "huggingface.co" in url: + raise ValueError( + "Failed to parse HuggingFace URL. Expected huggingface.co//" + ) + if "civitai.com" in url: + raise ValueError( + "Failed to parse CivitAI URL. Expected civitai.com/models/[/]" + ) + raise ValueError( + """Failed to parse URL. Expected either: +* Replicate model in the format / or // +* HuggingFace URL in the format huggingface.co// +* CivitAI URL in the format civitai.com/models/[/] +* Arbitrary .safetensors URLs from the Internet""" + ) + + +def make_replicate_model_download_url(owner: str, model_name: str) -> str: + return f"https://replicate.com/{owner}/{model_name}/_weights" + + +def make_replicate_version_download_url( + owner: str, model_name: str, version_id: str +) -> str: + return f"https://replicate.com/{owner}/{model_name}/versions/{version_id}/_weights" + + +def make_huggingface_download_url(owner: str, model_name: str) -> str: + url = f"https://huggingface.co/api/models/{owner}/{model_name}/tree/main" + response = requests.get(url) + response.raise_for_status() + + files = response.json() + safetensors_files = [f for f in files if f["path"].endswith(".safetensors")] + + if len(safetensors_files) == 0: + raise ValueError("No .safetensors file found in the repository") + elif len(safetensors_files) > 1: + raise ValueError("Multiple .safetensors files found in the repository") + + safetensors_path = safetensors_files[0]["path"] + return ( + f"https://huggingface.co/{owner}/{model_name}/resolve/main/{safetensors_path}" + ) + + +def make_civitai_download_url(model_id: str) -> str: + return f"https://civitai.com/api/download/models/{model_id}?type=Model&format=SafeTensor"