From 72e8bce7ace63730cbe402cb840900706e3bab43 Mon Sep 17 00:00:00 2001
From: Pushkar-Bhuse <41834903+Pushkar-Bhuse@users.noreply.github.com>
Date: Mon, 2 Jan 2023 19:49:52 -0600
Subject: [PATCH] Allow `FList` to store `MultiPack` entries (#902)

* Add MultiPack Entries to FList

* FastAPI bug fix

* add doc

Co-authored-by: mylibrar <54747962+mylibrar@users.noreply.github.com>
Co-authored-by: Hector <hunterhector@gmail.com>
---
 forte/data/ontology/core.py | 75 ++++++++++++++++++++++++++++++++-----
 forte/data/ontology/top.py  | 22 +++--------
 2 files changed, 70 insertions(+), 27 deletions(-)

diff --git a/forte/data/ontology/core.py b/forte/data/ontology/core.py
index ca5dc4663..58a2446b7 100644
--- a/forte/data/ontology/core.py
+++ b/forte/data/ontology/core.py
@@ -31,6 +31,7 @@
     Union,
     Dict,
     Iterator,
+    cast,
     overload,
     List,
     Any,
@@ -297,11 +298,13 @@ class FList(Generic[ParentEntryType], MutableSequence):
     def __init__(
         self,
         parent_entry: ParentEntryType,
-        data: Optional[List[int]] = None,
+        data: Optional[List[Union[int, Tuple[int, int]]]] = None,
     ):
         super().__init__()
         self.__parent_entry = parent_entry
-        self.__data: List[int] = [] if data is None else data
+        self.__data: List[Union[int, Tuple[int, int]]] = (
+            [] if data is None else data
+        )
 
     def __eq__(self, other):
         return self.__data == other._FList__data
@@ -310,7 +313,15 @@ def _set_parent(self, parent_entry: ParentEntryType):
         self.__parent_entry = parent_entry
 
     def insert(self, index: int, entry: EntryType):
-        self.__data.insert(index, entry.tid)
+        # If the pack id of the entry is not equal to the pack id
+        # of the parent, it indicates that the entries being stored
+        # are MultiPack entries. Thus, we store the entries as a tuple
+        # of the entry's pack id and the entry's tid in contrast to
+        # regular entries which are just stored by their tid
+        if entry.pack.pack_id != self.__parent_entry.pack.pack_id:
+            self.__data.insert(index, (entry.pack.pack_id, entry.tid))
+        else:
+            self.__data.insert(index, entry.tid)
 
     @overload
     @abstractmethod
@@ -326,22 +337,66 @@ def __getitem__(
         self, index: Union[int, slice]
     ) -> Union[EntryType, MutableSequence]:
         if isinstance(index, slice):
-            return [
-                self.__parent_entry.pack.get_entry(tid)
-                for tid in self.__data[index]
-            ]
+            if all(isinstance(val, int) for val in self.__data):
+                # If entry data is stored just be an integer, it indicates
+                # that this is a Single Pack entry (stored just by its tid)
+                return [
+                    self.__parent_entry.pack.get_entry(tid)
+                    for tid in self.__data[index]
+                ]
+            else:
+                # else, it indicates that this is a Multi Pack
+                # entry (stored as a tuple)
+                return [
+                    self.__parent_entry.pack.get_subentry(*attr)
+                    for attr in self.__data[index]
+                ]
         else:
-            return self.__parent_entry.pack.get_entry(self.__data[index])
+            if all(isinstance(val, int) for val in self.__data):
+                # If entry data is stored just be an integer, it indicates
+                # that this is a Single Pack entry (stored just by its tid)
+                return self.__parent_entry.pack.get_entry(self.__data[index])
+            else:
+                # else, it indicates that this is a Multi Pack
+                # entry (stored as a tuple)
+                return self.__parent_entry.pack.get_subentry(
+                    *self.__data[index]
+                )
 
     def __setitem__(
         self,
         index: Union[int, slice],
         value: Union[EntryType, Iterable[EntryType]],
     ) -> None:
+
         if isinstance(index, int):
-            self.__data[index] = value.tid  # type: ignore
+            value = cast(EntryType, value)
+            if value.pack.pack_id != self.__parent_entry.pack.pack_id:
+                # If the pack id of the entry is not equal to the pack id
+                # of the parent, it indicates that the entries being stored
+                # are MultiPack entries.
+                self.__data[index] = (value.pack.pack_id, value.tid)
+            else:
+                # If the pack id of the entry is equal to the pack id
+                # of the parent, it indicates that the entries being stored
+                # are Single Pack entries.
+                self.__data[index] = value.tid
         else:
-            self.__data[index] = [v.tid for v in value]  # type: ignore
+            value = cast(Iterable[EntryType], value)
+            if all(
+                val.pack.pack_id != self.__parent_entry.pack.pack_id
+                for val in value
+            ):
+                # If the pack id of the entry is not equal to the pack id
+                # of the parent for all entries in the FList data,
+                # it indicates that the entries being stored
+                # are MultiPack entries.
+                self.__data[index] = [(v.pack.pack_id, v.tid) for v in value]
+            else:
+                # If the pack id of the entry is equal to the pack id
+                # of the parent for any FList data item, it indicates that
+                # the entries being stored are Single Pack entries.
+                self.__data[index] = [v.tid for v in value]
 
     def __delitem__(self, index: Union[int, slice]) -> None:
         del self.__data[index]
diff --git a/forte/data/ontology/top.py b/forte/data/ontology/top.py
index 2be964f4b..24f569aef 100644
--- a/forte/data/ontology/top.py
+++ b/forte/data/ontology/top.py
@@ -23,7 +23,6 @@
     Union,
     Iterable,
     List,
-    cast,
 )
 import numpy as np
 
@@ -343,11 +342,7 @@ def get_members(self) -> List[Entry]:
                 "attached to any data pack."
             )
 
-        member_entries = []
-        if self.members is not None:
-            for m in self.members:
-                member_entries.append(m)
-        return member_entries
+        return list(self.members)
 
 
 class MultiPackGeneric(MultiEntry, Entry):
@@ -498,7 +493,7 @@ class MultiPackGroup(MultiEntry, BaseGroup[Entry]):
     of members.
     """
     member_type: str
-    members: Optional[FList[Entry]]
+    members: FList[Entry]
 
     MemberType = Entry
 
@@ -520,18 +515,11 @@ def add_member(self, member: Entry):
                 f"The members of {type(self)} should be "
                 f"instances of {self.MemberType}, but got {type(member)}"
             )
-        if self.members is None:
-            self.members = cast(FList, [member])
-        else:
-            self.members.append(member)
+
+        self.members.append(member)
 
     def get_members(self) -> List[Entry]:
-        members = []
-        if self.members is not None:
-            member_data = self.members
-            for m in member_data:
-                members.append(m)
-        return members
+        return list(self.members)
 
 
 @dataclass