|
6 | 6 | from .utils import DEBUG_MESSAGE
|
7 | 7 |
|
8 | 8 | import zipfile
|
| 9 | +from random import shuffle, seed |
9 | 10 |
|
10 | 11 |
|
| 12 | +RANDOM_SEED = 0 |
| 13 | + |
11 | 14 | VQA_SYSTEM_PROMPT = json.dumps({
|
12 | 15 | 'task': 'Answer the question presented to you truthfully.',
|
13 | 16 | 'requirements': [
|
@@ -42,6 +45,10 @@ class MOAT(ImageBaseDataset):
|
42 | 45 | 'MOAT': '803b5a176a5b01aa1b8094fae73532a2',
|
43 | 46 | }
|
44 | 47 |
|
| 48 | + def __init__(self, dataset, **kwargs): |
| 49 | + super().__init__(dataset, **kwargs) |
| 50 | + seed(RANDOM_SEED) # seed the random number generator to ensure reproducibility |
| 51 | + |
45 | 52 | def post_build(self, dataset):
|
46 | 53 | assert dataset == "MOAT", f"Wrong dataset name {dataset}"
|
47 | 54 | ROOT = LMUDataRoot()
|
@@ -70,16 +77,19 @@ def build_prompt(self, line):
|
70 | 77 | question, choices, images, outside_knowledge_text, outside_knowledge_images = line['question'], line['choices'], line['images'], line['outside_knowledge_text'], line['outside_knowledge_images'] # noqa: E501
|
71 | 78 | choices, images, outside_knowledge_images = toliststr(choices), toliststr(images), toliststr(outside_knowledge_images) # noqa: E501
|
72 | 79 |
|
| 80 | + if len(choices): |
| 81 | + shuffle(choices) # shuffle the choices to avoid bias |
| 82 | + question += f'\nThe choices are: {choices}' |
73 | 83 | msgs = [
|
74 | 84 | {
|
75 | 85 | 'type': 'text',
|
76 |
| - 'value': VQA_SYSTEM_PROMPT + '\n' + question + (f'\nThe choices are: {choices}' if choices else ''), |
| 86 | + 'value': VQA_SYSTEM_PROMPT + '\n' + question, |
77 | 87 | },
|
78 | 88 | ]
|
79 | 89 | for img in images:
|
80 | 90 | msgs.append({'type': 'image', 'value': osp.join(self.img_root, img)})
|
81 | 91 | if not pd.isna(outside_knowledge_text):
|
82 |
| - msgs.append({'type': 'text', 'value': outside_knowledge_text}) |
| 92 | + msgs.append({'type': 'text', 'value': 'Hint:\n' + outside_knowledge_text}) |
83 | 93 | for img in outside_knowledge_images:
|
84 | 94 | msgs.append({'type': 'image', 'value': osp.join(self.img_root, img)})
|
85 | 95 | return msgs
|
|
0 commit comments