diff --git a/arelight/pipelines/items/inference_bert_opennre.py b/arelight/pipelines/items/inference_bert_opennre.py index a43c4f5..c22a6cf 100644 --- a/arelight/pipelines/items/inference_bert_opennre.py +++ b/arelight/pipelines/items/inference_bert_opennre.py @@ -24,7 +24,7 @@ class BertOpenNREInferencePipelineItem(BasePipelineItem): def __init__(self, pretrained_bert=None, checkpoint_path=None, device_type='cpu', max_seq_length=128, pooler='cls', batch_size=10, tokenizers_parallelism=True, - predefined_ckpts=None): + table_name="contents", task_kwargs=None, predefined_ckpts=None): assert(isinstance(tokenizers_parallelism, bool)) self.__model = None @@ -35,6 +35,8 @@ def __init__(self, pretrained_bert=None, checkpoint_path=None, device_type='cpu' self.__pooler = pooler self.__batch_size = batch_size self.__predefined_ckpts = {} if predefined_ckpts is None else predefined_ckpts + self.__task_kwargs = task_kwargs + self.__table_name = table_name # Huggingface/Tokenizers compatibility. os.environ['TOKENIZERS_PARALLELISM'] = str(tokenizers_parallelism).lower() @@ -143,7 +145,8 @@ def __iter_predict_result(self, samples_filepath, batch_size): rel2id=self.__model.rel2id, tokenizer=self.__model.sentence_encoder.tokenize, batch_size=batch_size, - table_name="contents", + table_name=self.__table_name, + task_kwargs=self.__task_kwargs, shuffle=False) with sentence_eval.dataset as dataset: diff --git a/arelight/run/infer.py b/arelight/run/infer.py index 09ddec5..2376f76 100644 --- a/arelight/run/infer.py +++ b/arelight/run/infer.py @@ -2,6 +2,7 @@ from os.path import join, dirname, basename from arekit.common.data import const +from arekit.common.data.const import S_IND, T_IND, ID from arekit.common.docs.entities_grouping import EntitiesGroupingPipelineItem from arekit.common.experiment.data_type import DataType from arekit.common.labels.base import NoLabel @@ -9,6 +10,7 @@ from arekit.common.pipeline.base import BasePipeline from arekit.common.synonyms.grouping import SynonymsCollectionValuesGroupingProviders from arekit.common.text.parser import BaseTextParser +from arekit.contrib.bert.input.providers.text_pair import PairTextProvider from arekit.contrib.utils.data.readers.sqlite import SQliteReader from arekit.contrib.utils.data.storages.row_cache import RowCacheStorage from arekit.contrib.utils.data.writers.sqlite_native import SQliteWriter @@ -86,7 +88,9 @@ assert(is_port_number(number=args.d3js_host, is_optional=True)) - labels_scaler = SingleLabelScaler(NoLabel()) + # Classification task label scaler setup. + labels_scl = {a: int(v) for a, v in map(lambda itm: itm.split(":"), args.labels_fmt.split(','))} + labels_scaler = CustomLabelScaler(**labels_scl) def setup_collection_name(value): # Considering Predefined name if the latter has been declared. @@ -114,7 +118,7 @@ def setup_collection_name(value): "samples_io": SamplesIO(target_dir=output_dir, prefix=collection_name, reader=SQliteReader(table_name="contents"), - writer=SQliteWriter()), + writer=SQliteWriter(table_name="contents")), "storage": RowCacheStorage( force_collect_columns=[const.ENTITIES, const.ENTITY_VALUES, const.ENTITY_TYPES, const.SENT_IND]), "save_labels_func": lambda data_type: data_type != DataType.Test @@ -180,7 +184,14 @@ def __entity_display_value(entity_value): "batch_size": args.batch_size, "pooler": "cls", "predefined_ckpts": OPENNRE_CHECKPOINTS, - }, + "table_name": "contents", + "task_kwargs": { + "no_label": str(labels_scaler.label_to_int(NoLabel())), + "default_id_column": ID, + "index_columns": [S_IND, T_IND], + "text_columns": [PairTextProvider.TEXT_A, PairTextProvider.TEXT_B] + }, + }, } backend_setups = { @@ -269,10 +280,8 @@ def __entity_display_value(entity_value): "d3js_host": args.d3js_host, }) - labels_scl = {a: int(v) for a, v in map(lambda itm: itm.split(":"), args.labels_fmt.split(','))} - settings.append({ - "labels_scaler": CustomLabelScaler(**labels_scl), + "labels_scaler": labels_scaler, # We provide this settings for inference. "predict_filepath": join(output_dir, "{}-predict.tsv.gz".format(collection_name)), "samples_io": sampling_engines_setup["arekit"]["samples_io"], diff --git a/arelight/third_party/torch.py b/arelight/third_party/torch.py index ab3684a..791ce45 100644 --- a/arelight/third_party/torch.py +++ b/arelight/third_party/torch.py @@ -11,29 +11,29 @@ class SQLiteSentenceREDataset(data.Dataset): This is a original OpenNRE implementation, adapted for SQLite. """ - def __init__(self, path, table_name, rel2id, tokenizer, kwargs, text_columns=None): + def __init__(self, path, table_name, rel2id, tokenizer, kwargs, task_kwargs): """ Args: path: path of the input file sqlite file rel2id: dictionary of relation->id mapping tokenizer: function of tokenizing """ + assert(isinstance(task_kwargs, dict)) + assert("no_label" in task_kwargs) + assert("default_id_column" in task_kwargs) + assert("index_columns" in task_kwargs) + assert("text_columns" in task_kwargs) super().__init__() self.path = path self.tokenizer = tokenizer self.rel2id = rel2id self.kwargs = kwargs + self.task_kwargs = task_kwargs self.table_name = table_name self.sqlite_service = SQLite3Service() - # Task-related parameters. - # OpenNRE-related task provider. - self.no_label = "0" - self.default_id_column = "id" - self.index_columns = ["s_ind", "t_ind"] - self.text_columns = text_columns def iter_ids(self, id_column=None): - col_name = self.default_id_column if id_column is None else id_column + col_name = self.task_kwargs["default_id_column"] if id_column is None else id_column for row in self.sqlite_service.iter_rows(select_columns=col_name, table_name=self.table_name): yield row[0] @@ -41,31 +41,26 @@ def __len__(self): return self.sqlite_service.table_rows_count(self.table_name) def __getitem__(self, index): - - # Automatically assign column names. - # This is expected to be refactored as a task-specific approach of text columns assignation. - if self.text_columns is None: - self.text_columns = self.sqlite_service.get_column_names( - table_name=self.table_name, - filter_name=lambda col_name: col_name.startswith("text")) + found_text_columns = self.sqlite_service.get_column_names( + table_name=self.table_name, filter_name=lambda col_name: col_name in self.task_kwargs["text_columns"]) iter_rows = self.sqlite_service.iter_rows( - select_columns=",".join(self.index_columns + self.text_columns), + select_columns=",".join(self.task_kwargs["index_columns"] + found_text_columns), value=index, - column_value=self.default_id_column, + column_value=self.task_kwargs["default_id_column"], table_name=self.table_name) fetched_row = next(iter_rows) opennre_item = { - "text": " ".join(fetched_row[-len(self.text_columns):]), + "text": " ".join(fetched_row[-len(found_text_columns):]), "h": {"pos": [fetched_row[0], fetched_row[0] + 1]}, "t": {"pos": [fetched_row[1], fetched_row[1] + 1]}, } seq = list(self.tokenizer(opennre_item, **self.kwargs)) - return [self.rel2id[self.no_label]] + seq # label, seq1, seq2, ... + return [self.rel2id[self.task_kwargs["no_label"]]] + seq # label, seq1, seq2, ... def collate_fn(data): data = list(zip(*data)) @@ -90,11 +85,10 @@ def __exit__(self, exc_type, exc_val, exc_tb): self.sqlite_service.disconnect() -def sentence_re_loader(path, table_name, rel2id, tokenizer, batch_size, - shuffle, text_columns=None, num_workers=8, collate_fn=SQLiteSentenceREDataset.collate_fn, - **kwargs): +def sentence_re_loader(path, table_name, rel2id, tokenizer, batch_size, shuffle, + task_kwargs, num_workers=8, collate_fn=SQLiteSentenceREDataset.collate_fn, **kwargs): dataset = SQLiteSentenceREDataset(path=path, table_name=table_name, rel2id=rel2id, - tokenizer=tokenizer, kwargs=kwargs, text_columns=text_columns) + tokenizer=tokenizer, kwargs=kwargs, task_kwargs=task_kwargs) data_loader = data.DataLoader(dataset=dataset, batch_size=batch_size, shuffle=shuffle, diff --git a/test/test_opennre_bert_infer.py b/test/test_opennre_bert_infer.py index 2fafbc9..b2684f6 100644 --- a/test/test_opennre_bert_infer.py +++ b/test/test_opennre_bert_infer.py @@ -51,6 +51,12 @@ def infer_bert(pretrain_path, labels_scaler, output_file_gzip, predefined, ckpt_ rel2id=model.rel2id, tokenizer=model.sentence_encoder.tokenize, batch_size=batch_size, + task_kwargs={ + "no_label": "0", + "default_id_column": "id", + "index_columns": ["s_ind", "t_ind"], + "text_columns": ["text_a", "text_b"] + }, shuffle=False) # Open database. diff --git a/test/test_pipeline_infer.py b/test/test_pipeline_infer.py index 71b0d2d..77d15b2 100644 --- a/test/test_pipeline_infer.py +++ b/test/test_pipeline_infer.py @@ -67,8 +67,6 @@ def input_to_docs(texts): return docs def create_sampling_params(self): - writer = SQliteWriter() - return { "rows_provider": create_bert_sample_provider( label_scaler=SingleLabelScaler(NoLabel()), @@ -78,7 +76,7 @@ def create_sampling_params(self): "save_labels_func": lambda _: False, "samples_io": SamplesIO(target_dir=utils.TEST_OUT_DIR, reader=JsonlReader(), - writer=writer), + writer=SQliteWriter()), "storage": RowCacheStorage(force_collect_columns=[ const.ENTITIES, const.ENTITY_VALUES, const.ENTITY_TYPES, const.SENT_IND ])