-
-
Notifications
You must be signed in to change notification settings - Fork 1
/
graddatasetgen.py
72 lines (67 loc) · 2.18 KB
/
graddatasetgen.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import csv
from datasets import load_dataset
import random
ds = load_dataset("sentence-transformers/wikihow")
#Query Document Score
file = open("csv/gradients_test.csv", "r", newline="")
reader = csv.reader(file, delimiter="|")
header = next(reader)
file_to_write = open("csv/ungradient_data.csv", "w", newline="")
writer = csv.writer(file_to_write, delimiter="|")
writer.writerow("Query|Document")
doc_templates_end = [" Created {val}", " Published {val}", " Written {val}", " Released {val}", " Posted {val}"]
doc_templates_start = ["Created {val} ", "Published {val} ", "Written {val} ", "Released {val} ", "Posted {val} "]
def modify_query(i, text, start_or_end):
summary = ds['train'][i]['summary']
if start_or_end == 0:
summary = summary.rstrip(".?")
return summary + " " + text
else:
summary = summary[0].lower() + summary[1:]
return text + " " + summary
def modify_doc(i, text, start_or_end):
doc = ds['train'][i]['text']
if start_or_end == 0:
if random.randint(0, 1) == 0:
text = text.lower()
doc = doc.rstrip(".?")
return doc + " " + text
else:
if random.randint(0, 1) == 0:
text = text.lower()
doc = doc[0].lower() + doc[1:]
return text + " " + doc
j = 0
i = 0
while j < 30000-2:
#make sure there is a next. if there is no next break
try:
row = next(reader)
except:
break
query = row[0]
document = row[1]
score = row[2]
start_or_end = random.randint(0, 1)
#format doc
if start_or_end == 0:
doc_template = random.choice(doc_templates_start)
else:
doc_template = random.choice(doc_templates_end)
document = doc_template.format(val=document)
document = modify_doc(i, document, start_or_end)
#format query
start_or_end = random.randint(0, 1)
if start_or_end == 0:
query = modify_query(i, query, 0)
else:
query = modify_query(i, query, 1)
try:
#replace /n from query and document
query = query.replace("\n", "")
document = document.replace("\n", "")
writer.writerow([query, document])
i += 1
j += 1
except:
i += 1