Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Let get function take range annotation #769

Merged
merged 17 commits into from
May 25, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions forte/common/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,17 @@
# in the `tid_idx_dict` in DataStore.
ENTRY_DICT_ENTRY_INDEX = 1

# The index storing parent entry tid in Link entries
PARENT_TID_INDEX = 0

# The index storing child entry tid in Link entries
CHILD_TID_INDEX = 1

# The index storing member entry type in Group entries
MEMBER_TYPE_INDEX = 0

# The index storing the list of member entries tid in Group entries
MEMBER_TID_INDEX = 1

# The index where the first attribute appears in the internal entry data of DataStore.
ATTR_BEGIN_INDEX = 4
9 changes: 8 additions & 1 deletion forte/data/base_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,13 +182,20 @@ def get_entry_index(self, tid: int) -> int:
raise NotImplementedError

@abstractmethod
def get(self, type_name: str, include_sub_type: bool) -> Iterator[List]:
def get(
self,
type_name: str,
include_sub_type: bool,
range_annotation: Optional[Tuple[int]] = None,
) -> Iterator[List]:
r"""This function fetches entries from the data store of
type ``type_name``.

Args:
type_name: The index of the list in ``self.__elements``.
include_sub_type: A boolean to indicate whether get its subclass.
range_annotation: A tuple that contains the begin and end indices
of the searching range of annotation-like entries.

Returns:
An iterator of the entries matching the provided arguments.
Expand Down
113 changes: 101 additions & 12 deletions forte/data/data_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -797,7 +797,6 @@ def _delete_entry_by_loc(self, type_name: str, index_id: int):
f"The specified index_id [{index_id}] of type [{type_name}]"
f"is out of boundary for entry list of length {len(target_list)}."
)

if self._is_annotation(type_name):
target_list.pop(index_id)
if not target_list:
Expand Down Expand Up @@ -877,6 +876,23 @@ def get_entry_index(self, tid: int) -> int:
]
return index_id

def get_length(self, type_name: str) -> int:
r"""This function find the length of the `type_name` entry list.
It should not count None placeholders that appear in
non-annotation-like entry lists.

Args:
type_name (str): The fully qualified type name of a type.

Returns:
The count of not None entries.
"""
if self._is_annotation(type_name):
return len(self.__elements[type_name])
else:
delete_count = self.__deletion_count.get(type_name, 0)
return len(self.__elements[type_name]) - delete_count

def co_iterator_annotation_like(
self, type_names: List[str]
) -> Iterator[List]:
Expand Down Expand Up @@ -935,9 +951,9 @@ def co_iterator_annotation_like(
) from e
except IndexError as e: # self.__elements[tn][0] will be caught here.
raise ValueError(
f"Entry list of type name, {tn} which is "
f"Entry list of type name, {tn} which is"
" one list item of input argument `type_names`,"
" is empty. Please check data in this DataStore). "
" is empty. Please check data in this DataStore"
" to see if empty lists are expected"
f" or remove {tn} from input parameter type_names"
) from e
Expand Down Expand Up @@ -1010,30 +1026,103 @@ def co_iterator_annotation_like(
yield entry

def get(
self, type_name: str, include_sub_type: bool = True
self,
type_name: str,
include_sub_type: bool = True,
range_annotation: Optional[Tuple[int]] = None,
) -> Iterator[List]:
r"""This function fetches entries from the data store of
type ``type_name``.
type ``type_name``. If `include_sub_type` is set to True and
``type_name`` is in [Annotation, Group, List], this function also
fetches entries of subtype of ``type_name``. Otherwise, it only
fetches entries of type ``type_name``.

Args:
type_name: The fully qualified name of the entry.
include_sub_type: A boolean to indicate whether get its subclass.
range_annotation: A tuple that contains the begin and end indices
of the searching range of entries.

