Skip to content

Commit ae2baf7

Browse files
committed
Clean up
1 parent b1b6ec3 commit ae2baf7

File tree

1 file changed

+5
-205
lines changed

1 file changed

+5
-205
lines changed

lib/bumblebee/text/jina_bert.ex

Lines changed: 5 additions & 205 deletions
Original file line numberDiff line numberDiff line change
@@ -75,39 +75,12 @@ defmodule Bumblebee.Text.JinaBert do
7575
7676
## Architectures
7777
78-
* `:base` - plain BERT without any head on top
78+
* `:base` - plain Jina BERT without any head on top
7979
80-
* `:for_masked_language_modeling` - BERT with a language modeling
80+
* `:for_masked_language_modeling` - Jina BERT with a language modeling
8181
head. The head returns logits for each token in the original
8282
sequence
8383
84-
* `:for_sequence_classification` - BERT with a sequence
85-
classification head. The head returns logits corresponding to
86-
possible classes
87-
88-
* `:for_token_classification` - BERT with a token classification
89-
head. The head returns logits for each token in the original
90-
sequence
91-
92-
* `:for_question_answering` - BERT with a span classification head.
93-
The head returns logits for the span start and end positions
94-
95-
* `:for_multiple_choice` - BERT with a multiple choice prediction
96-
head. Each input in the batch consists of several sequences to
97-
choose from and the model returns logits corresponding to those
98-
choices
99-
100-
* `:for_next_sentence_prediction` - BERT with a next sentence
101-
prediction head. The head returns logits predicting whether the
102-
second sentence is random or in context
103-
104-
* `:for_pre_training` - BERT with both MLM and NSP heads as done
105-
during the pre-training
106-
107-
* `:for_causal_language_modeling` - BERT working as a decoder with
108-
a language modeling head. The head returns logits for each token
109-
in the original sequence
110-
11184
## Inputs
11285
11386
* `"input_ids"` - `{batch_size, sequence_length}`
@@ -135,15 +108,6 @@ defmodule Bumblebee.Text.JinaBert do
135108
Mask to nullify selected heads of the self-attention blocks in
136109
the encoder.
137110
138-
### Exceptions
139-
140-
The `:for_multiple_choice` model accepts groups of sequences, so the
141-
expected sequence shape is `{batch_size, num_choices, sequence_length}`.
142-
143-
The `:for_causal_language_modeling` model is a decoder and accepts
144-
the following additional inputs: `"encoder_hidden_state"`,
145-
`"encoder_attention_mask"`, `"cross_attention_head_mask"`, `"cache"`.
146-
147111
## Global layer options
148112
149113
#{Shared.global_layer_options_doc([:output_hidden_states, :output_attentions])}
@@ -155,6 +119,7 @@ defmodule Bumblebee.Text.JinaBert do
155119
## References
156120
157121
* [BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding](https://arxiv.org/abs/1810.04805)
122+
* [Jina Embeddings 2: 8192-Token General-Purpose Text Embeddings for Long Documents](https://arxiv.org/abs/2310.19923)
158123
159124
"""
160125

@@ -172,14 +137,7 @@ defmodule Bumblebee.Text.JinaBert do
172137
def architectures(),
173138
do: [
174139
:base,
175-
:for_masked_language_modeling,
176-
:for_sequence_classification,
177-
:for_token_classification,
178-
:for_question_answering,
179-
:for_multiple_choice,
180-
:for_next_sentence_prediction,
181-
:for_pre_training,
182-
:for_causal_language_modeling
140+
:for_masked_language_modeling
183141
]
184142

185143
@impl true
@@ -220,159 +178,6 @@ defmodule Bumblebee.Text.JinaBert do
220178
})
221179
end
222180

223-
def model(%__MODULE__{architecture: :for_sequence_classification} = spec) do
224-
inputs = inputs(spec)
225-
outputs = core(inputs, spec)
226-
227-
logits =
228-
outputs.pooled_state
229-
|> Axon.dropout(
230-
rate: classifier_dropout_rate(spec),
231-
name: "sequence_classification_head.dropout"
232-
)
233-
|> Axon.dense(spec.num_labels,
234-
kernel_initializer: kernel_initializer(spec),
235-
name: "sequence_classification_head.output"
236-
)
237-
238-
Layers.output(%{
239-
logits: logits,
240-
hidden_states: outputs.hidden_states,
241-
attentions: outputs.attentions
242-
})
243-
end
244-
245-
def model(%__MODULE__{architecture: :for_token_classification} = spec) do
246-
inputs = inputs(spec)
247-
outputs = core(inputs, spec)
248-
249-
logits =
250-
outputs.hidden_state
251-
|> Axon.dropout(
252-
rate: classifier_dropout_rate(spec),
253-
name: "token_classification_head.dropout"
254-
)
255-
|> Axon.dense(spec.num_labels,
256-
kernel_initializer: kernel_initializer(spec),
257-
name: "token_classification_head.output"
258-
)
259-
260-
Layers.output(%{
261-
logits: logits,
262-
hidden_states: outputs.hidden_states,
263-
attentions: outputs.attentions
264-
})
265-
end
266-
267-
def model(%__MODULE__{architecture: :for_question_answering} = spec) do
268-
inputs = inputs(spec)
269-
outputs = core(inputs, spec)
270-
271-
logits =
272-
Axon.dense(outputs.hidden_state, 2,
273-
kernel_initializer: kernel_initializer(spec),
274-
name: "question_answering_head.output"
275-
)
276-
277-
{start_logits, end_logits} = Layers.split_pair(logits)
278-
279-
Layers.output(%{
280-
start_logits: start_logits,
281-
end_logits: end_logits,
282-
hidden_states: outputs.hidden_states,
283-
attentions: outputs.attentions
284-
})
285-
end
286-
287-
def model(%__MODULE__{architecture: :for_multiple_choice} = spec) do
288-
inputs = inputs(spec, shape: {nil, nil, nil})
289-
290-
group_inputs = ["input_ids", "attention_mask", "token_type_ids", "position_ids"]
291-
292-
flat_inputs =
293-
Enum.reduce(group_inputs, inputs, fn name, inputs ->
294-
Map.update!(inputs, name, &Layers.flatten_leading/1)
295-
end)
296-
297-
outputs = core(flat_inputs, spec)
298-
299-
logits =
300-
outputs.pooled_state
301-
|> Axon.dropout(rate: classifier_dropout_rate(spec), name: "multiple_choice_head.dropout")
302-
|> Axon.dense(1,
303-
kernel_initializer: kernel_initializer(spec),
304-
name: "multiple_choice_head.output"
305-
)
306-
307-
# The final shape depends on the dynamic batch size and number
308-
# of choices, so we do a reshape based on the input shape
309-
logits =
310-
Axon.layer(
311-
fn logits, input_ids, _opts ->
312-
num_choices = Nx.axis_size(input_ids, 1)
313-
Nx.reshape(logits, {:auto, num_choices})
314-
end,
315-
[logits, inputs["input_ids"]]
316-
)
317-
318-
Layers.output(%{
319-
logits: logits,
320-
hidden_states: outputs.hidden_states,
321-
attentions: outputs.attentions
322-
})
323-
end
324-
325-
def model(%__MODULE__{architecture: :for_next_sentence_prediction} = spec) do
326-
inputs = inputs(spec)
327-
outputs = core(inputs, spec)
328-
329-
logits =
330-
Axon.dense(outputs.pooled_state, 2,
331-
kernel_initializer: kernel_initializer(spec),
332-
name: "next_sentence_prediction_head.output"
333-
)
334-
335-
Layers.output(%{
336-
logits: logits,
337-
hidden_states: outputs.hidden_states,
338-
attentions: outputs.attentions
339-
})
340-
end
341-
342-
def model(%__MODULE__{architecture: :for_pre_training} = spec) do
343-
inputs = inputs(spec)
344-
outputs = core(inputs, spec)
345-
346-
lm_logits = language_modeling_head(outputs.hidden_state, spec, name: "language_modeling_head")
347-
348-
nsp_logits =
349-
Axon.dense(outputs.pooled_state, 2,
350-
kernel_initializer: kernel_initializer(spec),
351-
name: "next_sentence_prediction_head.output"
352-
)
353-
354-
Layers.output(%{
355-
language_modeling_logits: lm_logits,
356-
next_sentence_prediction_logits: nsp_logits,
357-
hidden_states: outputs.hidden_states,
358-
attentions: outputs.attentions
359-
})
360-
end
361-
362-
def model(%__MODULE__{architecture: :for_causal_language_modeling} = spec) do
363-
inputs = inputs(spec, decoder?: true)
364-
outputs = core(inputs, spec, decoder?: true)
365-
logits = language_modeling_head(outputs.hidden_state, spec, name: "language_modeling_head")
366-
367-
Layers.output(%{
368-
logits: logits,
369-
hidden_states: outputs.hidden_states,
370-
attentions: outputs.attentions,
371-
cross_attentions: outputs.cross_attentions,
372-
cache: outputs.cache
373-
})
374-
end
375-
376181
@impl true
377182
def init_cache(spec, batch_size, max_length, inputs) do
378183
encoder_sequence_length =
@@ -539,11 +344,10 @@ defmodule Bumblebee.Text.JinaBert do
539344

540345
cross_attention? = decoder? and spec.use_cross_attention
541346

542-
# we build the alibi matrix only once instead of rebuilding
543-
# for this we must use the maximum seqlen
544347
alibi_relative_bias_matrix =
545348
Axon.nx(hidden_state, fn hidden_state ->
546349
{_, seqlen, _} = Nx.shape(hidden_state)
350+
547351
matrix = alibi_matrix(spec.num_attention_heads, spec.max_positions)
548352

549353
matrix[[.., .., 0..(seqlen - 1), 0..(seqlen - 1)]]
@@ -653,10 +457,6 @@ defmodule Bumblebee.Text.JinaBert do
653457
|> Axon.bias(name: join(name, "bias"))
654458
end
655459

656-
defp classifier_dropout_rate(spec) do
657-
spec.classifier_dropout_rate || spec.dropout_rate
658-
end
659-
660460
defp kernel_initializer(spec) do
661461
Axon.Initializers.normal(scale: spec.initializer_scale)
662462
end

0 commit comments

Comments
 (0)