Skip to content

Commit

Permalink
#135 removed. #131 simplified NER embedding into AREkit pipelines.
Browse files Browse the repository at this point in the history
  • Loading branch information
nicolay-r committed Jan 6, 2024
1 parent 652723c commit b89623d
Show file tree
Hide file tree
Showing 6 changed files with 36 additions and 33 deletions.
14 changes: 8 additions & 6 deletions arelight/pipelines/items/entities_ner_dp.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from arekit.common.bound import Bound
from arekit.common.docs.objects_parser import SentenceObjectsParserPipelineItem
from arekit.common.pipeline.items.base import BasePipelineItem
from arekit.common.text.partitioning.terms import TermsPartitioning

from arelight.ner.deep_pavlov import DeepPavlovNER
Expand All @@ -8,7 +8,7 @@
from arelight.utils import IdAssigner


class DeepPavlovNERPipelineItem(SentenceObjectsParserPipelineItem):
class DeepPavlovNERPipelineItem(BasePipelineItem):

def __init__(self, id_assigner, ner_model_name, obj_filter=None,
chunk_limit=128, display_value_func=None, **kwargs):
Expand All @@ -19,17 +19,15 @@ def __init__(self, id_assigner, ner_model_name, obj_filter=None,
assert(isinstance(chunk_limit, int) and chunk_limit > 0)
assert(isinstance(id_assigner, IdAssigner))
assert(callable(display_value_func) or display_value_func is None)
super(DeepPavlovNERPipelineItem, self).__init__(partitioning=TermsPartitioning(), **kwargs)
super(DeepPavlovNERPipelineItem, self).__init__(**kwargs)

# Initialize bert-based model instance.
self.__dp_ner = DeepPavlovNER(ner_model_name)
self.__obj_filter = obj_filter
self.__chunk_limit = chunk_limit
self.__id_assigner = id_assigner
self.__disp_value_func = display_value_func

def _get_parts_provider_func(self, input_data):
return self.__iter_subs_values_with_bounds(input_data)
self.__partitioning = TermsPartitioning()

def __iter_subs_values_with_bounds(self, terms_list):
assert(isinstance(terms_list, list))
Expand Down Expand Up @@ -66,3 +64,7 @@ def __iter_parsed_entities(self, processed_sequences, chunk_terms_list, chunk_of
value=value, e_type=s_obj.ObjectType, entity_id=self.__id_assigner.get_id(),
display_value=self.__disp_value_func(value) if self.__disp_value_func is not None else None)
yield entity, Bound(pos=chunk_offset + s_obj.Position, length=s_obj.Length)

def apply_core(self, input_data, pipeline_ctx):
parts_it = self.__iter_subs_values_with_bounds(input_data)
return self.__partitioning.provide(text=input_data, parts_it=parts_it)
27 changes: 24 additions & 3 deletions arelight/pipelines/items/entities_ner_transformers.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from arekit.common.bound import Bound
from arekit.common.docs.objects_parser import SentenceObjectsParserPipelineItem
from arekit.common.pipeline.items.base import BasePipelineItem
from arekit.common.text.partitioning.str import StringPartitioning
from arekit.common.utils import split_by_whitespaces

from arelight.pipelines.items.entity import IndexedEntity
from arelight.utils import IdAssigner, auto_import


class TransformersNERPipelineItem(SentenceObjectsParserPipelineItem):
class TransformersNERPipelineItem(BasePipelineItem):

def __init__(self, id_assigner, ner_model_name, device, obj_filter=None, display_value_func=None, **kwargs):
""" chunk_limit: int
Expand All @@ -28,8 +30,11 @@ def __init__(self, id_assigner, ner_model_name, device, obj_filter=None, display
self.__obj_filter = obj_filter
self.__id_assigner = id_assigner
self.__disp_value_func = display_value_func
self.__partitioning = StringPartitioning()

def _get_parts_provider_func(self, input_data):
# region Private methods

def __get_parts_provider_func(self, input_data):
assert(isinstance(input_data, str))
parts = self.annotate_ner(model=self.__model, tokenizer=self.__tokenizer, text=input_data,
device=self.__device)
Expand All @@ -52,3 +57,19 @@ def __iter_parsed_entities(self, parts):
display_value=self.__disp_value_func(value) if self.__disp_value_func is not None else None)

yield entity, Bound(pos=p["start"], length=p["end"] - p["start"])

@staticmethod
def __iter_fixed_terms(terms):
for e in terms:
if isinstance(e, str):
for term in split_by_whitespaces(e):
yield term
else:
yield e

# endregion

def apply_core(self, input_data, pipeline_ctx):
parts_it = self.__get_parts_provider_func(input_data)
handled = self.__partitioning.provide(text=input_data, parts_it=parts_it)
return list(self.__iter_fixed_terms(handled))
16 changes: 0 additions & 16 deletions arelight/pipelines/items/terms_splitter.py

This file was deleted.

2 changes: 0 additions & 2 deletions arelight/run/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
from arelight.pipelines.items.entities_default import TextEntitiesParser
from arelight.pipelines.items.entities_ner_dp import DeepPavlovNERPipelineItem
from arelight.pipelines.items.entities_ner_transformers import TransformersNERPipelineItem
from arelight.pipelines.items.terms_splitter import CustomTermsSplitterPipelineItem
from arelight.run.utils import merge_dictionaries, iter_group_values, create_sentence_parser, \
create_translate_model, is_port_number, iter_content, OPENNRE_CHECKPOINTS
from arelight.samplers.bert import create_bert_sample_provider
Expand Down Expand Up @@ -242,7 +241,6 @@ def __entity_display_value(entity_value):
text_parser_pipeline = [
BasePipelineItem(src_func=lambda s: s.Text),
entity_parsers[ner_framework](),
CustomTermsSplitterPipelineItem() if ner_framework == "transformers" else None,
text_translator_setup["ml-based" if args.translate_text is not None else None](),
EntitiesGroupingPipelineItem(
lambda value: SynonymsCollectionValuesGroupingProviders.provide_existed_or_register_missed_value(
Expand Down
8 changes: 3 additions & 5 deletions test/test_transformers_ner_ppl.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,10 @@
from tqdm import tqdm

from arekit.common.pipeline.items.base import BasePipelineItem
from arekit.common.text.partitioning.str import StringPartitioning
from arekit.common.utils import split_by_whitespaces

from arelight.pipelines.items.entities_ner_dp import DeepPavlovNERPipelineItem
from arelight.pipelines.items.entities_ner_transformers import TransformersNERPipelineItem
from arelight.pipelines.items.terms_splitter import CustomTermsSplitterPipelineItem
from arelight.third_party.transformers import annotate_ner_ppl, init_token_classification_model, annotate_ner
from arelight.utils import IdAssigner
from utils_ner import test_ner
Expand Down Expand Up @@ -50,9 +48,9 @@ def test_transformers(self):
# Declare input texts.

ppl_items = [
TransformersNERPipelineItem(id_assigner=IdAssigner(), ner_model_name="dslim/bert-base-NER", device="cpu",
src_func=lambda s: s.Text, partitioning=StringPartitioning()),
CustomTermsSplitterPipelineItem(),
TransformersNERPipelineItem(id_assigner=IdAssigner(),
ner_model_name="dslim/bert-base-NER", device="cpu",
src_func=lambda s: s.Text)
]

test_ner(texts=self.get_texts(), ner_ppl_items=ppl_items, prefix="transformers-ner")
Expand Down
2 changes: 1 addition & 1 deletion test/utils_ner.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,4 +66,4 @@ def test_ner(texts, ner_ppl_items, prefix):
"data_type_pipelines": {DataType.Test: test_pipeline}
})

BasePipelineLauncher.run(pipeline=pipeline_items, pipeline_ctx=ctx, src_key="doc_ids")
BasePipelineLauncher.run(pipeline=pipeline_items, pipeline_ctx=ctx, src_key="doc_ids")

0 comments on commit b89623d

Please sign in to comment.