Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rc/5.4.5 #866

Merged
merged 12 commits into from
Mar 10, 2025
Merged
2 changes: 1 addition & 1 deletion doc/source/configuration.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ Adjust driver configuration - these options are only available for this connecti
config.MAX_TRANSACTION_RETRY_TIME = 30.0 # default
config.RESOLVER = None # default
config.TRUST = neo4j.TRUST_SYSTEM_CA_SIGNED_CERTIFICATES # default
config.USER_AGENT = neomodel/v5.4.4 # default
config.USER_AGENT = neomodel/v5.4.5 # default

Setting the database name, if different from the default one::

Expand Down
20 changes: 20 additions & 0 deletions doc/source/traversal.rst
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,26 @@ With both `traverse_relations` and `fetch_relations`, you can force the use of a

Person.nodes.fetch_relations('city__country', Optional('country')).all()

Unique variables
----------------

If you want to use the same variable name for traversed nodes when chaining traversals, you can use the `unique_variables` method::

# This does not guarantee that coffees__species will traverse the same nodes as coffees
# So coffees__species can traverse the Coffee node "Gold 3000"
nodeset = (
Supplier.nodes.fetch_relations("coffees", "coffees__species")
.filter(coffees__name="Nescafe")
)

# This guarantees that coffees__species will traverse the same nodes as coffees
# So when fetching species, it will only fetch those of the Coffee node "Nescafe"
nodeset = (
Supplier.nodes.fetch_relations("coffees", "coffees__species")
.filter(coffees__name="Nescafe")
.unique_variables("coffees")
)

Resolve results
---------------

Expand Down
2 changes: 1 addition & 1 deletion neomodel/_version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "5.4.4"
__version__ = "5.4.5"
28 changes: 20 additions & 8 deletions neomodel/async_/match.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,9 +503,11 @@ def create_relation_identifier(self) -> str:
self._relation_identifier_count += 1
return f"r{self._relation_identifier_count}"

def create_node_identifier(self, prefix: str) -> str:
self._node_identifier_count += 1
return f"{prefix}{self._node_identifier_count}"
def create_node_identifier(self, prefix: str, path: str) -> str:
if path not in self.node_set._unique_variables:
self._node_identifier_count += 1
return f"{prefix}{self._node_identifier_count}"
return prefix

