Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added method to generate unique variable names for specific pathes. #864

Merged
merged 6 commits into from
Mar 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions doc/source/traversal.rst
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,26 @@ With both `traverse_relations` and `fetch_relations`, you can force the use of a

Person.nodes.fetch_relations('city__country', Optional('country')).all()

Unique variables
----------------

If you want to use the same variable name for traversed nodes when chaining traversals, you can use the `unique_variables` method::

# This does not guarantee that coffees__species will traverse the same nodes as coffees
# So coffees__species can traverse the Coffee node "Gold 3000"
nodeset = (
Supplier.nodes.fetch_relations("coffees", "coffees__species")
.filter(coffees__name="Nescafe")
)

# This guarantees that coffees__species will traverse the same nodes as coffees
# So when fetching species, it will only fetch those of the Coffee node "Nescafe"
nodeset = (
Supplier.nodes.fetch_relations("coffees", "coffees__species")
.filter(coffees__name="Nescafe")
.unique_variables("coffees")
)

Resolve results
---------------

Expand Down
17 changes: 13 additions & 4 deletions neomodel/async_/match.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,9 +503,11 @@ def create_relation_identifier(self) -> str:
self._relation_identifier_count += 1
return f"r{self._relation_identifier_count}"

def create_node_identifier(self, prefix: str) -> str:
self._node_identifier_count += 1
return f"{prefix}{self._node_identifier_count}"
def create_node_identifier(self, prefix: str, path: str) -> str:
if path not in self.node_set._unique_variables:
self._node_identifier_count += 1
return f"{prefix}{self._node_identifier_count}"
return prefix

