Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ORM: Use skip_orm as the default implementation for SqlaGroup.add_nodes and SqlaGroup.remove_nodes #6720

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 16 additions & 58 deletions src/aiida/storage/psql_dos/orm/groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,17 +167,10 @@ def add_nodes(self, nodes, **kwargs):
:note: all the nodes *and* the group itself have to be stored.

:param nodes: a list of `BackendNode` instance to be added to this group

:param kwargs:
skip_orm: When the flag is on, the SQLA ORM is skipped and SQLA is used
to create a direct SQL INSERT statement to the group-node relationship
table (to improve speed).
"""
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy.exc import IntegrityError

super().add_nodes(nodes)
skip_orm = kwargs.get('skip_orm', False)

def check_node(given_node):
"""Check if given node is of correct type and stored"""
Expand All @@ -188,31 +181,16 @@ def check_node(given_node):
raise ValueError('At least one of the provided nodes is unstored, stopping...')

with utils.disable_expire_on_commit(self.backend.get_session()) as session:
if not skip_orm:
# Get dbnodes here ONCE, otherwise each call to dbnodes will re-read the current value in the database
dbnodes = self.model.dbnodes

for node in nodes:
check_node(node)

# Use pattern as suggested here:
# http://docs.sqlalchemy.org/en/latest/orm/session_transaction.html#using-savepoint
try:
with session.begin_nested():
dbnodes.append(node.bare_model)
session.flush()
except IntegrityError:
# Duplicate entry, skip
pass
else:
ins_dict = []
for node in nodes:
check_node(node)
ins_dict.append({'dbnode_id': node.id, 'dbgroup_id': self.id})

table = self.GROUP_NODE_CLASS.__table__
ins = insert(table).values(ins_dict)
session.execute(ins.on_conflict_do_nothing(index_elements=['dbnode_id', 'dbgroup_id']))
ins_dict = []
for node in nodes:
check_node(node)
ins_dict.append({'dbnode_id': node.id, 'dbgroup_id': self.id})
if len(ins_dict) == 0:
return

table = self.GROUP_NODE_CLASS.__table__
ins = insert(table).values(ins_dict)
session.execute(ins.on_conflict_do_nothing(index_elements=['dbnode_id', 'dbgroup_id']))

# Commit everything as up till now we've just flushed
if not session.in_nested_transaction():
Expand All @@ -224,45 +202,25 @@ def remove_nodes(self, nodes, **kwargs):
:note: all the nodes *and* the group itself have to be stored.

:param nodes: a list of `BackendNode` instance to be added to this group
:param kwargs:
skip_orm: When the flag is set to `True`, the SQLA ORM is skipped and SQLA is used to create a direct SQL
DELETE statement to the group-node relationship table in order to improve speed.
"""
from sqlalchemy import and_

super().remove_nodes(nodes)

# Get dbnodes here ONCE, otherwise each call to dbnodes will re-read the current value in the database
dbnodes = self.model.dbnodes
skip_orm = kwargs.get('skip_orm', False)

def check_node(node):
if not isinstance(node, self.NODE_CLASS):
raise TypeError(f'invalid type {type(node)}, has to be {self.NODE_CLASS}')

if node.id is None:
raise ValueError('At least one of the provided nodes is unstored, stopping...')

list_nodes = []

with utils.disable_expire_on_commit(self.backend.get_session()) as session:
if not skip_orm:
for node in nodes:
check_node(node)

# Check first, if SqlA issues a DELETE statement for an unexisting key it will result in an error
if node.bare_model in dbnodes:
list_nodes.append(node.bare_model)

for node in list_nodes:
dbnodes.remove(node)
else:
table = self.GROUP_NODE_CLASS.__table__
for node in nodes:
check_node(node)
clause = and_(table.c.dbnode_id == node.id, table.c.dbgroup_id == self.id)
statement = table.delete().where(clause)
session.execute(statement)
table = self.GROUP_NODE_CLASS.__table__
for node in nodes:
check_node(node)
clause = and_(table.c.dbnode_id == node.id, table.c.dbgroup_id == self.id)
statement = table.delete().where(clause)
session.execute(statement)

if not session.in_nested_transaction():
session.commit()
Expand Down
63 changes: 0 additions & 63 deletions tests/orm/implementation/test_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,66 +25,3 @@ def test_creation_from_dbgroup(backend):

assert group.pk == gcopy.pk
assert group.uuid == gcopy.uuid


def test_add_nodes_skip_orm():
"""Test the `SqlaGroup.add_nodes` method with the `skip_orm=True` flag."""
group = orm.Group(label='test_adding_nodes').store().backend_entity

