Skip to content

Commit

Permalink
Merge branch 'master' of personal.github.com:cdqa-suite/cdQA
Browse files Browse the repository at this point in the history
  • Loading branch information
andrelmfarias committed Oct 25, 2019
2 parents c27a392 + c062154 commit ea3aacf
Show file tree
Hide file tree
Showing 6 changed files with 8 additions and 11 deletions.
3 changes: 3 additions & 0 deletions cdqa/pipeline/cdqa_sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,9 @@ def __init__(self, reader=None, retriever="bm25", retrieve_by_doc=False, **kwarg

self.retrieve_by_doc = retrieve_by_doc

if torch.cuda.is_available():
self.cuda()

def fit_retriever(self, df: pd.DataFrame = None):
""" Fit the QAPipeline retriever to a list of documents in a dataframe.
Parameters
Expand Down
2 changes: 1 addition & 1 deletion examples/tutorial-first-steps-cdqa.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"### Instantiate the cdQA pipeline from a pre-trained CPU reader"
"### Instantiate the cdQA pipeline from a pre-trained reader model"
]
},
{
Expand Down
4 changes: 2 additions & 2 deletions examples/tutorial-train-reader-squad.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"### Save CPU model locally"
"### Save model locally"
]
},
{
Expand All @@ -159,7 +159,7 @@
"metadata": {},
"outputs": [],
"source": [
"joblib.dump(reader, os.path.join(reader.output_dir, 'bert_qa_vCPU.joblib'))"
"joblib.dump(reader, os.path.join(reader.output_dir, 'bert_qa.joblib'))"
]
}
],
Expand Down
5 changes: 1 addition & 4 deletions examples/tutorial-use-pdf-converter.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@
"id": "FLZd4H_vPJuU"
},
"source": [
"### Instantiate the cdQA pipeline from a pre-trained CPU reader and send it to GPU"
"### Instantiate the cdQA pipeline from a pre-trained reader model"
]
},
{
Expand Down Expand Up @@ -299,9 +299,6 @@
"source": [
"cdqa_pipeline = QAPipeline(reader='./models/bert_qa_vCPU-sklearn.joblib', max_df=1.0)\n",
"\n",
"# Send model to GPU\n",
"cdqa_pipeline.cuda()\n",
"\n",
"# Fit Retriever to documents\n",
"cdqa_pipeline.fit_retriever(X=df)"
]
Expand Down
2 changes: 0 additions & 2 deletions tests/test_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,6 @@ def test_evaluate_pipeline():

cdqa_pipeline = QAPipeline(reader="./models/bert_qa_vCPU-sklearn.joblib", n_jobs=-1)
cdqa_pipeline.fit_retriever(df)
if torch.cuda.is_available():
cdqa_pipeline.cuda()

eval_dict = evaluate_pipeline(cdqa_pipeline, "./test_data.json", output_dir=None)

Expand Down
3 changes: 1 addition & 2 deletions tests/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@ def execute_pipeline(query, n_predictions=None):

cdqa_pipeline = QAPipeline(reader="models/bert_qa_vCPU-sklearn.joblib")
cdqa_pipeline.fit_retriever(df)
if torch.cuda.is_available():
cdqa_pipeline.cuda()

if n_predictions is not None:
predictions = cdqa_pipeline.predict(query, n_predictions=n_predictions)
result = []
Expand Down

0 comments on commit ea3aacf

Please sign in to comment.