Skip to content

Commit

Permalink
Clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
joelpaulkoch committed Nov 8, 2024
1 parent b1b6ec3 commit ae2baf7
Showing 1 changed file with 5 additions and 205 deletions.
210 changes: 5 additions & 205 deletions lib/bumblebee/text/jina_bert.ex
Original file line number Diff line number Diff line change
Expand Up @@ -75,39 +75,12 @@ defmodule Bumblebee.Text.JinaBert do
## Architectures
* `:base` - plain BERT without any head on top
* `:base` - plain Jina BERT without any head on top
* `:for_masked_language_modeling` - BERT with a language modeling
* `:for_masked_language_modeling` - Jina BERT with a language modeling
head. The head returns logits for each token in the original
sequence
* `:for_sequence_classification` - BERT with a sequence
classification head. The head returns logits corresponding to
possible classes
* `:for_token_classification` - BERT with a token classification
head. The head returns logits for each token in the original
sequence
* `:for_question_answering` - BERT with a span classification head.
The head returns logits for the span start and end positions
* `:for_multiple_choice` - BERT with a multiple choice prediction
head. Each input in the batch consists of several sequences to
choose from and the model returns logits corresponding to those
choices
* `:for_next_sentence_prediction` - BERT with a next sentence
prediction head. The head returns logits predicting whether the
second sentence is random or in context
* `:for_pre_training` - BERT with both MLM and NSP heads as done
during the pre-training
* `:for_causal_language_modeling` - BERT working as a decoder with
a language modeling head. The head returns logits for each token
in the original sequence
## Inputs
* `"input_ids"` - `{batch_size, sequence_length}`
Expand Down Expand Up @@ -135,15 +108,6 @@ defmodule Bumblebee.Text.JinaBert do
Mask to nullify selected heads of the self-attention blocks in
the encoder.
### Exceptions
The `:for_multiple_choice` model accepts groups of sequences, so the
expected sequence shape is `{batch_size, num_choices, sequence_length}`.
The `:for_causal_language_modeling` model is a decoder and accepts
the following additional inputs: `"encoder_hidden_state"`,
`"encoder_attention_mask"`, `"cross_attention_head_mask"`, `"cache"`.
## Global layer options
#{Shared.global_layer_options_doc([:output_hidden_states, :output_attentions])}
Expand All @@ -155,6 +119,7 @@ defmodule Bumblebee.Text.JinaBert do
## References
* [BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding](https://arxiv.org/abs/1810.04805)
* [Jina Embeddings 2: 8192-Token General-Purpose Text Embeddings for Long Documents](https://arxiv.org/abs/2310.19923)
"""

Expand All @@ -172,14 +137,7 @@ defmodule Bumblebee.Text.JinaBert do
def architectures(),
do: [
:base,
:for_masked_language_modeling,
:for_sequence_classification,
:for_token_classification,
:for_question_answering,
:for_multiple_choice,
:for_next_sentence_prediction,
:for_pre_training,
:for_causal_language_modeling
:for_masked_language_modeling
]

@impl true
Expand Down Expand Up @@ -220,159 +178,6 @@ defmodule Bumblebee.Text.JinaBert do
})
end

def model(%__MODULE__{architecture: :for_sequence_classification} = spec) do
inputs = inputs(spec)
outputs = core(inputs, spec)

logits =
outputs.pooled_state
|> Axon.dropout(
rate: classifier_dropout_rate(spec),
name: "sequence_classification_head.dropout"
)
|> Axon.dense(spec.num_labels,
kernel_initializer: kernel_initializer(spec),
name: "sequence_classification_head.output"
)

Layers.output(%{
logits: logits,
hidden_states: outputs.hidden_states,
attentions: outputs.attentions
})
end

def model(%__MODULE__{architecture: :for_token_classification} = spec) do
inputs = inputs(spec)
outputs = core(inputs, spec)

logits =
outputs.hidden_state
|> Axon.dropout(
rate: classifier_dropout_rate(spec),
name: "token_classification_head.dropout"
)
|> Axon.dense(spec.num_labels,
kernel_initializer: kernel_initializer(spec),
name: "token_classification_head.output"
)

Layers.output(%{
logits: logits,
hidden_states: outputs.hidden_states,
attentions: outputs.attentions
})
end

def model(%__MODULE__{architecture: :for_question_answering} = spec) do
inputs = inputs(spec)
outputs = core(inputs, spec)

logits =
Axon.dense(outputs.hidden_state, 2,
kernel_initializer: kernel_initializer(spec),
name: "question_answering_head.output"
)

{start_logits, end_logits} = Layers.split_pair(logits)

Layers.output(%{
start_logits: start_logits,
end_logits: end_logits,
hidden_states: outputs.hidden_states,
attentions: outputs.attentions
})
end

def model(%__MODULE__{architecture: :for_multiple_choice} = spec) do
inputs = inputs(spec, shape: {nil, nil, nil})

group_inputs = ["input_ids", "attention_mask", "token_type_ids", "position_ids"]

flat_inputs =
Enum.reduce(group_inputs, inputs, fn name, inputs ->
Map.update!(inputs, name, &Layers.flatten_leading/1)
end)

outputs = core(flat_inputs, spec)

logits =
outputs.pooled_state
|> Axon.dropout(rate: classifier_dropout_rate(spec), name: "multiple_choice_head.dropout")
|> Axon.dense(1,
kernel_initializer: kernel_initializer(spec),
name: "multiple_choice_head.output"
)

# The final shape depends on the dynamic batch size and number
# of choices, so we do a reshape based on the input shape
logits =
Axon.layer(
fn logits, input_ids, _opts ->
num_choices = Nx.axis_size(input_ids, 1)
Nx.reshape(logits, {:auto, num_choices})
end,
[logits, inputs["input_ids"]]
)

Layers.output(%{
logits: logits,
hidden_states: outputs.hidden_states,
attentions: outputs.attentions
})
end

def model(%__MODULE__{architecture: :for_next_sentence_prediction} = spec) do
inputs = inputs(spec)
outputs = core(inputs, spec)

logits =
Axon.dense(outputs.pooled_state, 2,
kernel_initializer: kernel_initializer(spec),
name: "next_sentence_prediction_head.output"
)

Layers.output(%{
logits: logits,
hidden_states: outputs.hidden_states,
attentions: outputs.attentions
})
end

def model(%__MODULE__{architecture: :for_pre_training} = spec) do
inputs = inputs(spec)
outputs = core(inputs, spec)

lm_logits = language_modeling_head(outputs.hidden_state, spec, name: "language_modeling_head")

nsp_logits =
Axon.dense(outputs.pooled_state, 2,
kernel_initializer: kernel_initializer(spec),
name: "next_sentence_prediction_head.output"
)

Layers.output(%{
language_modeling_logits: lm_logits,
next_sentence_prediction_logits: nsp_logits,
hidden_states: outputs.hidden_states,
attentions: outputs.attentions
})
end

def model(%__MODULE__{architecture: :for_causal_language_modeling} = spec) do
inputs = inputs(spec, decoder?: true)
outputs = core(inputs, spec, decoder?: true)
logits = language_modeling_head(outputs.hidden_state, spec, name: "language_modeling_head")

Layers.output(%{
logits: logits,
hidden_states: outputs.hidden_states,
attentions: outputs.attentions,
cross_attentions: outputs.cross_attentions,
cache: outputs.cache
})
end

@impl true
def init_cache(spec, batch_size, max_length, inputs) do
encoder_sequence_length =
Expand Down Expand Up @@ -539,11 +344,10 @@ defmodule Bumblebee.Text.JinaBert do

cross_attention? = decoder? and spec.use_cross_attention

# we build the alibi matrix only once instead of rebuilding
# for this we must use the maximum seqlen
alibi_relative_bias_matrix =
Axon.nx(hidden_state, fn hidden_state ->
{_, seqlen, _} = Nx.shape(hidden_state)

matrix = alibi_matrix(spec.num_attention_heads, spec.max_positions)

matrix[[.., .., 0..(seqlen - 1), 0..(seqlen - 1)]]
Expand Down Expand Up @@ -653,10 +457,6 @@ defmodule Bumblebee.Text.JinaBert do
|> Axon.bias(name: join(name, "bias"))
end

defp classifier_dropout_rate(spec) do
spec.classifier_dropout_rate || spec.dropout_rate
end

defp kernel_initializer(spec) do
Axon.Initializers.normal(scale: spec.initializer_scale)
end
Expand Down

0 comments on commit ae2baf7

Please sign in to comment.