node_01 = orm.Data().store().backend_entity
node_02 = orm.Data().store().backend_entity
node_03 = orm.Data().store().backend_entity
node_04 = orm.Data().store().backend_entity
node_05 = orm.Data().store().backend_entity
nodes = [node_01, node_02, node_03, node_04, node_05]

group.add_nodes([node_01], skip_orm=True)
group.add_nodes([node_02, node_03], skip_orm=True)
group.add_nodes((node_04, node_05), skip_orm=True)

assert set(_.pk for _ in nodes) == set(_.pk for _ in group.nodes)

# Try to add a node that is already present: there should be no problem
group.add_nodes([node_01], skip_orm=True)
assert set(_.pk for _ in nodes) == set(_.pk for _ in group.nodes)


def test_add_nodes_skip_orm_batch():
rabbull marked this conversation as resolved.
Show resolved Hide resolved
"""Test the `SqlaGroup.add_nodes` method with the `skip_orm=True` flag and batches."""
nodes = [orm.Data().store().backend_entity for _ in range(100)]

# Add nodes to groups using different batch size. Check in the end the correct addition.
batch_sizes = (1, 3, 10, 1000)
for batch_size in batch_sizes:
group = orm.Group(label=f'test_batches_{batch_size!s}').store()
group.backend_entity.add_nodes(nodes, skip_orm=True, batch_size=batch_size)
assert set(_.pk for _ in nodes) == set(_.pk for _ in group.nodes)


def test_remove_nodes_bulk():
"""Test node removal with `skip_orm=True`."""
group = orm.Group(label='test_removing_nodes').store().backend_entity

node_01 = orm.Data().store().backend_entity
node_02 = orm.Data().store().backend_entity
node_03 = orm.Data().store().backend_entity
node_04 = orm.Data().store().backend_entity
nodes = [node_01, node_02, node_03]

group.add_nodes(nodes)
assert set(_.pk for _ in nodes) == set(_.pk for _ in group.nodes)

# Remove a node that is not in the group: nothing should happen
group.remove_nodes([node_04], skip_orm=True)
assert set(_.pk for _ in nodes) == set(_.pk for _ in group.nodes)

# Remove one Node
nodes.remove(node_03)
group.remove_nodes([node_03], skip_orm=True)
assert set(_.pk for _ in nodes) == set(_.pk for _ in group.nodes)

# Remove a list of Nodes and check
nodes.remove(node_01)
nodes.remove(node_02)
group.remove_nodes([node_01, node_02], skip_orm=True)
assert set(_.pk for _ in nodes) == set(_.pk for _ in group.nodes)
25 changes: 24 additions & 1 deletion tests/orm/test_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,13 +149,27 @@ def test_add_nodes(self):
group.add_nodes(node_01)
assert set(_.pk for _ in nodes) == set(_.pk for _ in group.nodes)

# Try to add nothing: there should be no problem
group.add_nodes([])
assert set(_.pk for _ in nodes) == set(_.pk for _ in group.nodes)

nodes = [orm.Data().store().backend_entity for _ in range(100)]

# Add nodes to groups using different batch size. Check in the end the correct addition.
batch_sizes = (1, 3, 10, 1000)
for batch_size in batch_sizes:
group = orm.Group(label=f'test_batches_{batch_size!s}').store()
group.backend_entity.add_nodes(nodes, batch_size=batch_size)
assert set(_.pk for _ in nodes) == set(_.pk for _ in group.nodes)

def test_remove_nodes(self):
"""Test node removal."""
node_01 = orm.Data().store()
node_02 = orm.Data().store()
node_03 = orm.Data().store()
node_04 = orm.Data().store()
nodes = [node_01, node_02, node_03]
node_05 = orm.Data().store()
nodes = [node_01, node_02, node_03, node_05]
group = orm.Group(label=uuid.uuid4().hex).store()

# Add initial nodes
Expand All @@ -177,6 +191,15 @@ def test_remove_nodes(self):
group.remove_nodes([node_01, node_02])
assert set(_.pk for _ in nodes) == set(_.pk for _ in group.nodes)
rabbull marked this conversation as resolved.
Show resolved Hide resolved

# Remove to empty
nodes.remove(node_05)
group.remove_nodes([node_05])
assert set(_.pk for _ in nodes) == set(_.pk for _ in group.nodes)

# Try to remove nothing: there should be no problem
group.remove_nodes([])
assert set(_.pk for _ in nodes) == set(_.pk for _ in group.nodes)

def test_clear(self):
"""Test the `clear` method to remove all nodes."""
node_01 = orm.Data().store()
Expand Down
Loading