Skip to content

Identical perplexity/loss for instances during inference #649

@LalchandPandia

Description

@LalchandPandia

Hi,
I have fine-tuned OPT-125M using the finetune.py in bf16. But when I load the model and try to calculate the perplexity of each instance, I am getting same perplexities for different instance.
My code is:
def loss_per_sample(logits, labels, flag):
BATCH_SIZE = logits.size(dim=0)
VOCAB_SIZE = logits.size(dim=-1)
labels = labels.to(logits.device)

# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss_fn = torch.nn.CrossEntropyLoss(reduction="none")
loss = loss_fn(shift_logits.view(-1, VOCAB_SIZE), shift_labels.view(-1))
loss_tensor_per_sample = loss.view(BATCH_SIZE, -1)

sample_loss = [torch.sum(tensor)/torch.count_nonzero(tensor)
                if bool(torch.count_nonzero(tensor) > 0) 
                else torch.tensor(0.0).to(tensor.device)
                for tensor in loss_tensor_per_sample]
sample_perplexity = [torch.exp(s).item() for s in sample_loss]

print('sample_perplexity ',sample_perplexity)
sample_likelihood = [torch.exp(-s).item() for s in sample_loss]

return sample_perplexity, sample_likelihood

model_name_or_path = 'facebook/opt-125M'
trust_remote_code = False
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path,revision=None,trust_remote_code=trust_remote_code,use_fast=not True,)
config = AutoConfig.from_pretrained(
model_name_or_path)
model = AutoModelForCausalLM.from_pretrained(
model_name_or_path,
config=config,
trust_remote_code=False,
low_cpu_mem_usage=True,
torch_dtype=torch.bfloat16,
)
train_file = 'sample.json'
data_files = {}
dataset_args = {}
if train_file is not None:
data_files["train"] = train_file
raw_datasets = load_dataset(
"json",
data_files=data_files,
**dataset_args,
)
def encode_with_prompt_completion_format(example, tokenizer, max_seq_length):

messages = example["messages"]

for msg in messages:
    if msg['role'] == 'user':
        question = msg['content']
    if msg['role'] == 'assistant':
        answer = msg['content']
        
if not question.endswith((' ', '\n', '\t')) and not answer.startswith((' ', '\n', '\t')):
    example_text = question + ' ' + answer
else:
    example_text = question + answer
example_text = example_text + tokenizer.eos_token
tokenized_example = tokenizer(example_text, return_tensors='pt', max_length=max_seq_length, truncation=True)
input_ids = tokenized_example.input_ids
labels = input_ids.clone()
tokenized_prompt = tokenizer(question, return_tensors='pt', max_length=max_seq_length, truncation=True)
labels[:, :tokenized_prompt.input_ids.shape[1]] = -100
attention_mask = torch.ones_like(input_ids)
return {
    'input_ids': input_ids.flatten(),
    'labels': labels.flatten(),
    'attention_mask': attention_mask.flatten(),
}

train_dataset = raw_datasets["train"]
preprocessing_num_workers = 8
max_seq_length = 512
train_dataset = train_dataset.map(
partial(encode_with_prompt_completion_format, tokenizer=tokenizer, max_seq_length=max_seq_length),
batched=False,
num_proc=preprocessing_num_workers,
remove_columns=[
name for name in train_dataset.column_names if name not in ["input_ids", "labels", "attention_mask", "ids"]
],
desc="Tokenizing and reformatting instruction data",
)
train_dataset.set_format(type="pt")
train_dataset = train_dataset.filter(lambda example: (example["labels"] != -100).any())
with torch.no_grad():
for step, batch in enumerate(dataloader):
new_batch = {}
new_batch['input_ids'] = batch['input_ids']
new_batch['attention_mask'] = batch['attention_mask']
new_batch['labels'] = batch['labels']
print('eval batch ',batch)
outputs = model(**new_batch, use_cache=False)
print('outputs.loss ',outputs.loss)
logits = outputs.logits
print(logits)
logits_list.append(logits)
labels = batch["labels"]
embedding_size = logits.size(dim=-1)

        probs = torch.nn.functional.softmax(logits, dim=-1)
        # Shift so that tokens < n predict n
        shift_logits = logits[..., :-1, :].contiguous()
        shift_labels = labels[..., 1:].contiguous()
        shift_logits = shift_logits.view(-1, embedding_size)
        shift_labels = shift_labels.view(-1)
        # Enable model parallelism
        shift_labels = shift_labels.to(shift_logits.device)
        batch_perplexities, batch_likelihood  = loss_per_sample(logits, labels)

The question boils down to how should I do inference once trained using bf16

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions