|
| 1 | +defmodule Bumblebee.Text.JinaBertTest do |
| 2 | + use ExUnit.Case, async: true |
| 3 | + |
| 4 | + import Bumblebee.TestHelpers |
| 5 | + |
| 6 | + @moduletag model_test_tags() |
| 7 | + |
| 8 | + @tag slow: true |
| 9 | + test "jina-embeddings-v2-small-en" do |
| 10 | + repo = {:hf, "jinaai/jina-embeddings-v2-small-en"} |
| 11 | + |
| 12 | + {:ok, %{model: model, params: params, spec: _spec}} = |
| 13 | + Bumblebee.load_model(repo, |
| 14 | + params_filename: "model.safetensors", |
| 15 | + spec_overrides: [architecture: :base] |
| 16 | + ) |
| 17 | + |
| 18 | + inputs = %{ |
| 19 | + "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]), |
| 20 | + "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]) |
| 21 | + } |
| 22 | + |
| 23 | + outputs = Axon.predict(model, params, inputs) |
| 24 | + |
| 25 | + assert Nx.all_close( |
| 26 | + outputs.hidden_state[[.., 1..3, 1..3]], |
| 27 | + Nx.tensor([ |
| 28 | + [-0.1346, 0.1457, 0.5572], |
| 29 | + [-0.1383, 0.1412, 0.5643], |
| 30 | + [-0.1125, 0.1354, 0.5599] |
| 31 | + ]) |
| 32 | + ) |
| 33 | + end |
| 34 | + |
| 35 | + @tag :skip |
| 36 | + test ":base" do |
| 37 | + repo = {:hf, "doesnotexist/tiny-random-JinaBert"} |
| 38 | + |
| 39 | + assert {:ok, %{model: model, params: params, spec: spec}} = |
| 40 | + Bumblebee.load_model(repo) |
| 41 | + |
| 42 | + assert %Bumblebee.Text.JinaBert{architecture: :base} = spec |
| 43 | + |
| 44 | + inputs = %{ |
| 45 | + "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]), |
| 46 | + "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]) |
| 47 | + } |
| 48 | + |
| 49 | + outputs = Axon.predict(model, params, inputs) |
| 50 | + |
| 51 | + assert Nx.shape(outputs.hidden_state) == {1, 10, 32} |
| 52 | + |
| 53 | + assert_all_close( |
| 54 | + outputs.hidden_state[[.., 1..3, 1..3]], |
| 55 | + Nx.tensor([ |
| 56 | + [[-0.2331, 1.7817, 1.1736], [-1.1001, 1.3922, -0.3391], [0.0408, 0.8677, -0.0779]] |
| 57 | + ]) |
| 58 | + ) |
| 59 | + end |
| 60 | + |
| 61 | + @tag :skip |
| 62 | + test ":for_masked_language_modeling" do |
| 63 | + repo = {:hf, "doesnotexist/tiny-random-JinaBert"} |
| 64 | + |
| 65 | + assert {:ok, %{model: model, params: params, spec: spec}} = |
| 66 | + Bumblebee.load_model(repo) |
| 67 | + |
| 68 | + assert %Bumblebee.Text.Bert{architecture: :for_masked_language_modeling} = spec |
| 69 | + |
| 70 | + inputs = %{ |
| 71 | + "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]), |
| 72 | + "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]) |
| 73 | + } |
| 74 | + |
| 75 | + outputs = Axon.predict(model, params, inputs) |
| 76 | + |
| 77 | + assert Nx.shape(outputs.logits) == {1, 10, 1124} |
| 78 | + |
| 79 | + assert_all_close( |
| 80 | + outputs.logits[[.., 1..3, 1..3]], |
| 81 | + Nx.tensor([[[-0.0127, 0.0508, 0.0904], [0.1151, 0.1189, 0.0922], [0.0089, 0.1132, -0.2470]]]) |
| 82 | + ) |
| 83 | + end |
| 84 | +end |
0 commit comments