Skip to content

Performance question regarding next token prediction task #77

@HeMuling

Description

@HeMuling

I tried to perform next token prediction task using the pretrained model hyenadna-small-32k-seqlen-hf, and I found the result not so solid. Here' the code I tried:

from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoModelForCausalLM, AutoModel, AutoConfig
from transformers import TrainingArguments, Trainer, logging
from configuration_hyena import HyenaConfig
import torch

# instantiate pretrained model
checkpoint = 'hyenadna-small-32k-seqlen-hf'
max_length = 500
config = AutoConfig.from_pretrained(checkpoint, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(checkpoint, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(checkpoint, trust_remote_code=True, config=config)

seq = 'AGCTACATTGGCC'
tok_seq = tokenizer(seq)['input_ids']
print(tok_seq)
tok_seq = torch.LongTensor(tok_seq).unsqueeze(0)
print(tokenizer.batch_decode(tok_seq))
out = model(tok_seq)
tokenizer.batch_decode(out['logits'][:, :, :].argmax(-1))

and I get:

[7, 9, 8, 10, 7, 8, 7, 10, 10, 9, 9, 8, 8, 1]
['AGCTACATTGGCC[SEP]']

['AAATAAATTGTAAC']

In my understanding, I've set this model to perform next token prediction, therefore if I input a sequence 'AGCTACATTGGCC', the model should return something like 'AGCTACATTGGCC+new_predict_token' (i.e. keep the most of previous bases the same), but the sequence I get differs from what I input a lot. I wonder if there's anything wrong in my understanding or coding.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions