Skip to content

Commit

Permalink
implement op that reads multiple rows from dataframe (#395)
Browse files Browse the repository at this point in the history
Co-authored-by: Moshe Raboh [email protected] <[email protected]>
  • Loading branch information
mosheraboh and Moshe Raboh [email protected] authored Feb 4, 2025
1 parent 99cb04b commit 0f24c39
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 1 deletion.
65 changes: 65 additions & 0 deletions fuse/data/ops/ops_read.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,71 @@ def get_all_keys(self) -> List[Hashable]:
return list(self.data.keys())


class OpReadMultiFromDataframe(OpReadDataframe):
"""
Read multiple entries from dataframe at once.
In that case the key expected a string that build from multiple dataframe indices separated by "@SEP@"
For example
df = pd.DataFrame({
"sample_id": [0, 1, 2, 3, 4]
"my_data": [10, 11, 12, 13, 14]
})
sample_dict = {
"data.sample_id": "3@SEP@4"
}
will read row 3 from dataframe into sample_dict[f"my_data.0"]=13
And row 4 into sample_dict[f"my_data.1"]=14
"""

def __init__(
self,
data: Optional[pd.DataFrame] = None,
data_filename: Optional[str] = None,
columns_to_extract: Optional[List[str]] = None,
rename_columns: Optional[Dict[str, str]] = None,
key_name: str = "data.sample_id",
key_column: str = "sample_id",
multi_key_sep: str = "@SEP@",
):
super().__init__(
data,
data_filename,
columns_to_extract,
rename_columns,
key_name,
key_column,
)

self._multi_key_sep = multi_key_sep

# convert ids to strings to support simple split and concat
if not isinstance(next(iter(self._data.keys())), str):
self._data = {str(k): v for k, v in self._data.items()}

def __call__(self, sample_dict: NDict, prefix: Optional[str] = None) -> NDict:
multi_key = sample_dict[self._key_name]

assert isinstance(multi_key, str), "Error: only str sample ids are supported"

if self._multi_key_sep in multi_key:
keys = multi_key.split(self._multi_key_sep)
else:
keys = [multi_key]

for key_index, key in enumerate(keys):
# locate the required item
sample_data = self._data[key].copy()

# add values tp sample_dict
for name, value in sample_data.items():
if prefix is None:
sample_dict[f"{name}.{key_index}"] = value
else:
sample_dict[f"{prefix}.{name}.{key_index}"] = value

return sample_dict


class OpReadHDF5(OpBase):
"""
Op reading data from hd5f based dataset
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1137,7 +1137,7 @@ def encode_list(
# of input mapped to unk.
if on_unknown == "raise":
raise RuntimeError(
f"Encountered {unk_count} unknown tokens out of {len(merged_encoding.ids)} in input starting with {typed_input_list[0].input_string}"
f"Encountered {unk_count} unknown tokens out of {len(merged_encoding.ids)} in input starting with {[typed_input.input_string for typed_input in typed_input_list]}"
)
elif on_unknown == "warn":
if verbose == 0:
Expand Down

0 comments on commit 0f24c39

Please sign in to comment.