diff --git a/src/aiida/storage/psql_dos/orm/groups.py b/src/aiida/storage/psql_dos/orm/groups.py index ee82dc24ec..c030e74123 100644 --- a/src/aiida/storage/psql_dos/orm/groups.py +++ b/src/aiida/storage/psql_dos/orm/groups.py @@ -167,17 +167,10 @@ def add_nodes(self, nodes, **kwargs): :note: all the nodes *and* the group itself have to be stored. :param nodes: a list of `BackendNode` instance to be added to this group - - :param kwargs: - skip_orm: When the flag is on, the SQLA ORM is skipped and SQLA is used - to create a direct SQL INSERT statement to the group-node relationship - table (to improve speed). """ from sqlalchemy.dialects.postgresql import insert - from sqlalchemy.exc import IntegrityError super().add_nodes(nodes) - skip_orm = kwargs.get('skip_orm', False) def check_node(given_node): """Check if given node is of correct type and stored""" @@ -188,31 +181,16 @@ def check_node(given_node): raise ValueError('At least one of the provided nodes is unstored, stopping...') with utils.disable_expire_on_commit(self.backend.get_session()) as session: - if not skip_orm: - # Get dbnodes here ONCE, otherwise each call to dbnodes will re-read the current value in the database - dbnodes = self.model.dbnodes - - for node in nodes: - check_node(node) - - # Use pattern as suggested here: - # http://docs.sqlalchemy.org/en/latest/orm/session_transaction.html#using-savepoint - try: - with session.begin_nested(): - dbnodes.append(node.bare_model) - session.flush() - except IntegrityError: - # Duplicate entry, skip - pass - else: - ins_dict = [] - for node in nodes: - check_node(node) - ins_dict.append({'dbnode_id': node.id, 'dbgroup_id': self.id}) - - table = self.GROUP_NODE_CLASS.__table__ - ins = insert(table).values(ins_dict) - session.execute(ins.on_conflict_do_nothing(index_elements=['dbnode_id', 'dbgroup_id'])) + ins_dict = [] + for node in nodes: + check_node(node) + ins_dict.append({'dbnode_id': node.id, 'dbgroup_id': self.id}) + if len(ins_dict) == 0: + return + + table = self.GROUP_NODE_CLASS.__table__ + ins = insert(table).values(ins_dict) + session.execute(ins.on_conflict_do_nothing(index_elements=['dbnode_id', 'dbgroup_id'])) # Commit everything as up till now we've just flushed if not session.in_nested_transaction(): @@ -224,18 +202,11 @@ def remove_nodes(self, nodes, **kwargs): :note: all the nodes *and* the group itself have to be stored. :param nodes: a list of `BackendNode` instance to be added to this group - :param kwargs: - skip_orm: When the flag is set to `True`, the SQLA ORM is skipped and SQLA is used to create a direct SQL - DELETE statement to the group-node relationship table in order to improve speed. """ from sqlalchemy import and_ super().remove_nodes(nodes) - # Get dbnodes here ONCE, otherwise each call to dbnodes will re-read the current value in the database - dbnodes = self.model.dbnodes - skip_orm = kwargs.get('skip_orm', False) - def check_node(node): if not isinstance(node, self.NODE_CLASS): raise TypeError(f'invalid type {type(node)}, has to be {self.NODE_CLASS}') @@ -243,26 +214,13 @@ def check_node(node): if node.id is None: raise ValueError('At least one of the provided nodes is unstored, stopping...') - list_nodes = [] - with utils.disable_expire_on_commit(self.backend.get_session()) as session: - if not skip_orm: - for node in nodes: - check_node(node) - - # Check first, if SqlA issues a DELETE statement for an unexisting key it will result in an error - if node.bare_model in dbnodes: - list_nodes.append(node.bare_model) - - for node in list_nodes: - dbnodes.remove(node) - else: - table = self.GROUP_NODE_CLASS.__table__ - for node in nodes: - check_node(node) - clause = and_(table.c.dbnode_id == node.id, table.c.dbgroup_id == self.id) - statement = table.delete().where(clause) - session.execute(statement) + table = self.GROUP_NODE_CLASS.__table__ + for node in nodes: + check_node(node) + clause = and_(table.c.dbnode_id == node.id, table.c.dbgroup_id == self.id) + statement = table.delete().where(clause) + session.execute(statement) if not session.in_nested_transaction(): session.commit() diff --git a/tests/orm/implementation/test_groups.py b/tests/orm/implementation/test_groups.py index 199d363e25..19a6cf9d70 100644 --- a/tests/orm/implementation/test_groups.py +++ b/tests/orm/implementation/test_groups.py @@ -25,66 +25,3 @@ def test_creation_from_dbgroup(backend): assert group.pk == gcopy.pk assert group.uuid == gcopy.uuid - - -def test_add_nodes_skip_orm(): - """Test the `SqlaGroup.add_nodes` method with the `skip_orm=True` flag.""" - group = orm.Group(label='test_adding_nodes').store().backend_entity - - node_01 = orm.Data().store().backend_entity - node_02 = orm.Data().store().backend_entity - node_03 = orm.Data().store().backend_entity - node_04 = orm.Data().store().backend_entity - node_05 = orm.Data().store().backend_entity - nodes = [node_01, node_02, node_03, node_04, node_05] - - group.add_nodes([node_01], skip_orm=True) - group.add_nodes([node_02, node_03], skip_orm=True) - group.add_nodes((node_04, node_05), skip_orm=True) - - assert set(_.pk for _ in nodes) == set(_.pk for _ in group.nodes) - - # Try to add a node that is already present: there should be no problem - group.add_nodes([node_01], skip_orm=True) - assert set(_.pk for _ in nodes) == set(_.pk for _ in group.nodes) - - -def test_add_nodes_skip_orm_batch(): - """Test the `SqlaGroup.add_nodes` method with the `skip_orm=True` flag and batches.""" - nodes = [orm.Data().store().backend_entity for _ in range(100)] - - # Add nodes to groups using different batch size. Check in the end the correct addition. - batch_sizes = (1, 3, 10, 1000) - for batch_size in batch_sizes: - group = orm.Group(label=f'test_batches_{batch_size!s}').store() - group.backend_entity.add_nodes(nodes, skip_orm=True, batch_size=batch_size) - assert set(_.pk for _ in nodes) == set(_.pk for _ in group.nodes) - - -def test_remove_nodes_bulk(): - """Test node removal with `skip_orm=True`.""" - group = orm.Group(label='test_removing_nodes').store().backend_entity - - node_01 = orm.Data().store().backend_entity - node_02 = orm.Data().store().backend_entity - node_03 = orm.Data().store().backend_entity - node_04 = orm.Data().store().backend_entity - nodes = [node_01, node_02, node_03] - - group.add_nodes(nodes) - assert set(_.pk for _ in nodes) == set(_.pk for _ in group.nodes) - - # Remove a node that is not in the group: nothing should happen - group.remove_nodes([node_04], skip_orm=True) - assert set(_.pk for _ in nodes) == set(_.pk for _ in group.nodes) - - # Remove one Node - nodes.remove(node_03) - group.remove_nodes([node_03], skip_orm=True) - assert set(_.pk for _ in nodes) == set(_.pk for _ in group.nodes) - - # Remove a list of Nodes and check - nodes.remove(node_01) - nodes.remove(node_02) - group.remove_nodes([node_01, node_02], skip_orm=True) - assert set(_.pk for _ in nodes) == set(_.pk for _ in group.nodes) diff --git a/tests/orm/test_groups.py b/tests/orm/test_groups.py index c62f903400..3e0c592c9f 100644 --- a/tests/orm/test_groups.py +++ b/tests/orm/test_groups.py @@ -149,13 +149,27 @@ def test_add_nodes(self): group.add_nodes(node_01) assert set(_.pk for _ in nodes) == set(_.pk for _ in group.nodes) + # Try to add nothing: there should be no problem + group.add_nodes([]) + assert set(_.pk for _ in nodes) == set(_.pk for _ in group.nodes) + + nodes = [orm.Data().store().backend_entity for _ in range(100)] + + # Add nodes to groups using different batch size. Check in the end the correct addition. + batch_sizes = (1, 3, 10, 1000) + for batch_size in batch_sizes: + group = orm.Group(label=f'test_batches_{batch_size!s}').store() + group.backend_entity.add_nodes(nodes, batch_size=batch_size) + assert set(_.pk for _ in nodes) == set(_.pk for _ in group.nodes) + def test_remove_nodes(self): """Test node removal.""" node_01 = orm.Data().store() node_02 = orm.Data().store() node_03 = orm.Data().store() node_04 = orm.Data().store() - nodes = [node_01, node_02, node_03] + node_05 = orm.Data().store() + nodes = [node_01, node_02, node_03, node_05] group = orm.Group(label=uuid.uuid4().hex).store() # Add initial nodes @@ -177,6 +191,15 @@ def test_remove_nodes(self): group.remove_nodes([node_01, node_02]) assert set(_.pk for _ in nodes) == set(_.pk for _ in group.nodes) + # Remove to empty + nodes.remove(node_05) + group.remove_nodes([node_05]) + assert set(_.pk for _ in nodes) == set(_.pk for _ in group.nodes) + + # Try to remove nothing: there should be no problem + group.remove_nodes([]) + assert set(_.pk for _ in nodes) == set(_.pk for _ in group.nodes) + def test_clear(self): """Test the `clear` method to remove all nodes.""" node_01 = orm.Data().store()