Skip to content

Commit ae6024c

Browse files
committed
Add jina_bert_test.exs
1 parent ae2baf7 commit ae6024c

File tree

1 file changed

+84
-0
lines changed

1 file changed

+84
-0
lines changed
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
defmodule Bumblebee.Text.JinaBertTest do
2+
use ExUnit.Case, async: true
3+
4+
import Bumblebee.TestHelpers
5+
6+
@moduletag model_test_tags()
7+
8+
@tag slow: true
9+
test "jina-embeddings-v2-small-en" do
10+
repo = {:hf, "jinaai/jina-embeddings-v2-small-en"}
11+
12+
{:ok, %{model: model, params: params, spec: _spec}} =
13+
Bumblebee.load_model(repo,
14+
params_filename: "model.safetensors",
15+
spec_overrides: [architecture: :base]
16+
)
17+
18+
inputs = %{
19+
"input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]),
20+
"attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]])
21+
}
22+
23+
outputs = Axon.predict(model, params, inputs)
24+
25+
assert Nx.all_close(
26+
outputs.hidden_state[[.., 1..3, 1..3]],
27+
Nx.tensor([
28+
[-0.1346, 0.1457, 0.5572],
29+
[-0.1383, 0.1412, 0.5643],
30+
[-0.1125, 0.1354, 0.5599]
31+
])
32+
)
33+
end
34+
35+
@tag :skip
36+
test ":base" do
37+
repo = {:hf, "doesnotexist/tiny-random-JinaBert"}
38+
39+
assert {:ok, %{model: model, params: params, spec: spec}} =
40+
Bumblebee.load_model(repo)
41+
42+
assert %Bumblebee.Text.JinaBert{architecture: :base} = spec
43+
44+
inputs = %{
45+
"input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]),
46+
"attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]])
47+
}
48+
49+
outputs = Axon.predict(model, params, inputs)
50+
51+
assert Nx.shape(outputs.hidden_state) == {1, 10, 32}
52+
53+
assert_all_close(
54+
outputs.hidden_state[[.., 1..3, 1..3]],
55+
Nx.tensor([
56+
[[-0.2331, 1.7817, 1.1736], [-1.1001, 1.3922, -0.3391], [0.0408, 0.8677, -0.0779]]
57+
])
58+
)
59+
end
60+
61+
@tag :skip
62+
test ":for_masked_language_modeling" do
63+
repo = {:hf, "doesnotexist/tiny-random-JinaBert"}
64+
65+
assert {:ok, %{model: model, params: params, spec: spec}} =
66+
Bumblebee.load_model(repo)
67+
68+
assert %Bumblebee.Text.Bert{architecture: :for_masked_language_modeling} = spec
69+
70+
inputs = %{
71+
"input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]),
72+
"attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]])
73+
}
74+
75+
outputs = Axon.predict(model, params, inputs)
76+
77+
assert Nx.shape(outputs.logits) == {1, 10, 1124}
78+
79+
assert_all_close(
80+
outputs.logits[[.., 1..3, 1..3]],
81+
Nx.tensor([[[-0.0127, 0.0508, 0.0904], [0.1151, 0.1189, 0.0922], [0.0089, 0.1132, -0.2470]]])
82+
)
83+
end
84+
end

0 commit comments

Comments
 (0)