Skip to content

Commit

Permalink
Basic search
Browse files Browse the repository at this point in the history
  • Loading branch information
ThibaudDauce committed Jul 30, 2024
1 parent 5dcb7b0 commit 093f4e1
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 114 deletions.
77 changes: 61 additions & 16 deletions udata/core/elasticsearch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
Document,
Field,
Float,
Index,
Integer,
Keyword,
Nested,
Expand All @@ -21,6 +22,7 @@
token_filter,
tokenizer,
)
from mongoengine import Document as MongoDocument
from mongoengine import signals

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -89,29 +91,54 @@ class Index:

attributes = {"Index": Index}

for key, field in cls._fields.items():
info = getattr(field, "__additional_field_info__", None)
if info is None:
continue
for key, field, searchable in get_searchable_fields(cls):
attributes[key] = convert_db_field_to_elasticsearch(field, searchable)

searchable = info.get("searchable", False)
ElasticSearchModel = type(f"{cls.__name__}ElasticsearchModel", (Document,), attributes)

if not searchable:
continue
ensure_index_exists(ElasticSearchModel._index, index_name)

attributes[key] = convert_db_field_to_elasticsearch(field, searchable)
def elasticsearch_index(cls, document, **kwargs):
print("calling it!")
print(document.id)
print(document.title)
convert_mongo_document_to_elasticsearch_document(document).save()

ElasticSearchModel = type(f"{cls.__name__}ElasticsearchModel", (Document,), attributes)
def elasticsearch_search(query_text):
s = Search(using=client, index=index_name).query("match", title=query_text)
response = s.execute()
print(response)

# Get all the models from MongoDB to fetch all the correct fields.
models = {
str(model.id): model for model in cls.objects(id__in=[hit.id for hit in response])
}

def elasticsearch_index():
pass
# Map these object to the response array in order to preserve the sort order
# returned by Elasticsearch
return [models[hit.id] for hit in response]

cls.__elasticsearch_model__ = ElasticSearchModel
cls.__elasticsearch_index__ = elasticsearch_index
cls.__elasticsearch_search__ = elasticsearch_search

signals.post_save.connect(cls.__elasticsearch_index__, sender=cls)


def get_searchable_fields(cls):
for key, field in cls._fields.items():
info = getattr(field, "__additional_field_info__", None)
if info is None:
continue

searchable = info.get("searchable", False)

if not searchable:
continue

yield key, field, searchable


def convert_db_field_to_elasticsearch(field, searchable: bool | str) -> Field:
if isinstance(searchable, str):
return {
Expand All @@ -131,9 +158,27 @@ def convert_db_field_to_elasticsearch(field, searchable: bool | str) -> Field:
raise ValueError(f"Unsupported MongoEngine field type {field.__class__.__name__}")


# def ensure_index_exists(cls: type) -> None:
# now = datetime.utcnow().strftime("%Y-%m-%d-%H-%M")
# index_name_with_suffix = f"{alias}-{now}"
# pattern = f"{alias}-*"
def convert_mongo_document_to_elasticsearch_document(document: MongoDocument) -> Document:
attributes = {}
attributes["id"] = str(document.id)

for key, field, searchable in get_searchable_fields(document.__class__):
attributes[key] = getattr(document, key)

return document.__elasticsearch_model__(**attributes)


def ensure_index_exists(index: Index, index_name: str) -> None:
now = datetime.utcnow().strftime("%Y-%m-%d-%H-%M")
index_name_with_suffix = f"{index_name}-{now}"
pattern = f"{index_name}-*"

print("exporting template")
index_template = index.as_template(index_name, pattern)
index_template.save()

# pass
print("creating index")
client.indices.create(index=index_name_with_suffix)
print("creating alias")
client.indices.put_alias(index=index_name_with_suffix, name=index_name)
print("done")
110 changes: 12 additions & 98 deletions udata/tests/api/test_dataservices_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,107 +319,21 @@ def test_dataservice_api_create_with_custom_user_or_org(self):
self.assertEqual(dataservice.organization.id, me_org.id)

def test_elasticsearch(self):
from datetime import datetime

from elasticsearch import Elasticsearch
from elasticsearch_dsl import Date, Document, Integer, Keyword, Search, Text, connections

french_elision = token_filter(
"french_elision",
type="elision",
articles_case=True,
articles=[
"l",
"m",
"t",
"qu",
"n",
"s",
"j",
"d",
"c",
"jusqu",
"quoiqu",
"lorsqu",
"puisqu",
],
)
SEARCH_SYNONYMS = [
"AMD, administrateur ministériel des données, AMDAC",
"lolf, loi de finance",
"waldec, RNA, répertoire national des associations",
"ovq, baromètre des résultats",
"contour, découpage",
"rp, recensement de la population",
]
french_stop = token_filter("french_stop", type="stop", stopwords="_french_")
french_stemmer = token_filter("french_stemmer", type="stemmer", language="light_french")
french_synonym = token_filter(
"french_synonym",
type="synonym",
ignore_case=True,
expand=True,
synonyms=SEARCH_SYNONYMS,
)

dgv_analyzer = analyzer(
"french_dgv",
tokenizer=tokenizer("standard"),
filter=[french_elision, french_synonym, french_stemmer, french_stop],
)

# Define a default Elasticsearch client
client = connections.create_connection(hosts=["localhost"])

alias = "".join(random.choices(string.ascii_lowercase, k=10))
now = datetime.utcnow().strftime("%Y-%m-%d-%H-%M")
index_name_with_suffix = f"{alias}-{now}"
pattern = f"{alias}-*"

class Index:
name = alias

attributes = {"Index": Index}
attributes["title"] = Text(analyzer=dgv_analyzer)

Article = type("Article", (Document,), attributes)

print("exporting template")
index_template = Article._index.as_template(alias, pattern)
index_template.save()

print("creating index")
client.indices.create(index=index_name_with_suffix)
print("creating alias")
client.indices.put_alias(index=index_name_with_suffix, name=alias)
print("done")

# create and save and article
article = Article(meta={"id": 42}, title="Hello AMD world!", tags=["test"])
article.body = """ looong text """
article.published_from = datetime.now()
article.save()

article = Article.get(id=42)
print(article.title)

# Display cluster health
print(connections.get_connection().cluster.health())

print("sleeping")
dataservice_a = DataserviceFactory(title="Hello AMD world!")
dataservice_b = DataserviceFactory(title="Other one")
time.sleep(1)
print("go!")

s = Search(using=client, index=alias).query("match", title="AMDAC")
dataservices = Dataservice.__elasticsearch_search__("AMDAC")

# s.aggs.bucket("per_tag", "terms", field="tags").metric("max_lines", "max", field="lines")
assert len(dataservices) == 1
assert dataservices[0].id == dataservice_a.id

response = s.execute()

print("hit in response")
dataservice_b.title = "Hello AMD world!"
dataservice_b.save()
time.sleep(1)

for hit in response:
print(hit.meta.score, hit.title)
dataservices = Dataservice.__elasticsearch_search__("AMDAC")

assert len(response) == 1
assert response[0].title == "Hello AMD world!"
assert len(dataservices) == 2
assert dataservices[0].id == dataservice_a.id
assert dataservices[1].id == dataservice_b.id

0 comments on commit 093f4e1

Please sign in to comment.