Skip to content

Commit

Permalink
Fix batching of missing_atom_masks, and clean up code (lucidrains#160)
Browse files Browse the repository at this point in the history
* Update filter_pdb_test_mmcifs.py

* Update filter_pdb_train_mmcifs.py

* Update inputs.py

* Update weighted_pdb_sampler.py

* Update inputs.py
  • Loading branch information
amorehead authored Aug 9, 2024
1 parent addf096 commit cdbc900
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 33 deletions.
40 changes: 18 additions & 22 deletions alphafold3_pytorch/data/weighted_pdb_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def __init__(
alpha_nuc: float = 3.0,
alpha_ligand: float = 1.0,
pdb_ids_to_skip: List[str] = [],
subset_to_ids: list[int] | None = None,
pdb_ids_to_keep: list[str] | None = None,
):
# Load chain and interface mappings
if not isinstance(chain_mapping_paths, list):
Expand Down Expand Up @@ -226,28 +226,24 @@ def __init__(
"Precomputing chain and interface weights. This may take several minutes to complete."
)

# Subset to specific indices if provided
if exists(subset_to_ids):
chain_mapping = (
chain_mapping.with_row_index()
.filter(pl.col("index").is_in(subset_to_ids))
.select(["pdb_id", "chain_id", "molecule_id", "cluster_id"])
# Subset to specific PDB IDs if provided
if exists(pdb_ids_to_keep):
chain_mapping = chain_mapping.filter(pl.col("pdb_id").is_in(pdb_ids_to_keep)).select(
["pdb_id", "chain_id", "molecule_id", "cluster_id"]
)
interface_mapping = (
interface_mapping.with_row_index()
.filter(pl.col("index").is_in(subset_to_ids))
.select(
[
"pdb_id",
"interface_chain_id_1",
"interface_chain_id_2",
"interface_molecule_id_1",
"interface_molecule_id_2",
"interface_chain_cluster_id_1",
"interface_chain_cluster_id_2",
"interface_cluster_id",
]
)
interface_mapping = interface_mapping.filter(
pl.col("pdb_id").is_in(pdb_ids_to_keep)
).select(
[
"pdb_id",
"interface_chain_id_1",
"interface_chain_id_2",
"interface_molecule_id_1",
"interface_molecule_id_2",
"interface_chain_cluster_id_1",
"interface_chain_cluster_id_2",
"interface_cluster_id",
]
)

chain_mapping.insert_column(
Expand Down
20 changes: 11 additions & 9 deletions alphafold3_pytorch/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@
is_polymer,
)
from alphafold3_pytorch.utils.model_utils import exclusive_cumsum
from alphafold3_pytorch.utils.utils import default, exists, first, identity
from alphafold3_pytorch.utils.utils import default, exists, first

# silence RDKit's warnings

Expand Down Expand Up @@ -166,7 +166,8 @@ def inner(x, *args, **kwargs):
}

ATOM_DEFAULT_PAD_VALUES = dict(
molecule_atom_lens = 0
molecule_atom_lens = 0,
missing_atom_mask = True
)

@typecheck
Expand Down Expand Up @@ -2678,10 +2679,6 @@ def __init__(
assert folder.exists() and folder.is_dir(), f"{str(folder)} does not exist for PDBDataset"
self.folder = folder

self.files = {
os.path.splitext(os.path.basename(file.name))[0]: file
for file in folder.glob(os.path.join("**", "*.cif"))
}
self.sampler = sampler
self.sample_type = sample_type
self.training = training
Expand All @@ -2700,9 +2697,14 @@ def __init__(
if exists(self.sampler):
sampler_pdb_ids = set(self.sampler.mappings.get_column("pdb_id").to_list())
self.files = {
file: filepath
for (file, filepath) in self.files.items()
if file in sampler_pdb_ids
os.path.splitext(os.path.basename(filepath.name))[0]: filepath
for filepath in folder.glob(os.path.join("**", "*.cif"))
if os.path.splitext(os.path.basename(filepath.name))[0] in sampler_pdb_ids
}
else:
self.files = {
os.path.splitext(os.path.basename(file.name))[0]: file
for file in folder.glob(os.path.join("**", "*.cif"))
}

if exists(sample_only_pdb_ids):
Expand Down
1 change: 0 additions & 1 deletion scripts/filter_pdb_test_mmcifs.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
import glob
import os
from datetime import datetime
from typing import List, Tuple

import timeout_decorator
from tqdm.contrib.concurrent import process_map
Expand Down
1 change: 0 additions & 1 deletion scripts/filter_pdb_train_mmcifs.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
import random
from datetime import datetime
from operator import itemgetter
from typing import Dict, List, Set, Tuple

import numpy as np
import timeout_decorator
Expand Down

0 comments on commit cdbc900

Please sign in to comment.