You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I'm surprised that the following doesn't appear to work. I would have thought that we could vmap a Flax + Huggingface model across an array of "tokens".
import jax
import jax.numpy as jnp
from transformers import FlaxAutoModelForMaskedLM, AutoTokenizer
from transformers import FlaxAutoModelForSequenceClassification, AutoConfig
from functools import partial
model_checkpoint = "bert-base-cased"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
config = AutoConfig.from_pretrained(model_checkpoint, num_labels=1)
model = FlaxAutoModelForSequenceClassification.from_pretrained(model_checkpoint, config=config)
def process_text(batch):
tokens = tokenizer(batch, return_tensors="jax", padding=True)
tokens = {key: jnp.array(v) for key, v in tokens.items()}
return tokens
data = ["BERT revolutionized natural language processing by using bidirectional training to understand context better.",
"Developed by Google, it can interpret the meaning of words based on surrounding text, improving tasks like question answering and sentiment analysis.",
"BERT's architecture allows for fine-tuning on specific tasks without substantial modifications, making it versatile for a wide range of applications.",
"LLaMA, or Large Language Model, is known for its efficiency and scalability, often utilized in various natural language processing tasks.",
"It stands out for its ability to generate human-like text, perform translations, and understand context within large datasets.",
"LLaMA models are designed to be adaptable, supporting tasks ranging from simple text generation to complex question answering with less computational cost than some alternatives.",
"GEMM is a core algorithm in high-performance computing, critical for tasks in linear algebra, machine learning, and engineering.",
"It multiplies two matrices and adds the result to a third matrix, serving as a foundational operation in many computational applications.",
"Due to its importance, optimizations of GEMM are central to accelerating machine learning models and scientific computations on various hardware platforms."]
tokens = process_text(data)
def rfp(params, tokens):
return model(tokens['input_ids'], tokens['token_type_ids'], tokens['attention_mask'], params=params).logits
jax.vmap(rfp, in_axes=(None, {'input_ids': 0, 'token_type_ids': 0, 'attention_mask': 0}))(model.params, tokens)
TypeError: cannot reshape array of shape (32, 768) (size 24576) into shape (32, 768, 12, 64) (size 18874368)
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
-
I'm surprised that the following doesn't appear to work. I would have thought that we could
vmap
a Flax + Huggingface model across an array of "tokens".TypeError: cannot reshape array of shape (32, 768) (size 24576) into shape (32, 768, 12, 64) (size 18874368)
Beta Was this translation helpful? Give feedback.
All reactions