From cdbc9003f75259fbcb548a9b94f3c8f24042d476 Mon Sep 17 00:00:00 2001 From: Alex Morehead Date: Fri, 9 Aug 2024 13:36:32 -0500 Subject: [PATCH] Fix batching of `missing_atom_mask`s, and clean up code (#160) * Update filter_pdb_test_mmcifs.py * Update filter_pdb_train_mmcifs.py * Update inputs.py * Update weighted_pdb_sampler.py * Update inputs.py --- .../data/weighted_pdb_sampler.py | 40 +++++++++---------- alphafold3_pytorch/inputs.py | 20 +++++----- scripts/filter_pdb_test_mmcifs.py | 1 - scripts/filter_pdb_train_mmcifs.py | 1 - 4 files changed, 29 insertions(+), 33 deletions(-) diff --git a/alphafold3_pytorch/data/weighted_pdb_sampler.py b/alphafold3_pytorch/data/weighted_pdb_sampler.py index 36066262..4ae749dd 100644 --- a/alphafold3_pytorch/data/weighted_pdb_sampler.py +++ b/alphafold3_pytorch/data/weighted_pdb_sampler.py @@ -193,7 +193,7 @@ def __init__( alpha_nuc: float = 3.0, alpha_ligand: float = 1.0, pdb_ids_to_skip: List[str] = [], - subset_to_ids: list[int] | None = None, + pdb_ids_to_keep: list[str] | None = None, ): # Load chain and interface mappings if not isinstance(chain_mapping_paths, list): @@ -226,28 +226,24 @@ def __init__( "Precomputing chain and interface weights. This may take several minutes to complete." ) - # Subset to specific indices if provided - if exists(subset_to_ids): - chain_mapping = ( - chain_mapping.with_row_index() - .filter(pl.col("index").is_in(subset_to_ids)) - .select(["pdb_id", "chain_id", "molecule_id", "cluster_id"]) + # Subset to specific PDB IDs if provided + if exists(pdb_ids_to_keep): + chain_mapping = chain_mapping.filter(pl.col("pdb_id").is_in(pdb_ids_to_keep)).select( + ["pdb_id", "chain_id", "molecule_id", "cluster_id"] ) - interface_mapping = ( - interface_mapping.with_row_index() - .filter(pl.col("index").is_in(subset_to_ids)) - .select( - [ - "pdb_id", - "interface_chain_id_1", - "interface_chain_id_2", - "interface_molecule_id_1", - "interface_molecule_id_2", - "interface_chain_cluster_id_1", - "interface_chain_cluster_id_2", - "interface_cluster_id", - ] - ) + interface_mapping = interface_mapping.filter( + pl.col("pdb_id").is_in(pdb_ids_to_keep) + ).select( + [ + "pdb_id", + "interface_chain_id_1", + "interface_chain_id_2", + "interface_molecule_id_1", + "interface_molecule_id_2", + "interface_chain_cluster_id_1", + "interface_chain_cluster_id_2", + "interface_cluster_id", + ] ) chain_mapping.insert_column( diff --git a/alphafold3_pytorch/inputs.py b/alphafold3_pytorch/inputs.py index 32596b78..1c27ff56 100644 --- a/alphafold3_pytorch/inputs.py +++ b/alphafold3_pytorch/inputs.py @@ -63,7 +63,7 @@ is_polymer, ) from alphafold3_pytorch.utils.model_utils import exclusive_cumsum -from alphafold3_pytorch.utils.utils import default, exists, first, identity +from alphafold3_pytorch.utils.utils import default, exists, first # silence RDKit's warnings @@ -166,7 +166,8 @@ def inner(x, *args, **kwargs): } ATOM_DEFAULT_PAD_VALUES = dict( - molecule_atom_lens = 0 + molecule_atom_lens = 0, + missing_atom_mask = True ) @typecheck @@ -2678,10 +2679,6 @@ def __init__( assert folder.exists() and folder.is_dir(), f"{str(folder)} does not exist for PDBDataset" self.folder = folder - self.files = { - os.path.splitext(os.path.basename(file.name))[0]: file - for file in folder.glob(os.path.join("**", "*.cif")) - } self.sampler = sampler self.sample_type = sample_type self.training = training @@ -2700,9 +2697,14 @@ def __init__( if exists(self.sampler): sampler_pdb_ids = set(self.sampler.mappings.get_column("pdb_id").to_list()) self.files = { - file: filepath - for (file, filepath) in self.files.items() - if file in sampler_pdb_ids + os.path.splitext(os.path.basename(filepath.name))[0]: filepath + for filepath in folder.glob(os.path.join("**", "*.cif")) + if os.path.splitext(os.path.basename(filepath.name))[0] in sampler_pdb_ids + } + else: + self.files = { + os.path.splitext(os.path.basename(file.name))[0]: file + for file in folder.glob(os.path.join("**", "*.cif")) } if exists(sample_only_pdb_ids): diff --git a/scripts/filter_pdb_test_mmcifs.py b/scripts/filter_pdb_test_mmcifs.py index 99e9b42f..c2adc7fc 100644 --- a/scripts/filter_pdb_test_mmcifs.py +++ b/scripts/filter_pdb_test_mmcifs.py @@ -24,7 +24,6 @@ import glob import os from datetime import datetime -from typing import List, Tuple import timeout_decorator from tqdm.contrib.concurrent import process_map diff --git a/scripts/filter_pdb_train_mmcifs.py b/scripts/filter_pdb_train_mmcifs.py index 72c6ccd5..6f1fd9de 100644 --- a/scripts/filter_pdb_train_mmcifs.py +++ b/scripts/filter_pdb_train_mmcifs.py @@ -37,7 +37,6 @@ import random from datetime import datetime from operator import itemgetter -from typing import Dict, List, Set, Tuple import numpy as np import timeout_decorator