@@ -75,39 +75,12 @@ defmodule Bumblebee.Text.JinaBert do
75
75
76
76
## Architectures
77
77
78
- * `:base` - plain BERT without any head on top
78
+ * `:base` - plain Jina BERT without any head on top
79
79
80
- * `:for_masked_language_modeling` - BERT with a language modeling
80
+ * `:for_masked_language_modeling` - Jina BERT with a language modeling
81
81
head. The head returns logits for each token in the original
82
82
sequence
83
83
84
- * `:for_sequence_classification` - BERT with a sequence
85
- classification head. The head returns logits corresponding to
86
- possible classes
87
-
88
- * `:for_token_classification` - BERT with a token classification
89
- head. The head returns logits for each token in the original
90
- sequence
91
-
92
- * `:for_question_answering` - BERT with a span classification head.
93
- The head returns logits for the span start and end positions
94
-
95
- * `:for_multiple_choice` - BERT with a multiple choice prediction
96
- head. Each input in the batch consists of several sequences to
97
- choose from and the model returns logits corresponding to those
98
- choices
99
-
100
- * `:for_next_sentence_prediction` - BERT with a next sentence
101
- prediction head. The head returns logits predicting whether the
102
- second sentence is random or in context
103
-
104
- * `:for_pre_training` - BERT with both MLM and NSP heads as done
105
- during the pre-training
106
-
107
- * `:for_causal_language_modeling` - BERT working as a decoder with
108
- a language modeling head. The head returns logits for each token
109
- in the original sequence
110
-
111
84
## Inputs
112
85
113
86
* `"input_ids"` - `{batch_size, sequence_length}`
@@ -135,15 +108,6 @@ defmodule Bumblebee.Text.JinaBert do
135
108
Mask to nullify selected heads of the self-attention blocks in
136
109
the encoder.
137
110
138
- ### Exceptions
139
-
140
- The `:for_multiple_choice` model accepts groups of sequences, so the
141
- expected sequence shape is `{batch_size, num_choices, sequence_length}`.
142
-
143
- The `:for_causal_language_modeling` model is a decoder and accepts
144
- the following additional inputs: `"encoder_hidden_state"`,
145
- `"encoder_attention_mask"`, `"cross_attention_head_mask"`, `"cache"`.
146
-
147
111
## Global layer options
148
112
149
113
#{ Shared . global_layer_options_doc ( [ :output_hidden_states , :output_attentions ] ) }
@@ -155,6 +119,7 @@ defmodule Bumblebee.Text.JinaBert do
155
119
## References
156
120
157
121
* [BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding](https://arxiv.org/abs/1810.04805)
122
+ * [Jina Embeddings 2: 8192-Token General-Purpose Text Embeddings for Long Documents](https://arxiv.org/abs/2310.19923)
158
123
159
124
"""
160
125
@@ -172,14 +137,7 @@ defmodule Bumblebee.Text.JinaBert do
172
137
def architectures ( ) ,
173
138
do: [
174
139
:base ,
175
- :for_masked_language_modeling ,
176
- :for_sequence_classification ,
177
- :for_token_classification ,
178
- :for_question_answering ,
179
- :for_multiple_choice ,
180
- :for_next_sentence_prediction ,
181
- :for_pre_training ,
182
- :for_causal_language_modeling
140
+ :for_masked_language_modeling
183
141
]
184
142
185
143
@ impl true
@@ -220,159 +178,6 @@ defmodule Bumblebee.Text.JinaBert do
220
178
} )
221
179
end
222
180
223
- def model ( % __MODULE__ { architecture: :for_sequence_classification } = spec ) do
224
- inputs = inputs ( spec )
225
- outputs = core ( inputs , spec )
226
-
227
- logits =
228
- outputs . pooled_state
229
- |> Axon . dropout (
230
- rate: classifier_dropout_rate ( spec ) ,
231
- name: "sequence_classification_head.dropout"
232
- )
233
- |> Axon . dense ( spec . num_labels ,
234
- kernel_initializer: kernel_initializer ( spec ) ,
235
- name: "sequence_classification_head.output"
236
- )
237
-
238
- Layers . output ( % {
239
- logits: logits ,
240
- hidden_states: outputs . hidden_states ,
241
- attentions: outputs . attentions
242
- } )
243
- end
244
-
245
- def model ( % __MODULE__ { architecture: :for_token_classification } = spec ) do
246
- inputs = inputs ( spec )
247
- outputs = core ( inputs , spec )
248
-
249
- logits =
250
- outputs . hidden_state
251
- |> Axon . dropout (
252
- rate: classifier_dropout_rate ( spec ) ,
253
- name: "token_classification_head.dropout"
254
- )
255
- |> Axon . dense ( spec . num_labels ,
256
- kernel_initializer: kernel_initializer ( spec ) ,
257
- name: "token_classification_head.output"
258
- )
259
-
260
- Layers . output ( % {
261
- logits: logits ,
262
- hidden_states: outputs . hidden_states ,
263
- attentions: outputs . attentions
264
- } )
265
- end
266
-
267
- def model ( % __MODULE__ { architecture: :for_question_answering } = spec ) do
268
- inputs = inputs ( spec )
269
- outputs = core ( inputs , spec )
270
-
271
- logits =
272
- Axon . dense ( outputs . hidden_state , 2 ,
273
- kernel_initializer: kernel_initializer ( spec ) ,
274
- name: "question_answering_head.output"
275
- )
276
-
277
- { start_logits , end_logits } = Layers . split_pair ( logits )
278
-
279
- Layers . output ( % {
280
- start_logits: start_logits ,
281
- end_logits: end_logits ,
282
- hidden_states: outputs . hidden_states ,
283
- attentions: outputs . attentions
284
- } )
285
- end
286
-
287
- def model ( % __MODULE__ { architecture: :for_multiple_choice } = spec ) do
288
- inputs = inputs ( spec , shape: { nil , nil , nil } )
289
-
290
- group_inputs = [ "input_ids" , "attention_mask" , "token_type_ids" , "position_ids" ]
291
-
292
- flat_inputs =
293
- Enum . reduce ( group_inputs , inputs , fn name , inputs ->
294
- Map . update! ( inputs , name , & Layers . flatten_leading / 1 )
295
- end )
296
-
297
- outputs = core ( flat_inputs , spec )
298
-
299
- logits =
300
- outputs . pooled_state
301
- |> Axon . dropout ( rate: classifier_dropout_rate ( spec ) , name: "multiple_choice_head.dropout" )
302
- |> Axon . dense ( 1 ,
303
- kernel_initializer: kernel_initializer ( spec ) ,
304
- name: "multiple_choice_head.output"
305
- )
306
-
307
- # The final shape depends on the dynamic batch size and number
308
- # of choices, so we do a reshape based on the input shape
309
- logits =
310
- Axon . layer (
311
- fn logits , input_ids , _opts ->
312
- num_choices = Nx . axis_size ( input_ids , 1 )
313
- Nx . reshape ( logits , { :auto , num_choices } )
314
- end ,
315
- [ logits , inputs [ "input_ids" ] ]
316
- )
317
-
318
- Layers . output ( % {
319
- logits: logits ,
320
- hidden_states: outputs . hidden_states ,
321
- attentions: outputs . attentions
322
- } )
323
- end
324
-
325
- def model ( % __MODULE__ { architecture: :for_next_sentence_prediction } = spec ) do
326
- inputs = inputs ( spec )
327
- outputs = core ( inputs , spec )
328
-
329
- logits =
330
- Axon . dense ( outputs . pooled_state , 2 ,
331
- kernel_initializer: kernel_initializer ( spec ) ,
332
- name: "next_sentence_prediction_head.output"
333
- )
334
-
335
- Layers . output ( % {
336
- logits: logits ,
337
- hidden_states: outputs . hidden_states ,
338
- attentions: outputs . attentions
339
- } )
340
- end
341
-
342
- def model ( % __MODULE__ { architecture: :for_pre_training } = spec ) do
343
- inputs = inputs ( spec )
344
- outputs = core ( inputs , spec )
345
-
346
- lm_logits = language_modeling_head ( outputs . hidden_state , spec , name: "language_modeling_head" )
347
-
348
- nsp_logits =
349
- Axon . dense ( outputs . pooled_state , 2 ,
350
- kernel_initializer: kernel_initializer ( spec ) ,
351
- name: "next_sentence_prediction_head.output"
352
- )
353
-
354
- Layers . output ( % {
355
- language_modeling_logits: lm_logits ,
356
- next_sentence_prediction_logits: nsp_logits ,
357
- hidden_states: outputs . hidden_states ,
358
- attentions: outputs . attentions
359
- } )
360
- end
361
-
362
- def model ( % __MODULE__ { architecture: :for_causal_language_modeling } = spec ) do
363
- inputs = inputs ( spec , decoder?: true )
364
- outputs = core ( inputs , spec , decoder?: true )
365
- logits = language_modeling_head ( outputs . hidden_state , spec , name: "language_modeling_head" )
366
-
367
- Layers . output ( % {
368
- logits: logits ,
369
- hidden_states: outputs . hidden_states ,
370
- attentions: outputs . attentions ,
371
- cross_attentions: outputs . cross_attentions ,
372
- cache: outputs . cache
373
- } )
374
- end
375
-
376
181
@ impl true
377
182
def init_cache ( spec , batch_size , max_length , inputs ) do
378
183
encoder_sequence_length =
@@ -539,11 +344,10 @@ defmodule Bumblebee.Text.JinaBert do
539
344
540
345
cross_attention? = decoder? and spec . use_cross_attention
541
346
542
- # we build the alibi matrix only once instead of rebuilding
543
- # for this we must use the maximum seqlen
544
347
alibi_relative_bias_matrix =
545
348
Axon . nx ( hidden_state , fn hidden_state ->
546
349
{ _ , seqlen , _ } = Nx . shape ( hidden_state )
350
+
547
351
matrix = alibi_matrix ( spec . num_attention_heads , spec . max_positions )
548
352
549
353
matrix [ [ .. , .. , 0 .. ( seqlen - 1 ) , 0 .. ( seqlen - 1 ) ] ]
@@ -653,10 +457,6 @@ defmodule Bumblebee.Text.JinaBert do
653
457
|> Axon . bias ( name: join ( name , "bias" ) )
654
458
end
655
459
656
- defp classifier_dropout_rate ( spec ) do
657
- spec . classifier_dropout_rate || spec . dropout_rate
658
- end
659
-
660
460
defp kernel_initializer ( spec ) do
661
461
Axon.Initializers . normal ( scale: spec . initializer_scale )
662
462
end
0 commit comments