Skip to content

Commit

Permalink
Merge branch 'init_SentenceTransformer' into 'advana-release/v2.0.1'
Browse files Browse the repository at this point in the history
Init sentence transformer

See merge request advana/gamechanger/gamechanger-ml-source!180

Description:

How to test:

Remember to attach the issue if it's not
  • Loading branch information
Jared Ross committed Oct 24, 2023
2 parents f4b9f37 + 629d95b commit 45ba988
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 4 deletions.
35 changes: 34 additions & 1 deletion gamechangerml/api/fastapi/model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
from gamechangerml.src.search.sent_transformer.model import (
SentenceSearcher,
SentenceEncoder,
SemanticSearcher
SemanticSearcher,
GcSentenceTransformer
)
from gamechangerml.src.search.doc_compare import (
DocCompareSentenceEncoder,
Expand Down Expand Up @@ -68,6 +69,7 @@ def __init__(self):
__document_compare_searcher = None
__document_compare_encoder = None
__semantic_searcher = None
__gc_sentence_transformer = None

# Get methods for the models. If they don't exist try initializing them.
def getQA(self):
Expand Down Expand Up @@ -161,6 +163,14 @@ def getSemanticSearcher(self):
ModelLoader.initSemanticSearcher()
return ModelLoader.__semantic_searcher

def getGcSentenceTransformer(self):
if ModelLoader.__gc_sentence_transformer == None:
logger.warning(
"semantic_searcher was not set and was attempted to be used. Running init"
)
ModelLoader.initGcSentenceTransformer()
return ModelLoader.__gc_sentence_transformer

def set_error(self):
logger.error("Models cannot be directly set. Must use init methods.")

Expand All @@ -178,6 +188,7 @@ def set_error(self):
document_compare_searcher = property(getDocumentCompareSearcher, set_error)
document_compare_encoder = property(getDocumentCompareEncoder, set_error)
semantic_searcher = property(getSemanticSearcher, set_error)
gc_sentence_transformer = property(getGcSentenceTransformer, set_error)

@staticmethod
def initQA(qa_model_name=QA_MODEL.value):
Expand Down Expand Up @@ -385,6 +396,28 @@ def initSemanticSearcher(
except Exception as e:
logger.warning("** Could not load Similarity model for Semantic Searcher")
logger.warning(e)

@staticmethod
def initGcSentenceTransformer(transformer_path=LOCAL_TRANSFORMERS_DIR.value):
"""
initGcSentenceTransformer - loads Sentence Transformer on start
Args:
Returns:
"""
logger.info(f"Loading GC sentence transformer model")
try:
ModelLoader.__gc_sentence_transformer = GcSentenceTransformer(
encoder_model_name=EmbedderConfig.BASE_MODEL,
transformer_path=transformer_path
)
encoder_model = ModelLoader.__gc_sentence_transformer.encoder_model
# set cache variable defined in settings.py
latest_intel_model_encoder.value = encoder_model
logger.info(f"** Loaded GC Sentence Transformer Model from {encoder_model}")

except Exception as e:
logger.warning("** Could not load Encoder model")
logger.warning(e)

@staticmethod
def initSparse(model_name=latest_intel_model_trans.value):
Expand Down
5 changes: 3 additions & 2 deletions gamechangerml/api/fastapi/routers/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,10 @@ async def semantic_search(
Returns:
results: list of floats
"""
logger.info(f"Hit embedSemanticQuery with json: {body}")
query_text = body["query"]
embeddings = MODELS.semantic_searcher.embed_query(query_text)
return list(embeddings.astype(float))
embeddings = MODELS.gc_sentence_transformer.embed_query(query_text)
return list(embeddings.astype(float))


@router.post("/semanticSearch", status_code=200)
Expand Down
3 changes: 2 additions & 1 deletion gamechangerml/api/fastapi/routers/startup.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@
MODELS.initTopics,
MODELS.initRecommender,
MODELS.initDocumentCompareSearcher,
MODELS.initSemanticSearcher
MODELS.initSemanticSearcher,
MODELS.initGcSentenceTransformer
]


Expand Down
18 changes: 18 additions & 0 deletions gamechangerml/src/search/sent_transformer/model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from txtai.embeddings import Embeddings
from txtai.pipeline import Similarity
from txtai.ann import ANN
from sentence_transformers import SentenceTransformer

import os
import numpy as np
Expand Down Expand Up @@ -474,4 +475,21 @@ def search(

def embed_query(self, query_text):
embeddings = self.embedder.transform(query_text)
return embeddings

class GcSentenceTransformer(object):
"""
Instantiates Sentence Transformer model
Args:
transformer_path (str): Path to transformer directory
encoder_model (str): Model name supported by sentence_transformers
"""

def __init__(self, encoder_model_name, transformer_path):
self.encoder_model = SentenceTransformer(os.path.join(transformer_path, encoder_model_name))

def embed_query(self, query_text):
query_text = " ".join(preprocess(query_text))
embeddings = self.encoder_model.encode(query_text)
return embeddings

0 comments on commit 45ba988

Please sign in to comment.