diff --git a/lib/bumblebee/shared.ex b/lib/bumblebee/shared.ex index d5b81f19..36dc135f 100644 --- a/lib/bumblebee/shared.ex +++ b/lib/bumblebee/shared.ex @@ -210,6 +210,14 @@ defmodule Bumblebee.Shared do end end + def validate_string_or_pairs(input) do + case input do + input when is_binary(input) -> {:ok, input} + {left, right} when is_binary(left) and is_binary(right) -> {:ok, input} + _other -> {:error, "expected a string or a pair of strings, got: #{inspect(input)}"} + end + end + @doc """ Validates that the input is a single value and not a batch. """ diff --git a/lib/bumblebee/text.ex b/lib/bumblebee/text.ex index b11f47a7..770a2192 100644 --- a/lib/bumblebee/text.ex +++ b/lib/bumblebee/text.ex @@ -288,7 +288,7 @@ defmodule Bumblebee.Text do defdelegate translation(model_info, tokenizer, generation_config, opts \\ []), to: Bumblebee.Text.Translation - @type text_classification_input :: String.t() + @type text_classification_input :: String.t() | {String.t(), String.t()} @type text_classification_output :: %{predictions: list(text_classification_prediction())} @type text_classification_prediction :: %{score: number(), label: String.t()} diff --git a/lib/bumblebee/text/text_classification.ex b/lib/bumblebee/text/text_classification.ex index 22f0541c..be054ae8 100644 --- a/lib/bumblebee/text/text_classification.ex +++ b/lib/bumblebee/text/text_classification.ex @@ -74,7 +74,7 @@ defmodule Bumblebee.Text.TextClassification do |> Nx.Serving.batch_size(batch_size) |> Nx.Serving.process_options(batch_keys: batch_keys) |> Nx.Serving.client_preprocessing(fn input -> - {texts, multi?} = Shared.validate_serving_input!(input, &Shared.validate_string/1) + {texts, multi?} = Shared.validate_serving_input!(input, &Shared.validate_string_or_pairs/1) inputs = Nx.with_default_backend(Nx.BinaryBackend, fn ->