diff --git a/forte/data/base_pack.py b/forte/data/base_pack.py index a57e6d430..e4c3a9f5e 100644 --- a/forte/data/base_pack.py +++ b/forte/data/base_pack.py @@ -791,13 +791,13 @@ def get_ids_by_creator(self, component: str) -> Set[int]: return entry_set def is_created_by( - self, entry_tid: int, components: Union[str, Iterable[str]] + self, entry: Union[Entry, int], components: Union[str, Iterable[str]] ) -> bool: """ Check if the entry is created by any of the provided components. Args: - entry_tid: `tid` of the entry to check. + entry: `tid` of the entry or the entry object to check components: The list of component names. Returns: @@ -807,7 +807,7 @@ def is_created_by( components = [components] for c in components: - if entry_tid in self._creation_records[c]: + if entry in self._creation_records[c]: break else: # The entry not created by any of these components. diff --git a/forte/data/data_pack.py b/forte/data/data_pack.py index 9d9fcfe4f..bec44179c 100644 --- a/forte/data/data_pack.py +++ b/forte/data/data_pack.py @@ -37,7 +37,7 @@ ProcessExecutionException, UnknownOntologyClassException, ) -from forte.common.constants import TID_INDEX +from forte.common.constants import TID_INDEX, BEGIN_ATTR_NAME, END_ATTR_NAME from forte.data import data_utils_io from forte.data.data_store import DataStore from forte.data.entry_converter import EntryConverter @@ -1471,7 +1471,9 @@ def covers( def get( # type: ignore self, entry_type: Union[str, Type[EntryType]], - range_annotation: Optional[Union[Annotation, AudioAnnotation]] = None, + range_annotation: Optional[ + Union[Annotation, AudioAnnotation, int] + ] = None, components: Optional[Union[str, Iterable[str]]] = None, include_sub_type: bool = True, get_raw: bool = False, @@ -1540,8 +1542,9 @@ def get( # type: ignore Args: entry_type: The type of entries requested. range_annotation: The - range of entries requested. If `None`, will return valid - entries in the range of whole data pack. + range of entries requested. This value can be given by an + entry object or the ``tid`` of that entry. If `None`, will + return valid entries in the range of whole data pack. components: The component (creator) generating the entries requested. If `None`, will return valid entries generated by any component. @@ -1553,41 +1556,43 @@ def get( # type: ignore Yields: Each `Entry` found using this method. """ - entry_type_: Type[EntryType] = as_entry_type(entry_type) + # Convert entry_type to str + entry_type_ = ( + get_full_module_name(entry_type) + if not isinstance(entry_type, str) + else entry_type + ) + + # pylint: disable=protected-access + # Check if entry_type_ represents a valid entry + if not self._data_store._is_subclass(entry_type_, Entry): + raise ValueError( + f"The specified entry type [{entry_type}] " + f"does not correspond to a " + f"`forte.data.ontology.core.Entry` class" + ) def require_annotations(entry_class=Annotation) -> bool: - if issubclass(entry_type_, entry_class): + if self._data_store._is_subclass(entry_type_, entry_class): return True - if issubclass(entry_type_, Link): + if self._data_store._is_subclass(entry_type_, Link): + entry_class = as_entry_type(entry_type_) return issubclass( - entry_type_.ParentType, entry_class - ) and issubclass(entry_type_.ChildType, entry_class) - if issubclass(entry_type_, Group): - return issubclass(entry_type_.MemberType, entry_class) + entry_class.ParentType, entry_class + ) and issubclass(entry_class.ChildType, entry_class) + if self._data_store._is_subclass(entry_type_, Group): + entry_class = as_entry_type(entry_type_) + return issubclass(entry_class.MemberType, entry_class) return False # If we don't have any annotations but the items to check requires them, # then we simply yield from an empty list. if ( - len( - list( - self._data_store.all_entries( - "forte.data.ontology.top.Annotation" - ) - ) - ) - == 0 + self.num_annotations == 0 and isinstance(range_annotation, Annotation) and require_annotations(Annotation) ) or ( - len( - list( - self._data_store.all_entries( - "forte.data.ontology.top.AudioAnnotation" - ) - ) - ) - == 0 + self.num_audio_annotations == 0 and isinstance(range_annotation, AudioAnnotation) and require_annotations(AudioAnnotation) ): @@ -1614,12 +1619,30 @@ def require_annotations(entry_class=Annotation) -> bool: yield from [] return + # If range_annotation is specified, we record its begin and + # end index + range_begin: int + range_end: int + + if range_annotation is not None: + if isinstance(range_annotation, (Annotation, AudioAnnotation)): + range_begin = range_annotation.begin + range_end = range_annotation.end + else: + # range_annotation is given by the tid of the entry it + # represents + range_raw = self._data_store.transform_data_store_entry( + self.get_entry_raw(range_annotation) + ) + range_begin = range_raw[BEGIN_ATTR_NAME] + range_end = range_raw[END_ATTR_NAME] + try: for entry_data in self._data_store.get( - type_name=get_full_module_name(entry_type_), + type_name=entry_type_, include_sub_type=include_sub_type, range_span=range_annotation # type: ignore - and (range_annotation.begin, range_annotation.end), + and (range_begin, range_end), ): # Filter by components @@ -1639,7 +1662,9 @@ def require_annotations(entry_class=Annotation) -> bool: # Filter out incompatible audio span comparison for Links and Groups if ( - issubclass(entry_type_, (Link, Group)) + self._data_store._is_subclass( + entry_type_, (Link, Group) + ) and isinstance(range_annotation, AudioAnnotation) and not self._index.in_audio_span( entry, range_annotation.span diff --git a/forte/data/multi_pack.py b/forte/data/multi_pack.py index 82031178f..98a36dcbe 100644 --- a/forte/data/multi_pack.py +++ b/forte/data/multi_pack.py @@ -40,7 +40,7 @@ MultiPackGeneric, ) from forte.data.types import DataRequest -from forte.utils import get_class, get_full_module_name +from forte.utils import get_full_module_name from forte.version import DEFAULT_PACK_VERSION @@ -837,17 +837,20 @@ def get( # type: ignore insertion) """ - entry_type_: Type[EntryType] - if isinstance(entry_type, str): - entry_type_ = get_class(entry_type) - if not issubclass(entry_type_, Entry): - raise AttributeError( - f"The specified entry type [{entry_type}] " - f"does not correspond to a " - f"'forte.data.ontology.core.Entry' class" - ) - else: - entry_type_ = entry_type + entry_type_ = ( + get_full_module_name(entry_type) + if not isinstance(entry_type, str) + else entry_type + ) + + # Check if entry_type_ represents a valid entry + # pylint: disable=protected-access + if not self._data_store._is_subclass(entry_type_, Entry): + raise ValueError( + f"The specified entry type [{entry_type}] " + f"does not correspond to a " + f"`forte.data.ontology.core.Entry` class" + ) if components is not None: if isinstance(components, str): @@ -855,7 +858,7 @@ def get( # type: ignore try: for entry_data in self._data_store.get( - type_name=get_full_module_name(entry_type_), + type_name=entry_type_, include_sub_type=include_sub_type, ): # Filter by components