Skip to content

Commit

Permalink
Merge pull request #132 from dod-advana/task/UOT-145148
Browse files Browse the repository at this point in the history
Task/uot 145148
  • Loading branch information
JaredJRoss authored Jun 14, 2022
2 parents b4fe0bd + 3663018 commit 908d2d8
Show file tree
Hide file tree
Showing 6 changed files with 129 additions and 94 deletions.
116 changes: 55 additions & 61 deletions gamechangerml/api/fastapi/model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,17 +168,16 @@ def initQA(qa_model_name=QA_MODEL.value):
Returns:
"""
try:
if MODEL_LOAD_FLAG:
logger.info("Starting QA pipeline")
ModelLoader.__qa_model = QAReader(
transformer_path=LOCAL_TRANSFORMERS_DIR.value,
use_gpu=True,
model_name=qa_model_name,
**QAConfig.MODEL_ARGS,
)
# set cache variable defined in settings.py
QA_MODEL.value = ModelLoader.__qa_model.READER_PATH
logger.info("Finished loading QA Reader")
logger.info("Starting QA pipeline")
ModelLoader.__qa_model = QAReader(
transformer_path=LOCAL_TRANSFORMERS_DIR.value,
use_gpu=True,
model_name=qa_model_name,
**QAConfig.MODEL_ARGS,
)
# set cache variable defined in settings.py
QA_MODEL.value = ModelLoader.__qa_model.READER_PATH
logger.info("Finished loading QA Reader")
except OSError:
logger.error(f"Could not load Question Answer Model")

Expand All @@ -190,11 +189,10 @@ def initQE(qexp_model_path=QEXP_MODEL_NAME.value):
"""
logger.info(f"Loading Pretrained Vector from {qexp_model_path}")
try:
if MODEL_LOAD_FLAG:
ModelLoader.__query_expander = qe.QE(
qexp_model_path, **QexpConfig.MODEL_ARGS["init"]
)
logger.info("** Loaded Query Expansion Model")
ModelLoader.__query_expander = qe.QE(
qexp_model_path, **QexpConfig.MODEL_ARGS["init"]
)
logger.info("** Loaded Query Expansion Model")
except Exception as e:
logger.warning("** Could not load QE model")
logger.warning(e)
Expand All @@ -207,11 +205,10 @@ def initQEJBook(qexp_jbook_model_path=QEXP_JBOOK_MODEL_NAME.value):
"""
logger.info(f"Loading Pretrained Vector from {qexp_jbook_model_path}")
try:
if MODEL_LOAD_FLAG:
ModelLoader.__query_expander_jbook = qe.QE(
qexp_jbook_model_path, **QexpConfig.MODEL_ARGS["init"]
)
logger.info("** Loaded JBOOK Query Expansion Model")
ModelLoader.__query_expander_jbook = qe.QE(
qexp_jbook_model_path, **QexpConfig.MODEL_ARGS["init"]
)
logger.info("** Loaded JBOOK Query Expansion Model")
except Exception as e:
logger.warning("** Could not load JBOOK QE model")
logger.warning(e)
Expand All @@ -224,9 +221,8 @@ def initWordSim(model_path=WORD_SIM_MODEL.value):
"""
logger.info(f"Loading Word Sim Model from {model_path}")
try:
if MODEL_LOAD_FLAG:
ModelLoader.__word_sim = WordSim(model_path)
logger.info("** Loaded Word Sim Model")
ModelLoader.__word_sim = WordSim(model_path)
logger.info("** Loaded Word Sim Model")
except Exception as e:
logger.warning("** Could not load Word Sim model")
logger.warning(e)
Expand All @@ -240,21 +236,22 @@ def initSentenceSearcher(
Args:
Returns:
"""
logger.info(f"Loading Sentence Searcher with sent index path: {index_path}")
logger.info(
f"Loading Sentence Searcher with sent index path: {index_path}")
try:
if MODEL_LOAD_FLAG:
ModelLoader.__sentence_searcher = SentenceSearcher(
sim_model_name=SimilarityConfig.BASE_MODEL,
index_path=index_path,
transformer_path=transformer_path,
)

sim_model = ModelLoader.__sentence_searcher.similarity
# set cache variable defined in settings.py
latest_intel_model_sim.value = sim_model.sim_model
logger.info(
f"** Loaded Similarity Model from {sim_model.sim_model} and sent index from {index_path}"
)

ModelLoader.__sentence_searcher = SentenceSearcher(
sim_model_name=SimilarityConfig.BASE_MODEL,
index_path=index_path,
transformer_path=transformer_path,
)

sim_model = ModelLoader.__sentence_searcher.similarity
# set cache variable defined in settings.py
latest_intel_model_sim.value = sim_model.sim_model
logger.info(
f"** Loaded Similarity Model from {sim_model.sim_model} and sent index from {index_path}"
)

except Exception as e:
logger.warning("** Could not load Similarity model")
Expand All @@ -269,16 +266,15 @@ def initSentenceEncoder(transformer_path=LOCAL_TRANSFORMERS_DIR.value):
"""
logger.info(f"Loading encoder model")
try:
if MODEL_LOAD_FLAG:
ModelLoader.__sentence_encoder = SentenceEncoder(
encoder_model_name=EmbedderConfig.BASE_MODEL,
transformer_path=transformer_path,
**EmbedderConfig.MODEL_ARGS,
)
encoder_model = ModelLoader.__sentence_encoder.encoder_model
# set cache variable defined in settings.py
latest_intel_model_encoder.value = encoder_model
logger.info(f"** Loaded Encoder Model from {encoder_model}")
ModelLoader.__sentence_encoder = SentenceEncoder(
encoder_model_name=EmbedderConfig.BASE_MODEL,
transformer_path=transformer_path,
**EmbedderConfig.MODEL_ARGS,
)
encoder_model = ModelLoader.__sentence_encoder.encoder_model
# set cache variable defined in settings.py
latest_intel_model_encoder.value = encoder_model
logger.info(f"** Loaded Encoder Model from {encoder_model}")

except Exception as e:
logger.warning("** Could not load Encoder model")
Expand All @@ -291,7 +287,7 @@ def initDocumentCompareSearcher(
):
"""
initDocumentCompareSearcher - loads SentenceSearcher class on start
Args:
Args:
Returns:
"""
logger.info(
Expand Down Expand Up @@ -341,9 +337,9 @@ def initDocumentCompareEncoder(transformer_path=LOCAL_TRANSFORMERS_DIR.value):
@staticmethod
def initSparse(model_name=latest_intel_model_trans.value):
try:
if MODEL_LOAD_FLAG:
ModelLoader.__sparse_reader = sparse.SparseReader(model_name=model_name)
logger.info(f"Sparse Reader: {model_name} loaded")
ModelLoader.__sparse_reader = sparse.SparseReader(
model_name=model_name)
logger.info(f"Sparse Reader: {model_name} loaded")
except Exception as e:
logger.warning("** Could not load Sparse Reader")
logger.warning(e)
Expand All @@ -355,11 +351,10 @@ def initTopics(model_path=TOPICS_MODEL.value) -> None:
Returns:
"""
try:
if MODEL_LOAD_FLAG:
logger.info(f"Loading topic model {model_path}")
logger.info(TopicsConfig.DATA_ARGS)
ModelLoader.__topic_model = Topics(directory=model_path)
logger.info("Finished loading Topic Model")
logger.info(f"Loading topic model {model_path}")
logger.info(TopicsConfig.DATA_ARGS)
ModelLoader.__topic_model = Topics(directory=model_path)
logger.info("Finished loading Topic Model")
except Exception as e:
logger.warning("** Could not load Topic model")
logger.warning(e)
Expand All @@ -371,9 +366,8 @@ def initRecommender():
Returns:
"""
try:
if MODEL_LOAD_FLAG:
logger.info("Starting Recommender pipeline")
ModelLoader.__recommender = Recommender()
logger.info("Finished loading Recommender")
logger.info("Starting Recommender pipeline")
ModelLoader.__recommender = Recommender()
logger.info("Finished loading Recommender")
except OSError:
logger.error(f"** Could not load Recommender")
50 changes: 37 additions & 13 deletions gamechangerml/api/fastapi/routers/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,20 @@
import requests
import base64
import hashlib
import datetime

# must import sklearn first or you get an import error
from gamechangerml.src.search.query_expansion.utils import remove_original_kw
from gamechangerml.src.featurization.keywords.extract_keywords import get_keywords
from gamechangerml.src.text_handling.process import preprocess
from gamechangerml.api.fastapi.version import __version__
from gamechangerml.src.utilities import gc_web_api
from gamechangerml.api.utils.redisdriver import CacheVariable

# from gamechangerml.models.topic_models.tfidf import bigrams, tfidf_model
# from gamechangerml.src.featurization.summary import GensimSumm
from gamechangerml.api.fastapi.settings import *
from gamechangerml.api.fastapi.settings import CACHE_EXPIRE_DAYS
from gamechangerml.api.utils.logger import logger
from gamechangerml.api.fastapi.model_loader import ModelLoader

from gamechangerml.configs.config import QexpConfig
Expand Down Expand Up @@ -103,16 +106,32 @@ async def trans_sentence_infer(
results = {}
try:
query_text = body["text"]
results = MODELS.sentence_searcher.search(
query_text,
num_results,
process=process,
externalSim=False,
threshold=threshold,
)
cache = CacheVariable(query_text, True)
cached_value = cache.get_value()
if cached_value:
logger.info("Searched was found in cache")
results = cached_value
else:
results = MODELS.sentence_searcher.search(
query_text,
num_results,
process=process,
externalSim=False,
threshold=threshold,
)
cache.set_value(
results,
expire=int(
(
datetime.datetime.utcnow()
+ datetime.timedelta(days=CACHE_EXPIRE_DAYS)
).timestamp()
),
)
logger.info(results)
except Exception:
logger.error(f"Unable to get results from sentence transformer for {body}")
logger.error(
f"Unable to get results from sentence transformer for {body}")
response.status_code = status.HTTP_500_INTERNAL_SERVER_ERROR
raise
return results
Expand Down Expand Up @@ -169,7 +188,7 @@ async def post_expand_query_terms(body: dict, response: Response) -> dict:
query_expander = (
MODELS.query_expander
if body.get("qe_model", "gc_core") != "jbook"
or MODELS.query_expander_jbook == None
or MODELS.query_expander_jbook is None
else MODELS.query_expander_jbook
)
try:
Expand All @@ -181,7 +200,8 @@ async def post_expand_query_terms(body: dict, response: Response) -> dict:
# Removes original word from the return terms unless it is combined with another word
logger.info(f"original expanded terms: {expansion_list}")
finalTerms = remove_original_kw(expansion_list, terms_string)
expansion_dict[terms_string] = ['"{}"'.format(exp) for exp in finalTerms]
expansion_dict[terms_string] = [
'"{}"'.format(exp) for exp in finalTerms]
logger.info(f"-- Expanded {terms_string} to \n {finalTerms}")
# Perform word similarity
logger.info(f"Finding similiar words for: {terms_string}")
Expand Down Expand Up @@ -229,7 +249,8 @@ async def post_recommender(body: dict, response: Response) -> dict:
if body["sample"]:
sample = body["sample"]
logger.info(f"Recommending similar documents to {filenames}")
results = MODELS.recommender.get_recs(filenames=filenames, sample=sample)
results = MODELS.recommender.get_recs(
filenames=filenames, sample=sample)
if results["results"] != []:
logger.info(f"Found similar docs: \n {str(results)}")
else:
Expand All @@ -244,7 +265,10 @@ async def post_recommender(body: dict, response: Response) -> dict:

@router.post("/documentCompare", status_code=200)
async def document_compare_infer(
body: dict, response: Response, num_results: int = 10, process: bool = True,
body: dict,
response: Response,
num_results: int = 10,
process: bool = True,
) -> dict:
"""document_compare_infer - endpoint for document compare inference
Args:
Expand Down
5 changes: 3 additions & 2 deletions gamechangerml/api/fastapi/routers/startup.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,18 @@

@router.on_event("startup")
async def load_models():

if MODEL_LOAD_FLAG:
MODELS.initQA()
MODELS.initQE()
MODELS.initQEJBook()
MODELS.initSentenceEncoder()
MODELS.initSentenceSearcher()
# MODELS.initSentenceEncoder()
MODELS.initWordSim()
MODELS.initTopics()
MODELS.initRecommender()
MODELS.initDocumentCompareEncoder()
MODELS.initDocumentCompareSearcher()
# MODELS.initDocumentCompareSearcher()
logger.info("AFTER LOAD MODELS")
else:
logger.info("MODEL_LOAD_FLAG set to False, no models loaded")
Expand Down
2 changes: 1 addition & 1 deletion gamechangerml/api/fastapi/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
MODEL_LOAD_FLAG = False
else:
MODEL_LOAD_FLAG = True

CACHE_EXPIRE_DAYS = 15
if GC_ML_HOST == "":
GC_ML_HOST = "localhost"
ignore_files = ["._.DS_Store", ".DS_Store", "index"]
Expand Down
Loading

0 comments on commit 908d2d8

Please sign in to comment.