Skip to content

Commit

Permalink
modified interface
Browse files Browse the repository at this point in the history
  • Loading branch information
Pushkar-Bhuse committed Aug 25, 2022
1 parent 0365eef commit 7ec93f5
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 46 deletions.
6 changes: 3 additions & 3 deletions forte/data/base_pack.py
Original file line number Diff line number Diff line change
Expand Up @@ -791,13 +791,13 @@ def get_ids_by_creator(self, component: str) -> Set[int]:
return entry_set

def is_created_by(
self, entry_tid: int, components: Union[str, Iterable[str]]
self, entry: Union[Entry, int], components: Union[str, Iterable[str]]
) -> bool:
"""
Check if the entry is created by any of the provided components.
Args:
entry_tid: `tid` of the entry to check.
entry: `tid` of the entry or the entry object to check
components: The list of component names.
Returns:
Expand All @@ -807,7 +807,7 @@ def is_created_by(
components = [components]

for c in components:
if entry_tid in self._creation_records[c]:
if entry in self._creation_records[c]:
break
else:
# The entry not created by any of these components.
Expand Down
85 changes: 55 additions & 30 deletions forte/data/data_pack.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
ProcessExecutionException,
UnknownOntologyClassException,
)
from forte.common.constants import TID_INDEX
from forte.common.constants import TID_INDEX, BEGIN_ATTR_NAME, END_ATTR_NAME
from forte.data import data_utils_io
from forte.data.data_store import DataStore
from forte.data.entry_converter import EntryConverter
Expand Down Expand Up @@ -1471,7 +1471,9 @@ def covers(
def get( # type: ignore
self,
entry_type: Union[str, Type[EntryType]],
range_annotation: Optional[Union[Annotation, AudioAnnotation]] = None,
range_annotation: Optional[
Union[Annotation, AudioAnnotation, int]
] = None,
components: Optional[Union[str, Iterable[str]]] = None,
include_sub_type: bool = True,
get_raw: bool = False,
Expand Down Expand Up @@ -1540,8 +1542,9 @@ def get( # type: ignore
Args:
entry_type: The type of entries requested.
range_annotation: The
range of entries requested. If `None`, will return valid
entries in the range of whole data pack.
range of entries requested. This value can be given by an
entry object or the ``tid`` of that entry. If `None`, will
return valid entries in the range of whole data pack.
components: The component (creator)
generating the entries requested. If `None`, will return valid
entries generated by any component.
Expand All @@ -1553,41 +1556,43 @@ def get( # type: ignore
Yields:
Each `Entry` found using this method.
"""
entry_type_: Type[EntryType] = as_entry_type(entry_type)
# Convert entry_type to str
entry_type_ = (
get_full_module_name(entry_type)
if not isinstance(entry_type, str)
else entry_type
)

# pylint: disable=protected-access
# Check if entry_type_ represents a valid entry
if not self._data_store._is_subclass(entry_type_, Entry):
raise ValueError(
f"The specified entry type [{entry_type}] "
f"does not correspond to a "
f"`forte.data.ontology.core.Entry` class"
)

def require_annotations(entry_class=Annotation) -> bool:
if issubclass(entry_type_, entry_class):
if self._data_store._is_subclass(entry_type_, entry_class):
return True
if issubclass(entry_type_, Link):
if self._data_store._is_subclass(entry_type_, Link):
entry_class = as_entry_type(entry_type_)
return issubclass(
entry_type_.ParentType, entry_class
) and issubclass(entry_type_.ChildType, entry_class)
if issubclass(entry_type_, Group):
return issubclass(entry_type_.MemberType, entry_class)
entry_class.ParentType, entry_class
) and issubclass(entry_class.ChildType, entry_class)
if self._data_store._is_subclass(entry_type_, Group):
entry_class = as_entry_type(entry_type_)
return issubclass(entry_class.MemberType, entry_class)
return False

# If we don't have any annotations but the items to check requires them,
# then we simply yield from an empty list.
if (
len(
list(
self._data_store.all_entries(
"forte.data.ontology.top.Annotation"
)
)
)
== 0
self.num_annotations == 0
and isinstance(range_annotation, Annotation)
and require_annotations(Annotation)
) or (
len(
list(
self._data_store.all_entries(
"forte.data.ontology.top.AudioAnnotation"
)
)
)
== 0
self.num_audio_annotations == 0
and isinstance(range_annotation, AudioAnnotation)
and require_annotations(AudioAnnotation)
):
Expand All @@ -1614,12 +1619,30 @@ def require_annotations(entry_class=Annotation) -> bool:
yield from []
return

# If range_annotation is specified, we record its begin and
# end index
range_begin: int
range_end: int

if range_annotation is not None:
if isinstance(range_annotation, (Annotation, AudioAnnotation)):
range_begin = range_annotation.begin
range_end = range_annotation.end
else:
# range_annotation is given by the tid of the entry it
# represents
range_raw = self._data_store.transform_data_store_entry(
self.get_entry_raw(range_annotation)
)
range_begin = range_raw[BEGIN_ATTR_NAME]
range_end = range_raw[END_ATTR_NAME]

try:
for entry_data in self._data_store.get(
type_name=get_full_module_name(entry_type_),
type_name=entry_type_,
include_sub_type=include_sub_type,
range_span=range_annotation # type: ignore
and (range_annotation.begin, range_annotation.end),
and (range_begin, range_end),
):

# Filter by components
Expand All @@ -1639,7 +1662,9 @@ def require_annotations(entry_class=Annotation) -> bool:

# Filter out incompatible audio span comparison for Links and Groups
if (
issubclass(entry_type_, (Link, Group))
self._data_store._is_subclass(
entry_type_, (Link, Group)
)
and isinstance(range_annotation, AudioAnnotation)
and not self._index.in_audio_span(
entry, range_annotation.span
Expand Down
29 changes: 16 additions & 13 deletions forte/data/multi_pack.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
MultiPackGeneric,
)
from forte.data.types import DataRequest
from forte.utils import get_class, get_full_module_name
from forte.utils import get_full_module_name
from forte.version import DEFAULT_PACK_VERSION


Expand Down Expand Up @@ -837,25 +837,28 @@ def get( # type: ignore
insertion)
"""
entry_type_: Type[EntryType]
if isinstance(entry_type, str):
entry_type_ = get_class(entry_type)
if not issubclass(entry_type_, Entry):
raise AttributeError(
f"The specified entry type [{entry_type}] "
f"does not correspond to a "
f"'forte.data.ontology.core.Entry' class"
)
else:
entry_type_ = entry_type
entry_type_ = (
get_full_module_name(entry_type)
if not isinstance(entry_type, str)
else entry_type
)

# Check if entry_type_ represents a valid entry
# pylint: disable=protected-access
if not self._data_store._is_subclass(entry_type_, Entry):
raise ValueError(
f"The specified entry type [{entry_type}] "
f"does not correspond to a "
f"`forte.data.ontology.core.Entry` class"
)

if components is not None:
if isinstance(components, str):
components = [components]

try:
for entry_data in self._data_store.get(
type_name=get_full_module_name(entry_type_),
type_name=entry_type_,
include_sub_type=include_sub_type,
):
# Filter by components
Expand Down

0 comments on commit 7ec93f5

Please sign in to comment.