Skip to content

Commit

Permalink
Default to atomizing ligand and modified polymer residues for `Biomol…
Browse files Browse the repository at this point in the history
…ecules` (lucidrains#129)

* Update biomolecule.py

* Update filter_pdb_val_mmcifs.py
  • Loading branch information
amorehead authored Aug 4, 2024
1 parent c806ad2 commit c4c74ad
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 14 deletions.
46 changes: 35 additions & 11 deletions alphafold3_pytorch/common/biomolecule.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,8 @@ def crop_chains_with_masks(
list(zip(self.chain_index[chain_mask], self.residue_index[chain_mask]))
)
# NOTE: We must only consider unique chain-residue index pairs here,
# as otherwise we might count each ligand heavy atom as a residue in this mapping
# as otherwise we might count each ligand or modified polymer residue
# heavy atom as a residue in this mapping
subset_chain_residue_mapping = set(map(tuple, chain_residue_index))

# manually subset certain Biomolecule metadata
Expand Down Expand Up @@ -368,10 +369,20 @@ def spatial_crop(
for chemtype in self.chemtype
]
)
# NOTE: ligand atom position indices vary per ligand residue,
# NOTE: ligand and modified residue atom position indices vary per "pseudoresidue",
# so we can't rely on representative atom indices here
token_res_rep_atom_indices[self.chemtype == 3] = np.where(
self.atom_mask[self.chemtype == 3]
is_ligand_residue = self.chemtype == 3
is_modified_polymer_residue = np.array(
[
chemtype < 3
and get_residue_constants(res_chem_index=chemtype).restype_3to1.get(chemid, "X")
== "X"
for (chemtype, chemid) in zip(self.chemtype, self.chemid)
]
)
atomized_residue_mask = is_ligand_residue | is_modified_polymer_residue
token_res_rep_atom_indices[atomized_residue_mask] = np.where(
self.atom_mask[atomized_residue_mask]
)[1]
token_res_atom_position_mask[
np.arange(self.chain_id.size), token_res_rep_atom_indices
Expand Down Expand Up @@ -581,18 +592,22 @@ def get_ligand_atom_name(atom_name: str, atom_types_set: Set[str]) -> str:
def get_unique_res_atom_names(
mmcif_object: mmcif_parsing.MmcifObject,
) -> List[Tuple[List[List[str]], str, int]]:
"""Get atom name-chain ID tuples for each (e.g. ligand) "pseudoresidue" of each residue in each chain."""
"""Get atom name-chain ID tuples for each (e.g. ligand) "pseudoresidue" of each residue in each
chain."""
unique_res_atom_names = []
for chain in mmcif_object.structure:
chain_chem_comp = mmcif_object.chem_comp_details[chain.id]
for res, res_chem_comp in zip(chain, chain_chem_comp):
is_polymer_residue = is_polymer(res_chem_comp.type)
residue_constants = get_residue_constants(res_chem_type=res_chem_comp.type)
if is_polymer_residue:
# For polymer residues, append the atom types directly.
is_modified_polymer_residue = (
is_polymer_residue and residue_constants.restype_3to1.get(res.resname, "X") == "X"
)
if is_polymer_residue and not is_modified_polymer_residue:
# For unmodified polymer residues, append the atom types directly.
atoms_to_append = [residue_constants.atom_types]
else:
# For non-polymer residues, create a nested list of atom names.
# For non-polymer or modified polymer residues, create a nested list of atom names.
atoms_to_append = [
[atom.name for _ in range(residue_constants.atom_type_num)] for atom in res
]
Expand All @@ -605,7 +620,7 @@ def _from_mmcif_object(
mmcif_object: mmcif_parsing.MmcifObject,
chain_ids: Optional[Set[str]] = None,
atomize_ligand_residues: bool = True,
atomize_modified_polymer_residues: bool = False,
atomize_modified_polymer_residues: bool = True,
) -> Biomolecule:
"""Takes a Biopython structure/model mmCIF object and creates a `Biomolecule` instance.
Expand Down Expand Up @@ -857,7 +872,11 @@ def _from_mmcif_object(

@typecheck
def from_mmcif_string(
mmcif_str: str, file_id: str, chain_ids: Optional[Set[str]] = None
mmcif_str: str,
file_id: str,
chain_ids: Optional[Set[str]] = None,
atomize_ligand_residues: bool = True,
atomize_modified_polymer_residues: bool = True,
) -> Biomolecule:
"""Takes a mmCIF string and constructs a `Biomolecule` object.
Expand All @@ -881,7 +900,12 @@ def from_mmcif_string(
if parsing_result.mmcif_object is None:
raise list(parsing_result.errors.values())[0]

return _from_mmcif_object(parsing_result.mmcif_object, chain_ids=chain_ids)
return _from_mmcif_object(
parsing_result.mmcif_object,
chain_ids=chain_ids,
atomize_ligand_residues=atomize_ligand_residues,
atomize_modified_polymer_residues=atomize_modified_polymer_residues,
)


@typecheck
Expand Down
5 changes: 2 additions & 3 deletions scripts/filter_pdb_val_mmcifs.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
import glob
import os
from datetime import datetime
from typing import Tuple

import timeout_decorator
from tqdm.contrib.concurrent import process_map
Expand Down Expand Up @@ -63,9 +62,9 @@ def filter_num_tokens(
) -> bool:
"""Filter based on number of tokens."""
biomol = (
_from_mmcif_object(mmcif_object, atomize_modified_polymer_residues=True)
_from_mmcif_object(mmcif_object)
if "assembly" in mmcif_object.file_id
else get_assembly(_from_mmcif_object(mmcif_object, atomize_modified_polymer_residues=True))
else get_assembly(_from_mmcif_object(mmcif_object))
)
return (
len(biomol.atom_mask) < max_tokens
Expand Down

0 comments on commit c4c74ad

Please sign in to comment.