Skip to content

Commit 4bc18c7

Browse files
committed
[update] adapt tasks for it
1 parent 55f4e57 commit 4bc18c7

File tree

1 file changed

+9
-19
lines changed

1 file changed

+9
-19
lines changed

moe_peft/tasks/qa_tasks.py

Lines changed: 9 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,7 @@ def loading_data(
7777
)
7878
answer = "true" if data_point["answer"] else "false"
7979
if is_train:
80-
prompt += f" {answer}"
81-
labels = None
80+
labels = answer
8281
else:
8382
labels = [self.labels2id_[answer]]
8483
ret.append(InputData(inputs=prompt, labels=labels))
@@ -110,8 +109,7 @@ def loading_data(
110109
prompt += f" ({label}) {text}"
111110
prompt += "\nAnswer:"
112111
if is_train:
113-
prompt += " " + data_point["answerKey"]
114-
labels = None
112+
labels = data_point["answerKey"]
115113
else:
116114
labels = [self.labels2id_[data_point["answerKey"]]]
117115
ret.append(InputData(inputs=prompt, labels=labels))
@@ -138,8 +136,7 @@ def loading_data(
138136
prompt += "\nCorrect solution:"
139137
answer = self.labels_[data_point["label"]]
140138
if is_train:
141-
prompt += f" {answer}"
142-
labels = None
139+
labels = answer
143140
else:
144141
labels = [data_point["label"]]
145142
ret.append(InputData(inputs=prompt, labels=labels))
@@ -168,8 +165,7 @@ def loading_data(
168165
prompt += "\nAnswer:"
169166
label = int(data_point["label"]) - 1
170167
if is_train:
171-
prompt += f" {self.labels_[label]}"
172-
labels = None
168+
labels = self.labels_[label]
173169
else:
174170
labels = [label]
175171
ret.append(InputData(inputs=prompt, labels=labels))
@@ -200,8 +196,7 @@ def loading_data(
200196
prompt += "\nAnswer:"
201197
label = int(data_point["label"])
202198
if is_train:
203-
prompt += f" {self.labels_[label]}"
204-
labels = None
199+
labels = self.labels_[label]
205200
else:
206201
labels = [label]
207202
ret.append(InputData(inputs=prompt, labels=labels))
@@ -230,8 +225,7 @@ def loading_data(
230225
prompt += "\nAnswer:"
231226
label = int(data_point["answer"]) - 1
232227
if is_train:
233-
prompt += f" {self.labels_[label]}"
234-
labels = None
228+
labels = self.labels_[label]
235229
else:
236230
labels = [label]
237231
ret.append(InputData(inputs=prompt, labels=labels))
@@ -261,8 +255,7 @@ def loading_data(
261255
prompt += f" ({label}) {text}"
262256
prompt += "\nAnswer:"
263257
if is_train:
264-
prompt += " " + data_point["answerKey"]
265-
labels = None
258+
labels = data_point["answerKey"]
266259
else:
267260
labels = [self.labels2id_[data_point["answerKey"]]]
268261
ret.append(InputData(inputs=prompt, labels=labels))
@@ -295,13 +288,10 @@ def loading_data(
295288
prompt += f"({label}) {text}\n"
296289
answer = data_point["final_decision"]
297290
assert answer in self.labels2id_
291+
prompt += "Answer:"
298292
if is_train:
299-
prompt += f"Long Answer:\n{data_point['long_answer']}\n"
300-
prompt += "Answer:"
301-
prompt += f" {answer}"
302-
labels = None
293+
labels = answer
303294
else:
304-
prompt += "Answer:"
305295
labels = [self.labels2id_[answer]]
306296
ret.append(InputData(inputs=prompt, labels=labels))
307297

0 commit comments

Comments
 (0)