-
Notifications
You must be signed in to change notification settings - Fork 0
/
processing.py
138 lines (101 loc) · 4.33 KB
/
processing.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
from textwrap import wrap
import fitz
import regex as re
from nltk import sent_tokenize
import tensorflow as tf
import tensorflow_text
import requests
import os
def check_contain(r_word, points):
r = fitz.Quad(points).rect
r.intersect(r_word)
if r.get_area() >= r_word.get_area() * 0.8:
contain = True
else:
contain = False
return contain
def extract_annot(annot, words_on_page):
quad_points = annot.vertices
quad_count = int(len(quad_points) / 4)
sentences = ['' for i in range(quad_count)]
for i in range(quad_count):
points = quad_points[i * 4: i * 4 + 4]
words = [
w for w in words_on_page if
check_contain(fitz.Rect(w[:4]), points)
]
sentences[i] = ' '.join(w[4] for w in words)
sentence = ' '.join(sentences)
return sentence
def pdf_to_excerpts(filename):
doc = fitz.open(filename)
excerpts = []
for page_idx, page in enumerate(doc):
words = page.get_text('words')
annots = page.annots()
for annot in annots:
if annot is not None and annot.vertices is not None:
excerpt = extract_annot(annot, words)
excerpt = excerpt.strip()
excerpt = re.sub(r'\s+', ' ', excerpt)
excerpts += [excerpt]
return excerpts
def pdf_to_text(filename):
document = fitz.open(filename)
document = ' '.join([e.get_text() for e in document])
document = document.replace('\n', ' ')
document = re.sub(r'\s+', ' ', document)
return document
def extract_context(excerpt, document, size=400):
print(excerpt)
excerpt_start = re.search(re.escape(excerpt), document)
if excerpt_start is not None:
excerpt_start = excerpt_start.start()
excerpt_end = excerpt_start + len(excerpt)
context_start = excerpt_start - size
context_end = excerpt_end + size
context = document[context_start:context_end]
context = ' '.join(sent_tokenize(context)[1:-1])
return context
else:
return excerpt
def download_model():
model_specific_path = 't5_base/1611267950' #t5_3B/1611333896
if not os.path.exists('model'):
os.makedirs('model')
if not os.path.exists('model/saved_model.pb'):
saved_model = requests.get('https://storage.googleapis.com/decontext_dataset/' + model_specific_path + '/saved_model.pb')
open('model/saved_model.pb', 'wb').write(saved_model.content)
if not os.path.exists('model/variables'):
os.makedirs('model/variables')
variables0 = requests.get('https://storage.googleapis.com/decontext_dataset/' + model_specific_path + '/variables/variables.data-00000-of-00002')
open('model/variables/variables.data-00000-of-00002', 'wb').write(variables0.content)
variables1 = requests.get('https://storage.googleapis.com/decontext_dataset/' + model_specific_path + '/variables/variables.data-00001-of-00002')
open('model/variables/variables.data-00001-of-00002', 'wb').write(variables1.content)
variables_index = requests.get('https://storage.googleapis.com/decontext_dataset/' + model_specific_path + '/variables/variables.index')
open('model/variables/variables.index', 'wb').write(variables_index.content)
def load_predict_fn():
imported = tf.saved_model.load('./model', ['serve'])
return lambda x: imported.signatures['serving_default'](tf.constant(x))['outputs'].numpy()
def create_input(excerpt, context):
context_sents = sent_tokenize(context)
size = 1
wrapper = None
while size < len(context_sents) and wrapper is None:
left = 0
while left + size <= len(context_sents) and wrapper is None:
if excerpt in ' '.join(context_sents[left:left + size]):
wrapper = (left, left + size)
left += 1
size += 1
if wrapper == None:
return None
prefix = ' '.join(context_sents[:wrapper[0]])
target = ' '.join(context_sents[wrapper[0]:wrapper[1]])
suffix = ' '.join(context_sents[wrapper[1]:])
return ' [SEP] '.join(['', '', prefix, target, suffix])
def decontextualize_excerpt(excerpt, context, predict_fn):
input = create_input(excerpt, context)
if input is not None:
output = predict_fn([input])[0].decode('utf-8')
return output.split('####')[1]