Skip to content

Commit

Permalink
fix memory problem
Browse files Browse the repository at this point in the history
  • Loading branch information
Realive333 committed Apr 27, 2023
1 parent 25d268b commit 4c61f30
Show file tree
Hide file tree
Showing 7 changed files with 22 additions and 17 deletions.
10 changes: 6 additions & 4 deletions clipper/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,10 @@ def first_match(self):
results = []
for work in tqdm(self.works):
label = work['label']
split_text = work_seperator.splitter(work['text'])
try:
assert work_seperator.calculate_length(work['split']) > self.delimiter_size * self.BERT_MAXIMUM_INPUT, 'Work length is shorter than expection'
paragraphs = first_match.get_list(work, self.target_words, self.delimiter_size)
assert work_seperator.calculate_length(split_text) > self.delimiter_size * self.BERT_MAXIMUM_INPUT, 'Work length is shorter than expection'
paragraphs = first_match.get_list(split_text, self.target_words, self.delimiter_size)
except AssertionError as err:
print(err)
results.append({'label': label, 'paragraphs': paragraphs})
Expand All @@ -37,9 +38,10 @@ def nearest_k(self):
results = []
for work in tqdm(self.works):
label = work['label']
split_text = work_seperator.splitter(work['text'])
try:
assert work_seperator.calculate_length(work['split']) > self.delimiter_size * self.BERT_MAXIMUM_INPUT, 'Work length is shorter than expection'
paragraphs = nearest_k.get_list(work, self.wordlist, self.offset, self.delimiter_size)
assert work_seperator.calculate_length(split_text) > self.delimiter_size * self.BERT_MAXIMUM_INPUT, 'Work length is shorter than expection'
paragraphs = nearest_k.get_list(split_text, self.wordlist, self.offset, self.delimiter_size)
except AssertionError as err:
print(err)
results.append({'label': label, 'paragraphs': paragraphs})
Expand Down
3 changes: 1 addition & 2 deletions clipper/method/first_match.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,8 @@ def first_match_chunk(chunk, target_words):
if ct > 512:
return 'no match', 'n', ' '.join(chunk[:idx])

def get_list(work, target_words, sep_delimiter=1):
def get_list(split_text, target_words, sep_delimiter=1):
paragraphs = []
split_text = work['split']
chunks = list(work_seperator.seperate(split_text, len(split_text)//sep_delimiter))
for chunk in chunks[:sep_delimiter]:
try:
Expand Down
3 changes: 1 addition & 2 deletions clipper/method/nearest_k.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,8 @@ def nearest_k_chunk(split_text, similar_wordlist, offset):
best_paragraph = paragraph
return (best_paragraph, best_score)

def get_list(work, similar_wordlist, offset=256, sep_delimiter=1):
def get_list(split_text, similar_wordlist, offset=256, sep_delimiter=1):
paragraphs = []
split_text = work['split']
chunks = list(work_seperator.seperate(split_text, len(split_text)//sep_delimiter))
for chunk in chunks[:sep_delimiter]:
try:
Expand Down
8 changes: 7 additions & 1 deletion clipper/method/work_seperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,10 @@ def calculate_length(split_text):

def seperate(lst, n):
for i in range(0, len(lst), n):
yield lst[i:i+n]
yield lst[i:i+n]

def splitter(work_text):
text = work_text.replace('\n', '。 ')
split_text = text.split(' ')
split_text.remove('')
return split_text
4 changes: 2 additions & 2 deletions clipper/work_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def load_tsv(self, docType, dataType):
label = row[0]
text = row[1]
self.works.append({'label': label, 'text': text})

"""
def split_works(self):
works = []
for work in tqdm(self.works):
Expand All @@ -40,7 +40,7 @@ def split_works(self):
split_text.remove('')
works.append({'label': label, 'text': text, 'split': split_text})
self.works = works

"""
def load_target_words(self):
rows = []
with open('/data/realive333/kakuyomu-dataset/numeric_label.tsv', encoding='utf-8') as f:
Expand Down
5 changes: 3 additions & 2 deletions test/test_first_match.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def setUp(self):
self.wf = WorkFactory(42)
self.wf.load_tsv('morpheme', 'test')
self.wf.load_target_words()
self.wf.split_works()
#self.wf.split_works()
self.test_work = self.wf.get_works()[251]
self.target = self.wf.get_target()

Expand All @@ -24,6 +24,7 @@ def test_calculate_length(self):
self.assertEqual(assertion, 6)

def test_first_match(self):
result = first_match.get_list(self.test_work, self.target_words, 5)
split = work_seperator.splitter(self.test_work['text'])
result = first_match.get_list(split, self.target_words, 5)
print(result)
self.assertTrue(True)
6 changes: 2 additions & 4 deletions test/test_nearest_k.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,6 @@ def setUp(self):
self.wf = WorkFactory(42)
self.wf.load_tsv('morpheme', 'test')
self.wf.load_similar_wordlist()

self.wf.split_works()

self.test_work = self.wf.get_works()[251]
self.similar_wordlist = self.wf.get_similar_wordlist(10)

Expand All @@ -25,7 +22,8 @@ def test_get_nearest_words(self):

def test_nearest_k(self):
word_list = ['父', 'コー', '白鳥', '銃', 'ぼうや']
result = nearest_k.get_list(self.test_work, word_list, 256, 5)
split = work_seperator.splitter(self.test_work['text'])
result = nearest_k.get_list(split, word_list, 256, 5)
for r in result:
print(r)
print('='*10)
Expand Down

0 comments on commit 4c61f30

Please sign in to comment.