-
Notifications
You must be signed in to change notification settings - Fork 432
Description
Hi,
I have fine-tuned OPT-125M using the finetune.py in bf16. But when I load the model and try to calculate the perplexity of each instance, I am getting same perplexities for different instance.
My code is:
def loss_per_sample(logits, labels, flag):
BATCH_SIZE = logits.size(dim=0)
VOCAB_SIZE = logits.size(dim=-1)
labels = labels.to(logits.device)
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss_fn = torch.nn.CrossEntropyLoss(reduction="none")
loss = loss_fn(shift_logits.view(-1, VOCAB_SIZE), shift_labels.view(-1))
loss_tensor_per_sample = loss.view(BATCH_SIZE, -1)
sample_loss = [torch.sum(tensor)/torch.count_nonzero(tensor)
if bool(torch.count_nonzero(tensor) > 0)
else torch.tensor(0.0).to(tensor.device)
for tensor in loss_tensor_per_sample]
sample_perplexity = [torch.exp(s).item() for s in sample_loss]
print('sample_perplexity ',sample_perplexity)
sample_likelihood = [torch.exp(-s).item() for s in sample_loss]
return sample_perplexity, sample_likelihood
model_name_or_path = 'facebook/opt-125M'
trust_remote_code = False
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path,revision=None,trust_remote_code=trust_remote_code,use_fast=not True,)
config = AutoConfig.from_pretrained(
model_name_or_path)
model = AutoModelForCausalLM.from_pretrained(
model_name_or_path,
config=config,
trust_remote_code=False,
low_cpu_mem_usage=True,
torch_dtype=torch.bfloat16,
)
train_file = 'sample.json'
data_files = {}
dataset_args = {}
if train_file is not None:
data_files["train"] = train_file
raw_datasets = load_dataset(
"json",
data_files=data_files,
**dataset_args,
)
def encode_with_prompt_completion_format(example, tokenizer, max_seq_length):
messages = example["messages"]
for msg in messages:
if msg['role'] == 'user':
question = msg['content']
if msg['role'] == 'assistant':
answer = msg['content']
if not question.endswith((' ', '\n', '\t')) and not answer.startswith((' ', '\n', '\t')):
example_text = question + ' ' + answer
else:
example_text = question + answer
example_text = example_text + tokenizer.eos_token
tokenized_example = tokenizer(example_text, return_tensors='pt', max_length=max_seq_length, truncation=True)
input_ids = tokenized_example.input_ids
labels = input_ids.clone()
tokenized_prompt = tokenizer(question, return_tensors='pt', max_length=max_seq_length, truncation=True)
labels[:, :tokenized_prompt.input_ids.shape[1]] = -100
attention_mask = torch.ones_like(input_ids)
return {
'input_ids': input_ids.flatten(),
'labels': labels.flatten(),
'attention_mask': attention_mask.flatten(),
}
train_dataset = raw_datasets["train"]
preprocessing_num_workers = 8
max_seq_length = 512
train_dataset = train_dataset.map(
partial(encode_with_prompt_completion_format, tokenizer=tokenizer, max_seq_length=max_seq_length),
batched=False,
num_proc=preprocessing_num_workers,
remove_columns=[
name for name in train_dataset.column_names if name not in ["input_ids", "labels", "attention_mask", "ids"]
],
desc="Tokenizing and reformatting instruction data",
)
train_dataset.set_format(type="pt")
train_dataset = train_dataset.filter(lambda example: (example["labels"] != -100).any())
with torch.no_grad():
for step, batch in enumerate(dataloader):
new_batch = {}
new_batch['input_ids'] = batch['input_ids']
new_batch['attention_mask'] = batch['attention_mask']
new_batch['labels'] = batch['labels']
print('eval batch ',batch)
outputs = model(**new_batch, use_cache=False)
print('outputs.loss ',outputs.loss)
logits = outputs.logits
print(logits)
logits_list.append(logits)
labels = batch["labels"]
embedding_size = logits.size(dim=-1)
probs = torch.nn.functional.softmax(logits, dim=-1)
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
shift_logits = shift_logits.view(-1, embedding_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
batch_perplexities, batch_likelihood = loss_per_sample(logits, labels)
The question boils down to how should I do inference once trained using bf16