Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Not able to reproduce the scores using provided checkpoint on NLG tasks #138

Open
ylli0218 opened this issue Oct 12, 2023 · 5 comments
Open

Comments

@ylli0218
Copy link

Hi, I was able to reproduce the GLUE benchmark results but not the NLG task.

For NLG tasks, I downloaded the checkpoint for GPT2-M and follow the step 2,3,4 in the instructions

https://github.com/microsoft/LoRA/tree/main/examples/NLG

However, the scores were either extremely low, or there were errors during evaluations.

For e2e task I got:

SCORES:
BLEU: 0.0000
NIST: 0.0196
METEOR: 0.0034
ROUGE_L: 0.0072
CIDEr: 0.0000

For WebNLG and DART, I see error:
Error: test and reference not same length
ERROR ON COMPUTING METEOR. MAKE SURE YOU HAVE JAVA INSTALLED GLOBALLY ON YOUR MACHINE.
I do have java installed.
And the other scores were also low:
BLEU BLEU NLTK METEOR chrF++ TER BERT-SCORE P BERT-SCORE R BERT-SCORE F1 BLEURT
0 0 -1 0.11 1.96 0 0 0 -1

Any suggestions? Thank you.

@1181000705
Copy link

Hello, I have the same problem, have you solved it please?

@ylli0218
Copy link
Author

@1181000705 Yes, partially. The low NLG scores are because somehow the original script did not load the pretrained backbone models. So make sure you are loading the backbone models weights correctly, together with the lora weights, you will get the published scores.

As for the
Error: test and reference not same length
ERROR ON COMPUTING METEOR. MAKE SURE YOU HAVE JAVA INSTALLED GLOBALLY ON YOUR MACHINE.
No idea yet.

@edwardjhu
Copy link
Collaborator

The low NLG scores are because somehow the original script did not load the pretrained backbone models.

I see. Would be great if could make a PR to fix that!

ERROR ON COMPUTING METEOR. MAKE SURE YOU HAVE JAVA INSTALLED GLOBALLY ON YOUR MACHINE.

Seems like a dependency issue.

@EdwardIX
Copy link

EdwardIX commented Dec 20, 2023

The low NLG scores are because somehow the original script did not load the pretrained backbone models.

The error is here:
in the finetuning script examples/NLG/src/gpt2_ft.py line 256

if args.rank == 0:
        model_path = os.path.join(args.work_dir, f'model.{train_step}.pt')
        print('saving checkpoint', model_path)
        torch.save({'model_state_dict': model.state_dict()}, model_path) 
    distributed_sync(args)
    return train_step

The saved checkpoint contains the model.state_dict(), where the backbone gpt2 parameters are started with a prefix transformer.

However, in examples/NLG/src/model.py, line 448:

self.transformer.load_state_dict(state_dict, strict=False)

This transformer.load_state_dict method should accept parameters WITHOUT the prefix transformer. So when loading finetuned model from checkpoint, only LoRA weights are loaded

I suggest to slightly modify the weight loading function like this:

        for key in state_dict_tmp:
            new_key = None
            if key.endswith(".g"):
                new_key = key[:-2] + ".weight"
            elif key.endswith(".b"):
                new_key = key[:-2] + ".bias"
            elif key.endswith(".w"):
                new_key = key[:-2] + ".weight"
            
            if key.startswith("module.transformer."):
                new_key = key[len("module.transformer."):]

            if key.startswith("transformer."):             # Add 2 lines here to delete the prefix "transformer."
                new_key = key[len("transformer."):]

            if new_key:
                old_keys.append(key)
                new_keys.append(new_key)

@RayCyder
Copy link

ERROR ON COMPUTING METEOR. MAKE SURE YOU HAVE JAVA INSTALLED GLOBALLY ON YOUR MACHINE.

same error with webnlg, how to fix it

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants