-
Notifications
You must be signed in to change notification settings - Fork 1
/
run-fallacy-classification-without-premise.py
155 lines (123 loc) · 6.04 KB
/
run-fallacy-classification-without-premise.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
"""run-fallacy-classification-without-premise
Usage:
run-fallacy-classification-without-premise.py llama <prompt-template> <model-size> [<seed>] [--dev] [--8bit]
run-fallacy-classification-without-premise.py gpt4 <prompt-template> [--dev] [--overwrite]
run-fallacy-classification-without-premise.py parse-llm-output <file> [--dev]
Arguments:
<prompt-template> Prompt template to be used, e.g. "cls_without_premise/p4-connect-cls-D.txt".
<model-size> Model size to be used ("70b", "13b", "7b").
<seed> Seed value (default=1).
Options:
-h, --help Show this help message and exit
--dev Run on the development set
--8bit Use 8-bit precision. Only relevant if the model is "70b". Default is 4-bit for 70b and 8-bit
otherwise
--overwrite Overwrite existing files.
"""
from os.path import join
from typing import List, Dict
from docopt import docopt
from sklearn.metrics import classification_report
from missci.data.missci_data_loader import MissciDataLoader
from missci.modeling.gpt4 import GPTCaller
from missci.modeling.model_llama import query_llama_for_classification_with_implicit_premise
from missci.output_parser.llm_output_parser_fallacy import ClassifyFallacyLLMOutputParser
from missci.prompt_templates.classify_generate_template_filler import ClassifyGenerateTemplateFiller
from missci.util.directory_util import get_raw_prompt_prediction_directory, get_prediction_directory
from missci.util.fileutil import read_jsonl, write_jsonl
def run_llama_fallacy_classification_without_premise(
template_file: str, llama_size: str, split: str, instances: List[Dict], seed: int
) -> str:
"""
Prompt Llama2 to classify the applied fallacy only given the claim, the accurate premise and the context (but not
the fallacious premise). The LLM outputs will be stored in the "predictions/only-classify-raw" directory.
:param template_file: relative path of the prompt template within the "prompt_templates" directory.
:param llama_size: Llama2 size as string ("7b", "70b", "13b")
:param split: Data split ("train" or "dev"). Only used for naming.
:param instances: List of all instances that will be prompted.
:param seed: random seed (default=1)
:return:
"""
output_directory: str = get_raw_prompt_prediction_directory('classify-only')
return query_llama_for_classification_with_implicit_premise(
split=split,
instances=instances,
output_directory=output_directory,
template_file=template_file,
llama_size=llama_size,
seed=seed
)
def run_gpt4_fallacy_classification(template_file: str, split: str, instances: List[Dict], overwrite: bool) -> str:
"""
Prompt GPT 4 to classify the applied fallacy only given the claim, the accurate premise and the context (but not
the fallacious premise). The LLM outputs will be stored in the "predictions/only-classify-raw" directory.
:param template_file: relative path of the prompt template within the "prompt_templates" directory.
:param split: Data split ("train" or "dev"). Only used for naming.
:param instances: List of all instances that will be prompted.
:param overwrite: If set to true, existing GPT 4 predictions will not be re-used.
:return:
"""
template_filler = ClassifyGenerateTemplateFiller(template_file)
gpt4: GPTCaller = GPTCaller(
output_directory=get_raw_prompt_prediction_directory('classify-only'),
template_filler=template_filler,
overwrite=overwrite
)
return gpt4.prompt(instances, split)
def parse_prompt_llm_output(file_name: str, gold_instances: List[Dict], formatted=False):
gold_instance_dict = {
fallacy['id']: [
interchangeable_fallacy['class'] for interchangeable_fallacy in fallacy['interchangeable_fallacies']
]
for instance in gold_instances
for fallacy in instance['argument']['fallacies']
}
prompt_directory: str = get_raw_prompt_prediction_directory('classify-only')
predictions = list(read_jsonl(join(prompt_directory, file_name)))
assert len(predictions) == len(gold_instance_dict), f'MISMATCH {len(predictions)} vs {len(gold_instance_dict)}: {file_name}'
all_gold = []
all_pred = []
for pred in predictions:
if formatted:
predicted = ClassifyFallacyLLMOutputParser('').parse(pred)['predicted_parsed']['fallacy_name']
else:
assert False
predicted = get_single_fallacy_from_answer(pred['answer'])
gold_labels = gold_instance_dict[pred['data']['fallacy_id']]
# It is correct if it is among the interchangeable fallacies!
pred['predicted-label'] = predicted
pred['gold-labels'] = gold_labels
if predicted in gold_labels:
gold_label = predicted
else:
gold_label = gold_labels[0]
all_gold.append(gold_label)
all_pred.append(predicted)
write_jsonl(
join(get_prediction_directory('classify-only'), file_name.replace('.jsonl', '.parsed.jsonl')),
predictions
)
report = classification_report(all_gold, all_pred, zero_division=0, output_dict=True)
acc = round(report['accuracy'], 3)
f1 = round(report['macro avg']['f1-score'], 3)
print(file_name, f'\nACC: {acc} ; F1 MACRO: {f1}')
print()
def main():
args = docopt(__doc__)
split = 'dev' if args['--dev'] else 'test'
instances: List[Dict] = MissciDataLoader().load_raw_arguments(split)
if args['llama']:
seed: int = 1
if args['<seed>']:
seed = int(args['<seed>'])
run_llama_fallacy_classification_without_premise(
args['<prompt-template>'], args['<model-size>'], split, instances, seed
)
elif args['gpt4']:
run_gpt4_fallacy_classification(args['<prompt-template>'], split, instances, args['--overwrite'])
elif args['parse-llm-output']:
parse_prompt_llm_output(args['<file>'], instances, formatted=True)
else:
raise NotImplementedError()
if __name__ == '__main__':
main()