diff --git a/prointvar/utils.py b/prointvar/utils.py index d03d577..c32a4c4 100755 --- a/prointvar/utils.py +++ b/prointvar/utils.py @@ -492,7 +492,7 @@ def row_selector(data, key=None, value=None, method="isin"): assert type(method) is str if key in table: if method == "isin": - assert type(value) is tuple + assert hasattr(value, '__iter__') table = table.loc[table[key].isin(value)] elif method == "equals": # assert type(values) is str @@ -701,7 +701,7 @@ def constrain_column_types(data, dictionary, nan_value=None): # probably there are some NaNs in there pass if table[col].isnull().any().any() and nan_value is not None: - table[col] = table[col].fillna(nan_value, axis=1) + table[col] = table[col].fillna(nan_value) return table @@ -717,7 +717,7 @@ def exclude_columns(data, excluded=()): table = data if excluded is not None: - assert type(excluded) is tuple + assert hasattr(excluded, '__iter__') try: table = table.drop(list(excluded), axis=1) except ValueError: diff --git a/tests/test_utils.py b/tests/test_utils.py index a9fefe6..dd8f8d5 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -41,6 +41,8 @@ from prointvar.utils import get_pairwise_indexes from prointvar.utils import Make from prointvar.utils import get_start_end_ranges_consecutive_ints +from prointvar.utils import constrain_column_types +from prointvar.utils import exclude_columns from prointvar.config import config as c @@ -149,6 +151,14 @@ def setUp(self): self.get_pairwise_indexes = get_pairwise_indexes self.make_class = Make self.get_start_end_ranges_consecutive_ints = get_start_end_ranges_consecutive_ints + self.mock_df = pd.DataFrame( + [{'label': '1', 'value': 1, 'type': 23.4}, + {'label': '2', 'value': 1, 'type': 1}, + {'label': '3', 'value': 2, 'type': np.nan}, + {'label': '4', 'value': 3, 'type': 123.1}, + {'label': '5', 'value': 5, 'type': 0.32}]) + self.constrain_column_types = constrain_column_types + self.exclude_columns = exclude_columns logging.disable(logging.DEBUG) @@ -181,6 +191,9 @@ def tearDown(self): self.get_pairwise_indexes = None self.make_class = None self.get_start_end_ranges_consecutive_ints = None + self.mock_df = None + self.constrain_column_types = None + self.exclude_columns = None logging.disable(logging.NOTSET) @@ -382,18 +395,13 @@ def test_get_new_pro_ids(self): self.assertEqual(seq_id_2, '2') def test_row_selector(self): - data = pd.DataFrame([{'label': '1', 'value': 1}, - {'label': '2', 'value': 1}, - {'label': '3', 'value': 2}, - {'label': '4', 'value': 3}, - {'label': '5', 'value': 5}]) - d = self.row_selector(data, key='value', value=3, method='equals') + d = self.row_selector(self.mock_df, key='value', value=3, method='equals') self.assertEqual(len(d.index), 1) - d = self.row_selector(data, key='value', value=3, method='diffs') + d = self.row_selector(self.mock_df, key='value', value=3, method='diffs') self.assertEqual(len(d.index), 4) - d = self.row_selector(data, key='value', value=None, method='first') + d = self.row_selector(self.mock_df, key='value', value=None, method='first') self.assertEqual(len(d.index), 2) - d = self.row_selector(data, key='value', value=(2, 3), method='isin') + d = self.row_selector(self.mock_df, key='value', value=(2, 3), method='isin') self.assertEqual(len(d.index), 2) def test_merging_down_by_key(self): @@ -513,6 +521,27 @@ def test_get_real_ranges(self): self.assertEqual(starts, (1, 7, 13, 23, 32, 47)) self.assertEqual(ends, (3, 9, 19, 29, 42, 50)) + def test_constrain_column_types(self): + dtypes = {'type': 'float64', + 'value': 'int64', + 'label': 'object'} + + self.mock_df = self.constrain_column_types(self.mock_df, dtypes) + self.assertEqual(self.mock_df["type"].dtype, np.float64) + self.assertEqual(self.mock_df["value"].dtype, np.int64) + self.assertEqual(self.mock_df["label"].dtype, np.object) + + self.mock_df = self.constrain_column_types(self.mock_df, dtypes, + nan_value=0.0) + self.assertEqual(self.mock_df["type"].dtype, np.float64) + self.assertEqual(self.mock_df.loc[2, "type"], 0.0) + + def test_exclude_columns(self): + self.assertEqual(len(self.mock_df.columns), 3) + self.mock_df = self.exclude_columns(self.mock_df, excluded=("type",)) + self.assertEqual(len(self.mock_df.columns), 2) + self.assertNotIn("type", self.mock_df) + if __name__ == '__main__': logging.basicConfig(stream=sys.stderr)