diff --git a/gamechangerml/api/fastapi/model_loader.py b/gamechangerml/api/fastapi/model_loader.py index f675edc9..8994b9eb 100644 --- a/gamechangerml/api/fastapi/model_loader.py +++ b/gamechangerml/api/fastapi/model_loader.py @@ -16,7 +16,8 @@ from gamechangerml.src.search.sent_transformer.model import ( SentenceSearcher, SentenceEncoder, - SemanticSearcher + SemanticSearcher, + GcSentenceTransformer ) from gamechangerml.src.search.doc_compare import ( DocCompareSentenceEncoder, @@ -68,6 +69,7 @@ def __init__(self): __document_compare_searcher = None __document_compare_encoder = None __semantic_searcher = None + __gc_sentence_transformer = None # Get methods for the models. If they don't exist try initializing them. def getQA(self): @@ -161,6 +163,14 @@ def getSemanticSearcher(self): ModelLoader.initSemanticSearcher() return ModelLoader.__semantic_searcher + def getGcSentenceTransformer(self): + if ModelLoader.__gc_sentence_transformer == None: + logger.warning( + "semantic_searcher was not set and was attempted to be used. Running init" + ) + ModelLoader.initGcSentenceTransformer() + return ModelLoader.__gc_sentence_transformer + def set_error(self): logger.error("Models cannot be directly set. Must use init methods.") @@ -178,6 +188,7 @@ def set_error(self): document_compare_searcher = property(getDocumentCompareSearcher, set_error) document_compare_encoder = property(getDocumentCompareEncoder, set_error) semantic_searcher = property(getSemanticSearcher, set_error) + gc_sentence_transformer = property(getGcSentenceTransformer, set_error) @staticmethod def initQA(qa_model_name=QA_MODEL.value): @@ -385,6 +396,28 @@ def initSemanticSearcher( except Exception as e: logger.warning("** Could not load Similarity model for Semantic Searcher") logger.warning(e) + + @staticmethod + def initGcSentenceTransformer(transformer_path=LOCAL_TRANSFORMERS_DIR.value): + """ + initGcSentenceTransformer - loads Sentence Transformer on start + Args: + Returns: + """ + logger.info(f"Loading GC sentence transformer model") + try: + ModelLoader.__gc_sentence_transformer = GcSentenceTransformer( + encoder_model_name=EmbedderConfig.BASE_MODEL, + transformer_path=transformer_path + ) + encoder_model = ModelLoader.__gc_sentence_transformer.encoder_model + # set cache variable defined in settings.py + latest_intel_model_encoder.value = encoder_model + logger.info(f"** Loaded GC Sentence Transformer Model from {encoder_model}") + + except Exception as e: + logger.warning("** Could not load Encoder model") + logger.warning(e) @staticmethod def initSparse(model_name=latest_intel_model_trans.value): diff --git a/gamechangerml/api/fastapi/routers/search.py b/gamechangerml/api/fastapi/routers/search.py index 72a62951..c65197f8 100644 --- a/gamechangerml/api/fastapi/routers/search.py +++ b/gamechangerml/api/fastapi/routers/search.py @@ -93,9 +93,10 @@ async def semantic_search( Returns: results: list of floats """ + logger.info(f"Hit embedSemanticQuery with json: {body}") query_text = body["query"] - embeddings = MODELS.semantic_searcher.embed_query(query_text) - return list(embeddings.astype(float)) + embeddings = MODELS.gc_sentence_transformer.embed_query(query_text) + return list(embeddings.astype(float)) @router.post("/semanticSearch", status_code=200) diff --git a/gamechangerml/api/fastapi/routers/startup.py b/gamechangerml/api/fastapi/routers/startup.py index 465b7428..03604826 100644 --- a/gamechangerml/api/fastapi/routers/startup.py +++ b/gamechangerml/api/fastapi/routers/startup.py @@ -39,7 +39,8 @@ MODELS.initTopics, MODELS.initRecommender, MODELS.initDocumentCompareSearcher, - MODELS.initSemanticSearcher + MODELS.initSemanticSearcher, + MODELS.initGcSentenceTransformer ] diff --git a/gamechangerml/src/search/sent_transformer/model.py b/gamechangerml/src/search/sent_transformer/model.py index 6a37a449..fed5f002 100644 --- a/gamechangerml/src/search/sent_transformer/model.py +++ b/gamechangerml/src/search/sent_transformer/model.py @@ -1,6 +1,7 @@ from txtai.embeddings import Embeddings from txtai.pipeline import Similarity from txtai.ann import ANN +from sentence_transformers import SentenceTransformer import os import numpy as np @@ -474,4 +475,21 @@ def search( def embed_query(self, query_text): embeddings = self.embedder.transform(query_text) + return embeddings + +class GcSentenceTransformer(object): + """ + Instantiates Sentence Transformer model + + Args: + transformer_path (str): Path to transformer directory + encoder_model (str): Model name supported by sentence_transformers + """ + + def __init__(self, encoder_model_name, transformer_path): + self.encoder_model = SentenceTransformer(os.path.join(transformer_path, encoder_model_name)) + + def embed_query(self, query_text): + query_text = " ".join(preprocess(query_text)) + embeddings = self.encoder_model.encode(query_text) return embeddings \ No newline at end of file