Skip to content

Commit

Permalink
#134. Task-related parameters were moved outside from the 3-rd party …
Browse files Browse the repository at this point in the history
…OpenNRE-related classes.
  • Loading branch information
nicolay-r committed Dec 24, 2023
1 parent ade6cbe commit 10d98b3
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 34 deletions.
7 changes: 5 additions & 2 deletions arelight/pipelines/items/inference_bert_opennre.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class BertOpenNREInferencePipelineItem(BasePipelineItem):

def __init__(self, pretrained_bert=None, checkpoint_path=None, device_type='cpu',
max_seq_length=128, pooler='cls', batch_size=10, tokenizers_parallelism=True,
predefined_ckpts=None):
table_name="contents", task_kwargs=None, predefined_ckpts=None):
assert(isinstance(tokenizers_parallelism, bool))

self.__model = None
Expand All @@ -35,6 +35,8 @@ def __init__(self, pretrained_bert=None, checkpoint_path=None, device_type='cpu'
self.__pooler = pooler
self.__batch_size = batch_size
self.__predefined_ckpts = {} if predefined_ckpts is None else predefined_ckpts
self.__task_kwargs = task_kwargs
self.__table_name = table_name

# Huggingface/Tokenizers compatibility.
os.environ['TOKENIZERS_PARALLELISM'] = str(tokenizers_parallelism).lower()
Expand Down Expand Up @@ -143,7 +145,8 @@ def __iter_predict_result(self, samples_filepath, batch_size):
rel2id=self.__model.rel2id,
tokenizer=self.__model.sentence_encoder.tokenize,
batch_size=batch_size,
table_name="contents",
table_name=self.__table_name,
task_kwargs=self.__task_kwargs,
shuffle=False)

with sentence_eval.dataset as dataset:
Expand Down
21 changes: 15 additions & 6 deletions arelight/run/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@
from os.path import join, dirname, basename

from arekit.common.data import const
from arekit.common.data.const import S_IND, T_IND, ID
from arekit.common.docs.entities_grouping import EntitiesGroupingPipelineItem
from arekit.common.experiment.data_type import DataType
from arekit.common.labels.base import NoLabel
from arekit.common.labels.scaler.single import SingleLabelScaler
from arekit.common.pipeline.base import BasePipeline
from arekit.common.synonyms.grouping import SynonymsCollectionValuesGroupingProviders
from arekit.common.text.parser import BaseTextParser
from arekit.contrib.bert.input.providers.text_pair import PairTextProvider
from arekit.contrib.utils.data.readers.sqlite import SQliteReader
from arekit.contrib.utils.data.storages.row_cache import RowCacheStorage
from arekit.contrib.utils.data.writers.sqlite_native import SQliteWriter
Expand Down Expand Up @@ -86,7 +88,9 @@

assert(is_port_number(number=args.d3js_host, is_optional=True))

labels_scaler = SingleLabelScaler(NoLabel())
# Classification task label scaler setup.
labels_scl = {a: int(v) for a, v in map(lambda itm: itm.split(":"), args.labels_fmt.split(','))}
labels_scaler = CustomLabelScaler(**labels_scl)

def setup_collection_name(value):
# Considering Predefined name if the latter has been declared.
Expand Down Expand Up @@ -114,7 +118,7 @@ def setup_collection_name(value):
"samples_io": SamplesIO(target_dir=output_dir,
prefix=collection_name,
reader=SQliteReader(table_name="contents"),
writer=SQliteWriter()),
writer=SQliteWriter(table_name="contents")),
"storage": RowCacheStorage(
force_collect_columns=[const.ENTITIES, const.ENTITY_VALUES, const.ENTITY_TYPES, const.SENT_IND]),
"save_labels_func": lambda data_type: data_type != DataType.Test
Expand Down Expand Up @@ -180,7 +184,14 @@ def __entity_display_value(entity_value):
"batch_size": args.batch_size,
"pooler": "cls",
"predefined_ckpts": OPENNRE_CHECKPOINTS,
},
"table_name": "contents",
"task_kwargs": {
"no_label": str(labels_scaler.label_to_int(NoLabel())),
"default_id_column": ID,
"index_columns": [S_IND, T_IND],
"text_columns": [PairTextProvider.TEXT_A, PairTextProvider.TEXT_B]
},
},
}

backend_setups = {
Expand Down Expand Up @@ -269,10 +280,8 @@ def __entity_display_value(entity_value):
"d3js_host": args.d3js_host,
})

labels_scl = {a: int(v) for a, v in map(lambda itm: itm.split(":"), args.labels_fmt.split(','))}