def build_order_by(self, ident: str, source: "AsyncNodeSet") -> None:
if "?" in source.order_by_elements:
Expand Down Expand Up @@ -613,14 +615,16 @@ def build_traversal_from_path(
rhs_label = relationship.definition["node_class"].__label__
if relation.get("relation_filtering"):
rhs_name = rel_ident
rhs_ident = f":{rhs_label}"
else:
if index + 1 == len(parts) and "alias" in relation:
# If an alias is defined, use it to store the last hop in the path
rhs_name = relation["alias"]
else:
rhs_name = f"{rhs_label.lower()}_{rel_iterator}"
rhs_name = self.create_node_identifier(rhs_name)
rhs_ident = f"{rhs_name}:{rhs_label}"
rhs_name = self.create_node_identifier(rhs_name, rel_iterator)
rhs_ident = f"{rhs_name}:{rhs_label}"

if relation["include_in_return"] and not already_present:
self._additional_return(rhs_name)

Expand Down Expand Up @@ -825,9 +829,11 @@ def add_to_target(statement: str, connector: Q, optional: bool) -> None:
match_filters = [filter[0] for filter in target if not filter[1]]
opt_match_filters = [filter[0] for filter in target if filter[1]]
if q.connector == Q.OR and match_filters and opt_match_filters:
raise ValueError(
"Cannot filter using OR operator on variables coming from both MATCH and OPTIONAL MATCH statements"
)
# In this case, we can't split filters in two WHERE statements so we move
# everything into the one applied after OPTIONAL MATCH statements...
opt_match_filters += match_filters
match_filters = []

ret = f" {q.connector} ".join(match_filters)
if ret and q.negated:
ret = f"NOT ({ret})"
Expand Down Expand Up @@ -1381,6 +1387,7 @@ def __init__(self, source: Any) -> None:
self._extra_results: list = []
self._subqueries: list[Subquery] = []
self._intermediate_transforms: list = []
self._unique_variables: list[str] = []

def __await__(self) -> Any:
return self.all().__await__() # type: ignore[attr-defined]
Expand Down Expand Up @@ -1552,6 +1559,11 @@ def _register_relation_to_fetch(
item["alias"] = alias
return item

def unique_variables(self, *pathes: tuple[str, ...]) -> "AsyncNodeSet":
"""Generate unique variable names for the given pathes."""
self._unique_variables = pathes
return self

def fetch_relations(self, *relation_names: tuple[str, ...]) -> "AsyncNodeSet":
"""Specify a set of relations to traverse and return."""
relations = []
Expand Down
28 changes: 20 additions & 8 deletions neomodel/sync_/match.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,9 +501,11 @@ def create_relation_identifier(self) -> str:
self._relation_identifier_count += 1
return f"r{self._relation_identifier_count}"

def create_node_identifier(self, prefix: str) -> str:
self._node_identifier_count += 1
return f"{prefix}{self._node_identifier_count}"
def create_node_identifier(self, prefix: str, path: str) -> str:
if path not in self.node_set._unique_variables:
self._node_identifier_count += 1
return f"{prefix}{self._node_identifier_count}"
return prefix

def build_order_by(self, ident: str, source: "NodeSet") -> None:
if "?" in source.order_by_elements:
Expand Down Expand Up @@ -611,14 +613,16 @@ def build_traversal_from_path(
rhs_label = relationship.definition["node_class"].__label__
if relation.get("relation_filtering"):
rhs_name = rel_ident
rhs_ident = f":{rhs_label}"
else:
if index + 1 == len(parts) and "alias" in relation:
# If an alias is defined, use it to store the last hop in the path
rhs_name = relation["alias"]
else:
rhs_name = f"{rhs_label.lower()}_{rel_iterator}"
rhs_name = self.create_node_identifier(rhs_name)
rhs_ident = f"{rhs_name}:{rhs_label}"
rhs_name = self.create_node_identifier(rhs_name, rel_iterator)
rhs_ident = f"{rhs_name}:{rhs_label}"

if relation["include_in_return"] and not already_present:
self._additional_return(rhs_name)

Expand Down Expand Up @@ -823,9 +827,11 @@ def add_to_target(statement: str, connector: Q, optional: bool) -> None:
match_filters = [filter[0] for filter in target if not filter[1]]
opt_match_filters = [filter[0] for filter in target if filter[1]]
if q.connector == Q.OR and match_filters and opt_match_filters:
raise ValueError(
"Cannot filter using OR operator on variables coming from both MATCH and OPTIONAL MATCH statements"
)
# In this case, we can't split filters in two WHERE statements so we move
# everything into the one applied after OPTIONAL MATCH statements...
opt_match_filters += match_filters
match_filters = []

ret = f" {q.connector} ".join(match_filters)
if ret and q.negated:
ret = f"NOT ({ret})"
Expand Down Expand Up @@ -1377,6 +1383,7 @@ def __init__(self, source: Any) -> None:
self._extra_results: list = []
self._subqueries: list[Subquery] = []
self._intermediate_transforms: list = []
self._unique_variables: list[str] = []

def __await__(self) -> Any:
return self.all().__await__() # type: ignore[attr-defined]
Expand Down Expand Up @@ -1548,6 +1555,11 @@ def _register_relation_to_fetch(
item["alias"] = alias
return item

def unique_variables(self, *pathes: tuple[str, ...]) -> "NodeSet":
"""Generate unique variable names for the given pathes."""
self._unique_variables = pathes
return self

def fetch_relations(self, *relation_names: tuple[str, ...]) -> "NodeSet":
"""Specify a set of relations to traverse and return."""
relations = []
Expand Down
64 changes: 56 additions & 8 deletions test/async_/test_match_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from datetime import datetime
from test._async_compat import mark_async_test

import numpy as np
from pytest import raises, skip, warns

from neomodel import (
Expand Down Expand Up @@ -545,13 +544,14 @@ async def test_q_filters():
assert c6 in combined_coffees
assert c3 not in combined_coffees

with raises(
ValueError,
match=r"Cannot filter using OR operator on variables coming from both MATCH and OPTIONAL MATCH statements",
):
await Coffee.nodes.fetch_relations(Optional("species")).filter(
Q(name="Latte") | Q(species__name="Robusta")
).all()
robusta = await Species(name="Robusta").save()
await c4.species.connect(robusta)
latte_or_robusta_coffee = (
await Coffee.nodes.fetch_relations(Optional("species"))
.filter(Q(name="Latte") | Q(species__name="Robusta"))
.all()
)
assert len(latte_or_robusta_coffee) == 2

class QQ:
pass
Expand Down Expand Up @@ -632,6 +632,11 @@ async def test_relation_prop_filtering():
await nescafe.suppliers.connect(supplier2, {"since": datetime(2010, 4, 1, 0, 0)})
await nescafe.species.connect(arabica)

result = await Coffee.nodes.filter(
**{"suppliers|since__gt": datetime(2010, 4, 1, 0, 0)}
).all()
assert len(result) == 1

results = await Supplier.nodes.filter(
**{"coffees__name": "Nescafe", "coffees|since__gt": datetime(2018, 4, 1, 0, 0)}
).all()
Expand Down Expand Up @@ -1155,6 +1160,49 @@ async def test_in_filter_with_array_property():
), "Species found by tags with not match tags given"


@mark_async_test
async def test_unique_variables():
arabica = await Species(name="Arabica").save()
nescafe = await Coffee(name="Nescafe", price=99).save()
gold3000 = await Coffee(name="Gold 3000", price=11).save()
supplier1 = await Supplier(name="Supplier 1", delivery_cost=3).save()
supplier2 = await Supplier(name="Supplier 2", delivery_cost=20).save()
supplier3 = await Supplier(name="Supplier 3", delivery_cost=20).save()

await nescafe.suppliers.connect(supplier1, {"since": datetime(2020, 4, 1, 0, 0)})
await nescafe.suppliers.connect(supplier2, {"since": datetime(2010, 4, 1, 0, 0)})
await nescafe.species.connect(arabica)
await gold3000.suppliers.connect(supplier1, {"since": datetime(2020, 4, 1, 0, 0)})
await gold3000.species.connect(arabica)

nodeset = Supplier.nodes.fetch_relations("coffees", "coffees__species").filter(
coffees__name="Nescafe"
)
ast = await nodeset.query_cls(nodeset).build_ast()
query = ast.build_query()
assert "coffee_coffees1" in query
assert "coffee_coffees2" in query
results = await nodeset.all()
# This will be 3 because 2 suppliers for Nescafe and 1 for Gold 3000
# Gold 3000 is traversed because coffees__species redefines the coffees traversal
assert len(results) == 3

nodeset = (
Supplier.nodes.fetch_relations("coffees", "coffees__species")
.filter(coffees__name="Nescafe")
.unique_variables("coffees")
)
ast = await nodeset.query_cls(nodeset).build_ast()
query = ast.build_query()
assert "coffee_coffees" in query
assert "coffee_coffees1" not in query
assert "coffee_coffees2" not in query
results = await nodeset.all()
# This will 2 because Gold 3000 is excluded this time
# As coffees will be reused in coffees__species
assert len(results) == 2


@mark_async_test
async def test_async_iterator():
n = 10
Expand Down
64 changes: 56 additions & 8 deletions test/sync_/test_match_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from datetime import datetime
from test._async_compat import mark_sync_test

import numpy as np
from pytest import raises, skip, warns

from neomodel import (
Expand Down Expand Up @@ -541,13 +540,14 @@ def test_q_filters():
assert c6 in combined_coffees
assert c3 not in combined_coffees

with raises(
ValueError,
match=r"Cannot filter using OR operator on variables coming from both MATCH and OPTIONAL MATCH statements",
):
Coffee.nodes.fetch_relations(Optional("species")).filter(
Q(name="Latte") | Q(species__name="Robusta")
).all()
robusta = Species(name="Robusta").save()
c4.species.connect(robusta)
latte_or_robusta_coffee = (
Coffee.nodes.fetch_relations(Optional("species"))
.filter(Q(name="Latte") | Q(species__name="Robusta"))
.all()
)
assert len(latte_or_robusta_coffee) == 2

class QQ:
pass
Expand Down Expand Up @@ -624,6 +624,11 @@ def test_relation_prop_filtering():
nescafe.suppliers.connect(supplier2, {"since": datetime(2010, 4, 1, 0, 0)})
nescafe.species.connect(arabica)

result = Coffee.nodes.filter(
**{"suppliers|since__gt": datetime(2010, 4, 1, 0, 0)}
).all()
assert len(result) == 1

results = Supplier.nodes.filter(
**{"coffees__name": "Nescafe", "coffees|since__gt": datetime(2018, 4, 1, 0, 0)}
).all()
Expand Down Expand Up @@ -1139,6 +1144,49 @@ def test_in_filter_with_array_property():
), "Species found by tags with not match tags given"


@mark_sync_test
def test_unique_variables():
arabica = Species(name="Arabica").save()
nescafe = Coffee(name="Nescafe", price=99).save()
gold3000 = Coffee(name="Gold 3000", price=11).save()
supplier1 = Supplier(name="Supplier 1", delivery_cost=3).save()
supplier2 = Supplier(name="Supplier 2", delivery_cost=20).save()
supplier3 = Supplier(name="Supplier 3", delivery_cost=20).save()

nescafe.suppliers.connect(supplier1, {"since": datetime(2020, 4, 1, 0, 0)})
nescafe.suppliers.connect(supplier2, {"since": datetime(2010, 4, 1, 0, 0)})
nescafe.species.connect(arabica)
gold3000.suppliers.connect(supplier1, {"since": datetime(2020, 4, 1, 0, 0)})
gold3000.species.connect(arabica)

nodeset = Supplier.nodes.fetch_relations("coffees", "coffees__species").filter(
coffees__name="Nescafe"
)
ast = nodeset.query_cls(nodeset).build_ast()
query = ast.build_query()
assert "coffee_coffees1" in query
assert "coffee_coffees2" in query
results = nodeset.all()
# This will be 3 because 2 suppliers for Nescafe and 1 for Gold 3000
# Gold 3000 is traversed because coffees__species redefines the coffees traversal
assert len(results) == 3

nodeset = (
Supplier.nodes.fetch_relations("coffees", "coffees__species")
.filter(coffees__name="Nescafe")
.unique_variables("coffees")
)
ast = nodeset.query_cls(nodeset).build_ast()
query = ast.build_query()
assert "coffee_coffees" in query
assert "coffee_coffees1" not in query
assert "coffee_coffees2" not in query
results = nodeset.all()
# This will 2 because Gold 3000 is excluded this time
# As coffees will be reused in coffees__species
assert len(results) == 2


@mark_sync_test
def test_async_iterator():
n = 10
Expand Down