Skip to content

Commit

Permalink
Merge branch 'semantic_results_threshold' into 'advana-release/v2.0.1'
Browse files Browse the repository at this point in the history
implement default threshold to filter results

See merge request advana/gamechanger/gamechanger-ml-source!179

Description:

How to test:

Remember to attach the issue if it's not
  • Loading branch information
Jared Ross committed Jul 25, 2023
2 parents 5d4c3a1 + 390e1c7 commit 5422b0d
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 6 deletions.
8 changes: 5 additions & 3 deletions gamechangerml/api/fastapi/routers/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,10 @@ async def semantic_search(
body: dict,
response: Response,
num_results: int = 30,
threshold: float = .4
) -> dict:
"""semantic_title_search - endpoint for title transformer inference. Takes in a query, gets the embedding of the query, then finds the top (num_results) that match based on the embedding of the target (e.g. body["field"])
the threhsold will handle filtering results to only include similarity scores higher than the given threshold.
Args:
(dict) json format of the search query.\n
query: (str, required) a string of any length to embed and use for semantic search.\n
Expand All @@ -96,7 +97,7 @@ async def semantic_search(
results: (dict) results of inference.
"""
logger.info("SEMANTIC SEARCH - embedding query " + str(body["query"]) + "and pulling top results based on field " + str(body.get("field","title")))
logger.info("SEMANTIC SEARCH - embedding query " + str(body["query"]) + " and pulling top results based on field " + str(body.get("field","title")))
results = {}

try:
Expand All @@ -112,7 +113,8 @@ async def semantic_search(
search_results = MODELS.semantic_searcher.search(
query_text=query_text,
target_field=target_field,
num_results=num_results
num_results=num_results,
threshold=threshold
)
end = time.perf_counter()
logger.info(f"time: {end - start:0.4f} seconds")
Expand Down
10 changes: 7 additions & 3 deletions gamechangerml/src/search/sent_transformer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,12 +428,14 @@ def __init__(self, sim_model_name, index_path, transformer_path):

self.similarity = SimilarityRanker(sim_model_name, transformer_path)

def retrieve_topn(self, query, num_results, target_field):
def retrieve_topn(self, query, num_results, threshold, target_field):
# We'll use the title field for retrieval
retrieved = self.embedder.search(query, num_results)
results = []
for doc_id, score in retrieved:
doc = {}
if score < threshold:
continue
target_field_result = self.data[self.data['doc_id'] == str(doc_id)][target_field].values[0]
if not isinstance(target_field_result, str):
target_field_result = str(target_field_result)
Expand All @@ -445,7 +447,7 @@ def retrieve_topn(self, query, num_results, target_field):
return results

def search(
self, query_text, num_results, target_field="title"
self, query_text, num_results, threshold, target_field="title"
):
"""
Search the index and perform a similarity scoring reranker at
Expand All @@ -458,7 +460,9 @@ def search(
"""
logger.info(f"Sentence searching for: {query_text}")
if len(query_text) > 2:
top_results = self.retrieve_topn(query=query_text, num_results=num_results, target_field=target_field)
top_results = self.retrieve_topn(
query=query_text, num_results=num_results, target_field=target_field, threshold=threshold
)

top_results = sorted(
top_results, key=lambda i: i["score"], reverse=True
Expand Down

0 comments on commit 5422b0d

Please sign in to comment.