Skip to content

Commit

Permalink
Add PDB distillation support (lucidrains#274)
Browse files Browse the repository at this point in the history
* Update data_pipeline.py

* Update test_af3.py

* Update test_input.py

* Update test_dataloading.py

* Create test_msa_loading.py

* Update test_template_loading.py

* Update data/test dir

* Create uniprot_to_pdb_id_mapping.dat

* Update README.md

* Create reduce_uniprot_ids_to_pdb.py

* Update inputs.py

* Create distillation_data_download.sh

* Update distillation_data_download.sh

* Update distillation_data_download.sh

* Create reduce_uniprot_predictions_to_pdb.py

* Update reduce_uniprot_predictions_to_pdb.py

* Update reduce_uniprot_predictions_to_pdb.py

* Update reduce_uniprot_predictions_to_pdb.py

* Update reduce_uniprot_predictions_to_pdb.py

* Update distillation_data_download.sh

* Update biomolecule.py

* Update data_pipeline.py

* Update mmcif_writing.py

* Update msa_parsing.py

* Update template_parsing.py

* Update inputs.py

* Update test_af3.py

* Update test_dataloading.py

* Update test_input.py

* Update test_msa_loading.py

* Update test_template_loading.py

* Update test_weighted_sampling.py

* Add new test data

* Update reduce_uniprot_predictions_to_pdb.py

* Update README.md

* Update trainer_with_pdb_dataset_and_weighted_sampling.yaml

* Update test_trainer.py

* Update trainer_with_pdb_dataset.yaml

* Update trainer_with_pdb_dataset_and_weighted_sampling.yaml

* Update trainer_with_pdb_dataset_and_weighted_sampling.yaml

* Update training_with_pdb_dataset.yaml

* Update trainer_with_pdb_dataset_and_weighted_sampling.yaml

* Update trainer_with_atom_dataset_created_from_pdb.yaml

* Update trainer_with_atom_dataset_created_from_pdb.yaml

* Update trainer_with_pdb_dataset.yaml

* Update training_with_pdb_dataset.yaml

* Update test_weighted_sampling.py

* Update trainer_with_pdb_dataset_and_weighted_sampling.yaml
  • Loading branch information
amorehead authored Sep 20, 2024
1 parent 0af38c7 commit 54ace0e
Show file tree
Hide file tree
Showing 67 changed files with 889,161 additions and 148 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ python scripts/cluster_pdb_test_mmcifs.py --mmcif_dir <mmcif_dir> --reference_1_

**Note**: The `--clustering_filtered_pdb_dataset` flag is recommended when clustering the filtered PDB dataset as curated using the scripts above, as this flag will enable faster runtimes in this context (since filtering leaves each chain's residue IDs 1-based). However, this flag must **not** be provided when clustering other (i.e., non-PDB) datasets of mmCIF files. Otherwise, interface clustering may be performed incorrectly, as these datasets' mmCIF files may not use strict 1-based residue indexing for each chain.

**Note**: One can instead download preprocessed (i.e., filtered) mmCIF (`train`/`val`/`test`) files (~25GB, comprising 148k complexes) and chain/interface clustering (`train`/`val`/`test`) files (~3GB) for the PDB's `20240101` AWS snapshot via a [shared OneDrive folder](https://mailmissouri-my.sharepoint.com/:f:/g/personal/acmwhb_umsystem_edu/EqU8tjUmmKxJr-FAlq4tzaIBi2TIBtmw5Vl3k_kmgNlepA?e=mzlyv6). Each of these `tar.gz` archives should be decompressed within the `data/pdb_data/` directory e.g., via `tar -xzf data_caches.tar.gz -C data/pdb_data/`. Moreover, mappings of UniProt accession IDs to taxonomic IDs for MSA pairing can be downloaded and extracted via the commands `wget https://colabfold.steineggerlab.workers.dev/af3/uniref30_2202_accession_mapping.tsv.gz -P data/pdb_data/data_caches/` and `gunzip data/pdb_data/data_caches/uniref30_2202_accession_mapping.tsv.gz`.
**Note**: One can instead download preprocessed (i.e., filtered) mmCIF (`train`/`val`/`test`) files (~25GB, comprising 148k complexes) and chain/interface clustering (`train`/`val`/`test`) files (~3GB) for the PDB's `20240101` AWS snapshot via a [shared OneDrive folder](https://mailmissouri-my.sharepoint.com/:f:/g/personal/acmwhb_umsystem_edu/EqU8tjUmmKxJr-FAlq4tzaIBi2TIBtmw5Vl3k_kmgNlepA?e=mzlyv6). Each of these `tar.gz` archives should be decompressed within the `data/pdb_data/` directory e.g., via `tar -xzf data_caches.tar.gz -C data/pdb_data/`. One can also download and prepare PDB distillation data using as a reference the script `scripts/distillation_data_download.sh`. Once downloaded, one can run `scripts/reduce_uniprot_predictions_to_pdb.py` to filter this dataset to only examples associated with at least one PDB entry. Moreover, for convenience, a [mapping](https://ftp.uniprot.org/pub/databases/uniprot/current_release/knowledgebase/idmapping/idmapping.dat.gz) of UniProt accession IDs to PDB IDs for training on PDB distillation data has already been downloaded and extracted as `data/afdb_data/data_caches/uniprot_to_pdb_id_mapping.dat`.

## Contributing

Expand Down
4 changes: 2 additions & 2 deletions alphafold3_pytorch/common/biomolecule.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,10 +460,10 @@ def crop(
if exists(chain_1) and exists(chain_2):
crop_fn_weights = [contiguous_weight, spatial_weight, spatial_interface_weight]
elif exists(chain_1) or exists(chain_2):
crop_fn_weights = [contiguous_weight, spatial_weight + spatial_interface_weight, 0.0]
crop_fn_weights = [contiguous_weight, spatial_weight, 0.0]
else:
crop_fn_weights = [
contiguous_weight + spatial_weight + spatial_interface_weight,
contiguous_weight,
0.0,
0.0,
]
Expand Down
11 changes: 4 additions & 7 deletions alphafold3_pytorch/data/data_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,6 @@ def make_msa_features(
msas: Dict[str, msa_parsing.Msa],
chain_id_to_residue: Dict[str, Dict[str, List[int]]],
num_msa_one_hot: int,
uniprot_accession_to_tax_id_mapping: Dict[str, str] | None = None,
ligand_chemtype_index: int = 3,
) -> List[Dict[str, np.ndarray]]:
"""
Expand All @@ -123,7 +122,6 @@ def make_msa_features(
:param msas: The mapping of chain IDs to lists of MSAs for each chain.
:param chain_id_to_residue: The mapping of chain IDs to residue information.
:param num_msa_one_hot: The number of one-hot classes for MSA features.
:param uniprot_accession_to_tax_id_mapping: The mapping of UniProt accession IDs to NCBI taxonomy IDs.
:param ligand_chemtype_index: The index of the ligand in the chemical type list.
:return: The MSA chain feature dictionaries.
"""
Expand Down Expand Up @@ -216,10 +214,9 @@ def make_msa_features(
int_msa.append(msa_res_types)
deletion_matrix.append(msa_deletion_values)

species_id = ""
if exists(uniprot_accession_to_tax_id_mapping):
accession_id = msa_parsing.get_accession_id(msa.descriptions[sequence_index])
species_id = uniprot_accession_to_tax_id_mapping.get(accession_id, "")
# Parse species ID for MSA pairing if possible.
species_id = msa_parsing.get_identifiers(msa.descriptions[sequence_index]).species_id

if sequence_index == 0:
species_id = "-1" # Tag target sequence for filtering.
species_ids.append(species_id)
Expand Down Expand Up @@ -654,7 +651,7 @@ def make_mmcif_features(


if __name__ == "__main__":
filepath = os.path.join("data", "test", "mmcifs", "a4", "7a4d-assembly1.cif")
filepath = os.path.join("data", "test", "pdb_data", "mmcifs", "a4", "7a4d-assembly1.cif")
file_id = os.path.splitext(os.path.basename(filepath))[0]

mmcif_object = mmcif_parsing.parse_mmcif_object(
Expand Down
11 changes: 9 additions & 2 deletions alphafold3_pytorch/data/mmcif_writing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import numpy as np

from loguru import logger

from alphafold3_pytorch.common.biomolecule import _from_mmcif_object, to_mmcif
from alphafold3_pytorch.data.data_pipeline import get_assembly
from alphafold3_pytorch.data.mmcif_parsing import MmcifObject, parse_mmcif_object
Expand All @@ -13,8 +15,13 @@ def write_mmcif_from_filepath_and_id(
):
"""Write an input mmCIF file to an output mmCIF filepath using the provided keyword arguments
(e.g., sampled coordinates)."""
mmcif_object = parse_mmcif_object(filepath=input_filepath, file_id=file_id)
return write_mmcif(mmcif_object, output_filepath=output_filepath, **kwargs)
try:
mmcif_object = parse_mmcif_object(filepath=input_filepath, file_id=file_id)
return write_mmcif(mmcif_object, output_filepath=output_filepath, **kwargs)
except Exception as e:
logger.warning(
f"Failed to write mmCIF file {output_filepath} due to: {e}. Perhaps cropping was performed on this example?"
)


def write_mmcif(
Expand Down
9 changes: 9 additions & 0 deletions alphafold3_pytorch/data/msa_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,15 @@ def __len__(self):
"""Returns the number of sequences in the MSA."""
return len(self.sequences)

def __add__(self, other):
"""Concatenates two MSAs."""
return Msa(
sequences=self.sequences + other.sequences,
deletion_matrix=self.deletion_matrix + other.deletion_matrix,
descriptions=self.descriptions + other.descriptions,
msa_type=self.msa_type,
)

def truncate(self, max_seqs: int):
"""Truncates the MSA to the first `max_seqs` sequences."""
max_seqs = min(len(self.sequences), max_seqs)
Expand Down
159 changes: 157 additions & 2 deletions alphafold3_pytorch/data/template_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,12 +88,12 @@ def parse_m8(

# Filter the DataFrame to only include rows where
# (1) the template ID does not contain any part of the query ID;
# (2) the template's identity is between 0.1 and 0.95, exclusively;
# (2) the template's identity is between 0.3 and 0.95, exclusively;
# (3) the alignment length is greater than 0;
# (4) the template's length is at least 10; and
# (5) the number of templates is less than the (optional) maximum number of templates.
df = df.filter(~pl.col("Template ID").str.contains(query_id))
df = df.filter((pl.col("Identity") > 0.1) & (pl.col("Identity") < 0.95))
df = df.filter((pl.col("Identity") > 0.3) & (pl.col("Identity") < 0.95))
df = df.filter(pl.col("Alignment Length") > 0)
df = df.filter((pl.col("Template End") - pl.col("Template Start")) >= 9)
if exists(max_templates):
Expand Down Expand Up @@ -135,6 +135,161 @@ def parse_m8(
return template_biomols


@typecheck
def parse_hhr(
hhr_filepath: str,
template_type: TEMPLATE_TYPE,
query_id: str,
mmcif_dir: str,
max_templates: int | None = None,
num_templates: int | None = None,
template_cutoff_date: datetime | None = None,
randomly_sample_num_templates: bool = False,
verbose: bool = False,
) -> List[Tuple[Biomolecule, TEMPLATE_TYPE]]:
"""Parse an HHR file and return a list of template Biomolecule objects.
:param hhr_filepath: The path to the HHR file.
:param template_type: The type of template to parse.
:param query_id: The ID of the query sequence.
:param mmcif_dir: The directory containing mmCIF files.
:param max_templates: The (optional) maximum number of templates to return.
:param num_templates: The (optional) number of templates to return.
:param template_cutoff_date: The (optional) cutoff date for templates.
:param randomly_sample_num_templates: Whether to randomly sample the number of templates to
return.
:param verbose: Whether to log verbose output.
:return: A list of template Biomolecule objects and their template types.
"""
# Define the column names and types.
schema = {
"No": pl.Int32,
"Hit": pl.Utf8,
"Prob": pl.Float64,
"E-value": pl.Float64,
"P-value": pl.Float64,
"Score": pl.Float64,
"SS": pl.Float64,
"Cols": pl.Int32,
"Query HMM": pl.Utf8,
"Template HMM": pl.Utf8,
}

# Identify how many rows to parse.
rows = []
rows_found = False
with open(hhr_filepath, "r") as f:
for line in f:
if line.startswith(" No Hit"):
rows_found = True
elif line.startswith("No 1"):
rows_found = False

if rows_found and not line.startswith(" No Hit") and line.strip():
line_parts = line.strip().split()

# NOTE: The `Hit` and `Template HMM` columns may contain spaces.
if len(line_parts) > 10:
num_hit_parts = len(line_parts) - 10
line_parts = (
[line_parts[0], " ".join(line_parts[1 : 1 + num_hit_parts])]
+ line_parts[1 + num_hit_parts : -2]
+ [" ".join(line_parts[-2:])]
)

rows.append(line_parts)

assert len(rows) > 0, f"No parseable rows found in HHR file {hhr_filepath}."

# Read the HHR file as a DataFrame.
try:
df = pl.DataFrame(rows, schema=schema, orient="row")
except Exception as e:
if verbose:
logger.warning(f"Skipping loading HHR file {hhr_filepath} due to: {e}")
return []

# Add shortcut columns to the DataFrame.
df = df.with_columns(
[
# `Identity` is `Cols` divided by the second integer in `Query HMM`
(
pl.col("Cols")
/ pl.col("Query HMM").map_elements(
lambda x: int(x.split("-")[1]), return_dtype=pl.Float64
)
).alias("Identity"),
# `Template Start` extracted from `Template HMM`
pl.col("Template HMM")
.map_elements(lambda x: extract_template_hmm_range(x)[0], return_dtype=pl.Int32)
.alias("Template Start"),
# `Template End` extracted from `Template HMM`
pl.col("Template HMM")
.map_elements(lambda x: extract_template_hmm_range(x)[1], return_dtype=pl.Int32)
.alias("Template End"),
]
)

# Filter the DataFrame to only include rows where
# (1) the template hit ID does not contain any part of the query ID;
# (2) the template's identity is between 0.3 and 0.95, exclusively;
# (3) the alignment length is greater than 0;
# (4) the template's length is at least 10; and
# (5) the number of templates is less than the (optional) maximum number of templates.
df = df.filter(~pl.col("Hit").str.contains(query_id.upper()))
df = df.filter((pl.col("Identity") > 0.3) & (pl.col("Identity") < 0.95))
df = df.filter(pl.col("Cols") > 0)
df = df.filter((pl.col("Template End") - pl.col("Template Start")) >= 9)
if exists(max_templates):
df = df.head(max_templates)

# Select the number of templates to return.
if len(df) and exists(num_templates) and randomly_sample_num_templates:
df = df.sample(min(len(df), num_templates))
elif exists(num_templates):
df = df.head(num_templates)

# Load each template chain as a Biomolecule object.
template_biomols = []
for i in range(len(df)):
row = df[i]
row_template_id = row["Hit"].item().lower()
template_id, template_chain = row_template_id.split(" ")[0].split("_")
template_fpath = os.path.join(mmcif_dir, template_id[1:3], f"{template_id}-assembly1.cif")
if not os.path.exists(template_fpath):
continue
try:
template_mmcif_object = mmcif_parsing.parse_mmcif_object(
template_fpath, row_template_id
)
template_release_date = extract_mmcif_metadata_field(
template_mmcif_object, "release_date"
)
if (
exists(template_cutoff_date)
and datetime.strptime(template_release_date, "%Y-%m-%d") > template_cutoff_date
):
continue
template_biomol = _from_mmcif_object(
template_mmcif_object, chain_ids=set(template_chain.upper())
)
if len(template_biomol.atom_positions):
template_biomols.append((template_biomol, template_type))
except Exception as e:
if verbose:
logger.warning(f"Skipping loading template {template_id} due to: {e}")

return template_biomols


def extract_template_hmm_range(template_hmm: str):
"""Extract the start and end indices of the template HMM range.
:param template_hmm: The template HMM range string.
:return: A tuple containing the start and end indices of the template HMM range.
"""
start, end = template_hmm.split("-")
return int(start), int(end.split()[0]) # Split further to remove the parenthesis


def _extract_template_features(
template_biomol: Biomolecule,
mapping: Mapping[int, int],
Expand Down
Loading

0 comments on commit 54ace0e

Please sign in to comment.