Skip to content

Commit

Permalink
optimise add_arpeggio_res_split by parsing arpeggio residue ID with…
Browse files Browse the repository at this point in the history
… `DataFrame.str` methods.
  • Loading branch information
stuartmac committed Sep 27, 2017
1 parent 728530a commit a8fe3c4
Showing 1 changed file with 38 additions and 71 deletions.
109 changes: 38 additions & 71 deletions prointvar/arpeggio.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,77 +213,44 @@ def add_arpeggio_res_split(data):
"""
table = data

# get most frequent chain in ENTRY_A and use it to define the inter. direction
# for multiple chains use decreasing frequency
chains_a = [v.split('/')[0] for v in table['ENTRY_A'].tolist()]
Chains = namedtuple('Chains', 'key numb')
freqs = [Chains(key=k, numb=n) for k, n in zip(Counter(chains_a).keys(),
Counter(chains_a).values())]
freqs_dict = {ent.key: ent.numb for ent in sorted(freqs, key=attrgetter('numb'),
reverse=True)}

def get_chain_id(entry):
return entry.split('/')[0]

def get_res_id(entry):
values = string_split(entry.split('/')[1])
return values[0]

def get_icode_id(entry):
values = string_split(entry.split('/')[1])
if len(values) == 2:
return values[1]
else:
return '?'

def get_res_full_id(entry):
return entry.split('/')[1]

def get_atom_id(entry):
return entry.split('/')[2]

chain_a = []
res_a = []
inscode_a = []
res_full_a = []
atom_a = []
chain_b = []
res_b = []
inscode_b = []
res_full_b = []
atom_b = []
for ix in table.index:
chain1 = get_chain_id(table.loc[ix, 'ENTRY_A'])
chain2 = get_chain_id(table.loc[ix, 'ENTRY_B'])
# sort the A->B order based on the frequency of each chain ID
if freqs_dict[chain1] >= freqs_dict[chain2]:
entry1 = 'ENTRY_A'
entry2 = 'ENTRY_B'
else:
entry1 = 'ENTRY_B'
entry2 = 'ENTRY_A'
chain_a.append(get_chain_id(table.loc[ix, entry1]))
res_a.append(get_res_id(table.loc[ix, entry1]))
inscode_a.append(get_icode_id(table.loc[ix, entry1]))
res_full_a.append(get_res_full_id(table.loc[ix, entry1]))
atom_a.append(get_atom_id(table.loc[ix, entry1]))
chain_b.append(get_chain_id(table.loc[ix, entry2]))
res_b.append(get_res_id(table.loc[ix, entry2]))
inscode_b.append(get_icode_id(table.loc[ix, entry2]))
res_full_b.append(get_res_full_id(table.loc[ix, entry2]))
atom_b.append(get_atom_id(table.loc[ix, entry2]))

assert len(chain_a) == len(table)
table['CHAIN_A'] = chain_a
table['RES_A'] = res_a
table['INSCODE_A'] = inscode_a
table['RES_FULL_A'] = res_a
table['ATOM_A'] = atom_a
table['CHAIN_B'] = chain_b
table['RES_B'] = res_b
table['INSCODE_B'] = inscode_b
table['RES_FULL_B'] = res_b
table['ATOM_B'] = atom_b
#FIXME: This is no longer used... I don't think it was working as intended in the first place
# # get most frequent chain in ENTRY_A and use it to define the inter. direction
# # for multiple chains use decreasing frequency
# chains_a = [v.split('/')[0] for v in table['ENTRY_A'].tolist()]
# Chains = namedtuple('Chains', 'key numb')
# freqs = [Chains(key=k, numb=n) for k, n in zip(Counter(chains_a).keys(),
# Counter(chains_a).values())]
# freqs_dict = {ent.key: ent.numb for ent in sorted(freqs, key=attrgetter('numb'),
# reverse=True)}

def _parse_arpeggio_atom(s):
# Note that this method does not reorder contact direction...
table = s.str.split('/', expand=True)

# Rename columns
suffix = s.name[-1]
column_name_dict = {k: v.format(suffix) for k, v in enumerate(['CHAIN_{}', 'RES_FULL_{}', 'ATOM_{}', 'X_{}'])}
table = table.rename(columns=column_name_dict)

# Parse RES_FULL_X to DataFrame
res_full = table['RES_FULL_{}'.format(suffix)].str.split(r'(\d+)', expand=True)
res_full = res_full.drop(0, axis=1) # the split always returns 3 columns, we'll never need the first

# Format residues with no insertion code
res_full.loc[:, 2] = res_full.loc[:, 2].str.replace('', '?')

# Rename columns
column_name_dict = {k: v.format(suffix) for k, v in enumerate(['RES_{}', 'INSCODE_{}'], 1)}
res_full = res_full.rename(columns=column_name_dict)

table = table.join(res_full)

return table

tbA = _parse_arpeggio_atom(table['ENTRY_A'])
tbB = _parse_arpeggio_atom(table['ENTRY_B'])
table = table.join(tbA.join(tbB))

return table


Expand Down

0 comments on commit a8fe3c4

Please sign in to comment.