diff --git a/forte/data/base_pack.py b/forte/data/base_pack.py index 9257187ca..641f73ed0 100644 --- a/forte/data/base_pack.py +++ b/forte/data/base_pack.py @@ -803,13 +803,13 @@ def get_ids_by_creator(self, component: str) -> Set[int]: return entry_set def is_created_by( - self, entry: Entry, components: Union[str, Iterable[str]] + self, entry: Union[Entry, int], components: Union[str, Iterable[str]] ) -> bool: """ Check if the entry is created by any of the provided components. Args: - entry: The entry to check. + entry: `tid` of the entry or the entry object to check components: The list of component names. Returns: @@ -818,8 +818,10 @@ def is_created_by( if isinstance(components, str): components = [components] + entry_tid = entry.tid if isinstance(entry, Entry) else entry + for c in components: - if entry.tid in self._creation_records[c]: + if entry_tid in self._creation_records[c]: break else: # The entry not created by any of these components. diff --git a/forte/data/data_pack.py b/forte/data/data_pack.py index f79b89f7f..3ca647876 100644 --- a/forte/data/data_pack.py +++ b/forte/data/data_pack.py @@ -37,7 +37,7 @@ ProcessExecutionException, UnknownOntologyClassException, ) -from forte.common.constants import TID_INDEX +from forte.common.constants import TID_INDEX, BEGIN_ATTR_NAME, END_ATTR_NAME from forte.data import data_utils_io from forte.data.data_store import DataStore from forte.data.entry_converter import EntryConverter @@ -49,10 +49,11 @@ Annotation, Link, Group, - SinglePackEntries, Generics, AudioAnnotation, Payload, + SinglePackEntries, + AnnotationLikeEntries, ) from forte.data.modality import Modality @@ -1503,9 +1504,12 @@ def covers( def get( # type: ignore self, entry_type: Union[str, Type[EntryType]], - range_annotation: Optional[Union[Annotation, AudioAnnotation]] = None, + range_annotation: Optional[ + Union[Annotation, AudioAnnotation, int] + ] = None, components: Optional[Union[str, Iterable[str]]] = None, include_sub_type: bool = True, + get_raw: bool = False, ) -> Iterable[EntryType]: r"""This function is used to get data from a data pack with various methods. @@ -1571,28 +1575,47 @@ def get( # type: ignore Args: entry_type: The type of entries requested. range_annotation: The - range of entries requested. If `None`, will return valid - entries in the range of whole data pack. + range of entries requested. This value can be given by an + entry object or the ``tid`` of that entry. If `None`, will + return valid entries in the range of whole data pack. components: The component (creator) generating the entries requested. If `None`, will return valid entries generated by any component. include_sub_type: whether to consider the sub types of the provided entry type. Default `True`. + get_raw: boolean to indicate if the entry should be returned in + its primitive form as opposed to an object. False by default Yields: Each `Entry` found using this method. """ - entry_type_: Type[EntryType] = as_entry_type(entry_type) + # Convert entry_type to str + entry_type_ = ( + get_full_module_name(entry_type) + if not isinstance(entry_type, str) + else entry_type + ) + + # pylint: disable=protected-access + # Check if entry_type_ represents a valid entry + if not self._data_store._is_subclass(entry_type_, Entry): + raise ValueError( + f"The specified entry type [{entry_type}] " + f"does not correspond to a " + f"`forte.data.ontology.core.Entry` class" + ) def require_annotations(entry_class=Annotation) -> bool: - if issubclass(entry_type_, entry_class): + if self._data_store._is_subclass(entry_type_, entry_class): return True - if issubclass(entry_type_, Link): + + curr_class: Type[EntryType] = as_entry_type(entry_type_) + if issubclass(curr_class, Link): return issubclass( - entry_type_.ParentType, entry_class - ) and issubclass(entry_type_.ChildType, entry_class) - if issubclass(entry_type_, Group): - return issubclass(entry_type_.MemberType, entry_class) + curr_class.ParentType, entry_class + ) and issubclass(curr_class.ChildType, entry_class) + if issubclass(curr_class, Group): + return issubclass(curr_class.MemberType, entry_class) return False # If we don't have any annotations but the items to check requires them, @@ -1631,28 +1654,58 @@ def require_annotations(entry_class=Annotation) -> bool: yield from [] return + # If range_annotation is specified, we record its begin and + # end index + range_begin: int + range_end: int + + if range_annotation is not None: + if isinstance(range_annotation, AnnotationLikeEntries): + range_begin = range_annotation.begin + range_end = range_annotation.end + else: + # range_annotation is given by the tid of the entry it + # represents + range_raw = self._data_store.transform_data_store_entry( + self.get_entry_raw(range_annotation) + ) + range_begin = range_raw[BEGIN_ATTR_NAME] + range_end = range_raw[END_ATTR_NAME] + try: for entry_data in self._data_store.get( - type_name=get_full_module_name(entry_type_), + type_name=entry_type_, include_sub_type=include_sub_type, range_span=range_annotation # type: ignore - and (range_annotation.begin, range_annotation.end), + and (range_begin, range_end), ): - entry: Entry = self.get_entry(tid=entry_data[TID_INDEX]) + # Filter by components if components is not None: - if not self.is_created_by(entry, components): + if not self.is_created_by( + entry_data[TID_INDEX], components + ): continue - # Filter out incompatible audio span comparison for Links and Groups - if ( - issubclass(entry_type_, (Link, Group)) - and isinstance(range_annotation, AudioAnnotation) - and not self._index.in_audio_span( - entry, range_annotation.span + entry: Union[Entry, Dict[str, Any]] + if get_raw: + entry = self._data_store.transform_data_store_entry( + entry_data ) - ): - continue + else: + entry = self.get_entry(tid=entry_data[TID_INDEX]) + + # Filter out incompatible audio span comparison for Links and Groups + if ( + self._data_store._is_subclass( + entry_type_, (Link, Group) + ) + and isinstance(range_annotation, AudioAnnotation) + and not self._index.in_audio_span( + entry, range_annotation.span + ) + ): + continue yield entry # type: ignore except ValueError: diff --git a/forte/data/data_store.py b/forte/data/data_store.py index ec2721d4e..77952e86c 100644 --- a/forte/data/data_store.py +++ b/forte/data/data_store.py @@ -813,9 +813,22 @@ def fetch_entry_type_data( # ie. NoneType. if attr_class is None: attr_class = type(None) - attr_args = get_args(attr_info.type) - if len(attr_args) == 0: + raw_attr_args = get_args(attr_info.type) + if len(raw_attr_args) == 0: attr_args = tuple([attr_info.type]) + else: + attr_args = () + for args in raw_attr_args: + # This is the case when we have a multidimensional + # type attribute like List[Tuple[int, int]]. In this + # case get_args will return a tuple of tuples that + # looks like ((Tuple, int, int),). We thus convert + # this into a single dimensional tuple - + # (Tuple, int, int). + if isinstance(args, tuple): + attr_args += args + else: + attr_args += (args,) # Prior to Python 3.7, fetching generic type # aliases resulted in actual type objects whereas from @@ -1321,6 +1334,91 @@ def _get_existing_ann_entry_tid(self, entry: List[Any]): "getting entry id for annotation-like entry." ) + def get_attribute_positions(self, type_name: str) -> Dict[str, int]: + r"""This function returns a dictionary where the key represents + the attributes of the entry of type ``type_name`` and value + represents the index of the position where this attribute is + stored in the data store entry of this type. + For example: + + .. code-block:: python + + positions = data_store.get_attribute_positions( + "ft.onto.base_ontology.Document" + ) + + # positions = { + # "begin": 2, + # "end": 3, + # "payload_idx": 4, + # "document_class": 5, + # "sentiment": 6, + # "classifications": 7 + # } + + Args: + type_name (str): The fully qualified type name of a type. + + Returns: + A dictionary indicating the attributes of an entry of type + ``type_name`` and their respective positions in a data store + entry. + """ + type_data = self._get_type_info(type_name) + + positions: Dict[str, int] = {} + for attr, val in type_data[constants.ATTR_INFO_KEY].items(): + positions[attr] = val[constants.ATTR_INDEX_KEY] + + return positions + + def transform_data_store_entry(self, entry: List[Any]) -> Dict: + r""" + This method converts a raw data store entry into a format more easily + understandable to users. Data Store entries are stored as lists and + are not very easily understandable. This method converts ``DataStore`` + entries from a list format to a dictionary based format where the key + is the names of the attributes of an entry and the value is the values + corresponding attributes in the data store entry. + For example: + + .. code-block:: python + + >>> data_store = DataStore() + >>> tid = data_store.add_entry_raw( + ... type_name = 'ft.onto.base_ontology.Sentence', + ... tid = 101, attribute_data = [0,10]) + >>> entry = data_store.get_entry(tid)[0] + >>> transformed_entry = data_store.transform_data_store_entry(entry) + >>> transformed_entry == { 'begin': 0, 'end': 10, 'payload_idx': 0, + ... 'speaker': None, 'part_id': None, 'sentiment': {}, + ... 'classification': {}, 'classifications': {}, 'tid': 101, + ... 'type': 'ft.onto.base_ontology.Sentence'} + True + + Args: + entry: A list representing a valid data store entry + + Returns: + a dictionary representing the the input data store entry + """ + + attribute_positions = self.get_attribute_positions( + entry[constants.ENTRY_TYPE_INDEX] + ) + + # We now convert the entry from data store format (list) to user + # representation format (dict) to make the contents of the entry more + # understandable. + user_rep: Dict[str, Any] = {} + for attr, pos in attribute_positions.items(): + user_rep[attr] = entry[pos] + + user_rep["tid"] = entry[constants.TID_INDEX] + user_rep["type"] = entry[constants.ENTRY_TYPE_INDEX] + + return user_rep + def set_attribute(self, tid: int, attr_name: str, attr_value: Any): r"""This function locates the entry data with ``tid`` and sets its ``attr_name`` with `attr_value`. It first finds ``attr_id`` according diff --git a/forte/data/multi_pack.py b/forte/data/multi_pack.py index 3121bfc2f..ef064fa84 100644 --- a/forte/data/multi_pack.py +++ b/forte/data/multi_pack.py @@ -16,7 +16,7 @@ import logging from pathlib import Path -from typing import Dict, List, Union, Iterator, Optional, Type, Any, Tuple +from typing import Dict, List, Union, Iterator, Optional, Type, Any, Tuple, cast import jsonpickle @@ -40,7 +40,7 @@ MultiPackGeneric, ) from forte.data.types import DataRequest -from forte.utils import get_class, get_full_module_name +from forte.utils import get_full_module_name from forte.version import DEFAULT_PACK_VERSION @@ -799,7 +799,8 @@ def get( # type: ignore self, entry_type: Union[str, Type[EntryType]], components: Optional[Union[str, List[str]]] = None, - include_sub_type=True, + include_sub_type: bool = True, + get_raw: bool = False, ) -> Iterator[EntryType]: """Get entries of ``entry_type`` from this multi pack. @@ -824,6 +825,8 @@ def get( # type: ignore any component will be returned. include_sub_type: whether to return the sub types of the queried `entry_type`. True by default. + get_raw: boolean to indicate if the entry should be returned in + its primitive form as opposed to an object. False by default Returns: An iterator of the entries matching the arguments, following @@ -831,17 +834,20 @@ def get( # type: ignore insertion) """ - entry_type_: Type[EntryType] - if isinstance(entry_type, str): - entry_type_ = get_class(entry_type) - if not issubclass(entry_type_, Entry): - raise AttributeError( - f"The specified entry type [{entry_type}] " - f"does not correspond to a " - f"'forte.data.ontology.core.Entry' class" - ) - else: - entry_type_ = entry_type + entry_type_ = ( + get_full_module_name(entry_type) + if not isinstance(entry_type, str) + else entry_type + ) + + # Check if entry_type_ represents a valid entry + # pylint: disable=protected-access + if not self._data_store._is_subclass(entry_type_, Entry): + raise ValueError( + f"The specified entry type [{entry_type}] " + f"does not correspond to a " + f"`forte.data.ontology.core.Entry` class" + ) if components is not None: if isinstance(components, str): @@ -849,19 +855,28 @@ def get( # type: ignore try: for entry_data in self._data_store.get( - type_name=get_full_module_name(entry_type_), + type_name=entry_type_, include_sub_type=include_sub_type, ): - entry: Entry = self._entry_converter.get_entry_object( - tid=entry_data[TID_INDEX], - pack=self, - type_name=entry_data[ENTRY_TYPE_INDEX], - ) # Filter by components if components is not None: - if not self.is_created_by(entry, components): + if not self.is_created_by( + entry_data[TID_INDEX], components + ): continue + entry: Union[Entry, Dict[str, Any]] + + if get_raw: + data_store = cast(DataStore, self._data_store) + entry = data_store.transform_data_store_entry(entry_data) + else: + entry = self._entry_converter.get_entry_object( + tid=entry_data[TID_INDEX], + pack=self, + type_name=entry_data[ENTRY_TYPE_INDEX], + ) + yield entry # type: ignore except ValueError: # type_name does not exist in DataStore diff --git a/forte/data/ontology/top.py b/forte/data/ontology/top.py index 24f569aef..da1613666 100644 --- a/forte/data/ontology/top.py +++ b/forte/data/ontology/top.py @@ -1136,3 +1136,4 @@ def __setstate__(self, state): Payload, ) MultiPackEntries = (MultiPackLink, MultiPackGroup, MultiPackGeneric) +AnnotationLikeEntries = (Annotation, AudioAnnotation) diff --git a/tests/forte/data/data_pack_test.py b/tests/forte/data/data_pack_test.py index d8c06a452..9cf0dd119 100644 --- a/tests/forte/data/data_pack_test.py +++ b/tests/forte/data/data_pack_test.py @@ -318,6 +318,47 @@ def test_get_entries(self): with self.assertRaises(ValueError): for doc in self.data_pack.get("forte.data.data_pack.DataPack"): print(doc) + + # Test get raw entries + + # fetching documents + primitive_documents = list(self.data_pack.get(Document, get_raw = True)) + object_documents = list(self.data_pack.get(Document)) + + self.assertEqual( + primitive_documents[0], + { + 'begin': 0, + 'end': 228, + 'payload_idx': 0, + 'document_class': [], + 'sentiment': {}, + 'classifications': {}, + 'tid': object_documents[0].tid, + 'type': 'ft.onto.base_ontology.Document' + } + ) + + # fetching groups + for doc in object_documents: + members: List[str] = [] + group_members: List[List[str]] = [] + # Fetching raw group entries + for group in self.data_pack.get( + "ft.onto.base_ontology.CoreferenceGroup", doc, get_raw=True + ): + em: EntityMention + # Note that group is a dict and not an object + for em in group["members"]: + em_object = self.data_pack.get_entry(em) + members.append(em_object.text) + group_members.append(sorted(members)) + + self.assertEqual( + group_members, + [["He", "The Indonesian billionaire James Riady", "he"]] + ) + def test_delete_entry(self): # test delete entry diff --git a/tests/forte/data/data_store_serialization_test.py b/tests/forte/data/data_store_serialization_test.py index 73cd88443..0db86f554 100644 --- a/tests/forte/data/data_store_serialization_test.py +++ b/tests/forte/data/data_store_serialization_test.py @@ -75,7 +75,7 @@ def setUp(self) -> None: }, "forte.data.ontology.top.Group": { "attributes": { - 'members': {'type': (FList, (Entry,)), 'index': 2}, + 'members': {'type': (list, (int,)), 'index': 2}, 'member_type': {'type': (type(None), (str,)), 'index': 3} }, "parent_entry": "forte.data.ontology.core.BaseGroup", @@ -333,7 +333,7 @@ def test_save_attribute_pickle(self): }, "forte.data.ontology.top.Group": { "attributes": { - 'members': {'type': (FList, (Entry,)), 'index': 2}, + 'members': {'type': (list, (int,)), 'index': 2}, 'member_type': {'type': (type(None), (str,)), 'index': 3} }, "parent_entry": "forte.data.ontology.core.BaseGroup", @@ -800,7 +800,7 @@ def test_fast_pickle(self): }, "forte.data.ontology.top.Group": { "attributes": { - 'members': {'type': (FList, (Entry,)), 'index': 2}, + 'members': {'type': (list, (int,)), 'index': 2}, 'member_type': {'type': (type(None), (str,)), 'index': 3} }, "parent_entry": "forte.data.ontology.core.BaseGroup", @@ -923,7 +923,7 @@ def test_delete_serialize(self): }, "forte.data.ontology.top.Group": { "attributes": { - 'members': {'type': (FList, (Entry,)), 'index': 2}, + 'members': {'type': (list, (int,)), 'index': 2}, 'member_type': {'type': (type(None), (str,)), 'index': 3} }, "parent_entry": "forte.data.ontology.core.BaseGroup", diff --git a/tests/forte/data/data_store_test.py b/tests/forte/data/data_store_test.py index ba56e7e1e..b6338190e 100644 --- a/tests/forte/data/data_store_test.py +++ b/tests/forte/data/data_store_test.py @@ -203,7 +203,7 @@ def setUp(self) -> None: }, "forte.data.ontology.top.Group": { "attributes": { - 'members': {'type': (FList, (Entry,)), 'index': 2}, + 'members': {'type': (list, (int,)), 'index': 2}, 'member_type': {'type': (type(None), (str,)), 'index': 3} }, "parent_class": {"Entry", "BaseGroup"} @@ -222,7 +222,7 @@ def setUp(self) -> None: }, "forte.data.ontology.top.MultiPackGroup": { "attributes": { - 'members': {'type': (type, (FList[Entry], type(None))), 'index': 2}, + 'members': {'type': (list, (Tuple, int, int)), 'index': 2}, 'member_type': {'type': (type(None), (str,)), 'index': 3} }, "parent_class": {"Entry", "MultiEntry", "BaseGroup"} diff --git a/tests/forte/data/multi_pack_test.py b/tests/forte/data/multi_pack_test.py index 14c5eb0fc..30aaa7002 100644 --- a/tests/forte/data/multi_pack_test.py +++ b/tests/forte/data/multi_pack_test.py @@ -2,6 +2,7 @@ Unit tests for multi pack related operations. """ import logging +from typing import Any, Dict import unittest from forte.data.data_pack import DataPack @@ -96,10 +97,28 @@ def test_multipack_groups(self): g: MultiPackGroup for g in self.multi_pack.get(MultiPackGroup): e: Annotation - group_content.append(tuple([e.text for e in g.get_members()])) + temp_list = [] + for e in g.get_members(): + temp_list.append(e.text) + group_content.append(tuple(temp_list)) self.assertListEqual(expected_content, group_content) + # Get raw groups + group_content = [] + grp: Dict[str, Any] + for grp in self.multi_pack.get(MultiPackGroup, get_raw=True): + temp_list = [] + # Note here that grp represents a dictionary and not an object + for pack, mem in grp['members']: + mem_obj = self.multi_pack.get_subentry(pack, mem) + temp_list.append(mem_obj.text) + + group_content.append(tuple(temp_list)) + self.assertListEqual(expected_content, group_content) + + + def test_multipack_entries(self): """ Test some multi pack entry.