Returns:
An iterator of the entries matching the provided arguments.
"""

def within_range(
entry: List[Any], range_annotation: Tuple[int]
) -> bool:
"""
A helper function for deciding whether an annotation entry is
inside the `range_annotation`.
"""
if not self._is_annotation(entry[constants.ENTRY_TYPE_INDEX]):
return False
return (
entry[constants.BEGIN_INDEX]
>= range_annotation[constants.BEGIN_INDEX]
and entry[constants.END_INDEX]
<= range_annotation[constants.END_INDEX]
)

if type_name not in self.__elements:
raise ValueError(f"type {type_name} does not exist")
entry_class = get_class(type_name)
all_types = set()
if include_sub_type:
entry_class = get_class(type_name)
all_types = []
# iterate all classes to find subclasses
for type in self.__elements:
if issubclass(get_class(type), entry_class):
all_types.append(type)
all_types.add(type)
else:
all_types.add(type_name)
all_types = list(all_types)
all_types.sort()
hunterhector marked this conversation as resolved.
Show resolved Hide resolved
if self._is_annotation(type_name):
if range_annotation is None:
yield from self.co_iterator_annotation_like(all_types)
else:
for entry in self.co_iterator_annotation_like(all_types):
if within_range(entry, range_annotation):
yield entry
hunterhector marked this conversation as resolved.
Show resolved Hide resolved
elif issubclass(entry_class, Link):
for type in all_types:
yield from self.iter(type)
if range_annotation is None:
yield from self.iter(type)
else:
for entry in self.__elements[type]:
if (
entry[constants.PARENT_TID_INDEX]
in self.__tid_ref_dict
) and (
entry[constants.CHILD_TID_INDEX]
in self.__tid_ref_dict
):
parent = self.__tid_ref_dict[
entry[constants.PARENT_TID_INDEX]
]
child = self.__tid_ref_dict[
entry[constants.CHILD_TID_INDEX]
]
if within_range(
parent, range_annotation
) and within_range(child, range_annotation):
yield entry
elif issubclass(entry_class, Group):
for type in all_types:
if range_annotation is None:
yield from self.iter(type)
else:
for entry in self.__elements[type]:
member_type = entry[constants.MEMBER_TYPE_INDEX]
if self._is_annotation(member_type):
members = entry[constants.MEMBER_TID_INDEX]
within = True
for m in members:
e = self.__tid_ref_dict[m]
if not within_range(e, range_annotation):
within = False
break
if within:
yield entry
else:
if type_name not in self.__elements:
raise KeyError(f"type {type_name} does not exist")
yield from self.iter(type_name)

def iter(self, type_name: str) -> Iterator[List]:
Expand Down
87 changes: 51 additions & 36 deletions tests/forte/data/data_store_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,16 +499,9 @@ def test_add_annotation_raw(self):
tid_sent: int = self.data_store.add_annotation_raw(
"ft.onto.base_ontology.Sentence", 5, 8
)
num_doc = len(
self.data_store._DataStore__elements[
"ft.onto.base_ontology.Document"
]
)
num_sent = len(
self.data_store._DataStore__elements[
"ft.onto.base_ontology.Sentence"
]
)
num_doc = self.data_store.get_length("ft.onto.base_ontology.Document")

num_sent = self.data_store.get_length("ft.onto.base_ontology.Sentence")

self.assertEqual(num_doc, 3)
self.assertEqual(num_sent, 3)
Expand Down Expand Up @@ -536,11 +529,10 @@ def test_add_annotation_raw(self):
tid_em: int = self.data_store.add_annotation_raw(
"ft.onto.base_ontology.EntityMention", 10, 12
)
num_phrase = len(
self.data_store._DataStore__elements[
"ft.onto.base_ontology.EntityMention"
]
num_phrase = self.data_store.get_length(
"ft.onto.base_ontology.EntityMention"
)

self.assertEqual(num_phrase, 1)
self.assertEqual(len(DataStore._type_attributes), 3)
self.assertEqual(len(self.data_store._DataStore__tid_ref_dict), 8)
Expand Down Expand Up @@ -701,14 +693,50 @@ def test_get(self):
self.assertEqual(instances[0][2], 1234)
self.assertEqual(instances[1][2], 3456)

# get all entries
# For types other than annotation, group or link, not support include_subtype
instances = list(self.data_store.get("forte.data.ontology.core.Entry"))
self.assertEqual(len(instances), 9)
self.assertEqual(len(instances), 0)

self.assertEqual(
self.data_store.get_length("forte.data.ontology.core.Entry"), 0
)

# get annotations with subclasses and range annotation
instances = list(
self.data_store.get(
"forte.data.ontology.top.Annotation", range_annotation=(1, 20)
)
)
self.assertEqual(len(instances), 2)

# get groups with subclasses
instances = list(self.data_store.get("forte.data.ontology.top.Group"))
self.assertEqual(len(instances), 3)

# get entries without subclasses
# get groups with subclasses and range annotation
instances = list(
self.data_store.get(
"forte.data.ontology.core.Entry", include_sub_type=False
"forte.data.ontology.top.Group", range_annotation=(1, 20)
)
)
self.assertEqual(len(instances), 0)

# get links with subclasses
instances = list(self.data_store.get("forte.data.ontology.top.Link"))
self.assertEqual(len(instances), 1)

# get links with subclasses and range annotation
instances = list(
self.data_store.get(
"forte.data.ontology.top.Link", range_annotation=(0, 9)
)
)
self.assertEqual(len(instances), 1)

# get links with subclasses and range annotation
instances = list(
self.data_store.get(
"forte.data.ontology.top.Link", range_annotation=(4, 11)
)
)
self.assertEqual(len(instances), 0)
Expand All @@ -720,22 +748,13 @@ def test_delete_entry(self):
self.data_store.delete_entry(1234567)
self.data_store.delete_entry(1234)
self.data_store.delete_entry(9999)
# After 3 deletion. 2 left. (2 documents, 1 sentence, 1 group)
num_doc = len(
self.data_store._DataStore__elements[
"ft.onto.base_ontology.Document"
]
)

# num_sent = len(
# self.data_store._DataStore__elements[
# "ft.onto.base_ontology.Sentence"
# ]
# )
# After 3 deletion. 5 left. (1 document, 1 annotation, 2 groups, 1 link)
num_doc = self.data_store.get_length("ft.onto.base_ontology.Document")
num_group = self.data_store.get_length("forte.data.ontology.top.Group")

self.assertEqual(len(self.data_store._DataStore__tid_ref_dict), 2)
self.assertEqual(num_doc, 1)
# self.assertEqual(num_sent, 0)
self.assertEqual(num_group, 3)

# delete group
self.data_store.delete_entry(10123)
Expand Down Expand Up @@ -771,11 +790,7 @@ def test_delete_entry_by_loc(self):
# dict entry is not deleted; only delete entry in element list
self.assertEqual(len(self.data_store._DataStore__tid_ref_dict), 5)
self.assertEqual(
len(
self.data_store._DataStore__elements[
"ft.onto.base_ontology.Document"
]
),
self.data_store.get_length("ft.onto.base_ontology.Document"),
1,
)

Expand Down