From 4623083c54749921f1ea10eaa9e47d996e1b53d1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jonatan=20K=C5=82osko?= Date: Wed, 27 Sep 2023 23:10:51 +0700 Subject: [PATCH] Improve loading errors and add docs on HF repos (#256) --- README.md | 54 ++- lib/bumblebee.ex | 498 +++++++++++++++--------- lib/bumblebee/huggingface/hub.ex | 63 ++- mix.exs | 2 +- mix.lock | 2 +- test/bumblebee/huggingface/hub_test.exs | 8 +- 6 files changed, 433 insertions(+), 194 deletions(-) diff --git a/README.md b/README.md index a860b4e5..3c65f661 100644 --- a/README.md +++ b/README.md @@ -75,7 +75,59 @@ Nx.Serving.run(serving, "The capital of [MASK] is Paris.") #=> } ``` -We load the BERT model from Hugging Face Hub, then plug it into an end-to-end pipeline in the form of "serving", finally we use the serving to get our task done. For more details check out [the documentation](https://hexdocs.pm/bumblebee) and the resources below. +We load the BERT model from Hugging Face Hub, then plug it into an end-to-end pipeline in the form of "serving", finally we use the serving to get our task done. For more details check out [the documentation](https://hexdocs.pm/bumblebee). + + + +## HuggingFace Hub + +HuggingFace Hub is a platform hosting models, datasets and demo apps (Spaces), all using Git repositories (with Git LFS for large files). For further information check out the [Hub documentation](https://huggingface.co/docs/hub/index) and explore the [model repositories](https://huggingface.co/models). + +### Models + +Model repositories are regular Git repositories, therefore they can store arbitrary files. However, most repositories store models saved using the Python [Transformers](https://github.com/huggingface/transformers) library. Bumblebee is an Elixir counterpart of Transformers and allows for importing those models, as long as they are implemented in Bumblebee. + +A repository in the Transformers format does not store an actual model, only the trained parameters and a configuration file. The configuration file specifies the model type (e.g. BERT) and high-level properties, such as the number layers and their size. The model implementation lives in the library code (both Transformers and Bumblebee). When loading a model, the library fetches the configuration and builds a matching model, then it fetches the trained parameters to pair them with the model. The key takeaway is that in order to use any given model, it needs to have an implementation in Bumblebee. + +### Model repository + +Here is a list of files commonly found in a repository following the Transformers format. + + * `config.json` - model configuration, specifies the model type and model-specific options. You can think of this as a blueprint for how the model should be constructed + + * `pytorch_model.bin` - raw model parameters (tensors) serialized from a PyTorch model using [PyTorch format](https://pytorch.org/docs/stable/generated/torch.save.html) (supported by Bumblebee) + + * `model.safetensors` - raw model parameters (tensors) serialized from a PyTorch model using [Safetensors](https://github.com/huggingface/safetensors) (supported by Bumblebee) + + * `flax_model.msgpack`, `tf_model.h5` - raw model parameters (tensors) serialized from Flax and Tensorflow models respectively (not supported by Bumblebee) + + * `tokenizer.json`, `tokenizer_config.json` - tokenizer configuration, describes how to convert text input to model inputs (tensors). See [Tokenizer support](#tokenizer-support) + + * `preprocessor_config.json` - featurizer configuration, describes how to convert real-world input (image, audio) to model inputs (tensors) + + * `generation_config.json` - a set of configuration options specific to text generation, such as token sampling strategy and various constraints + +### Model support + +As pointed out above, in order to load a model, the given model type must be implemented in Bumblebee. One way to find out is calling `Bumblebee.load_model({:hf, "model-repo"})`. Alternatively, you can open the `config.json` file in the model repository and copy the class name under `"architectures"`, then [search Bumblebee codebase](https://github.com/search?q=repo%3Aelixir-nx%2Fbumblebee+BertForMaskedLM&type=code) for that keyword. + +Also note that certain repositories include multiple models in separate repositories, for example [`stabilityai/stable-diffusion-2`](https://huggingface.co/stabilityai/stable-diffusion-2). In such case use `Bumblebee.load_model({:hf, "model-repo", subdir: "..."})`. + +### Tokenizer support + +The Transformers library distinguishes two types of tokenizer implementations: + + * "slow tokenizer" - a tokenizer implemented in Python and stored as `tokenizer_config.json` and a couple extra files + + * "fast tokenizer" - a tokenizer implemented in Rust and stored in a single file - `tokenizer.json` + +Bumblebee relies on the Rust implementations (through bindings to [Tokenizers](https://github.com/huggingface/tokenizers)) and therefore always requires the `tokenizer.json` file. Many repositories only include files for a "slow tokenizer". When you stumble upon such repository, there are two options you can try. + +First, if the repository is clearly a fine-tuned version of another model, you can look for `tokenizer.json` in the original model repository. For example, [`textattack/bert-base-uncased-yelp-polarity`](https://huggingface.co/textattack/bert-base-uncased-yelp-polarity) only includes `tokenizer_config.json`, but it is a fine-tuned version of [`bert-base-uncased`](https://huggingface.co/bert-base-uncased), which does include `tokenizer.json`. Consequently, you can safely load the model from `textattack/bert-base-uncased-yelp-polarity` and tokenizer from `bert-base-uncased`. + +Otherwise, the Transformers library includes conversion rules to load a "slow tokenizer" and convert it to a corresponding "fast tokenizer", which is possible in most cases. You can generate the `tokenizer.json` file using [this tool](https://jonatanklosko-bumblebee-tools.hf.space/apps/tokenizer-generator). Once successful, you can follow the steps to submit a PR adding `tokenizer.json` to the model repository. Note that you do not have to wait for the PR to be merged, instead you can copy commit SHA from the PR and load the tokenizer with `Bumblebee.load_tokenizer({:hf, "model-repo", revision: "..."})`. + + ## License diff --git a/lib/bumblebee.ex b/lib/bumblebee.ex index 10800598..72dc24da 100644 --- a/lib/bumblebee.ex +++ b/lib/bumblebee.ex @@ -1,4 +1,11 @@ defmodule Bumblebee do + @external_resource "README.md" + + [_, readme_docs, _] = + "README.md" + |> File.read!() + |> String.split("") + @moduledoc """ Pre-trained `Axon` models for easy inference and boosted training. @@ -51,6 +58,8 @@ defmodule Bumblebee do > > The models are generally large, so make sure to configure an efficient > `Nx` backend, such as `EXLA` or `Torchx`. + + #{readme_docs} """ alias Bumblebee.HuggingFace @@ -58,10 +67,12 @@ defmodule Bumblebee do @config_filename "config.json" @featurizer_filename "preprocessor_config.json" @tokenizer_filename "tokenizer.json" + @tokenizer_config_filename "tokenizer_config.json" @tokenizer_special_tokens_filename "special_tokens_map.json" @generation_filename "generation_config.json" @scheduler_filename "scheduler_config.json" - @params_filename %{pytorch: "pytorch_model.bin"} + @pytorch_params_filename "pytorch_model.bin" + @safetensors_params_filename "model.safetensors" @transformers_class_to_model %{ "AlbertForMaskedLM" => {Bumblebee.Text.Albert, :for_masked_language_modeling}, @@ -335,38 +346,53 @@ defmodule Bumblebee do module = opts[:module] architecture = opts[:architecture] - with {:ok, path} <- download(repository, @config_filename), - {:ok, spec_data} <- decode_config(path) do - {inferred_module, inferred_architecture, inference_error} = - case infer_model_type(spec_data) do - {:ok, module, architecture} -> {module, architecture, nil} - {:error, error} -> {nil, nil, error} - end + with {:ok, repo_files} <- get_repo_files(repository) do + do_load_spec(repository, repo_files, module, architecture) + end + end - module = module || inferred_module - architecture = architecture || inferred_architecture + defp do_load_spec(repository, repo_files, module, architecture) do + case repo_files do + %{@config_filename => etag} -> + with {:ok, path} <- download(repository, @config_filename, etag), + {:ok, spec_data} <- decode_config(path) do + {inferred_module, inferred_architecture, inference_error} = + case infer_model_type(spec_data) do + {:ok, module, architecture} -> {module, architecture, nil} + {:error, error} -> {nil, nil, error} + end - unless module do - raise "#{inference_error}, please specify the :module and :architecture options" - end + module = module || inferred_module + architecture = architecture || inferred_architecture - architectures = module.architectures() + unless module do + raise ArgumentError, + "#{inference_error}, please specify the :module and :architecture options" + end - if architecture && architecture not in architectures do - raise ArgumentError, - "expected architecture to be one of: #{Enum.map_join(architectures, ", ", &inspect/1)}, but got: #{inspect(architecture)}" - end + architectures = module.architectures() - spec = - if architecture do - configure(module, architecture: architecture) - else - configure(module) - end + if architecture && architecture not in architectures do + raise ArgumentError, + "expected architecture to be one of: #{Enum.map_join(architectures, ", ", &inspect/1)}, but got: #{inspect(architecture)}" + end - spec = HuggingFace.Transformers.Config.load(spec, spec_data) + spec = + if architecture do + configure(module, architecture: architecture) + else + configure(module) + end - {:ok, spec} + spec = HuggingFace.Transformers.Config.load(spec, spec_data) + + {:ok, spec} + end + + %{} -> + raise ArgumentError, + "no config file found in the given repository. Please refer to Bumblebee" <> + " README to learn about repositories and supported models" end end @@ -458,88 +484,102 @@ defmodule Bumblebee do :log_params_diff ]) - spec_response = - if spec = opts[:spec] do - {:ok, spec} - else - load_spec(repository, Keyword.take(opts, [:module, :architecture])) - end - - with {:ok, spec} <- spec_response, + with {:ok, repo_files} <- get_repo_files(repository), + {:ok, spec} <- maybe_load_model_spec(opts, repository, repo_files), model <- build_model(spec), - {:ok, params} <- - load_params( - spec, - model, - repository, - opts - |> Keyword.take([:params_filename, :log_params_diff, :backend]) - ) do + {:ok, params} <- load_params(spec, model, repository, repo_files, opts) do {:ok, %{model: model, params: params, spec: spec}} end end - defp load_params(%module{} = spec, model, repository, opts) do - # TODO: support format: :auto | :axon | :pytorch - format = :pytorch - filename = opts[:params_filename] || @params_filename[format] + defp maybe_load_model_spec(opts, repository, repo_files) do + if spec = opts[:spec] do + {:ok, spec} + else + do_load_spec(repository, repo_files, opts[:module], opts[:architecture]) + end + end + defp load_params(%module{} = spec, model, repository, repo_files, opts) do input_template = module.input_template(spec) params_mapping = Bumblebee.HuggingFace.Transformers.Model.params_mapping(spec) - with {:ok, paths} <- download_params_files(repository, filename) do - params = - Bumblebee.Conversion.PyTorch.load_params!( - model, - input_template, - paths, - [ - params_mapping: params_mapping, - loader_fun: filename |> Path.extname() |> params_file_loader_fun() - ] ++ Keyword.take(opts, [:backend, :log_params_diff]) - ) + {filename, sharded?} = infer_params_filename(repo_files, opts[:params_filename]) + loader_fun = filename |> Path.extname() |> params_file_loader_fun() + + with {:ok, paths} <- download_params_files(repository, repo_files, filename, sharded?) do + opts = + [ + params_mapping: params_mapping, + loader_fun: loader_fun + ] ++ Keyword.take(opts, [:backend, :log_params_diff]) + params = Bumblebee.Conversion.PyTorch.load_params!(model, input_template, paths, opts) {:ok, params} end end - defp download_params_files(repository, filename) do - case download(repository, filename) do - {:ok, path} -> - {:ok, [path]} + defp infer_params_filename(repo_files, nil = _filename) do + cond do + Map.has_key?(repo_files, @pytorch_params_filename) -> + {@pytorch_params_filename, false} - error -> - # Check for sharded params - with {:ok, path} <- download(repository, filename <> ".index.json"), - {:ok, sharded_metadata} <- decode_config(path) do - filenames = - for {_layer, filename} <- sharded_metadata["weight_map"], uniq: true, do: filename + Map.has_key?(repo_files, @pytorch_params_filename <> ".index.json") -> + {@pytorch_params_filename, true} - Enum.reduce_while(filenames, {:ok, []}, fn filename, {:ok, paths} -> - case download(repository, filename) do - {:ok, path} -> {:cont, {:ok, [path | paths]}} - error -> {:halt, error} - end - end) - else - _ -> error - end + Map.has_key?(repo_files, @safetensors_params_filename) -> + {@safetensors_params_filename, false} + + Map.has_key?(repo_files, @safetensors_params_filename <> ".index.json") -> + {@safetensors_params_filename, true} + + true -> + raise ArgumentError, + "none of the expected parameters files found in the repository." <> + " If the file exists under an unusual name, try specifying :params_filename" end end - defp params_file_loader_fun(".safetensors") do - fn path -> - path - |> File.read!() - |> Safetensors.load!() + defp infer_params_filename(repo_files, filename) do + cond do + Map.has_key?(repo_files, filename) -> + {filename, false} + + Map.has_key?(repo_files, filename <> ".index.json") -> + {filename, true} + + true -> + raise ArgumentError, "could not find file #{inspect(filename)} in the repository" end end - defp params_file_loader_fun(_) do - &Bumblebee.Conversion.PyTorch.Loader.load!/1 + defp download_params_files(repository, repo_files, filename, false = _sharded?) do + with {:ok, path} <- download(repository, filename, repo_files[filename]) do + {:ok, [path]} + end end + defp download_params_files(repository, repo_files, filename, true = _sharded?) do + index_filename = filename <> ".index.json" + + with {:ok, path} <- download(repository, index_filename, repo_files[index_filename]), + {:ok, sharded_metadata} <- decode_config(path) do + filenames = + for {_layer, filename} <- sharded_metadata["weight_map"], uniq: true, do: filename + + Enum.reduce_while(filenames, {:ok, []}, fn filename, {:ok, paths} -> + case download(repository, filename, repo_files[filename]) do + {:ok, path} -> {:cont, {:ok, [path | paths]}} + error -> {:halt, error} + end + end) + end + end + + defp params_file_loader_fun(".safetensors"), do: &Safetensors.read!/1 + defp params_file_loader_fun(_), do: &Bumblebee.Conversion.PyTorch.Loader.load!/1 + @doc """ Featurizes `input` with the given featurizer. @@ -592,22 +632,34 @@ defmodule Bumblebee do opts = Keyword.validate!(opts, [:module]) module = opts[:module] - with {:ok, path} <- download(repository, @featurizer_filename), - {:ok, featurizer_data} <- decode_config(path) do - module = - module || - case infer_featurizer_type(featurizer_data, repository) do - {:ok, module} -> module - {:error, error} -> raise "#{error}, please specify the :module option" - end + case get_repo_files(repository) do + {:ok, %{@featurizer_filename => etag} = repo_files} -> + with {:ok, path} <- download(repository, @featurizer_filename, etag), + {:ok, featurizer_data} <- decode_config(path) do + module = + module || + case infer_featurizer_type(featurizer_data, repository, repo_files) do + {:ok, module} -> + module + + {:error, error} -> + raise ArgumentError, "#{error}, please specify the :module option" + end + + featurizer = configure(module) + featurizer = HuggingFace.Transformers.Config.load(featurizer, featurizer_data) + {:ok, featurizer} + end + + {:ok, %{}} -> + raise ArgumentError, "no featurizer found in the given repository" - featurizer = configure(module) - featurizer = HuggingFace.Transformers.Config.load(featurizer, featurizer_data) - {:ok, featurizer} + {:error, message} -> + {:error, message} end end - defp infer_featurizer_type(%{"feature_extractor_type" => class_name}, _repository) do + defp infer_featurizer_type(%{"feature_extractor_type" => class_name}, _repository, _repo_files) do case @transformers_class_to_featurizer[class_name] do nil -> {:error, @@ -618,7 +670,7 @@ defmodule Bumblebee do end end - defp infer_featurizer_type(%{"image_processor_type" => class_name}, _repository) do + defp infer_featurizer_type(%{"image_processor_type" => class_name}, _repository, _repo_files) do case @transformers_image_processor_type_to_featurizer[class_name] do nil -> {:error, @@ -629,8 +681,8 @@ defmodule Bumblebee do end end - defp infer_featurizer_type(_featurizer_data, repository) do - with {:ok, path} <- download(repository, @config_filename), + defp infer_featurizer_type(_featurizer_data, repository, repo_files) do + with {:ok, path} <- download(repository, @config_filename, repo_files[@config_filename]), {:ok, featurizer_data} <- decode_config(path) do case featurizer_data do %{"model_type" => model_type} -> @@ -727,36 +779,59 @@ defmodule Bumblebee do opts = Keyword.validate!(opts, [:module]) module = opts[:module] - with {:ok, path} <- download(repository, @tokenizer_filename) do - module = - module || - case infer_tokenizer_type(repository) do - {:ok, module} -> module - {:error, error} -> raise "#{error}, please specify the :module option" - end + case get_repo_files(repository) do + {:ok, %{@tokenizer_filename => etag} = repo_files} -> + with {:ok, path} <- download(repository, @tokenizer_filename, etag) do + module = + module || + case infer_tokenizer_type(repository, repo_files) do + {:ok, module} -> + module + + {:error, error} -> + raise ArgumentError, "#{error}, please specify the :module option" + end + + special_tokens_map_result = + if Map.has_key?(repo_files, @tokenizer_special_tokens_filename) do + etag = repo_files[@tokenizer_special_tokens_filename] + + with {:ok, path} <- download(repository, @tokenizer_special_tokens_filename, etag) do + decode_config(path) + end + else + {:ok, %{}} + end + + with {:ok, special_tokens_map} <- special_tokens_map_result do + tokenizer = struct!(module) - special_tokens_map = - with {:ok, path} <- download(repository, @tokenizer_special_tokens_filename), - {:ok, special_tokens_map} <- decode_config(path) do - special_tokens_map - else - _ -> %{} + tokenizer = + HuggingFace.Transformers.Config.load(tokenizer, %{ + "tokenizer_file" => path, + "special_tokens_map" => special_tokens_map + }) + + {:ok, tokenizer} + end end - tokenizer = struct!(module) + {:ok, %{@tokenizer_config_filename => _}} -> + raise ArgumentError, + "expected a Rust-compatible tokenizer.json file, however the repository" <> + " includes tokenizer in a different format. Please refer to Bumblebee" <> + " README to see the possible steps you can take" - tokenizer = - HuggingFace.Transformers.Config.load(tokenizer, %{ - "tokenizer_file" => path, - "special_tokens_map" => special_tokens_map - }) + {:ok, %{}} -> + raise ArgumentError, "no tokenizer found in the given repository" - {:ok, tokenizer} + {:error, message} -> + {:error, message} end end - defp infer_tokenizer_type(repository) do - with {:ok, path} <- download(repository, @config_filename), + defp infer_tokenizer_type(repository, repo_files) do + with {:ok, path} <- download(repository, @config_filename, repo_files[@config_filename]), {:ok, tokenizer_data} <- decode_config(path) do case tokenizer_data do %{"model_type" => model_type} -> @@ -806,46 +881,58 @@ defmodule Bumblebee do repository = normalize_repository!(repository) - with {:ok, path} <- download(repository, @config_filename), - {:ok, spec_data} <- decode_config(path) do - spec_module = opts[:spec_module] + case get_repo_files(repository) do + {:ok, %{@config_filename => etag} = repo_files} -> + with {:ok, path} <- download(repository, @config_filename, etag), + {:ok, spec_data} <- decode_config(path) do + spec_module = opts[:spec_module] - {inferred_module, inference_error} = - case infer_model_type(spec_data) do - {:ok, module, _architecture} -> {module, nil} - {:error, error} -> {nil, error} - end + {inferred_module, inference_error} = + case infer_model_type(spec_data) do + {:ok, module, _architecture} -> {module, nil} + {:error, error} -> {nil, error} + end - spec_module = spec_module || inferred_module + spec_module = spec_module || inferred_module - unless spec_module do - raise "#{inference_error}, please specify the :spec_module option" - end + unless spec_module do + raise ArgumentError, "#{inference_error}, please specify the :spec_module option" + end - generation_data_result = - case download(repository, @generation_filename) do - {:ok, path} -> decode_config(path) - # Fallback to the spec data, since it used to include - # generation attributes - {:error, _} -> {:ok, spec_data} - end + generation_data_result = + if Map.has_key?(repo_files, @generation_filename) do + etag = repo_files[@generation_filename] + + with {:ok, path} <- download(repository, @generation_filename, etag) do + decode_config(path) + end + else + # Fallback to the spec data, since it used to include + # generation attributes + {:ok, spec_data} + end - with {:ok, generation_data} <- generation_data_result do - config = struct!(Bumblebee.Text.GenerationConfig) - config = HuggingFace.Transformers.Config.load(config, generation_data) + with {:ok, generation_data} <- generation_data_result do + config = struct!(Bumblebee.Text.GenerationConfig) + config = HuggingFace.Transformers.Config.load(config, generation_data) - extra_config_module = Bumblebee.Text.Generation.extra_config_module(struct!(spec_module)) + extra_config_module = + Bumblebee.Text.Generation.extra_config_module(struct!(spec_module)) - extra_config = - if extra_config_module do - extra_config = struct!(extra_config_module) - HuggingFace.Transformers.Config.load(extra_config, generation_data) - end + extra_config = + if extra_config_module do + extra_config = struct!(extra_config_module) + HuggingFace.Transformers.Config.load(extra_config, generation_data) + end - config = %{config | extra_config: extra_config} + config = %{config | extra_config: extra_config} - {:ok, config} - end + {:ok, config} + end + end + + {:error, message} -> + {:error, message} end end @@ -918,18 +1005,30 @@ defmodule Bumblebee do opts = Keyword.validate!(opts, [:module]) module = opts[:module] - with {:ok, path} <- download(repository, @scheduler_filename), - {:ok, scheduler_data} <- decode_config(path) do - module = - module || - case infer_scheduler_type(scheduler_data) do - {:ok, module} -> module - {:error, error} -> raise "#{error}, please specify the :module option" - end + case get_repo_files(repository) do + {:ok, %{@scheduler_filename => etag}} -> + with {:ok, path} <- download(repository, @scheduler_filename, etag), + {:ok, scheduler_data} <- decode_config(path) do + module = + module || + case infer_scheduler_type(scheduler_data) do + {:ok, module} -> + module + + {:error, error} -> + raise ArgumentError, "#{error}, please specify the :module option" + end + + scheduler = configure(module) + scheduler = HuggingFace.Transformers.Config.load(scheduler, scheduler_data) + {:ok, scheduler} + end - scheduler = configure(module) - scheduler = HuggingFace.Transformers.Config.load(scheduler, scheduler_data) - {:ok, scheduler} + {:ok, %{}} -> + raise ArgumentError, "no scheduler found in the given repository" + + {:error, message} -> + {:error, message} end end @@ -948,7 +1047,56 @@ defmodule Bumblebee do {:error, "could not infer featurizer type from the configuration"} end - defp download({:local, dir}, filename) do + defp get_repo_files({:local, dir}) do + case File.ls(dir) do + {:ok, filenames} -> + repo_files = + for filename <- filenames, + path = Path.join(dir, filename), + File.regular?(path), + into: %{}, + do: {filename, nil} + + {:ok, repo_files} + + {:error, reason} -> + {:error, "could not read #{dir}, reason: #{:file.format_error(reason)}"} + end + end + + defp get_repo_files({:hf, repository_id, opts}) do + subdir = opts[:subdir] + url = HuggingFace.Hub.file_listing_url(repository_id, subdir, opts[:revision]) + + result = + HuggingFace.Hub.cached_download( + url, + Keyword.take(opts, [:cache_dir, :offline, :auth_token]) + ) + + with {:ok, path} <- result, + {:ok, data} <- decode_config(path) do + repo_files = + for entry <- data, entry["type"] == "file", into: %{} do + path = entry["path"] + + name = + if subdir do + String.replace_leading(path, subdir <> "/", "") + else + path + end + + etag_content = entry["lfs"]["oid"] || entry["oid"] + etag = <> + {name, etag} + end + + {:ok, repo_files} + end + end + + defp download({:local, dir}, filename, _etag) do path = Path.join(dir, filename) if File.exists?(path) do @@ -958,21 +1106,19 @@ defmodule Bumblebee do end end - defp download({:hf, repository_id, opts}, filename) do - revision = opts[:revision] - cache_dir = opts[:cache_dir] - offline = opts[:offline] - auth_token = opts[:auth_token] - subdir = opts[:subdir] - - filename = if subdir, do: subdir <> "/" <> filename, else: filename + defp download({:hf, repository_id, opts}, filename, etag) do + filename = + if subdir = opts[:subdir] do + subdir <> "/" <> filename + else + filename + end - url = HuggingFace.Hub.file_url(repository_id, filename, revision) + url = HuggingFace.Hub.file_url(repository_id, filename, opts[:revision]) - HuggingFace.Hub.cached_download(url, - cache_dir: cache_dir, - offline: offline, - auth_token: auth_token + HuggingFace.Hub.cached_download( + url, + [etag: etag] ++ Keyword.take(opts, [:cache_dir, :offline, :auth_token]) ) end diff --git a/lib/bumblebee/huggingface/hub.ex b/lib/bumblebee/huggingface/hub.ex index 8201dcfa..85028dec 100644 --- a/lib/bumblebee/huggingface/hub.ex +++ b/lib/bumblebee/huggingface/hub.ex @@ -14,6 +14,16 @@ defmodule Bumblebee.HuggingFace.Hub do @huggingface_endpoint <> "/#{repository_id}/resolve/#{revision}/#{filename}" end + @doc """ + Returns a URL to list the contents of a Hugging Face repository. + """ + @spec file_listing_url(String.t(), String.t() | nil, String.t() | nil) :: String.t() + def file_listing_url(repository_id, subdir, revision) do + revision = revision || "main" + path = if(subdir, do: "/" <> subdir) + @huggingface_endpoint <> "/api/models/#{repository_id}/tree/#{revision}#{path}" + end + @doc """ Downloads file from the given URL and returns a path to the file. @@ -33,6 +43,10 @@ defmodule Bumblebee.HuggingFace.Hub do * `:auth_token` - the token to use as HTTP bearer authorization for remote files + * `:etag` - by default a HEAD request is made to fetch the latest + ETag value, however if the value is already known, it can be + passed as an option instead (to skip the extra request) + """ @spec cached_download(String.t(), keyword()) :: {:ok, String.t()} | {:error, String.t()} def cached_download(url, opts \\ []) do @@ -60,10 +74,18 @@ defmodule Bumblebee.HuggingFace.Hub do {:ok, entry_path} _ -> - {:error, "could not find file in local cache and outgoing traffic is disabled"} + {:error, + "could not find file in local cache and outgoing traffic is disabled, url: #{url}"} end else - with {:ok, etag, download_url} <- head_download(url, headers) do + head_result = + if etag = opts[:etag] do + {:ok, etag, url} + else + head_download(url, headers) + end + + with {:ok, etag, download_url} <- head_result do entry_path = Path.join(dir, entry_filename(url, etag)) case load_json(metadata_path) do @@ -71,7 +93,8 @@ defmodule Bumblebee.HuggingFace.Hub do {:ok, entry_path} _ -> - case HTTP.download(download_url, entry_path, headers: headers) |> finish_request() do + case HTTP.download(download_url, entry_path, headers: headers) + |> finish_request(download_url) do :ok -> :ok = store_json(metadata_path, %{"etag" => etag, "url" => url}) {:ok, entry_path} @@ -88,7 +111,8 @@ defmodule Bumblebee.HuggingFace.Hub do defp head_download(url, headers) do with {:ok, response} <- - HTTP.request(:head, url, follow_redirects: false, headers: headers) |> finish_request(), + HTTP.request(:head, url, follow_redirects: false, headers: headers) + |> finish_request(url), {:ok, etag} <- fetch_etag(response) do download_url = if response.status in 300..399 do @@ -101,20 +125,35 @@ defmodule Bumblebee.HuggingFace.Hub do end end - defp finish_request(:ok), do: :ok + defp finish_request(:ok, _url), do: :ok - defp finish_request({:ok, response}) when response.status in 100..399, do: {:ok, response} + defp finish_request({:ok, response}, _url) when response.status in 100..399, do: {:ok, response} - defp finish_request({:ok, response}) do + defp finish_request({:ok, response}, url) do case HTTP.get_header(response, "x-error-code") do - "RepoNotFound" -> {:error, "repository not found"} - "EntryNotFound" -> {:error, "file not found"} - "RevisionNotFound" -> {:error, "revision not found"} - _ -> {:error, "HTTP request failed with status #{response.status}"} + code when code == "RepoNotFound" or response.status == 401 -> + {:error, + "repository not found, url: #{url}. Please make sure you specified" <> + " the correct repository id. If you are trying to access a private" <> + " or gated repository, use an authentication token"} + + "EntryNotFound" -> + {:error, "file not found, url: #{url}"} + + "RevisionNotFound" -> + {:error, "revision not found, url: #{url}"} + + "GatedRepo" -> + {:error, + "cannot access gated repository, url: #{url}. Make sure to request access" <> + " for the repository and use an authentication token"} + + _ -> + {:error, "HTTP request failed with status #{response.status}, url: #{url}"} end end - defp finish_request({:error, reason}) do + defp finish_request({:error, reason}, _url) do {:error, "failed to make an HTTP request, reason: #{inspect(reason)}"} end diff --git a/mix.exs b/mix.exs index c5d6bd18..93a0dde5 100644 --- a/mix.exs +++ b/mix.exs @@ -40,7 +40,7 @@ defmodule Bumblebee.MixProject do # {:torchx, github: "elixir-nx/nx", sparse: "torchx", override: true, only: [:dev, :test]}, {:nx_image, "~> 0.1.0"}, {:unpickler, "~> 0.1.0"}, - {:safetensors, "~> 0.1.1"}, + {:safetensors, "~> 0.1.2"}, {:castore, "~> 0.1 or ~> 1.0"}, {:jason, "~> 1.4.0"}, {:unzip, "0.8.0"}, diff --git a/mix.lock b/mix.lock index 6892a9a9..332ae9a5 100644 --- a/mix.lock +++ b/mix.lock @@ -30,7 +30,7 @@ "progress_bar": {:hex, :progress_bar, "3.0.0", "f54ff038c2ac540cfbb4c2bfe97c75e7116ead044f3c2b10c9f212452194b5cd", [:mix], [{:decimal, "~> 2.0", [hex: :decimal, repo: "hexpm", optional: false]}], "hexpm", "6981c2b25ab24aecc91a2dc46623658e1399c21a2ae24db986b90d678530f2b7"}, "ranch": {:hex, :ranch, "1.8.0", "8c7a100a139fd57f17327b6413e4167ac559fbc04ca7448e9be9057311597a1d", [:make, :rebar3], [], "hexpm", "49fbcfd3682fab1f5d109351b61257676da1a2fdbe295904176d5e521a2ddfe5"}, "rustler_precompiled": {:hex, :rustler_precompiled, "0.6.2", "d2218ba08a43fa331957f30481d00b666664d7e3861431b02bd3f4f30eec8e5b", [:mix], [{:castore, "~> 0.1 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: false]}, {:rustler, "~> 0.23", [hex: :rustler, repo: "hexpm", optional: true]}], "hexpm", "b9048eaed8d7d14a53f758c91865cc616608a438d2595f621f6a4b32a5511709"}, - "safetensors": {:hex, :safetensors, "0.1.1", "b5859a010fb56249ecfba4799d316e96b89152576af2db7657786c55dcf2f5b6", [:mix], [{:jason, "~> 1.4", [hex: :jason, repo: "hexpm", optional: false]}, {:nx, "~> 0.5", [hex: :nx, repo: "hexpm", optional: false]}], "hexpm", "dfbb525bf3debb2e2d90f840728af70da5d55f6caa091cac4d0891a4eb4c52d5"}, + "safetensors": {:hex, :safetensors, "0.1.2", "849434fea20b2ed14b92e74205a925d86039c4ef53efe861e5c7b574c3ba8fa6", [:mix], [{:jason, "~> 1.4", [hex: :jason, repo: "hexpm", optional: false]}, {:nx, "~> 0.5", [hex: :nx, repo: "hexpm", optional: false]}], "hexpm", "298a5c82e34fc3b955464b89c080aa9a2625a47d69148d51113771e19166d4e0"}, "stb_image": {:hex, :stb_image, "0.6.2", "d680a418416b1d778231d1d16151be3474d187e8505e1bd524aa0d08d2de094f", [:make, :mix], [{:cc_precompiler, "~> 0.1.0", [hex: :cc_precompiler, repo: "hexpm", optional: false]}, {:elixir_make, "~> 0.7.0", [hex: :elixir_make, repo: "hexpm", optional: false]}, {:kino, "~> 0.7", [hex: :kino, repo: "hexpm", optional: true]}, {:nx, "~> 0.4", [hex: :nx, repo: "hexpm", optional: true]}], "hexpm", "231ad012f649dd2bd5ef99e9171e814f3235e8f7c45009355789ac4836044a39"}, "telemetry": {:hex, :telemetry, "1.2.1", "68fdfe8d8f05a8428483a97d7aab2f268aaff24b49e0f599faa091f1d4e7f61c", [:rebar3], [], "hexpm", "dad9ce9d8effc621708f99eac538ef1cbe05d6a874dd741de2e689c47feafed5"}, "tokenizers": {:hex, :tokenizers, "0.4.0", "140283ca74a971391ddbd83cd8cbdb9bd03736f37a1b6989b82d245a95e1eb97", [:mix], [{:castore, "~> 0.1 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: false]}, {:rustler, ">= 0.0.0", [hex: :rustler, repo: "hexpm", optional: true]}, {:rustler_precompiled, "~> 0.6", [hex: :rustler_precompiled, repo: "hexpm", optional: false]}], "hexpm", "ef1a9824f5a893cd3b831c0e5b3d72caa250d2ec462035cc6afef6933b13a82e"}, diff --git a/test/bumblebee/huggingface/hub_test.exs b/test/bumblebee/huggingface/hub_test.exs index ef8f81f8..cad31697 100644 --- a/test/bumblebee/huggingface/hub_test.exs +++ b/test/bumblebee/huggingface/hub_test.exs @@ -135,7 +135,7 @@ defmodule Bumblebee.HuggingFace.HubTest do url = url(bypass.port) <> "/file.json" - assert {:error, "HTTP request failed with status 500"} = + assert {:error, "HTTP request failed with status 500, url: " <> _} = Hub.cached_download(url, cache_dir: tmp_dir) end @@ -150,7 +150,8 @@ defmodule Bumblebee.HuggingFace.HubTest do url = url(bypass.port) <> "/file.json" - assert {:error, "repository not found"} = Hub.cached_download(url, cache_dir: tmp_dir) + assert {:error, "repository not found, url: " <> _} = + Hub.cached_download(url, cache_dir: tmp_dir) end @tag :tmp_dir @@ -178,7 +179,8 @@ defmodule Bumblebee.HuggingFace.HubTest do %{bypass: bypass, tmp_dir: tmp_dir} do url = url(bypass.port) <> "/file.json" - assert {:error, "could not find file in local cache and outgoing traffic is disabled"} = + assert {:error, + "could not find file in local cache and outgoing traffic is disabled, url: " <> _} = Hub.cached_download(url, cache_dir: tmp_dir, offline: true) end end