Skip to content

Commit

Permalink
Improved the row_selector function making it more generic and easy to…
Browse files Browse the repository at this point in the history
… get right, by allowing to pass iterables and characters/numericals.
  • Loading branch information
biomadeira committed Sep 22, 2017
1 parent 4fb9733 commit 1479fc3
Show file tree
Hide file tree
Showing 8 changed files with 52 additions and 52 deletions.
16 changes: 8 additions & 8 deletions prointvar/arpeggio.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,35 +382,35 @@ def get_arpeggio_selected_from_table(data, chain_A=None, chain_B=None,
table = data

if chain_A is not None:
table = row_selector(table, 'CHAIN_A', chain_A, method="isin")
table = row_selector(table, 'CHAIN_A', chain_A)
logger.info("Arpeggio table filtered by CHAIN_A...")

if chain_B is not None:
table = row_selector(table, 'CHAIN_B', chain_B, method="isin")
table = row_selector(table, 'CHAIN_B', chain_B)
logger.info("Arpeggio table filtered by CHAIN_B...")

if res_A is not None:
table = row_selector(table, 'RES_A', res_A, method="isin")
table = row_selector(table, 'RES_A', res_A)
logger.info("Arpeggio table filtered by RES_A...")

if res_B is not None:
table = row_selector(table, 'RES_B', res_B, method="isin")
table = row_selector(table, 'RES_B', res_B)
logger.info("Arpeggio table filtered by RES_B...")

if res_full_A is not None:
table = row_selector(table, 'RES_FULL_A', res_full_A, method="isin")
table = row_selector(table, 'RES_FULL_A', res_full_A)
logger.info("Arpeggio table filtered by RES_FULL_A...")

if res_full_B is not None:
table = row_selector(table, 'RES_FULL_B', res_full_B, method="isin")
table = row_selector(table, 'RES_FULL_B', res_full_B)
logger.info("Arpeggio table filtered by RES_FULL_B...")

if atom_A is not None:
table = row_selector(table, 'ATOM_A', atom_A, method="isin")
table = row_selector(table, 'ATOM_A', atom_A)
logger.info("Arpeggio table filtered by ATOM_A...")

if atom_B is not None:
table = row_selector(table, 'ATOM_B', atom_B, method="isin")
table = row_selector(table, 'ATOM_B', atom_B)
logger.info("Arpeggio table filtered by ATOM_B...")

return table
Expand Down
6 changes: 3 additions & 3 deletions prointvar/dssp.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,15 +201,15 @@ def get_dssp_selected_from_table(data, chain=None, chain_full=None, res=None):
# excluding rows
table = data
if chain is not None:
table = row_selector(table, 'CHAIN', chain, method="isin")
table = row_selector(table, 'CHAIN', chain)
logger.info("DSSP table filtered by CHAIN...")

if chain_full is not None:
table = row_selector(table, 'CHAIN_FULL', chain_full, method="isin")
table = row_selector(table, 'CHAIN_FULL', chain_full)
logger.info("DSSP table filtered by CHAIN_FULL...")

if res is not None:
table = row_selector(table, 'RES', res, method="isin")
table = row_selector(table, 'RES', res)
logger.info("DSSP table filtered by RES...")

return table
Expand Down
8 changes: 4 additions & 4 deletions prointvar/hbplus.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,19 +118,19 @@ def get_hbplus_selected_from_table(data, chain_A=None, chain_D=None,
# excluding rows
table = data
if chain_D is not None:
table = row_selector(table, 'CHAIN_D', chain_D, method="isin")
table = row_selector(table, 'CHAIN_D', chain_D)
logger.info("HBPLUS table filtered by CHAIN_D...")

if chain_A is not None:
table = row_selector(table, 'CHAIN_A', chain_A, method="isin")
table = row_selector(table, 'CHAIN_A', chain_A)
logger.info("HBPLUS table filtered by CHAIN_A...")

if res_D is not None:
table = row_selector(table, 'RES_D', res_D, method="isin")
table = row_selector(table, 'RES_D', res_D)
logger.info("HBPLUS table filtered by RES_D...")

if res_A is not None:
table = row_selector(table, 'RES_A', res_A, method="isin")
table = row_selector(table, 'RES_A', res_A)
logger.info("HBPLUS table filtered by RES_A...")

return table
Expand Down
22 changes: 10 additions & 12 deletions prointvar/pdbx.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,7 @@ def parse_mmcif_atoms_from_file(inputfile, excluded=(), add_res_full=True,

# if only first model (>1 in NMR structures)
if first_model:
table = row_selector(table, key='pdbx_PDB_model_num', value=None,
method='first')
table = row_selector(table, key='pdbx_PDB_model_num', value='first')

# table modular extensions
if add_contacts:
Expand All @@ -126,7 +125,7 @@ def parse_mmcif_atoms_from_file(inputfile, excluded=(), add_res_full=True,
logger.info("PDBx removed altlocs...")

if remove_hydrogens:
table = row_selector(table, key='type_symbol', value='H', method='diffs')
table = row_selector(table, key='type_symbol', value='H', reverse=True)
logger.info("PDBx removed existing hydrogens...")

if remove_partial_res:
Expand Down Expand Up @@ -220,8 +219,7 @@ def parse_pdb_atoms_from_file(inputfile, excluded=(), add_contacts=False,

# if only first model (>1 in NMR structures)
if first_model:
table = row_selector(table, key='pdbx_PDB_model_num', value=None,
method='first')
table = row_selector(table, key='pdbx_PDB_model_num', value='first')

# fixes the 'pdbx_PDB_ins_code'
table = fix_pdb_ins_code(table)
Expand Down Expand Up @@ -249,7 +247,7 @@ def parse_pdb_atoms_from_file(inputfile, excluded=(), add_contacts=False,
logger.info("PDBx removed altlocs...")

if remove_hydrogens:
table = row_selector(table, key='type_symbol', value='H', method='diffs')
table = row_selector(table, key='type_symbol', value='H', reverse=True)
logger.info("PDBx removed existing hydrogens...")

if remove_partial_res:
Expand Down Expand Up @@ -375,27 +373,27 @@ def get_mmcif_selected_from_table(data, chain=None, res=None, res_full=None, com
# excluding rows
table = data
if chain is not None:
table = row_selector(table, '{}_asym_id'.format(category), chain, method="isin")
table = row_selector(table, '{}_asym_id'.format(category), chain)
logger.info("PDBx table filtered by %s_asym_id...", category)

if res is not None:
table = row_selector(table, '{}_seq_id'.format(category), res, method="isin")
table = row_selector(table, '{}_seq_id'.format(category), res)
logger.info("PDBx table filtered by %s_seq_id...", category)

if res_full is not None:
table = row_selector(table, '{}_seq_id_full'.format(category), res_full, method="isin")
table = row_selector(table, '{}_seq_id_full'.format(category), res_full)
logger.info("PDBx table filtered by %s_seq_id_full...", category)

if comp is not None:
table = row_selector(table, '{}_comp_id'.format(category), comp, method="isin")
table = row_selector(table, '{}_comp_id'.format(category), comp)
logger.info("PDBx table filtered by %s_comp_id...", category)

if atom is not None:
table = row_selector(table, '{}_atom_id'.format(category), atom, method="isin")
table = row_selector(table, '{}_atom_id'.format(category), atom)
logger.info("PDBx table filtered by %s_atom_id...", category)

if lines is not None:
table = row_selector(table, 'group_PDB', lines, method="isin")
table = row_selector(table, 'group_PDB', lines)
logger.info("PDBx table filtered by group_PDB...")

return table
Expand Down
10 changes: 5 additions & 5 deletions prointvar/sifts.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,23 +355,23 @@ def get_sifts_selected_from_table(data, chain=None, chain_auth=None, res=None,
# excluding rows
table = data
if chain is not None:
table = row_selector(table, 'PDB_entityId', chain, method="isin")
table = row_selector(table, 'PDB_entityId', chain)
logger.info("SIFTS table filtered by PDB_entityId...")

if chain_auth is not None:
table = row_selector(table, 'PDB_dbChainId', chain_auth, method="isin")
table = row_selector(table, 'PDB_dbChainId', chain_auth)
logger.info("SIFTS table filtered by PDB_dbChainId...")

if res is not None:
table = row_selector(table, 'PDB_dbResNum', res, method="isin")
table = row_selector(table, 'PDB_dbResNum', res)
logger.info("SIFTS table filtered by PDB_dbResNum...")

if uniprot is not None:
table = row_selector(table, 'UniProt_dbAccessionId', uniprot, method="isin")
table = row_selector(table, 'UniProt_dbAccessionId', uniprot)
logger.info("SIFTS table filtered by UniProt_dbAccessionId...")

if site is not None:
table = row_selector(table, 'UniProt_dbResNum', site, method="isin")
table = row_selector(table, 'UniProt_dbResNum', site)
logger.info("SIFTS table filtered by UniProt_dbResNum...")

return table
Expand Down
32 changes: 17 additions & 15 deletions prointvar/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,34 +474,36 @@ def compute_rsa(acc, resname, method="Sander"):
return rsa


def row_selector(data, key=None, value=None, method="isin"):
def row_selector(data, key=None, value=None, reverse=False):
"""
Generic method to filter columns
:param data: pandas DataFrame
:param key: pandas DataFrame column name
:param value: value(s) to be looked for
:param method: operator method
:return:
:param reverse: opposite behavior (e.g. 'isIn' becomes 'isNotIn'
and 'equals' becomes 'differs')
:return: returns a modified pandas DataFrame
"""

table = data
assert type(table) is pd.core.frame.DataFrame
if ((key is not None and value is not None) or
(key is not None and method == 'first')):
if key is not None and value is not None:
assert type(key) is str
assert type(method) is str
if key in table:
if method == "isin":
assert hasattr(value, '__iter__')
table = table.loc[table[key].isin(value)]
elif method == "equals":
# assert type(values) is str
table = table.loc[table[key] == value]
elif method == "diffs":
table = table.loc[table[key] != value]
elif method == "first":
if value == 'first':
value = table[key].iloc[0]
table = table.loc[table[key] == value]
elif (hasattr(value, '__iter__') and
(type(value) is tuple or type(value) is list)):
if not reverse:
table = table.loc[table[key].isin(value)]
else:
table = table.loc[~table[key].isin(value)]
else:
if not reverse:
table = table.loc[table[key] == value]
else:
table = table.loc[table[key] != value]
else:
logger.debug("%s not in the DataFrame...", key)

Expand Down
2 changes: 1 addition & 1 deletion prointvar/variants.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def flatten_ensembl_variants(data, excluded=(), synonymous=True):
# filter synonymous
if not synonymous:
table = row_selector(table, key='consequenceType',
value='synonymous_variant', method="diffs")
value='synonymous_variant', reverse=True)
return table


Expand Down
8 changes: 4 additions & 4 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,13 +395,13 @@ def test_get_new_pro_ids(self):
self.assertEqual(seq_id_2, '2')

def test_row_selector(self):
d = self.row_selector(self.mock_df, key='value', value=3, method='equals')
d = self.row_selector(self.mock_df, key='value', value=3)
self.assertEqual(len(d.index), 1)
d = self.row_selector(self.mock_df, key='value', value=3, method='diffs')
d = self.row_selector(self.mock_df, key='value', value=3, reverse=True)
self.assertEqual(len(d.index), 4)
d = self.row_selector(self.mock_df, key='value', value=None, method='first')
d = self.row_selector(self.mock_df, key='value', value='first')
self.assertEqual(len(d.index), 2)
d = self.row_selector(self.mock_df, key='value', value=(2, 3), method='isin')
d = self.row_selector(self.mock_df, key='value', value=(2, 3))
self.assertEqual(len(d.index), 2)

def test_merging_down_by_key(self):
Expand Down

0 comments on commit 1479fc3

Please sign in to comment.