diff --git a/test/test_pipeline_sample.py b/test/test_pipeline_sample.py index f3e9b22..7110fba 100644 --- a/test/test_pipeline_sample.py +++ b/test/test_pipeline_sample.py @@ -1,16 +1,10 @@ -from arekit.common.pipeline.items.base import BasePipelineItem -from arekit.common.utils import split_by_whitespaces - import utils import unittest -import ru_sent_tokenize -from ru_sent_tokenize import ru_sent_tokenize from os.path import join -from arekit.common.docs.base import Document +from arekit.common.utils import split_by_whitespaces from arekit.common.docs.entities_grouping import EntitiesGroupingPipelineItem -from arekit.common.docs.sentence import BaseDocumentSentence from arekit.common.experiment.data_type import DataType from arekit.common.labels.base import NoLabel from arekit.common.labels.scaler.single import SingleLabelScaler @@ -18,6 +12,7 @@ from arekit.common.synonyms.grouping import SynonymsCollectionValuesGroupingProviders from arekit.common.data import const from arekit.common.pipeline.context import PipelineContext +from arekit.common.pipeline.items.base import BasePipelineItem from arekit.contrib.utils.data.writers.sqlite_native import SQliteWriter from arekit.contrib.utils.processing.lemmatization.mystem import MystemWrapper from arekit.contrib.utils.synonyms.stemmer_based import StemmerBasedSynonymCollection @@ -49,16 +44,6 @@ class BertTestSerialization(unittest.TestCase): model. """ - @staticmethod - def input_to_docs(texts): - docs = [] - for doc_id, contents in enumerate(texts): - sentences = ru_sent_tokenize(contents) - sentences = list(map(lambda text: BaseDocumentSentence(text), sentences)) - doc = Document(doc_id=doc_id, sentences=sentences) - docs.append(doc) - return docs - @staticmethod def iter_groups(filepath): with open(filepath, 'r', encoding='utf-8') as file: @@ -95,7 +80,7 @@ def test(self): ] # Composing labels formatter and experiment preparation. - doc_provider = utils.InMemoryDocProvider(docs=BertTestSerialization.input_to_docs(texts)) + doc_provider = utils.InMemoryDocProvider(docs=utils.input_to_docs(texts)) pipeline = [ AREkitSerializerPipelineItem( rows_provider=create_bert_sample_provider(