settings.append({
"labels_scaler": CustomLabelScaler(**labels_scl),
"labels_scaler": labels_scaler,
# We provide this settings for inference.
"predict_filepath": join(output_dir, "{}-predict.tsv.gz".format(collection_name)),
"samples_io": sampling_engines_setup["arekit"]["samples_io"],
Expand Down
40 changes: 17 additions & 23 deletions arelight/third_party/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,61 +11,56 @@ class SQLiteSentenceREDataset(data.Dataset):
This is a original OpenNRE implementation, adapted for SQLite.
"""

def __init__(self, path, table_name, rel2id, tokenizer, kwargs, text_columns=None):
def __init__(self, path, table_name, rel2id, tokenizer, kwargs, task_kwargs):
"""
Args:
path: path of the input file sqlite file
rel2id: dictionary of relation->id mapping
tokenizer: function of tokenizing
"""
assert(isinstance(task_kwargs, dict))
assert("no_label" in task_kwargs)
assert("default_id_column" in task_kwargs)
assert("index_columns" in task_kwargs)
assert("text_columns" in task_kwargs)
super().__init__()
self.path = path
self.tokenizer = tokenizer
self.rel2id = rel2id
self.kwargs = kwargs
self.task_kwargs = task_kwargs
self.table_name = table_name
self.sqlite_service = SQLite3Service()
# Task-related parameters.
# OpenNRE-related task provider.
self.no_label = "0"
self.default_id_column = "id"
self.index_columns = ["s_ind", "t_ind"]
self.text_columns = text_columns

def iter_ids(self, id_column=None):
col_name = self.default_id_column if id_column is None else id_column
col_name = self.task_kwargs["default_id_column"] if id_column is None else id_column
for row in self.sqlite_service.iter_rows(select_columns=col_name, table_name=self.table_name):
yield row[0]

def __len__(self):
return self.sqlite_service.table_rows_count(self.table_name)

def __getitem__(self, index):

# Automatically assign column names.
# This is expected to be refactored as a task-specific approach of text columns assignation.
if self.text_columns is None:
self.text_columns = self.sqlite_service.get_column_names(
table_name=self.table_name,
filter_name=lambda col_name: col_name.startswith("text"))
found_text_columns = self.sqlite_service.get_column_names(
table_name=self.table_name, filter_name=lambda col_name: col_name in self.task_kwargs["text_columns"])

iter_rows = self.sqlite_service.iter_rows(
select_columns=",".join(self.index_columns + self.text_columns),
select_columns=",".join(self.task_kwargs["index_columns"] + found_text_columns),
value=index,
column_value=self.default_id_column,
column_value=self.task_kwargs["default_id_column"],
table_name=self.table_name)

fetched_row = next(iter_rows)

opennre_item = {
"text": " ".join(fetched_row[-len(self.text_columns):]),
"text": " ".join(fetched_row[-len(found_text_columns):]),
"h": {"pos": [fetched_row[0], fetched_row[0] + 1]},
"t": {"pos": [fetched_row[1], fetched_row[1] + 1]},
}

seq = list(self.tokenizer(opennre_item, **self.kwargs))

return [self.rel2id[self.no_label]] + seq # label, seq1, seq2, ...
return [self.rel2id[self.task_kwargs["no_label"]]] + seq # label, seq1, seq2, ...

def collate_fn(data):
data = list(zip(*data))
Expand All @@ -90,11 +85,10 @@ def __exit__(self, exc_type, exc_val, exc_tb):
self.sqlite_service.disconnect()


def sentence_re_loader(path, table_name, rel2id, tokenizer, batch_size,
shuffle, text_columns=None, num_workers=8, collate_fn=SQLiteSentenceREDataset.collate_fn,
**kwargs):
def sentence_re_loader(path, table_name, rel2id, tokenizer, batch_size, shuffle,
task_kwargs, num_workers=8, collate_fn=SQLiteSentenceREDataset.collate_fn, **kwargs):
dataset = SQLiteSentenceREDataset(path=path, table_name=table_name, rel2id=rel2id,
tokenizer=tokenizer, kwargs=kwargs, text_columns=text_columns)
tokenizer=tokenizer, kwargs=kwargs, task_kwargs=task_kwargs)
data_loader = data.DataLoader(dataset=dataset,
batch_size=batch_size,
shuffle=shuffle,
Expand Down
6 changes: 6 additions & 0 deletions test/test_opennre_bert_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,12 @@ def infer_bert(pretrain_path, labels_scaler, output_file_gzip, predefined, ckpt_
rel2id=model.rel2id,
tokenizer=model.sentence_encoder.tokenize,
batch_size=batch_size,
task_kwargs={
"no_label": "0",
"default_id_column": "id",
"index_columns": ["s_ind", "t_ind"],
"text_columns": ["text_a", "text_b"]
},
shuffle=False)

# Open database.
Expand Down
4 changes: 1 addition & 3 deletions test/test_pipeline_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,6 @@ def input_to_docs(texts):
return docs

def create_sampling_params(self):
writer = SQliteWriter()

return {
"rows_provider": create_bert_sample_provider(
label_scaler=SingleLabelScaler(NoLabel()),
Expand All @@ -78,7 +76,7 @@ def create_sampling_params(self):
"save_labels_func": lambda _: False,
"samples_io": SamplesIO(target_dir=utils.TEST_OUT_DIR,
reader=JsonlReader(),
writer=writer),
writer=SQliteWriter()),
"storage": RowCacheStorage(force_collect_columns=[
const.ENTITIES, const.ENTITY_VALUES, const.ENTITY_TYPES, const.SENT_IND
])
Expand Down

0 comments on commit 10d98b3

Please sign in to comment.