From a8a28e6dfc91ee775a7cba3bc2b41f3f2df7aa37 Mon Sep 17 00:00:00 2001 From: Nicolay Rusnachenko Date: Mon, 8 Jan 2024 10:13:12 +0000 Subject: [PATCH] Unit-test: link to the default downloading dir. #131 --- test/test_opennre_bert_infer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/test_opennre_bert_infer.py b/test/test_opennre_bert_infer.py index 100cd66..9e49424 100644 --- a/test/test_opennre_bert_infer.py +++ b/test/test_opennre_bert_infer.py @@ -14,6 +14,7 @@ from arelight.predict.writer_csv import TsvPredictWriter from arelight.run.utils import OPENNRE_CHECKPOINTS from arelight.third_party.torch import sentence_re_loader +from arelight.utils import get_default_download_dir class TestLoadModel(unittest.TestCase): @@ -44,7 +45,7 @@ def infer_bert(pretrain_path, labels_scaler, output_file_gzip, predefined, ckpt_ model = BertOpenNREInferencePipelineItem.init_bert_model( pretrain_path=pretrain_path, labels_scaler=labels_scaler, ckpt_path=ckpt_path, device_type="cpu", max_length=max_length, mask_entity=mask_entity, - dir_to_donwload=os.getcwd(), pooler=pooler, predefined=predefined) + dir_to_donwload=get_default_download_dir(), pooler=pooler, predefined=predefined) eval_loader = sentence_re_loader(path=test_data_file, table_name="contents",