From 1479fc36813ec7b8d3f3d31a0388fd22dbe3289f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fa=CC=81bio=20Madeira?= Date: Fri, 22 Sep 2017 10:51:58 +0100 Subject: [PATCH] Improved the row_selector function making it more generic and easy to get right, by allowing to pass iterables and characters/numericals. --- prointvar/arpeggio.py | 16 ++++++++-------- prointvar/dssp.py | 6 +++--- prointvar/hbplus.py | 8 ++++---- prointvar/pdbx.py | 22 ++++++++++------------ prointvar/sifts.py | 10 +++++----- prointvar/utils.py | 32 +++++++++++++++++--------------- prointvar/variants.py | 2 +- tests/test_utils.py | 8 ++++---- 8 files changed, 52 insertions(+), 52 deletions(-) diff --git a/prointvar/arpeggio.py b/prointvar/arpeggio.py index 8575cad..49e1ab7 100644 --- a/prointvar/arpeggio.py +++ b/prointvar/arpeggio.py @@ -382,35 +382,35 @@ def get_arpeggio_selected_from_table(data, chain_A=None, chain_B=None, table = data if chain_A is not None: - table = row_selector(table, 'CHAIN_A', chain_A, method="isin") + table = row_selector(table, 'CHAIN_A', chain_A) logger.info("Arpeggio table filtered by CHAIN_A...") if chain_B is not None: - table = row_selector(table, 'CHAIN_B', chain_B, method="isin") + table = row_selector(table, 'CHAIN_B', chain_B) logger.info("Arpeggio table filtered by CHAIN_B...") if res_A is not None: - table = row_selector(table, 'RES_A', res_A, method="isin") + table = row_selector(table, 'RES_A', res_A) logger.info("Arpeggio table filtered by RES_A...") if res_B is not None: - table = row_selector(table, 'RES_B', res_B, method="isin") + table = row_selector(table, 'RES_B', res_B) logger.info("Arpeggio table filtered by RES_B...") if res_full_A is not None: - table = row_selector(table, 'RES_FULL_A', res_full_A, method="isin") + table = row_selector(table, 'RES_FULL_A', res_full_A) logger.info("Arpeggio table filtered by RES_FULL_A...") if res_full_B is not None: - table = row_selector(table, 'RES_FULL_B', res_full_B, method="isin") + table = row_selector(table, 'RES_FULL_B', res_full_B) logger.info("Arpeggio table filtered by RES_FULL_B...") if atom_A is not None: - table = row_selector(table, 'ATOM_A', atom_A, method="isin") + table = row_selector(table, 'ATOM_A', atom_A) logger.info("Arpeggio table filtered by ATOM_A...") if atom_B is not None: - table = row_selector(table, 'ATOM_B', atom_B, method="isin") + table = row_selector(table, 'ATOM_B', atom_B) logger.info("Arpeggio table filtered by ATOM_B...") return table diff --git a/prointvar/dssp.py b/prointvar/dssp.py index ace42ea..decbeac 100644 --- a/prointvar/dssp.py +++ b/prointvar/dssp.py @@ -201,15 +201,15 @@ def get_dssp_selected_from_table(data, chain=None, chain_full=None, res=None): # excluding rows table = data if chain is not None: - table = row_selector(table, 'CHAIN', chain, method="isin") + table = row_selector(table, 'CHAIN', chain) logger.info("DSSP table filtered by CHAIN...") if chain_full is not None: - table = row_selector(table, 'CHAIN_FULL', chain_full, method="isin") + table = row_selector(table, 'CHAIN_FULL', chain_full) logger.info("DSSP table filtered by CHAIN_FULL...") if res is not None: - table = row_selector(table, 'RES', res, method="isin") + table = row_selector(table, 'RES', res) logger.info("DSSP table filtered by RES...") return table diff --git a/prointvar/hbplus.py b/prointvar/hbplus.py index f6a318a..7437f0f 100644 --- a/prointvar/hbplus.py +++ b/prointvar/hbplus.py @@ -118,19 +118,19 @@ def get_hbplus_selected_from_table(data, chain_A=None, chain_D=None, # excluding rows table = data if chain_D is not None: - table = row_selector(table, 'CHAIN_D', chain_D, method="isin") + table = row_selector(table, 'CHAIN_D', chain_D) logger.info("HBPLUS table filtered by CHAIN_D...") if chain_A is not None: - table = row_selector(table, 'CHAIN_A', chain_A, method="isin") + table = row_selector(table, 'CHAIN_A', chain_A) logger.info("HBPLUS table filtered by CHAIN_A...") if res_D is not None: - table = row_selector(table, 'RES_D', res_D, method="isin") + table = row_selector(table, 'RES_D', res_D) logger.info("HBPLUS table filtered by RES_D...") if res_A is not None: - table = row_selector(table, 'RES_A', res_A, method="isin") + table = row_selector(table, 'RES_A', res_A) logger.info("HBPLUS table filtered by RES_A...") return table diff --git a/prointvar/pdbx.py b/prointvar/pdbx.py index af5a580..b55eaef 100644 --- a/prointvar/pdbx.py +++ b/prointvar/pdbx.py @@ -100,8 +100,7 @@ def parse_mmcif_atoms_from_file(inputfile, excluded=(), add_res_full=True, # if only first model (>1 in NMR structures) if first_model: - table = row_selector(table, key='pdbx_PDB_model_num', value=None, - method='first') + table = row_selector(table, key='pdbx_PDB_model_num', value='first') # table modular extensions if add_contacts: @@ -126,7 +125,7 @@ def parse_mmcif_atoms_from_file(inputfile, excluded=(), add_res_full=True, logger.info("PDBx removed altlocs...") if remove_hydrogens: - table = row_selector(table, key='type_symbol', value='H', method='diffs') + table = row_selector(table, key='type_symbol', value='H', reverse=True) logger.info("PDBx removed existing hydrogens...") if remove_partial_res: @@ -220,8 +219,7 @@ def parse_pdb_atoms_from_file(inputfile, excluded=(), add_contacts=False, # if only first model (>1 in NMR structures) if first_model: - table = row_selector(table, key='pdbx_PDB_model_num', value=None, - method='first') + table = row_selector(table, key='pdbx_PDB_model_num', value='first') # fixes the 'pdbx_PDB_ins_code' table = fix_pdb_ins_code(table) @@ -249,7 +247,7 @@ def parse_pdb_atoms_from_file(inputfile, excluded=(), add_contacts=False, logger.info("PDBx removed altlocs...") if remove_hydrogens: - table = row_selector(table, key='type_symbol', value='H', method='diffs') + table = row_selector(table, key='type_symbol', value='H', reverse=True) logger.info("PDBx removed existing hydrogens...") if remove_partial_res: @@ -375,27 +373,27 @@ def get_mmcif_selected_from_table(data, chain=None, res=None, res_full=None, com # excluding rows table = data if chain is not None: - table = row_selector(table, '{}_asym_id'.format(category), chain, method="isin") + table = row_selector(table, '{}_asym_id'.format(category), chain) logger.info("PDBx table filtered by %s_asym_id...", category) if res is not None: - table = row_selector(table, '{}_seq_id'.format(category), res, method="isin") + table = row_selector(table, '{}_seq_id'.format(category), res) logger.info("PDBx table filtered by %s_seq_id...", category) if res_full is not None: - table = row_selector(table, '{}_seq_id_full'.format(category), res_full, method="isin") + table = row_selector(table, '{}_seq_id_full'.format(category), res_full) logger.info("PDBx table filtered by %s_seq_id_full...", category) if comp is not None: - table = row_selector(table, '{}_comp_id'.format(category), comp, method="isin") + table = row_selector(table, '{}_comp_id'.format(category), comp) logger.info("PDBx table filtered by %s_comp_id...", category) if atom is not None: - table = row_selector(table, '{}_atom_id'.format(category), atom, method="isin") + table = row_selector(table, '{}_atom_id'.format(category), atom) logger.info("PDBx table filtered by %s_atom_id...", category) if lines is not None: - table = row_selector(table, 'group_PDB', lines, method="isin") + table = row_selector(table, 'group_PDB', lines) logger.info("PDBx table filtered by group_PDB...") return table diff --git a/prointvar/sifts.py b/prointvar/sifts.py index 91a6b20..6b0eeaf 100644 --- a/prointvar/sifts.py +++ b/prointvar/sifts.py @@ -355,23 +355,23 @@ def get_sifts_selected_from_table(data, chain=None, chain_auth=None, res=None, # excluding rows table = data if chain is not None: - table = row_selector(table, 'PDB_entityId', chain, method="isin") + table = row_selector(table, 'PDB_entityId', chain) logger.info("SIFTS table filtered by PDB_entityId...") if chain_auth is not None: - table = row_selector(table, 'PDB_dbChainId', chain_auth, method="isin") + table = row_selector(table, 'PDB_dbChainId', chain_auth) logger.info("SIFTS table filtered by PDB_dbChainId...") if res is not None: - table = row_selector(table, 'PDB_dbResNum', res, method="isin") + table = row_selector(table, 'PDB_dbResNum', res) logger.info("SIFTS table filtered by PDB_dbResNum...") if uniprot is not None: - table = row_selector(table, 'UniProt_dbAccessionId', uniprot, method="isin") + table = row_selector(table, 'UniProt_dbAccessionId', uniprot) logger.info("SIFTS table filtered by UniProt_dbAccessionId...") if site is not None: - table = row_selector(table, 'UniProt_dbResNum', site, method="isin") + table = row_selector(table, 'UniProt_dbResNum', site) logger.info("SIFTS table filtered by UniProt_dbResNum...") return table diff --git a/prointvar/utils.py b/prointvar/utils.py index c32a4c4..3dcb6d1 100755 --- a/prointvar/utils.py +++ b/prointvar/utils.py @@ -474,34 +474,36 @@ def compute_rsa(acc, resname, method="Sander"): return rsa -def row_selector(data, key=None, value=None, method="isin"): +def row_selector(data, key=None, value=None, reverse=False): """ Generic method to filter columns :param data: pandas DataFrame :param key: pandas DataFrame column name :param value: value(s) to be looked for - :param method: operator method - :return: + :param reverse: opposite behavior (e.g. 'isIn' becomes 'isNotIn' + and 'equals' becomes 'differs') + :return: returns a modified pandas DataFrame """ table = data assert type(table) is pd.core.frame.DataFrame - if ((key is not None and value is not None) or - (key is not None and method == 'first')): + if key is not None and value is not None: assert type(key) is str - assert type(method) is str if key in table: - if method == "isin": - assert hasattr(value, '__iter__') - table = table.loc[table[key].isin(value)] - elif method == "equals": - # assert type(values) is str - table = table.loc[table[key] == value] - elif method == "diffs": - table = table.loc[table[key] != value] - elif method == "first": + if value == 'first': value = table[key].iloc[0] table = table.loc[table[key] == value] + elif (hasattr(value, '__iter__') and + (type(value) is tuple or type(value) is list)): + if not reverse: + table = table.loc[table[key].isin(value)] + else: + table = table.loc[~table[key].isin(value)] + else: + if not reverse: + table = table.loc[table[key] == value] + else: + table = table.loc[table[key] != value] else: logger.debug("%s not in the DataFrame...", key) diff --git a/prointvar/variants.py b/prointvar/variants.py index f265ccd..7a628a2 100644 --- a/prointvar/variants.py +++ b/prointvar/variants.py @@ -113,7 +113,7 @@ def flatten_ensembl_variants(data, excluded=(), synonymous=True): # filter synonymous if not synonymous: table = row_selector(table, key='consequenceType', - value='synonymous_variant', method="diffs") + value='synonymous_variant', reverse=True) return table diff --git a/tests/test_utils.py b/tests/test_utils.py index dd8f8d5..7c3bf6f 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -395,13 +395,13 @@ def test_get_new_pro_ids(self): self.assertEqual(seq_id_2, '2') def test_row_selector(self): - d = self.row_selector(self.mock_df, key='value', value=3, method='equals') + d = self.row_selector(self.mock_df, key='value', value=3) self.assertEqual(len(d.index), 1) - d = self.row_selector(self.mock_df, key='value', value=3, method='diffs') + d = self.row_selector(self.mock_df, key='value', value=3, reverse=True) self.assertEqual(len(d.index), 4) - d = self.row_selector(self.mock_df, key='value', value=None, method='first') + d = self.row_selector(self.mock_df, key='value', value='first') self.assertEqual(len(d.index), 2) - d = self.row_selector(self.mock_df, key='value', value=(2, 3), method='isin') + d = self.row_selector(self.mock_df, key='value', value=(2, 3)) self.assertEqual(len(d.index), 2) def test_merging_down_by_key(self):