From 4f3726f4669e9de48184a9c04b3511a14ec3d03b Mon Sep 17 00:00:00 2001 From: Nicolay Rusnachenko Date: Thu, 4 Jan 2024 22:29:34 +0000 Subject: [PATCH] #134 fixed. #136 implemented, not switched to batching mode. --- arelight/download.py | 0 .../pipelines/data/annot_pairs_nolabel.py | 4 +- .../pipelines/items/backend_d3js_graphs.py | 19 +++---- .../items/backend_d3js_operations.py | 25 ++++---- arelight/pipelines/items/entities_default.py | 4 +- arelight/pipelines/items/entities_ner_dp.py | 10 ++-- .../items/entities_ner_transformers.py | 12 +--- .../pipelines/items/inference_bert_opennre.py | 24 ++++---- arelight/pipelines/items/inference_writer.py | 14 ++--- arelight/pipelines/items/serializer_arekit.py | 4 +- arelight/run/infer.py | 20 +++---- arelight/run/operations.py | 28 +++++---- arelight/run/utils.py | 2 +- arelight/synonyms.py | 14 +++++ arelight/utils.py | 28 +++++++++ dependencies.txt | 3 +- test/test_arekit_iter_data.py | 9 +-- test/test_backend_d3js_graphs.py | 5 +- test/test_csv_reader.py | 2 +- test/test_dp_ner_pipeline_item.py | 43 +++++++------- test/test_dp_ner_sample.py | 11 ++-- test/test_pipeline_infer.py | 45 +++++++++------ test/test_pipeline_sample.py | 57 ++++++++++--------- test/test_transformers_ner_ppl.py | 15 +++-- test/utils_ner.py | 19 ++++--- 25 files changed, 231 insertions(+), 186 deletions(-) create mode 100644 arelight/download.py create mode 100644 arelight/synonyms.py diff --git a/arelight/download.py b/arelight/download.py new file mode 100644 index 0000000..e69de29 diff --git a/arelight/pipelines/data/annot_pairs_nolabel.py b/arelight/pipelines/data/annot_pairs_nolabel.py index e7362d0..13befc9 100644 --- a/arelight/pipelines/data/annot_pairs_nolabel.py +++ b/arelight/pipelines/data/annot_pairs_nolabel.py @@ -9,7 +9,7 @@ def create_neutral_annotation_pipeline(synonyms, dist_in_terms_bound, terms_per_context, - doc_provider, text_parser, dist_in_sentences=0): + doc_provider, text_pipeline, dist_in_sentences=0): nolabel_annotator = AlgorithmBasedTextOpinionAnnotator( value_to_group_id_func=lambda value: @@ -28,7 +28,7 @@ def create_neutral_annotation_pipeline(synonyms, dist_in_terms_bound, terms_per_ annotation_pipeline = text_opinion_extraction_pipeline( entity_index_func=lambda indexed_entity: indexed_entity.ID, - text_parser=text_parser, + pipeline_items=text_pipeline, get_doc_by_id_func=doc_provider.by_id, annotators=[ nolabel_annotator diff --git a/arelight/pipelines/items/backend_d3js_graphs.py b/arelight/pipelines/items/backend_d3js_graphs.py index 07c660a..e3fa27c 100644 --- a/arelight/pipelines/items/backend_d3js_graphs.py +++ b/arelight/pipelines/items/backend_d3js_graphs.py @@ -7,7 +7,6 @@ from arekit.common.experiment.data_type import DataType from arekit.common.labels.scaler.base import BaseLabelScaler from arekit.common.labels.str_fmt import StringLabelsFormatter -from arekit.common.pipeline.context import PipelineContext from arekit.common.pipeline.items.base import BasePipelineItem from arelight.arekit.parse_predict import iter_predicted_labels @@ -20,7 +19,8 @@ class D3jsGraphsBackendPipelineItem(BasePipelineItem): - def __init__(self, graph_min_links=0.01, graph_a_labels=None, weights=True): + def __init__(self, graph_min_links=0.01, graph_a_labels=None, weights=True, **kwargs): + super(D3jsGraphsBackendPipelineItem, self).__init__(**kwargs) self.__graph_min_links = graph_min_links # Setup filters for the A and B graphs for further operations application. @@ -58,19 +58,18 @@ def iter_column_value(self, samples, column_value): yield parsed_row[column_value] def apply_core(self, input_data, pipeline_ctx): - assert(isinstance(input_data, PipelineContext)) - predict_filepath = input_data.provide("predict_filepath") - result_reader = input_data.provide("predict_reader") - labels_fmt = input_data.provide("labels_formatter") + predict_filepath = pipeline_ctx.provide("predict_filepath") + result_reader = pipeline_ctx.provide("predict_reader") + labels_fmt = pipeline_ctx.provide("labels_formatter") assert(isinstance(labels_fmt, StringLabelsFormatter)) - labels_scaler = input_data.provide("labels_scaler") + labels_scaler = pipeline_ctx.provide("labels_scaler") assert(isinstance(labels_scaler, BaseLabelScaler)) predict_storage = result_reader.read(predict_filepath) assert(isinstance(predict_storage, BaseRowsStorage)) # Reading samples. - samples_io = input_data.provide("samples_io") + samples_io = pipeline_ctx.provide("samples_io") samples_filepath = samples_io.create_target(data_type=DataType.Test) samples = samples_io.Reader.read(samples_filepath) @@ -90,5 +89,5 @@ def apply_core(self, input_data, pipeline_ctx): weights=self.__graph_weights) # Saving graph as the collection name for it. - input_data.update("d3js_graph_a", value=graph) - input_data.update("d3js_collection_name", value=samples_io.Prefix) + pipeline_ctx.update("d3js_graph_a", value=graph) + pipeline_ctx.update("d3js_collection_name", value=samples_io.Prefix) diff --git a/arelight/pipelines/items/backend_d3js_operations.py b/arelight/pipelines/items/backend_d3js_operations.py index 5e7dc53..b3998a7 100644 --- a/arelight/pipelines/items/backend_d3js_operations.py +++ b/arelight/pipelines/items/backend_d3js_operations.py @@ -4,7 +4,6 @@ from arekit.common.data.rows_fmt import create_base_column_fmt from arekit.common.labels.str_fmt import StringLabelsFormatter -from arekit.common.pipeline.context import PipelineContext from arekit.common.pipeline.items.base import BasePipelineItem from arelight.backend.d3js.relations_graph_operations import graphs_operations @@ -17,21 +16,21 @@ class D3jsGraphOperationsBackendPipelineItem(BasePipelineItem): - def __init__(self): + def __init__(self, **kwargs): # Parameters for sampler. + super(D3jsGraphOperationsBackendPipelineItem, self).__init__(**kwargs) self.__column_fmts = [create_base_column_fmt(fmt_type="parser")] def apply_core(self, input_data, pipeline_ctx): - assert(isinstance(input_data, PipelineContext)) - - graph_a = input_data.provide_or_none("d3js_graph_a") - graph_b = input_data.provide_or_none("d3js_graph_b") - op = input_data.provide_or_none("d3js_graph_operations") - weights = input_data.provide_or_none("d3js_graph_weights") - target_dir = input_data.provide("d3js_graph_output_dir") - collection_name = input_data.provide("d3js_collection_name") - labels_fmt = input_data.provide("labels_formatter") - host_port = input_data.provide_or_none("d3js_host") + + graph_a = pipeline_ctx.provide_or_none("d3js_graph_a") + graph_b = pipeline_ctx.provide_or_none("d3js_graph_b") + op = pipeline_ctx.provide_or_none("d3js_graph_operations") + weights = pipeline_ctx.provide_or_none("d3js_graph_weights") + target_dir = pipeline_ctx.provide("d3js_graph_output_dir") + collection_name = pipeline_ctx.provide("d3js_collection_name") + labels_fmt = pipeline_ctx.provide("labels_formatter") + host_port = pipeline_ctx.provide_or_none("d3js_host") assert(isinstance(labels_fmt, StringLabelsFormatter)) graph = graphs_operations(graph_A=graph_a, graph_B=graph_b, operation=op, weights=weights) \ @@ -49,7 +48,7 @@ def apply_core(self, input_data, pipeline_ctx): save_demo_page(target_dir=target_dir, collection_name=collection_name, host_root_path=f"http://localhost:{host_port}/" if host_port is not None else "./", - desc_name=input_data.provide_or_none("d3js_collection_description"), + desc_name=pipeline_ctx.provide_or_none("d3js_collection_description"), desc_labels={label_type.__name__: labels_fmt.label_to_str(label_type()) for label_type in labels_fmt._stol.values()}) diff --git a/arelight/pipelines/items/entities_default.py b/arelight/pipelines/items/entities_default.py index 9042735..f75478f 100644 --- a/arelight/pipelines/items/entities_default.py +++ b/arelight/pipelines/items/entities_default.py @@ -5,10 +5,10 @@ class TextEntitiesParser(BasePipelineItem): - def __init__(self, id_assigner, display_value_func=None): + def __init__(self, id_assigner, display_value_func=None, **kwargs): assert(isinstance(id_assigner, IdAssigner)) assert(callable(display_value_func) or display_value_func is None) - super(TextEntitiesParser, self).__init__() + super(TextEntitiesParser, self).__init__(**kwargs) self.__id_assigner = id_assigner self.__disp_value_func = display_value_func diff --git a/arelight/pipelines/items/entities_ner_dp.py b/arelight/pipelines/items/entities_ner_dp.py index 8c0430a..761e5ca 100644 --- a/arelight/pipelines/items/entities_ner_dp.py +++ b/arelight/pipelines/items/entities_ner_dp.py @@ -10,7 +10,8 @@ class DeepPavlovNERPipelineItem(SentenceObjectsParserPipelineItem): - def __init__(self, id_assigner, ner_model_name, obj_filter=None, chunk_limit=128, display_value_func=None): + def __init__(self, id_assigner, ner_model_name, obj_filter=None, + chunk_limit=128, display_value_func=None, **kwargs): """ chunk_limit: int length of text part in words that is going to be provided in input. """ @@ -18,6 +19,7 @@ def __init__(self, id_assigner, ner_model_name, obj_filter=None, chunk_limit=128 assert(isinstance(chunk_limit, int) and chunk_limit > 0) assert(isinstance(id_assigner, IdAssigner)) assert(callable(display_value_func) or display_value_func is None) + super(DeepPavlovNERPipelineItem, self).__init__(partitioning=TermsPartitioning(), **kwargs) # Initialize bert-based model instance. self.__dp_ner = DeepPavlovNER(ner_model_name) @@ -25,9 +27,8 @@ def __init__(self, id_assigner, ner_model_name, obj_filter=None, chunk_limit=128 self.__chunk_limit = chunk_limit self.__id_assigner = id_assigner self.__disp_value_func = display_value_func - super(DeepPavlovNERPipelineItem, self).__init__(TermsPartitioning()) - def _get_parts_provider_func(self, input_data, pipeline_ctx): + def _get_parts_provider_func(self, input_data): return self.__iter_subs_values_with_bounds(input_data) def __iter_subs_values_with_bounds(self, terms_list): @@ -65,6 +66,3 @@ def __iter_parsed_entities(self, processed_sequences, chunk_terms_list, chunk_of value=value, e_type=s_obj.ObjectType, entity_id=self.__id_assigner.get_id(), display_value=self.__disp_value_func(value) if self.__disp_value_func is not None else None) yield entity, Bound(pos=chunk_offset + s_obj.Position, length=s_obj.Length) - - def apply_core(self, input_data, pipeline_ctx): - return super(DeepPavlovNERPipelineItem, self).apply_core(input_data=input_data, pipeline_ctx=pipeline_ctx) diff --git a/arelight/pipelines/items/entities_ner_transformers.py b/arelight/pipelines/items/entities_ner_transformers.py index 7aebad1..7406a88 100644 --- a/arelight/pipelines/items/entities_ner_transformers.py +++ b/arelight/pipelines/items/entities_ner_transformers.py @@ -1,6 +1,5 @@ from arekit.common.bound import Bound from arekit.common.docs.objects_parser import SentenceObjectsParserPipelineItem -from arekit.common.text.partitioning.str import StringPartitioning from arelight.pipelines.items.entity import IndexedEntity from arelight.utils import IdAssigner, auto_import @@ -8,20 +7,20 @@ class TransformersNERPipelineItem(SentenceObjectsParserPipelineItem): - def __init__(self, id_assigner, ner_model_name, device, obj_filter=None, display_value_func=None): + def __init__(self, id_assigner, ner_model_name, device, obj_filter=None, display_value_func=None, **kwargs): """ chunk_limit: int length of text part in words that is going to be provided in input. """ assert(callable(obj_filter) or obj_filter is None) assert(isinstance(id_assigner, IdAssigner)) assert(callable(display_value_func) or display_value_func is None) + super(TransformersNERPipelineItem, self).__init__(**kwargs) # Setup third-party modules. model_init = auto_import("arelight.third_party.transformers.init_token_classification_model") self.annotate_ner = auto_import("arelight.third_party.transformers.annotate_ner") # Transformers-related parameters. - self.__device = device self.__model, self.__tokenizer = model_init(model_path=ner_model_name, device=self.__device) @@ -30,9 +29,7 @@ def __init__(self, id_assigner, ner_model_name, device, obj_filter=None, display self.__id_assigner = id_assigner self.__disp_value_func = display_value_func - super(TransformersNERPipelineItem, self).__init__(StringPartitioning()) - - def _get_parts_provider_func(self, input_data, pipeline_ctx): + def _get_parts_provider_func(self, input_data): assert(isinstance(input_data, str)) parts = self.annotate_ner(model=self.__model, tokenizer=self.__tokenizer, text=input_data, device=self.__device) @@ -55,6 +52,3 @@ def __iter_parsed_entities(self, parts): display_value=self.__disp_value_func(value) if self.__disp_value_func is not None else None) yield entity, Bound(pos=p["start"], length=p["end"] - p["start"]) - - def apply_core(self, input_data, pipeline_ctx): - return super(TransformersNERPipelineItem, self).apply_core(input_data=input_data, pipeline_ctx=pipeline_ctx) diff --git a/arelight/pipelines/items/inference_bert_opennre.py b/arelight/pipelines/items/inference_bert_opennre.py index c22a6cf..fa12c3d 100644 --- a/arelight/pipelines/items/inference_bert_opennre.py +++ b/arelight/pipelines/items/inference_bert_opennre.py @@ -1,21 +1,17 @@ -import json -import logging import os from os.path import exists, join +import logging import torch from arekit.common.experiment.data_type import DataType -from arekit.common.pipeline.context import PipelineContext from arekit.common.pipeline.items.base import BasePipelineItem -from arekit.common.utils import download from opennre.encoder import BERTEntityEncoder, BERTEncoder from opennre.model import SoftmaxNN from arelight.third_party.torch import sentence_re_loader -from arelight.utils import get_default_download_dir - +from arelight.utils import get_default_download_dir, download logger = logging.getLogger(__name__) @@ -24,8 +20,9 @@ class BertOpenNREInferencePipelineItem(BasePipelineItem): def __init__(self, pretrained_bert=None, checkpoint_path=None, device_type='cpu', max_seq_length=128, pooler='cls', batch_size=10, tokenizers_parallelism=True, - table_name="contents", task_kwargs=None, predefined_ckpts=None): + table_name="contents", task_kwargs=None, predefined_ckpts=None, **kwargs): assert(isinstance(tokenizers_parallelism, bool)) + super(BertOpenNREInferencePipelineItem, self).__init__(**kwargs) self.__model = None self.__pretrained_bert = pretrained_bert @@ -161,21 +158,20 @@ def __iter_predict_result(self, samples_filepath, batch_size): return results_it, total def apply_core(self, input_data, pipeline_ctx): - assert(isinstance(input_data, PipelineContext)) # Fetching the input data. - labels_scaler = input_data.provide("labels_scaler") + labels_scaler = pipeline_ctx.provide("labels_scaler") # Try to obrain from the specific input variable. - samples_filepath = input_data.provide_or_none("opennre_samples_filepath") + samples_filepath = pipeline_ctx.provide_or_none("opennre_samples_filepath") if samples_filepath is None: - samples_io = input_data.provide("samples_io") + samples_io = pipeline_ctx.provide("samples_io") samples_filepath = samples_io.create_target(data_type=DataType.Test) # Initialize model if the latter has not been yet. if self.__model is None: - ckpt_dir = input_data.provide_or_none("opennre_ckpt_cache_dir") + ckpt_dir = pipeline_ctx.provide_or_none("opennre_ckpt_cache_dir") self.__model = self.init_bert_model( pretrain_path=self.__pretrained_bert, @@ -189,5 +185,5 @@ def apply_core(self, input_data, pipeline_ctx): dir_to_donwload=get_default_download_dir() if ckpt_dir is None else ckpt_dir) iter_infer, total = self.__iter_predict_result(samples_filepath=samples_filepath, batch_size=self.__batch_size) - input_data.update("iter_infer", iter_infer) - input_data.update("iter_total", total) + pipeline_ctx.update("iter_infer", iter_infer) + pipeline_ctx.update("iter_total", total) diff --git a/arelight/pipelines/items/inference_writer.py b/arelight/pipelines/items/inference_writer.py index c889c26..0d2b4d4 100644 --- a/arelight/pipelines/items/inference_writer.py +++ b/arelight/pipelines/items/inference_writer.py @@ -1,4 +1,3 @@ -from arekit.common.pipeline.context import PipelineContext from arekit.common.pipeline.items.base import BasePipelineItem from arelight.predict_provider import BasePredictProvider @@ -7,24 +6,23 @@ class InferenceWriterPipelineItem(BasePipelineItem): - def __init__(self, writer): + def __init__(self, writer, **kwargs): assert(isinstance(writer, BasePredictWriter)) + super(InferenceWriterPipelineItem, self).__init__(**kwargs) self.__writer = writer def apply_core(self, input_data, pipeline_ctx): - assert(isinstance(input_data, PipelineContext)) # Setup predicted result writer. - target = input_data.provide("predict_filepath") - print(target) + target = pipeline_ctx.provide("predict_filepath") self.__writer.set_target(target) # Gathering the content title, contents_it = BasePredictProvider().provide( - sample_id_with_uint_labels_iter=input_data.provide("iter_infer"), - labels_count=input_data.provide("labels_scaler").LabelsCount) + sample_id_with_uint_labels_iter=pipeline_ctx.provide("iter_infer"), + labels_count=pipeline_ctx.provide("labels_scaler").LabelsCount) with self.__writer: self.__writer.write(title=title, contents_it=contents_it, - total=input_data.provide_or_none("iter_total")) + total=pipeline_ctx.provide_or_none("iter_total")) diff --git a/arelight/pipelines/items/serializer_arekit.py b/arelight/pipelines/items/serializer_arekit.py index e7549b2..1ae0082 100644 --- a/arelight/pipelines/items/serializer_arekit.py +++ b/arelight/pipelines/items/serializer_arekit.py @@ -1,4 +1,3 @@ -from arekit.common.pipeline.context import PipelineContext from arekit.contrib.utils.pipelines.items.sampling.base import BaseSerializerPipelineItem @@ -8,9 +7,8 @@ class AREkitSerializerPipelineItem(BaseSerializerPipelineItem): """ def apply_core(self, input_data, pipeline_ctx): - assert(isinstance(input_data, PipelineContext)) super(AREkitSerializerPipelineItem, self).apply_core(input_data=input_data, pipeline_ctx=pipeline_ctx) # Host samples into the result for further pipeline items. - input_data.update("samples_io", self._samples_io) + pipeline_ctx.update("samples_io", self._samples_io) diff --git a/arelight/run/infer.py b/arelight/run/infer.py index 2376f76..90ade1c 100644 --- a/arelight/run/infer.py +++ b/arelight/run/infer.py @@ -7,16 +7,16 @@ from arekit.common.experiment.data_type import DataType from arekit.common.labels.base import NoLabel from arekit.common.labels.scaler.single import SingleLabelScaler -from arekit.common.pipeline.base import BasePipeline +from arekit.common.pipeline.base import BasePipelineLauncher +from arekit.common.pipeline.items.base import BasePipelineItem from arekit.common.synonyms.grouping import SynonymsCollectionValuesGroupingProviders -from arekit.common.text.parser import BaseTextParser +from arekit.common.utils import split_by_whitespaces from arekit.contrib.bert.input.providers.text_pair import PairTextProvider from arekit.contrib.utils.data.readers.sqlite import SQliteReader from arekit.contrib.utils.data.storages.row_cache import RowCacheStorage from arekit.contrib.utils.data.writers.sqlite_native import SQliteWriter from arekit.contrib.utils.entities.formatters.str_simple_sharp_prefixed_fmt import SharpPrefixedEntitiesSimpleFormatter from arekit.contrib.utils.io_utils.samples import SamplesIO -from arekit.contrib.utils.pipelines.items.text.terms_splitter import TermsSplitterParser from arekit.contrib.utils.pipelines.items.text.translator import MLTextTranslatorPipelineItem from arekit.contrib.utils.processing.lemmatization.mystem import MystemWrapper from arekit.contrib.utils.synonyms.simple import SimpleSynonymCollection @@ -161,6 +161,7 @@ def __entity_display_value(entity_value): display_value_func=__entity_display_value), # Parser based on DeepPavlov backend. "deeppavlov": lambda: DeepPavlovNERPipelineItem( + src_func=lambda text: split_by_whitespaces(text), obj_filter=None if ner_object_types is None else lambda s_obj: s_obj.ObjectType in ner_object_types, ner_model_name=ner_model_name, id_assigner=IdAssigner(), @@ -209,8 +210,6 @@ def __entity_display_value(entity_value): infer_engines={key: infer_engines_setup[key] for key in [args.bert_framework]}, backend_engines={key: backend_setups[key] for key in [args.backend]}) - pipeline = BasePipeline(pipeline) - # Settings. settings = [] @@ -240,15 +239,15 @@ def __entity_display_value(entity_value): synonyms = synonyms_setup["lemmatized" if args.stemmer is not None else None]() # Setup text parser. - text_parser = BaseTextParser(pipeline=[ - TermsSplitterParser() if ner_framework == "deeppavlov" else None, + text_parser_pipeline = [ + BasePipelineItem(src_func=lambda s: s.Text), entity_parsers[ner_framework](), CustomTermsSplitterPipelineItem() if ner_framework == "transformers" else None, text_translator_setup["ml-based" if args.translate_text is not None else None](), EntitiesGroupingPipelineItem( lambda value: SynonymsCollectionValuesGroupingProviders.provide_existed_or_register_missed_value( synonyms=synonyms, value=value)) - ]) + ] # Reading from the optionally large list of files. doc_provider = CachedFilesDocProvider( @@ -263,7 +262,7 @@ def __entity_display_value(entity_value): dist_in_terms_bound=terms_per_context, doc_provider=doc_provider, terms_per_context=terms_per_context, - text_parser=text_parser) + text_pipeline=text_parser_pipeline) settings.append({ "data_type_pipelines": {DataType.Test: data_pipeline}, @@ -288,4 +287,5 @@ def __entity_display_value(entity_value): }) # Launch application. - pipeline.run(input_data=PipelineResult(merge_dictionaries(settings))) + BasePipelineLauncher.run(pipeline=pipeline, pipeline_ctx=PipelineResult(merge_dictionaries(settings)), + src_key="doc_ids") diff --git a/arelight/run/operations.py b/arelight/run/operations.py index f8a8187..64ead92 100644 --- a/arelight/run/operations.py +++ b/arelight/run/operations.py @@ -2,7 +2,7 @@ import os from datetime import datetime -from arekit.common.pipeline.base import BasePipeline +from arekit.common.pipeline.base import BasePipelineLauncher from arelight.backend.d3js.relations_graph_operations import OP_UNION, OP_DIFFERENCE, OP_INTERSECTION from arelight.backend.d3js.ui_web import GRAPH_TYPE_FORCE @@ -117,21 +117,19 @@ def get_graph_path(text): description = args.description if args.description else \ get_input_with_default("Specify description of new graph (enter to skip)\n", default_description) - pipeline = BasePipeline([ - D3jsGraphOperationsBackendPipelineItem() - ]) - labels_fmt = {a: v for a, v in map(lambda item: item.split(":"), args.d3js_label_names.split(','))} # Launch application. - pipeline.run(input_data=PipelineResult({ - # We provide this settings for inference. - "labels_formatter": CustomLabelsFormatter(**labels_fmt), - "d3js_graph_output_dir": output_dir, - "d3js_collection_description": description, - "d3js_host": str(8000) if do_host else None, - "d3js_graph_a": load_graph(graph_A_file_path), - "d3js_graph_b": load_graph(graph_B_file_path), - "d3js_graph_operations": operation, - "d3js_collection_name": collection_name + BasePipelineLauncher.run( + pipeline=[D3jsGraphOperationsBackendPipelineItem()], + pipeline_ctx=PipelineResult({ + # We provide this settings for inference. + "labels_formatter": CustomLabelsFormatter(**labels_fmt), + "d3js_graph_output_dir": output_dir, + "d3js_collection_description": description, + "d3js_host": str(8000) if do_host else None, + "d3js_graph_a": load_graph(graph_A_file_path), + "d3js_graph_b": load_graph(graph_B_file_path), + "d3js_graph_operations": operation, + "d3js_collection_name": collection_name })) diff --git a/arelight/run/utils.py b/arelight/run/utils.py index abfb2f8..40ad604 100644 --- a/arelight/run/utils.py +++ b/arelight/run/utils.py @@ -5,9 +5,9 @@ from arekit.common.docs.base import Document from arekit.common.docs.sentence import BaseDocumentSentence -from arekit.contrib.source.synonyms.utils import iter_synonym_groups from arelight.pipelines.demo.labels.scalers import CustomLabelScaler +from arelight.synonyms import iter_synonym_groups from arelight.utils import auto_import, iter_csv_lines logger = logging.getLogger(__name__) diff --git a/arelight/synonyms.py b/arelight/synonyms.py new file mode 100644 index 0000000..2277f0a --- /dev/null +++ b/arelight/synonyms.py @@ -0,0 +1,14 @@ +from arekit.common.utils import progress_bar_defined + + +def iter_synonym_groups(input_file, sep=",", desc=""): + """ All the synonyms groups organized in lines, separated by `sep` + """ + lines = input_file.readlines() + + for line in progress_bar_defined(lines, total=len(lines), desc=desc, unit="opins"): + + if isinstance(line, bytes): + line = line.decode() + + yield line.split(sep) diff --git a/arelight/utils.py b/arelight/utils.py index cd6f409..76c46e8 100644 --- a/arelight/utils.py +++ b/arelight/utils.py @@ -3,6 +3,9 @@ import os import sys +import requests +from tqdm import tqdm + def auto_import(name): """ Import from the external python packages. @@ -53,3 +56,28 @@ def iter_csv_lines(csv_file, column_name, delimiter=","): for row in csv_reader: yield row[column_name] + + +def download(dest_file_path, source_url): + """ Refered to https://github.com/nicolay-r/ner-bilstm-crf-tensorflow/blob/master/ner/utils.py + Simple http file downloader + """ + print(('Downloading from {src} to {dest}'.format(src=source_url, dest=dest_file_path))) + + sys.stdout.flush() + datapath = os.path.dirname(dest_file_path) + + if not os.path.exists(datapath): + os.makedirs(datapath, mode=0o755) + + dest_file_path = os.path.abspath(dest_file_path) + + r = requests.get(source_url, stream=True) + total_length = int(r.headers.get('content-length', 0)) + + with open(dest_file_path, 'wb') as f: + pbar = tqdm(total=total_length, unit='B', unit_scale=True) + for chunk in r.iter_content(chunk_size=32 * 1024): + if chunk: # filter out keep-alive new chunks + pbar.update(len(chunk)) + f.write(chunk) \ No newline at end of file diff --git a/dependencies.txt b/dependencies.txt index 0122a9f..e97b64f 100644 --- a/dependencies.txt +++ b/dependencies.txt @@ -5,4 +5,5 @@ pytorch-crf==0.7.2 arekit @ git+https://github.com/nicolay-r/AREkit@0.25.0-rc open-nre @ git+https://github.com/thunlp/OpenNRE@53b6c9400775ab066dc4f462e81ce05ea2b128e7 nltk==3.8.1 -googletrans==3.1.0a0 \ No newline at end of file +googletrans==3.1.0a0 +requests \ No newline at end of file diff --git a/test/test_arekit_iter_data.py b/test/test_arekit_iter_data.py index c71579b..7cae583 100644 --- a/test/test_arekit_iter_data.py +++ b/test/test_arekit_iter_data.py @@ -1,11 +1,11 @@ from arekit.common.data.rows_fmt import create_base_column_fmt from arekit.common.data.rows_parser import ParsedSampleRow +from arekit.common.pipeline.base import BasePipelineLauncher import utils import unittest from os.path import join -from arekit.common.pipeline.base import BasePipeline from arekit.common.data.storages.base import BaseRowsStorage from arekit.common.experiment.data_type import DataType from arekit.contrib.utils.data.readers.csv_pd import PandasCsvReader @@ -43,9 +43,9 @@ def test_pipeline_item(self): prefix="arekit-iter-data", writer=None) - pipeline = BasePipeline(pipeline=[ + pipeline = [ D3jsGraphsBackendPipelineItem() - ]) + ] ppl_result = PipelineResult({ "labels_scaler": CustomLabelScaler(), @@ -55,5 +55,6 @@ def test_pipeline_item(self): }) ppl_result.update("samples_io", samples_io) ppl_result.update("predict_filepath", value=join(utils.TEST_OUT_DIR, "predict.tsv.gz")) - pipeline.run(input_data=ppl_result) + + BasePipelineLauncher.run(pipeline=pipeline, pipeline_ctx=ppl_result, src_key="labels_scaler") diff --git a/test/test_backend_d3js_graphs.py b/test/test_backend_d3js_graphs.py index 4810fff..20db38c 100644 --- a/test/test_backend_d3js_graphs.py +++ b/test/test_backend_d3js_graphs.py @@ -4,7 +4,7 @@ from os.path import join, exists import pandas as pd -from arekit.common.pipeline.base import BasePipeline +from arekit.common.pipeline.base import BasePipelineLauncher from arekit.contrib.utils.data.readers.jsonl import JsonlReader from arekit.contrib.utils.io_utils.samples import SamplesIO @@ -90,7 +90,6 @@ def test_pipeline(self): reader=JsonlReader(), writer=None) - pipeline = BasePipeline(ppl) ppl_result = PipelineResult(extra_params={ "samples_io": samples_io, "labels_scaler": CustomLabelScaler(), @@ -99,4 +98,4 @@ def test_pipeline(self): ppl_result.update("predict_filepath", value=join(utils.TEST_OUT_DIR, "predict.tsv.gz")) ppl_result.update("labels_formatter", value=CustomLabelsFormatter()) - pipeline.run(input_data=ppl_result) + BasePipelineLauncher.run(pipeline=ppl, pipeline_ctx=ppl_result, src_key="samples_io") diff --git a/test/test_csv_reader.py b/test/test_csv_reader.py index e30e3bb..496e507 100644 --- a/test/test_csv_reader.py +++ b/test/test_csv_reader.py @@ -10,7 +10,7 @@ class CsvReadingTest(unittest.TestCase): def test(self): - file_path = join(utils.TEST_DATA_DIR, 'arekit-iter-data-test-0.csv') + file_path = join(utils.TEST_DATA_DIR, 'arekit-iter-data-test.csv') csv_file = open(file_path, mode="r", encoding="utf-8-sig") for line in iter_csv_lines(csv_file=csv_file, delimiter=',', column_name="text_a"): print(line) diff --git a/test/test_dp_ner_pipeline_item.py b/test/test_dp_ner_pipeline_item.py index 65498b3..753768e 100644 --- a/test/test_dp_ner_pipeline_item.py +++ b/test/test_dp_ner_pipeline_item.py @@ -1,17 +1,18 @@ +from arekit.common.utils import split_by_whitespaces + import utils from os.path import join import unittest +from arekit.common.docs.base import Document +from arekit.common.docs.parser import DocumentParsers from arekit.common.docs.entities_grouping import EntitiesGroupingPipelineItem from arekit.common.entities.base import Entity from arekit.common.synonyms.grouping import SynonymsCollectionValuesGroupingProviders from arekit.common.text.enums import TermFormat -from arekit.common.text.parsed import BaseParsedText -from arekit.common.text.parser import BaseTextParser from arekit.contrib.utils.processing.lemmatization.mystem import MystemWrapper from arekit.contrib.utils.synonyms.stemmer_based import StemmerBasedSynonymCollection -from arekit.contrib.utils.pipelines.items.text.terms_splitter import TermsSplitterParser from arelight.pipelines.items.entities_ner_dp import DeepPavlovNERPipelineItem from arelight.pipelines.items.entity import IndexedEntity @@ -26,13 +27,13 @@ class BertOntonotesPipelineItemTest(unittest.TestCase): def test_pipeline_item_rus(self): # Declaring text processing pipeline. - text_parser = BaseTextParser(pipeline=[ - TermsSplitterParser(), + pipeline_items = [ DeepPavlovNERPipelineItem( + src_func=lambda text: split_by_whitespaces(text), id_assigner=IdAssigner(), obj_filter=lambda s_obj: s_obj.ObjectType in ["ORG", "PERSON", "LOC", "GPE"], ner_model_name="ner_ontonotes_bert_mult"), - ]) + ] # Read file contents. text_filepath = join(utils.TEST_DATA_DIR, "rus_input_text_example.txt") @@ -40,8 +41,8 @@ def test_pipeline_item_rus(self): text = f.read().rstrip() # Output parsed text. - parsed_text = text_parser.run(text) - for t in parsed_text.iter_terms(TermFormat.Raw): + parsed_doc = DocumentParsers.parse(doc=Document(doc_id=0, sentences=[text]), pipeline_items=pipeline_items) + for t in parsed_doc.get_sentence(0).iter_terms(TermFormat.Raw): print("<{}> ({})".format(t.Value, t.Type) if isinstance(t, Entity) else t) def test_pipeline(self): @@ -55,19 +56,19 @@ def test_pipeline(self): stemmer=MystemWrapper(), is_read_only=False) # Declare text parser. - text_parser = BaseTextParser(pipeline=[ - TermsSplitterParser(), - DeepPavlovNERPipelineItem(id_assigner=IdAssigner(), ner_model_name="ner_ontonotes_bert_mult"), + pipeline_items = [ + DeepPavlovNERPipelineItem( + src_func=lambda t: split_by_whitespaces(t), + id_assigner=IdAssigner(), + ner_model_name="ner_ontonotes_bert_mult"), EntitiesGroupingPipelineItem( lambda value: SynonymsCollectionValuesGroupingProviders.provide_existed_or_register_missed_value( synonyms=synonyms, value=value)) - ]) + ] # Launch pipeline. - parsed_text = text_parser.run(text) - assert(isinstance(parsed_text, BaseParsedText)) - - for term in parsed_text.iter_terms(TermFormat.Raw): + parsed_doc = DocumentParsers.parse(doc=Document(doc_id=0, sentences=[text]), pipeline_items=pipeline_items) + for term in parsed_doc.get_sentence(0).iter_terms(TermFormat.Raw): if isinstance(term, IndexedEntity): print(term.ID, term.GroupIndex, term.Value) else: @@ -76,13 +77,13 @@ def test_pipeline(self): def test_pipeline_item_eng_book(self): # Declaring text processing pipeline. - text_parser = BaseTextParser(pipeline=[ - TermsSplitterParser(), + pipeline_items = [ DeepPavlovNERPipelineItem( + src_func=lambda t: split_by_whitespaces(t), id_assigner=IdAssigner(), obj_filter=lambda s_obj: s_obj.ObjectType in ["ORG", "PERSON", "LOC", "GPE"], ner_model_name="ner_ontonotes_bert"), - ]) + ] # Read file contents. text_filepath = join(utils.TEST_DATA_DIR, "book-war-and-peace-test.txt") @@ -90,8 +91,8 @@ def test_pipeline_item_eng_book(self): text = f.read().rstrip() # Output parsed text. - parsed_text = text_parser.run(text) - for t in parsed_text.iter_terms(TermFormat.Raw): + parsed_doc = DocumentParsers.parse(doc=Document(doc_id=0, sentences=[text]), pipeline_items=pipeline_items) + for t in parsed_doc.get_sentence(0).iter_terms(TermFormat.Raw): print("<{}> ({})".format(t.Value, t.Type) if isinstance(t, Entity) else t) diff --git a/test/test_dp_ner_sample.py b/test/test_dp_ner_sample.py index fa9270e..117d6cb 100644 --- a/test/test_dp_ner_sample.py +++ b/test/test_dp_ner_sample.py @@ -1,9 +1,10 @@ import unittest -from arekit.contrib.source.synonyms.utils import iter_synonym_groups -from arekit.contrib.utils.pipelines.items.text.terms_splitter import TermsSplitterParser +from arekit.common.pipeline.items.base import BasePipelineItem +from arekit.common.utils import split_by_whitespaces from arelight.pipelines.items.entities_ner_dp import DeepPavlovNERPipelineItem +from arelight.synonyms import iter_synonym_groups from arelight.utils import IdAssigner from utils_ner import test_ner @@ -32,8 +33,10 @@ def test(self): ] ner_ppl_items = [ - TermsSplitterParser(), - DeepPavlovNERPipelineItem(id_assigner=IdAssigner(), ner_model_name="ner_ontonotes_bert_mult") + BasePipelineItem(src_func=lambda s: s.Text), + DeepPavlovNERPipelineItem(src_func=lambda text: split_by_whitespaces(text), + id_assigner=IdAssigner(), + ner_model_name="ner_ontonotes_bert_mult") ] test_ner(texts=texts, ner_ppl_items=ner_ppl_items, prefix="dp_ner") diff --git a/test/test_pipeline_infer.py b/test/test_pipeline_infer.py index 77d15b2..cf5daef 100644 --- a/test/test_pipeline_infer.py +++ b/test/test_pipeline_infer.py @@ -1,3 +1,6 @@ +from arekit.common.pipeline.items.base import BasePipelineItem +from arekit.common.utils import split_by_whitespaces + import utils from os.path import join, realpath, dirname @@ -16,11 +19,8 @@ from arekit.common.docs.entities_grouping import EntitiesGroupingPipelineItem from arekit.common.docs.sentence import BaseDocumentSentence from arekit.common.experiment.data_type import DataType -from arekit.common.pipeline.base import BasePipeline +from arekit.common.pipeline.base import BasePipelineLauncher from arekit.common.synonyms.grouping import SynonymsCollectionValuesGroupingProviders -from arekit.common.text.parser import BaseTextParser -from arekit.contrib.source.synonyms.utils import iter_synonym_groups -from arekit.contrib.utils.pipelines.items.text.terms_splitter import TermsSplitterParser from arekit.contrib.utils.synonyms.simple import SimpleSynonymCollection from arelight.pipelines.data.annot_pairs_nolabel import create_neutral_annotation_pipeline @@ -29,6 +29,7 @@ from arelight.pipelines.items.entities_ner_dp import DeepPavlovNERPipelineItem from arelight.samplers.bert import create_bert_sample_provider from arelight.samplers.types import BertSampleProviderTypes +from arelight.synonyms import iter_synonym_groups from arelight.utils import IdAssigner from ru_sent_tokenize import ru_sent_tokenize @@ -87,22 +88,22 @@ def launch(self, pipeline): # We consider a texts[0] from the constant list. actual_content = self.texts - pipeline = BasePipeline(pipeline) synonyms = SimpleSynonymCollection(iter_group_values_lists=[], is_read_only=False) id_assigner = IdAssigner() # Setup text parsing. - text_parser = BaseTextParser(pipeline=[ - TermsSplitterParser(), + text_parser = [ + BasePipelineItem(src_func=lambda s: s.Text), DeepPavlovNERPipelineItem(ner_model_name="ner_ontonotes_bert_mult", + src_func=lambda text: split_by_whitespaces(text), id_assigner=id_assigner, obj_filter=lambda s_obj: s_obj.ObjectType in ["ORG", "PERSON", "LOC", "GPE"], chunk_limit=128), EntitiesGroupingPipelineItem( lambda value: SynonymsCollectionValuesGroupingProviders.provide_existed_or_register_missed_value( synonyms=synonyms, value=value)) - ]) + ] data_pipeline = create_neutral_annotation_pipeline( synonyms=synonyms, @@ -110,14 +111,16 @@ def launch(self, pipeline): dist_in_sentences=0, doc_provider=utils.InMemoryDocProvider(docs=self.input_to_docs(actual_content)), terms_per_context=50, - text_parser=text_parser) + text_pipeline=text_parser) - pipeline.run(input_data=PipelineContext(d={ - "labels_scaler": CustomLabelScaler(), - "predict_filepath": join(utils.TEST_OUT_DIR, "predict.tsv.gz"), - "data_type_pipelines": {DataType.Test: data_pipeline}, - "doc_ids": list(range(len(actual_content))), - })) + BasePipelineLauncher.run(pipeline=pipeline, + pipeline_ctx=PipelineContext(d={ + "labels_scaler": CustomLabelScaler(), + "predict_filepath": join(utils.TEST_OUT_DIR, "predict.tsv.gz"), + "data_type_pipelines": {DataType.Test: data_pipeline}, + "doc_ids": list(range(len(actual_content))), + }), + src_key="labels_scaler") def test_opennre(self): @@ -130,8 +133,14 @@ def test_opennre(self): "pretrained_bert": "DeepPavlov/rubert-base-cased", "checkpoint_path": "ra4-rsr1_DeepPavlov-rubert-base-cased_cls.pth.tar", "device_type": "cpu", - "max_seq_length": 128 - } - }) + "max_seq_length": 128, + "task_kwargs": { + "no_label": "0", + "default_id_column": "id", + "index_columns": ["s_ind", "t_ind"], + "text_columns": ["text_a", "text_b"] + }, + } + }) self.launch(pipeline) diff --git a/test/test_pipeline_sample.py b/test/test_pipeline_sample.py index 1fc54b6..4d645f3 100644 --- a/test/test_pipeline_sample.py +++ b/test/test_pipeline_sample.py @@ -1,3 +1,6 @@ +from arekit.common.pipeline.items.base import BasePipelineItem +from arekit.common.utils import split_by_whitespaces + import utils import unittest import ru_sent_tokenize @@ -11,25 +14,23 @@ from arekit.common.experiment.data_type import DataType from arekit.common.labels.base import NoLabel from arekit.common.labels.scaler.single import SingleLabelScaler -from arekit.common.pipeline.base import BasePipeline +from arekit.common.pipeline.base import BasePipelineLauncher from arekit.common.synonyms.grouping import SynonymsCollectionValuesGroupingProviders -from arekit.common.text.parser import BaseTextParser from arekit.common.data import const from arekit.common.pipeline.context import PipelineContext from arekit.contrib.utils.data.writers.sqlite_native import SQliteWriter from arekit.contrib.utils.io_utils.samples import SamplesIO -from arekit.contrib.utils.pipelines.items.text.terms_splitter import TermsSplitterParser from arekit.contrib.utils.processing.lemmatization.mystem import MystemWrapper from arekit.contrib.utils.synonyms.stemmer_based import StemmerBasedSynonymCollection from arekit.contrib.utils.entities.formatters.str_simple_sharp_prefixed_fmt import SharpPrefixedEntitiesSimpleFormatter from arekit.contrib.utils.data.storages.row_cache import RowCacheStorage -from arekit.contrib.source.synonyms.utils import iter_synonym_groups from arelight.pipelines.data.annot_pairs_nolabel import create_neutral_annotation_pipeline from arelight.pipelines.items.entities_default import TextEntitiesParser from arelight.pipelines.items.serializer_arekit import AREkitSerializerPipelineItem from arelight.samplers.bert import create_bert_sample_provider from arelight.samplers.types import BertSampleProviderTypes +from arelight.synonyms import iter_synonym_groups from arelight.utils import IdAssigner @@ -85,29 +86,29 @@ def test(self): is_read_only=False) # Declare text parser. - text_parser = BaseTextParser(pipeline=[ - TermsSplitterParser(), - TextEntitiesParser(id_assigner=IdAssigner()), - EntitiesGroupingPipelineItem(lambda value: - SynonymsCollectionValuesGroupingProviders.provide_existed_or_register_missed_value( + text_parser_pipeline = [ + BasePipelineItem(src_func=lambda s: s.Text), + TextEntitiesParser(src_func=lambda s: split_by_whitespaces(s), id_assigner=IdAssigner()), + EntitiesGroupingPipelineItem( + lambda value: SynonymsCollectionValuesGroupingProviders.provide_existed_or_register_missed_value( synonyms=synonyms, value=value)) - ]) + ] # Composing labels formatter and experiment preparation. doc_provider = utils.InMemoryDocProvider(docs=BertTestSerialization.input_to_docs(texts)) - pipeline = BasePipeline([AREkitSerializerPipelineItem( - rows_provider=create_bert_sample_provider( - label_scaler=SingleLabelScaler(NoLabel()), - provider_type=BertSampleProviderTypes.NLI_M, - entity_formatter=SharpPrefixedEntitiesSimpleFormatter(), - crop_window=50, - ), - save_labels_func=lambda _: False, - samples_io=SamplesIO(target_dir=utils.TEST_OUT_DIR, writer=SQliteWriter()), - storage=RowCacheStorage(force_collect_columns=[ - const.ENTITIES, const.ENTITY_VALUES, const.ENTITY_TYPES, const.SENT_IND + pipeline = [ + AREkitSerializerPipelineItem( + rows_provider=create_bert_sample_provider( + label_scaler=SingleLabelScaler(NoLabel()), + provider_type=BertSampleProviderTypes.NLI_M, + entity_formatter=SharpPrefixedEntitiesSimpleFormatter(), + crop_window=50), + save_labels_func=lambda _: False, + samples_io=SamplesIO(target_dir=utils.TEST_OUT_DIR, writer=SQliteWriter()), + storage=RowCacheStorage(force_collect_columns=[ + const.ENTITIES, const.ENTITY_VALUES, const.ENTITY_TYPES, const.SENT_IND ])) - ]) + ] synonyms = StemmerBasedSynonymCollection(iter_group_values_lists=[], stemmer=MystemWrapper(), is_read_only=False) @@ -117,13 +118,15 @@ def test(self): dist_in_terms_bound=50, dist_in_sentences=0, doc_provider=doc_provider, - text_parser=text_parser, + text_pipeline=text_parser_pipeline, terms_per_context=50) - pipeline.run(input_data=PipelineContext(d={ - "doc_ids": list(range(len(texts))), - "data_type_pipelines": {DataType.Test: test_pipeline} - })) + BasePipelineLauncher.run(pipeline=pipeline, + pipeline_ctx=PipelineContext(d={ + "doc_ids": list(range(len(texts))), + "data_type_pipelines": {DataType.Test: test_pipeline} + }), + src_key="doc_ids") if __name__ == '__main__': diff --git a/test/test_transformers_ner_ppl.py b/test/test_transformers_ner_ppl.py index 0347f73..79bb1a2 100644 --- a/test/test_transformers_ner_ppl.py +++ b/test/test_transformers_ner_ppl.py @@ -1,9 +1,11 @@ import unittest import time - -from arekit.contrib.utils.pipelines.items.text.terms_splitter import TermsSplitterParser from tqdm import tqdm +from arekit.common.pipeline.items.base import BasePipelineItem +from arekit.common.text.partitioning.str import StringPartitioning +from arekit.common.utils import split_by_whitespaces + from arelight.pipelines.items.entities_ner_dp import DeepPavlovNERPipelineItem from arelight.pipelines.items.entities_ner_transformers import TransformersNERPipelineItem from arelight.pipelines.items.terms_splitter import CustomTermsSplitterPipelineItem @@ -48,7 +50,8 @@ def test_transformers(self): # Declare input texts. ppl_items = [ - TransformersNERPipelineItem(id_assigner=IdAssigner(), ner_model_name="dslim/bert-base-NER", device="cpu"), + TransformersNERPipelineItem(id_assigner=IdAssigner(), ner_model_name="dslim/bert-base-NER", device="cpu", + src_func=lambda s: s.Text, partitioning=StringPartitioning()), CustomTermsSplitterPipelineItem(), ] @@ -57,8 +60,10 @@ def test_transformers(self): def test_benchmark(self): ppl_items = [ - TermsSplitterParser(), - DeepPavlovNERPipelineItem(id_assigner=IdAssigner(), ner_model_name="ner_ontonotes_bert") + BasePipelineItem(src_func=lambda s: s.Text), + DeepPavlovNERPipelineItem(id_assigner=IdAssigner(), + src_func=lambda text: split_by_whitespaces(text), + ner_model_name="ner_ontonotes_bert") ] test_ner(texts=self.get_texts(), diff --git a/test/utils_ner.py b/test/utils_ner.py index 4b67d92..97903b7 100644 --- a/test/utils_ner.py +++ b/test/utils_ner.py @@ -2,10 +2,9 @@ from arekit.common.experiment.data_type import DataType from arekit.common.labels.base import NoLabel from arekit.common.labels.scaler.single import SingleLabelScaler -from arekit.common.pipeline.base import BasePipeline +from arekit.common.pipeline.base import BasePipelineLauncher from arekit.common.pipeline.context import PipelineContext from arekit.common.synonyms.grouping import SynonymsCollectionValuesGroupingProviders -from arekit.common.text.parser import BaseTextParser from arekit.contrib.utils.data.storages.row_cache import RowCacheStorage from arekit.contrib.utils.data.writers.csv_native import NativeCsvWriter from arekit.contrib.utils.entities.formatters.str_simple_sharp_prefixed_fmt import SharpPrefixedEntitiesSimpleFormatter @@ -26,11 +25,11 @@ def test_ner(texts, ner_ppl_items, prefix): synonyms = SimpleSynonymCollection(iter_group_values_lists=[], is_read_only=False) # Declare text parser. - text_parser = BaseTextParser(pipeline=ner_ppl_items + [ + text_pipeline_items = ner_ppl_items + [ EntitiesGroupingPipelineItem( lambda value: SynonymsCollectionValuesGroupingProviders.provide_existed_or_register_missed_value( synonyms=synonyms, value=value)) - ]) + ] # Single label scaler. single_label_scaler = SingleLabelScaler(NoLabel()) @@ -44,7 +43,7 @@ def test_ner(texts, ner_ppl_items, prefix): entity_formatter=SharpPrefixedEntitiesSimpleFormatter(), crop_window=50) - pipeline = BasePipeline([ + pipeline_items = [ BaseSerializerPipelineItem( rows_provider=rows_provider, storage=RowCacheStorage(), @@ -52,17 +51,19 @@ def test_ner(texts, ner_ppl_items, prefix): writer=NativeCsvWriter(delimiter=','), prefix=prefix), save_labels_func=lambda data_type: data_type != DataType.Test) - ]) + ] # Initialize data processing pipeline. test_pipeline = create_neutral_annotation_pipeline(synonyms=synonyms, dist_in_terms_bound=50, dist_in_sentences=0, doc_provider=doc_provider, - text_parser=text_parser, + text_pipeline=text_pipeline_items, terms_per_context=50) - pipeline.run(input_data=PipelineContext({ + ctx = PipelineContext({ "doc_ids": list(range(len(texts))), "data_type_pipelines": {DataType.Test: test_pipeline} - })) + }) + + BasePipelineLauncher.run(pipeline=pipeline_items, pipeline_ctx=ctx, src_key="doc_ids")