def build_order_by(self, ident: str, source: "AsyncNodeSet") -> None:
if "?" in source.order_by_elements:
Expand Down Expand Up @@ -620,8 +622,9 @@ def build_traversal_from_path(
rhs_name = relation["alias"]
else:
rhs_name = f"{rhs_label.lower()}_{rel_iterator}"
rhs_name = self.create_node_identifier(rhs_name)
rhs_name = self.create_node_identifier(rhs_name, rel_iterator)
rhs_ident = f"{rhs_name}:{rhs_label}"

if relation["include_in_return"] and not already_present:
self._additional_return(rhs_name)

Expand Down Expand Up @@ -1382,6 +1385,7 @@ def __init__(self, source: Any) -> None:
self._extra_results: list = []
self._subqueries: list[Subquery] = []
self._intermediate_transforms: list = []
self._unique_variables: list[str] = []

def __await__(self) -> Any:
return self.all().__await__() # type: ignore[attr-defined]
Expand Down Expand Up @@ -1553,6 +1557,11 @@ def _register_relation_to_fetch(
item["alias"] = alias
return item

def unique_variables(self, *pathes: tuple[str, ...]) -> "AsyncNodeSet":
"""Generate unique variable names for the given pathes."""
self._unique_variables = pathes
return self

def fetch_relations(self, *relation_names: tuple[str, ...]) -> "AsyncNodeSet":
"""Specify a set of relations to traverse and return."""
relations = []
Expand Down
17 changes: 13 additions & 4 deletions neomodel/sync_/match.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,9 +501,11 @@ def create_relation_identifier(self) -> str:
self._relation_identifier_count += 1
return f"r{self._relation_identifier_count}"

def create_node_identifier(self, prefix: str) -> str:
self._node_identifier_count += 1
return f"{prefix}{self._node_identifier_count}"
def create_node_identifier(self, prefix: str, path: str) -> str:
if path not in self.node_set._unique_variables:
self._node_identifier_count += 1
return f"{prefix}{self._node_identifier_count}"
return prefix

def build_order_by(self, ident: str, source: "NodeSet") -> None:
if "?" in source.order_by_elements:
Expand Down Expand Up @@ -618,8 +620,9 @@ def build_traversal_from_path(
rhs_name = relation["alias"]
else:
rhs_name = f"{rhs_label.lower()}_{rel_iterator}"
rhs_name = self.create_node_identifier(rhs_name)
rhs_name = self.create_node_identifier(rhs_name, rel_iterator)
rhs_ident = f"{rhs_name}:{rhs_label}"

if relation["include_in_return"] and not already_present:
self._additional_return(rhs_name)

Expand Down Expand Up @@ -1378,6 +1381,7 @@ def __init__(self, source: Any) -> None:
self._extra_results: list = []
self._subqueries: list[Subquery] = []
self._intermediate_transforms: list = []
self._unique_variables: list[str] = []

def __await__(self) -> Any:
return self.all().__await__() # type: ignore[attr-defined]
Expand Down Expand Up @@ -1549,6 +1553,11 @@ def _register_relation_to_fetch(
item["alias"] = alias
return item

def unique_variables(self, *pathes: tuple[str, ...]) -> "NodeSet":
"""Generate unique variable names for the given pathes."""
self._unique_variables = pathes
return self

def fetch_relations(self, *relation_names: tuple[str, ...]) -> "NodeSet":
"""Specify a set of relations to traverse and return."""
relations = []
Expand Down
44 changes: 43 additions & 1 deletion test/async_/test_match_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from datetime import datetime
from test._async_compat import mark_async_test

import numpy as np
from pytest import raises, skip, warns

from neomodel import (
Expand Down Expand Up @@ -1160,6 +1159,49 @@ async def test_in_filter_with_array_property():
), "Species found by tags with not match tags given"


@mark_async_test
async def test_unique_variables():
arabica = await Species(name="Arabica").save()
nescafe = await Coffee(name="Nescafe", price=99).save()
gold3000 = await Coffee(name="Gold 3000", price=11).save()
supplier1 = await Supplier(name="Supplier 1", delivery_cost=3).save()
supplier2 = await Supplier(name="Supplier 2", delivery_cost=20).save()
supplier3 = await Supplier(name="Supplier 3", delivery_cost=20).save()

await nescafe.suppliers.connect(supplier1, {"since": datetime(2020, 4, 1, 0, 0)})
await nescafe.suppliers.connect(supplier2, {"since": datetime(2010, 4, 1, 0, 0)})
await nescafe.species.connect(arabica)
await gold3000.suppliers.connect(supplier1, {"since": datetime(2020, 4, 1, 0, 0)})
await gold3000.species.connect(arabica)

nodeset = Supplier.nodes.fetch_relations("coffees", "coffees__species").filter(
coffees__name="Nescafe"
)
ast = await nodeset.query_cls(nodeset).build_ast()
query = ast.build_query()
assert "coffee_coffees1" in query
assert "coffee_coffees2" in query
results = await nodeset.all()
# This will be 3 because 2 suppliers for Nescafe and 1 for Gold 3000
# Gold 3000 is traversed because coffees__species redefines the coffees traversal
assert len(results) == 3

nodeset = (
Supplier.nodes.fetch_relations("coffees", "coffees__species")
.filter(coffees__name="Nescafe")
.unique_variables("coffees")
)
ast = await nodeset.query_cls(nodeset).build_ast()
query = ast.build_query()
assert "coffee_coffees" in query
assert "coffee_coffees1" not in query
assert "coffee_coffees2" not in query
results = await nodeset.all()
# This will 2 because Gold 3000 is excluded this time
# As coffees will be reused in coffees__species
assert len(results) == 2


@mark_async_test
async def test_async_iterator():
n = 10
Expand Down
44 changes: 43 additions & 1 deletion test/sync_/test_match_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from datetime import datetime
from test._async_compat import mark_sync_test

import numpy as np
from pytest import raises, skip, warns

from neomodel import (
Expand Down Expand Up @@ -1144,6 +1143,49 @@ def test_in_filter_with_array_property():
), "Species found by tags with not match tags given"


@mark_sync_test
def test_unique_variables():
arabica = Species(name="Arabica").save()
nescafe = Coffee(name="Nescafe", price=99).save()
gold3000 = Coffee(name="Gold 3000", price=11).save()
supplier1 = Supplier(name="Supplier 1", delivery_cost=3).save()
supplier2 = Supplier(name="Supplier 2", delivery_cost=20).save()
supplier3 = Supplier(name="Supplier 3", delivery_cost=20).save()

nescafe.suppliers.connect(supplier1, {"since": datetime(2020, 4, 1, 0, 0)})
nescafe.suppliers.connect(supplier2, {"since": datetime(2010, 4, 1, 0, 0)})
nescafe.species.connect(arabica)
gold3000.suppliers.connect(supplier1, {"since": datetime(2020, 4, 1, 0, 0)})
gold3000.species.connect(arabica)

nodeset = Supplier.nodes.fetch_relations("coffees", "coffees__species").filter(
coffees__name="Nescafe"
)
ast = nodeset.query_cls(nodeset).build_ast()
query = ast.build_query()
assert "coffee_coffees1" in query
assert "coffee_coffees2" in query
results = nodeset.all()
# This will be 3 because 2 suppliers for Nescafe and 1 for Gold 3000
# Gold 3000 is traversed because coffees__species redefines the coffees traversal
assert len(results) == 3

nodeset = (
Supplier.nodes.fetch_relations("coffees", "coffees__species")
.filter(coffees__name="Nescafe")
.unique_variables("coffees")
)
ast = nodeset.query_cls(nodeset).build_ast()
query = ast.build_query()
assert "coffee_coffees" in query
assert "coffee_coffees1" not in query
assert "coffee_coffees2" not in query
results = nodeset.all()
# This will 2 because Gold 3000 is excluded this time
# As coffees will be reused in coffees__species
assert len(results) == 2


@mark_sync_test
def test_async_iterator():
n = 10
Expand Down
Loading