Skip to content

Commit

Permalink
refactor nearest-k, add controller
Browse files Browse the repository at this point in the history
  • Loading branch information
Realive333 committed Apr 27, 2023
1 parent b4efbc5 commit 25d268b
Show file tree
Hide file tree
Showing 5 changed files with 106 additions and 4 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -128,5 +128,6 @@ dmypy.json
# Pyre type checker
.pyre/

/saves
.json
.tsv
46 changes: 46 additions & 0 deletions clipper/controller.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from . import work_factory
from tqdm import tqdm
from .method import nearest_k, first_match, work_seperator

class Controller:
def __init__(self, target, dataset_type, clip_type, size=5, offset=256, n_size=100):
self.offset = offset
self.BERT_MAXIMUM_INPUT = 512
self.wf = work_factory.WorkFactory(target)
self.wf.load_tsv('morpheme', dataset_type)
self.wf.load_target_words()
self.wf.split_works()
self.delimiter_size = size
self.n = n_size

if clip_type == 'first-match':
self.target_words = self.wf.get_target_words()
elif clip_type == 'nearest-k':
self.wf.load_similar_wordlist()
self.wordlist = self.wf.get_similar_wordlist(self.n)

self.works = self.wf.get_works()

def first_match(self):
results = []
for work in tqdm(self.works):
label = work['label']
try:
assert work_seperator.calculate_length(work['split']) > self.delimiter_size * self.BERT_MAXIMUM_INPUT, 'Work length is shorter than expection'
paragraphs = first_match.get_list(work, self.target_words, self.delimiter_size)
except AssertionError as err:
print(err)
results.append({'label': label, 'paragraphs': paragraphs})
return results

def nearest_k(self):
results = []
for work in tqdm(self.works):
label = work['label']
try:
assert work_seperator.calculate_length(work['split']) > self.delimiter_size * self.BERT_MAXIMUM_INPUT, 'Work length is shorter than expection'
paragraphs = nearest_k.get_list(work, self.wordlist, self.offset, self.delimiter_size)
except AssertionError as err:
print(err)
results.append({'label': label, 'paragraphs': paragraphs})
return results
7 changes: 3 additions & 4 deletions clipper/work_factory.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import csv
import sys
csv.field_size_limit(sys.maxsize)

from tqdm import tqdm
import logging

class WorkFactory:
Expand All @@ -25,14 +25,14 @@ def load_tsv(self, docType, dataType):
path = self.get_path(docType, dataType)
with open(path, 'r', encoding='utf-8') as file:
rows = csv.reader(file, delimiter='\t')
for row in rows:
for row in tqdm(rows):
label = row[0]
text = row[1]
self.works.append({'label': label, 'text': text})

def split_works(self):
works = []
for work in self.works:
for work in tqdm(self.works):
label = work['label']
text = work['text']
text = text.replace('\n', '。 ')
Expand All @@ -55,7 +55,6 @@ def load_similar_wordlist(self):
next(reader, None) # Skip headder
for row in reader:
self.similar_wordlist.append({'word': row[0], 'score': row[1]})


def get_works(self):
return self.works
Expand Down
43 changes: 43 additions & 0 deletions kakuyomu_clipper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import os
import json
from argparse import ArgumentParser
from clipper.method import nearest_k, first_match, work_seperator
from clipper import controller
from tqdm import tqdm

def main(args):
target = args.target
clip_type = args.type
size = args.size
offset = args.offset
n_size = args.n

save_path = f'./saves/{clip_type}-{size}/{target}'
os.makedirs(save_path, exist_ok=True)

train = controller.Controller(target, 'train', clip_type, size, offset, n_size)
train_result = train.nearest_k()
with open(f'{save_path}/train.json', 'w') as f:
json.dump(train_result, f)

dev = controller.Controller(target, 'dev', clip_type, size, offset, n_size)
dev_result = dev.nearest_k()
with open(f'{save_path}/dev.json', 'w') as f:
json.dump(dev_result, f)

test = controller.Controller(target, 'test', clip_type, size, offset, n_size)
test_result = test.nearest_k()
with open(f'{save_path}/test.json', 'w') as f:
json.dump(test_result, f)

if __name__ == '__main__':
parser = ArgumentParser(description='Kakuyomu-Clipper')

parser.add_argument('--target', type=int, default=42)
parser.add_argument('--type', type=str, default='first-match')
parser.add_argument('--size', type=int, default=5)
parser.add_argument('--offset', type=int, default=256)
parser.add_argument('--n', type=int, default=100)

args = parser.parse_args()
main(args)
13 changes: 13 additions & 0 deletions script_kakuyomu_clipper.bash
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
echo "kakuyomu-clipper shell script start"

target=( 1 2 3 4 5 20 39 40 42 69 70 71 73 74 75 77 79 80 81 83 84 87 90 96 120 121 122 126 128 199 200 203 204 214 259 260 281 284 291)
type='nearest-k'
docsize=5
offset=256
nsize=100

for item in "${target[@]}";
do
echo "cliping... type=$type target=$item size=$docsize offset=$offset nsize=$nsize"
python kakuyomu_clipper.py --type=$type --target=$item --size=$docsize --offset=$offset --n=$nsize
done

0 comments on commit 25d268b

Please sign in to comment.