Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create basic (lower) level interface for attributes (for new feature #920) #921

Open
wants to merge 13 commits into
base: master
Choose a base branch
from
Open
167 changes: 167 additions & 0 deletions forte/data/data_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -1481,6 +1481,173 @@ def get_attribute(self, tid: int, attr_name: str) -> Any:

return entry[attr_id]

def get_attributes_of_tid(self, tid: int, attr_names: List[str]) -> dict:
r"""This function returns the value of attributes listed in
``attr_names`` for the entry with ``tid``. It locates the entry data
with ``tid`` and finds attributes listed in ``attr_names`` and return
as a dict.

Args:
tid: Unique id of the entry.
attr_names: List of names of the attribute.

Returns:
A dict with keys listed in ``attr_names`` for attributes of the
entry with ``tid``.

Raises:
KeyError: when ``tid`` or ``attr_name`` is not found.
"""
entry, entry_type = self.get_entry(tid)
attrs: dict = {}
for attr_name in attr_names:
try:
attr_id = self._get_type_attribute_dict(entry_type)[attr_name][
constants.ATTR_INDEX_KEY
]
except KeyError as e:
raise KeyError(
f"{entry_type} has no {attr_name} attribute."
) from e
attrs[attr_name] = entry[attr_id]

return attrs

def get_attributes_of_tids(
self, list_of_tid: List[int], attr_names: List[str]
) -> List[Any]:
r"""This function returns the value of attributes listed in
``attr_names`` for entries in listed in the ``list_of_tid``.
It locates the entries data with ``tid`` and put attributes
listed in ``attr_name`` in a dict for each entry.

Args:
list_of_tid: List of unique ids of the entry.
attr_names: List of name of the attribute.

Returns:
A list of dict with ``attr_name`` as key for attributes
of the entries requested.

Raises:
KeyError: when ``tid`` or ``attr_name`` is not found.
"""
tids_attrs = []
for tid in list_of_tid:
entry, entry_type = self.get_entry(tid)
attrs: dict = {}
for attr_name in attr_names:
try:
attr_id = self._get_type_attribute_dict(entry_type)[
attr_name
][constants.ATTR_INDEX_KEY]
except KeyError as e:
raise KeyError(
f"{entry_type} has no {attr_name} attribute."
) from e
attrs[attr_name] = entry[attr_id]

tids_attrs.append(attrs)

return tids_attrs

def get_attributes_of_type(
self,
type_name: str,
attributes_names: List[str],
include_sub_type: bool = True,
range_span: Optional[Tuple[int, int]] = None,
) -> Iterator[dict]:
r"""This function fetches required attributes of entries from the
data store of type ``type_name``. If `include_sub_type` is set to
True and ``type_name`` is in [Annotation], this function also
fetches entries of subtype of ``type_name``. Otherwise, it only
fetches entries of type ``type_name``.

Args:
type_name: The fully qualified name of the entry.
attributes_names: list of attributes to be fetched for each entry
include_sub_type: A boolean to indicate whether get its subclass.
range_span: A tuple that contains the begin and end indices
of the searching range of entries.

Returns:
An iterator of the attributes of the entry in dict matching the
provided arguments.
"""

entry_class = get_class(type_name)
all_types = set()
if include_sub_type:
for type in self.__elements:
if issubclass(get_class(type), entry_class):
all_types.add(type)
else:
all_types.add(type_name)
all_types = list(all_types)
all_types.sort()

if self._is_annotation(type_name):
if range_span is None:
# yield from self.co_iterator_annotation_like(all_types)
for entry in self.co_iterator_annotation_like(all_types):
attrs: dict = {"tid": entry[0]}
for attr_name in attributes_names:
try:
attr_id = self._get_type_attribute_dict(type_name)[
attr_name
][constants.ATTR_INDEX_KEY]
except KeyError as e:
raise KeyError(
f"{type_name} has no {attr_name} attribute."
) from e
attrs[attr_name] = entry[attr_id]

yield attrs
else:
for entry in self.co_iterator_annotation_like(
all_types, range_span=range_span
):
attrs = {"tid": entry[0]}
for attr_name in attributes_names:
try:
attr_id = self._get_type_attribute_dict(type_name)[
attr_name
][constants.ATTR_INDEX_KEY]
except KeyError as e:
raise KeyError(
f"{type_name} has no {attr_name} attribute."
) from e
attrs[attr_name] = entry[attr_id]

yield attrs # attrs instead of entry
elif issubclass(entry_class, Link):
raise NotImplementedError(
f"{type_name} of Link is not currently supported."
)
elif issubclass(entry_class, Group):
raise NotImplementedError(
f"{type_name} of Group is not currently supported."
)
else:
if type_name not in self.__elements:
raise ValueError(f"type {type_name} does not exist")
# yield from self.iter(type_name)
for entry in self.iter(type_name):
attrs = {"tid": entry[0]}
for attr_name in attributes_names:
try:
attr_id = self._get_type_attribute_dict(type_name)[
attr_name
][constants.ATTR_INDEX_KEY]
except KeyError as e:
raise KeyError(
f"{type_name} has no {attr_name} attribute."
) from e
attrs[attr_name] = entry[attr_id]

yield attrs

def _get_attr(self, tid: int, attr_id: int) -> Any:
r"""This function locates the entry data with ``tid`` and gets the value
of ``attr_id`` of this entry. Called by `get_attribute()`.
Expand Down
123 changes: 120 additions & 3 deletions tests/forte/data/data_store_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,7 +699,6 @@ def value_err_fn():
self.assertRaises(ValueError, value_err_fn)

def test_add_annotation_raw(self):

# test add Document entry
tid_doc: int = self.data_store.add_entry_raw(
type_name="ft.onto.base_ontology.Document",
Expand Down Expand Up @@ -1039,6 +1038,126 @@ def test_get_attribute(self):
):
self.data_store.get_attribute(9999, "class")

def test_get_attributes_of_tid(self):
result_dict = self.data_store.get_attributes_of_tid(
9999, ["begin", "end", "speaker"]
)
result_dict2 = self.data_store.get_attributes_of_tid(
3456, ["payload_idx", "classifications"]
)

self.assertEqual(result_dict["begin"], 6)
self.assertEqual(result_dict["end"], 9)
self.assertEqual(result_dict["speaker"], "teacher")
self.assertEqual(result_dict2["payload_idx"], 1)
self.assertEqual(result_dict2["classifications"], {})

# Entry with such tid does not exist
with self.assertRaisesRegex(KeyError, "Entry with tid 1111 not found."):
self.data_store.get_attributes_of_tid(1111, ["speaker"])

# Get attribute field that does not exist
with self.assertRaisesRegex(
KeyError, "ft.onto.base_ontology.Sentence has no class attribute."
):
self.data_store.get_attributes_of_tid(9999, ["class"])

def test_get_attributes_of_tids(self):
tids_attrs: list[dict]
# tids_attrs2: list[dict]
tids_attrs = self.data_store.get_attributes_of_tids(
[9999, 3456], ["begin", "end", "payload_idx"]
)
tids_attrs2 = self.data_store.get_attributes_of_tids(
[9999], ["begin", "speaker"]
)

self.assertEqual(tids_attrs2[0]["begin"], 6)
self.assertEqual(tids_attrs[0]["end"], 9)
self.assertEqual(tids_attrs[1]["payload_idx"], 1)
self.assertEqual(tids_attrs2[0]["speaker"], "teacher")

# Entry with such tid does not exist
with self.assertRaisesRegex(KeyError, "Entry with tid 1111 not found."):
self.data_store.get_attributes_of_tids([1111], ["speaker"])

# Get attribute field that does not exist
with self.assertRaisesRegex(
KeyError, "ft.onto.base_ontology.Sentence has no class attribute."
):
self.data_store.get_attributes_of_tids([9999], ["class"])

def test_get_attributes_of_type(self):
# get document entries
instances = list(
self.data_store.get_attributes_of_type(
"ft.onto.base_ontology.Document",
["begin", "end", "payload_idx"],
)
)
# print(instances)
self.assertEqual(len(instances), 2)
# check tid
self.assertEqual(instances[0]["tid"], 1234)
self.assertEqual(instances[0]["end"], 5)
self.assertEqual(instances[1]["tid"], 3456)
self.assertEqual(instances[1]["begin"], 10)

# For types other than annotation, group or link, not support include_subtype
instances = list(
self.data_store.get_attributes_of_type(
"forte.data.ontology.core.Entry", ["begin", "end"]
)
)
self.assertEqual(len(instances), 0)

self.assertEqual(
self.data_store.get_length("forte.data.ontology.core.Entry"), 0
)

# get annotations with subclasses and range annotation
instances = list(
self.data_store.get_attributes_of_type(
"forte.data.ontology.top.Annotation",
["begin", "end"],
range_span=(1, 20),
)
)
self.assertEqual(len(instances), 2)

# get groups with subclasses
# instances = list(self.data_store.get_attributes_of_type(
# "forte.data.ontology.top.Group", ["begin", "end"]))
# self.assertEqual(len(instances), 3)

# # get groups with subclasses and range annotation
# instances = list(
# self.data_store.get(
# "forte.data.ontology.top.Group", range_span=(1, 20)
# )
# )
# self.assertEqual(len(instances), 0)
#
# # get links with subclasses
# instances = list(self.data_store.get("forte.data.ontology.top.Link"))
# self.assertEqual(len(instances), 1)
#
# # get links with subclasses and range annotation
# instances = list(
# self.data_store.get(
# "forte.data.ontology.top.Link", range_span=(0, 9)
# )
# )
# self.assertEqual(len(instances), 1)
#
# # get links with subclasses and range annotation
# instances = list(
# self.data_store.get(
# "forte.data.ontology.top.Link", range_span=(4, 11)
# )
# )
# self.assertEqual(len(instances), 0)

def test_set_attribute(self):
# change attribute
self.data_store.set_attribute(9999, "speaker", "student")
Expand Down Expand Up @@ -1328,7 +1447,6 @@ def test_get_entry_attribute_by_class(self):
)

def test_is_subclass(self):

import forte

self.assertEqual(
Expand Down Expand Up @@ -1396,7 +1514,6 @@ def test_is_subclass(self):
)

def test_check_onto_file(self):

expected_type_attributes = {
"ft.onto.test.Description": {
"attributes": {
Expand Down