diff --git a/udata/core/dataservices/models.py b/udata/core/dataservices/models.py index 24db36b1b9..a3b7311bb2 100644 --- a/udata/core/dataservices/models.py +++ b/udata/core/dataservices/models.py @@ -1,5 +1,7 @@ from datetime import datetime +from elasticsearch_dsl import Search, query + import udata.core.contact_point.api_fields as contact_api_fields import udata.core.dataset.api_fields as datasets_api_fields from udata.api_fields import field, function_field, generate_fields @@ -23,6 +25,42 @@ DATASERVICE_FORMATS = ["REST", "WMS", "WSL"] +def build_search_query(query_text: str, score_functions): + return query.Q( + "bool", + should=[ + query.Q( + "function_score", + query=query.Bool( + should=[ + query.MultiMatch( + query=query_text, + type="phrase", + fields=["title^15", "acronym^15", "description^8"], + ) + ] + ), + functions=score_functions, + ), + query.Q( + "function_score", + query=query.Bool( + should=[ + query.MultiMatch( + query=query_text, + type="cross_fields", + fields=["title^7", "acronym^7", "description^4"], + operator="and", + ) + ] + ), + functions=score_functions, + ), + # query.Match(title={"query": query_text, "fuzziness": "AUTO:4,6"}), + ], + ) + + class DataserviceQuerySet(OwnedQuerySet): def visible(self): return self(archived_at=None, deleted_at=None, private=False) @@ -58,7 +96,12 @@ class HarvestMetadata(db.EmbeddedDocument): @generate_fields() -@elasticsearch() +@elasticsearch( + score_functions_description={ + "metrics.followers": {"factor": 4, "modifier": "sqrt", "missing": 1} + }, + build_search_query=build_search_query, +) class Dataservice(WithMetrics, Owned, db.Document): meta = { "indexes": [ diff --git a/udata/core/elasticsearch/__init__.py b/udata/core/elasticsearch/__init__.py index 94dd0cb700..2bd39c7637 100644 --- a/udata/core/elasticsearch/__init__.py +++ b/udata/core/elasticsearch/__init__.py @@ -1,3 +1,4 @@ +import json import logging import random import string @@ -12,13 +13,17 @@ Field, Float, Index, + InnerDoc, Integer, Keyword, Nested, + Object, + Q, Search, Text, analyzer, connections, + query, token_filter, tokenizer, ) @@ -74,15 +79,21 @@ client = connections.create_connection(hosts=["localhost"]) -def elasticsearch(**kwargs): +def elasticsearch(score_functions_description={}, build_search_query=None, **kwargs): def wrapper(cls): - cls.elasticsearch = generate_elasticsearch_model(cls) + cls.elasticsearch = generate_elasticsearch_model( + cls, + score_functions_description=score_functions_description, + build_search_query=build_search_query, + ) return cls return wrapper -def generate_elasticsearch_model(cls: type) -> type: +def generate_elasticsearch_model( + cls: type, score_functions_description, build_search_query +) -> type: index_name = cls._get_collection_name() # Testing name to have a new index in each test. @@ -103,10 +114,45 @@ class Index: def elasticsearch_index(cls, document, **kwargs): convert_mongo_document_to_elasticsearch_document(document).save() + score_functions = [ + query.SF("field_value_factor", field=key, **value) + for key, value in score_functions_description.items() + ] + def elasticsearch_search(query_text): - s = Search(using=client, index=index_name).query("match", title=query_text) - response = s.execute() - print(response) + s: Search = ElasticSearchModel.search() + + if query_text: + query = build_search_query(query_text, score_functions) + else: + query = Q( + "function_score", + query=query.MatchAll(), + functions=score_functions, + ) + + print("---------------------") + print("---------------------") + print("---------------------") + print("---------------------") + print("---------------------") + print(score_functions_description) + for field in score_functions_description.keys(): + print(field) + levels = field.split(".") + print(levels) + + if len(levels) == 1: + pass + elif len(levels) == 2: + query = Q("nested", path=levels[0], query=query) + else: + raise RuntimeError( + f"This system only support one level deep score function fields. '{field}' contains two or more dots." + ) + + print(json.dumps(s.query(query).to_dict(), indent=2)) + response = s.query(query).execute() # Get all the models from MongoDB to fetch all the correct fields. models = { @@ -151,6 +197,8 @@ def convert_db_field_to_elasticsearch(field, searchable: bool | str) -> Field: return Boolean() elif isinstance(field, mongo_fields.DateTimeField): return Date() + elif isinstance(field, mongo_fields.DictField): + return Nested() elif isinstance(field, mongo_fields.ReferenceField): return Nested(field.document_type_obj.__elasticsearch_model__) else: @@ -160,6 +208,7 @@ def convert_db_field_to_elasticsearch(field, searchable: bool | str) -> Field: def convert_mongo_document_to_elasticsearch_document(document: MongoDocument) -> Document: attributes = {} attributes["id"] = str(document.id) + attributes["meta"] = {"id": str(document.id)} for key, field, searchable in get_searchable_fields(document.__class__): attributes[key] = getattr(document, key) @@ -180,7 +229,7 @@ def ensure_index_exists(index: Index, index_name: str) -> None: if index.exists(): return - now = datetime.now(datetime.UTC).strftime("%Y-%m-%d-%H-%M") + now = datetime.utcnow().strftime("%Y-%m-%d-%H-%M") index_name_with_suffix = f"{index_name}-{now}" # Because we create the index manually (`elasticsearch_dsl` creates an index diff --git a/udata/core/metrics/models.py b/udata/core/metrics/models.py index 0baf737f36..9823e5a367 100644 --- a/udata/core/metrics/models.py +++ b/udata/core/metrics/models.py @@ -8,6 +8,7 @@ class WithMetrics(object): metrics = field( db.DictField(), readonly=True, + searchable=True, # TODO change to indexable ) __metrics_keys__ = [] diff --git a/udata/tests/api/test_dataservices_api.py b/udata/tests/api/test_dataservices_api.py index 07527fa8c7..75866f0400 100644 --- a/udata/tests/api/test_dataservices_api.py +++ b/udata/tests/api/test_dataservices_api.py @@ -319,8 +319,18 @@ def test_dataservice_api_create_with_custom_user_or_org(self): self.assertEqual(dataservice.organization.id, me_org.id) def test_elasticsearch(self): - dataservice_a = DataserviceFactory(title="Hello AMD world!") - dataservice_b = DataserviceFactory(title="Other one") + dataservice_a = DataserviceFactory( + title="Hello AMD world!", + metrics={ + "followers": 42, + }, + ) + dataservice_b = DataserviceFactory( + title="Other one", + metrics={ + "followers": 1337, + }, + ) time.sleep(1) dataservices = Dataservice.__elasticsearch_search__("AMDAC") @@ -330,10 +340,12 @@ def test_elasticsearch(self): dataservice_b.title = "Hello AMD world!" dataservice_b.save() - time.sleep(1) + time.sleep(3) dataservices = Dataservice.__elasticsearch_search__("AMDAC") assert len(dataservices) == 2 - assert dataservices[0].id == dataservice_a.id - assert dataservices[1].id == dataservice_b.id + + # `dataservice_b` should be first because it has a lot of followers + assert dataservices[0].id == dataservice_b.id + assert dataservices[1].id == dataservice_a.id