diff --git a/udata/api_fields.py b/udata/api_fields.py index 9e88645b57..4c4fe1dce7 100644 --- a/udata/api_fields.py +++ b/udata/api_fields.py @@ -6,6 +6,7 @@ import udata.api.fields as custom_restx_fields from udata.api import api, base_reference +from udata.core.elasticsearch import is_elasticsearch_enable from udata.mongo.errors import FieldValidationError lazy_reference = api.model( @@ -244,6 +245,9 @@ def wrapper(cls): if info is None: continue + if not info.get("api", True): + continue + def make_lambda(method): """ Factory function to create a lambda with the correct scope. @@ -308,43 +312,64 @@ def make_lambda(method): def apply_sort_filters_and_pagination(base_query): args = cls.__index_parser__.parse_args() - if sortables and args["sort"]: - negate = args["sort"].startswith("-") - sort_key = args["sort"][1:] if negate else args["sort"] - - sort_by = next( - (sortable["value"] for sortable in sortables if sortable["key"] == sort_key), - None, + if ( + args.get("q") + and is_elasticsearch_enable() + and getattr(cls, "__elasticsearch_search__", None) is not None + ): + # Do an Elasticsearch query + print(cls.__elasticsearch_search__(args.get("q"))) + print( + { + "data": cls.__elasticsearch_search__(args.get("q")), + } ) + return { + "data": cls.__elasticsearch_search__(args.get("q")), + } + else: + # Do a regular MongoDB query + if sortables and args["sort"]: + negate = args["sort"].startswith("-") + sort_key = args["sort"][1:] if negate else args["sort"] + + sort_by = next( + ( + sortable["value"] + for sortable in sortables + if sortable["key"] == sort_key + ), + None, + ) - if sort_by: - if negate: - sort_by = "-" + sort_by + if sort_by: + if negate: + sort_by = "-" + sort_by - base_query = base_query.order_by(sort_by) + base_query = base_query.order_by(sort_by) - if searchable and args.get("q"): - phrase_query = " ".join([f'"{elem}"' for elem in args["q"].split(" ")]) - base_query = base_query.search_text(phrase_query) + if searchable and args.get("q"): + phrase_query = " ".join([f'"{elem}"' for elem in args["q"].split(" ")]) + base_query = base_query.search_text(phrase_query) - for filterable in filterables: - if args.get(filterable["key"]): - for constraint in filterable["constraints"]: - if constraint == "objectid" and not ObjectId.is_valid( - args[filterable["key"]] - ): - api.abort(400, f'`{filterable["key"]}` must be an identifier') + for filterable in filterables: + if args.get(filterable["key"]): + for constraint in filterable["constraints"]: + if constraint == "objectid" and not ObjectId.is_valid( + args[filterable["key"]] + ): + api.abort(400, f'`{filterable["key"]}` must be an identifier') - base_query = base_query.filter( - **{ - filterable["column"]: args[filterable["key"]], - } - ) + base_query = base_query.filter( + **{ + filterable["column"]: args[filterable["key"]], + } + ) - if paginable: - base_query = base_query.paginate(args["page"], args["page_size"]) + if paginable: + base_query = base_query.paginate(args["page"], args["page_size"]) - return base_query + return base_query cls.apply_sort_filters_and_pagination = apply_sort_filters_and_pagination return cls diff --git a/udata/core/dataservices/api.py b/udata/core/dataservices/api.py index 49a44b52b1..1c8d355eb9 100644 --- a/udata/core/dataservices/api.py +++ b/udata/core/dataservices/api.py @@ -25,7 +25,11 @@ def get(self): """List or search all dataservices""" query = Dataservice.objects.visible() - return Dataservice.apply_sort_filters_and_pagination(query) + results = Dataservice.apply_sort_filters_and_pagination(query) + print(results) + + print("here") + return results @api.secure @api.doc("create_dataservice", responses={400: "Validation error"}) diff --git a/udata/core/dataservices/models.py b/udata/core/dataservices/models.py index bebc89b016..89b85a2540 100644 --- a/udata/core/dataservices/models.py +++ b/udata/core/dataservices/models.py @@ -95,7 +95,7 @@ class HarvestMetadata(db.EmbeddedDocument): archived_at = field(db.DateTimeField()) -@generate_fields() +@generate_fields(searchable=True) @elasticsearch( score_functions_description={ "public_service_score": {"factor": 8, "modifier": "sqrt", "missing": 1}, diff --git a/udata/core/elasticsearch/__init__.py b/udata/core/elasticsearch/__init__.py index 5bbcfc86b2..9cf8d1bd43 100644 --- a/udata/core/elasticsearch/__init__.py +++ b/udata/core/elasticsearch/__init__.py @@ -79,6 +79,10 @@ T = TypeVar("T") +def is_elasticsearch_enable() -> bool: + return True + + def elasticsearch( score_functions_description: dict[str, dict] = {}, build_search_query=None, @@ -166,7 +170,7 @@ def elasticsearch_search(query_text: str): else: query = Q( "function_score", - query=query.MatchAll(), + query=query.MatchAll(), # todo only match `searchable` field and not `indexable` / `filterable` functions=score_functions, ) diff --git a/udata/tests/api/test_dataservices_api.py b/udata/tests/api/test_dataservices_api.py index 186644ceb9..89cded2f17 100644 --- a/udata/tests/api/test_dataservices_api.py +++ b/udata/tests/api/test_dataservices_api.py @@ -353,48 +353,50 @@ def test_elasticsearch(self): ) time.sleep(1) - dataservices = Dataservice.__elasticsearch_search__("AMDAC") + print(self.get(url_for("api.dataservices", q="AMDAC")).json) + + dataservices = self.get(url_for("api.dataservices", q="AMDAC")).json["data"] assert len(dataservices) == 3 - assert dataservices[0].title == dataservice_c.title - assert dataservices[1].title == dataservice_a.title + assert dataservices[0]["title"] == dataservice_c.title + assert dataservices[1]["title"] == dataservice_a.title assert ( - dataservices[2].title == dataservice_b.title + dataservices[2]["title"] == dataservice_b.title ) # b is last even if it doesn't really match. dataservice_b.title = "B - Hello AMD world!" dataservice_b.save() time.sleep(3) - dataservices = Dataservice.__elasticsearch_search__("AMDAC") + dataservices = self.get(url_for("api.dataservices", q="AMDAC")).json["data"] assert len(dataservices) == 3 # `dataservice_b` should be first because it has a lot of followers - assert dataservices[0].title == dataservice_b.title - assert dataservices[1].title == dataservice_c.title - assert dataservices[2].title == dataservice_a.title + assert dataservices[0]["title"] == dataservice_b.title + assert dataservices[1]["title"] == dataservice_c.title + assert dataservices[2]["title"] == dataservice_a.title dataservice_a.organization = orga_sp dataservice_a.save() assert dataservice_a.public_service_score() == 4 time.sleep(3) - dataservices = Dataservice.__elasticsearch_search__("AMDAC") + dataservices = self.get(url_for("api.dataservices", q="AMDAC")).json["data"] assert len(dataservices) == 3 - assert dataservices[0].title == dataservice_b.title - assert dataservices[1].title == dataservice_a.title - assert dataservices[2].title == dataservice_c.title + assert dataservices[0]["title"] == dataservice_b.title + assert dataservices[1]["title"] == dataservice_a.title + assert dataservices[2]["title"] == dataservice_c.title dataservice_b.archived_at = datetime.utcnow() dataservice_b.save() time.sleep(3) - dataservices = Dataservice.__elasticsearch_search__("AMDAC") + dataservices = self.get(url_for("api.dataservices", q="AMDAC")).json["data"] assert len(dataservices) == 2 - assert dataservices[0].title == dataservice_a.title - assert dataservices[1].title == dataservice_c.title + assert dataservices[0]["title"] == dataservice_a.title + assert dataservices[1]["title"] == dataservice_c.title