diff --git a/test/test_opennre_bert_infer.py b/test/test_opennre_bert_infer.py index 100cd66..9e49424 100644 --- a/test/test_opennre_bert_infer.py +++ b/test/test_opennre_bert_infer.py @@ -14,6 +14,7 @@ from arelight.predict.writer_csv import TsvPredictWriter from arelight.run.utils import OPENNRE_CHECKPOINTS from arelight.third_party.torch import sentence_re_loader +from arelight.utils import get_default_download_dir class TestLoadModel(unittest.TestCase): @@ -44,7 +45,7 @@ def infer_bert(pretrain_path, labels_scaler, output_file_gzip, predefined, ckpt_ model = BertOpenNREInferencePipelineItem.init_bert_model( pretrain_path=pretrain_path, labels_scaler=labels_scaler, ckpt_path=ckpt_path, device_type="cpu", max_length=max_length, mask_entity=mask_entity, - dir_to_donwload=os.getcwd(), pooler=pooler, predefined=predefined) + dir_to_donwload=get_default_download_dir(), pooler=pooler, predefined=predefined) eval_loader = sentence_re_loader(path=test_data_file, table_name="contents",