Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Even more improvements to MSA subsampling #306

Merged
merged 10 commits into from
Feb 16, 2025
22 changes: 19 additions & 3 deletions chai_lab/chai1.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@
from chai_lab.data.dataset.msas.colabfold import generate_colabfold_msas
from chai_lab.data.dataset.msas.load import get_msa_contexts
from chai_lab.data.dataset.msas.msa_context import MSAContext
from chai_lab.data.dataset.msas.utils import subsample_msa_rows
from chai_lab.data.dataset.msas.utils import (
subsample_and_reorder_msa_feats_n_mask,
)
from chai_lab.data.dataset.structure.all_atom_structure_context import (
AllAtomStructureContext,
)
Expand Down Expand Up @@ -705,14 +707,28 @@ def run_folding_on_context(
token_single_trunk_repr = token_single_initial_repr
token_pair_trunk_repr = token_pair_initial_repr
for _ in tqdm(range(num_trunk_recycles), desc="Trunk recycles"):
subsampled_msa_input_feats, subsampled_msa_mask = None, None
if recycle_msa_subsample > 0:
subsampled_msa_input_feats, subsampled_msa_mask = (
subsample_and_reorder_msa_feats_n_mask(
msa_input_feats,
msa_mask,
)
)
(token_single_trunk_repr, token_pair_trunk_repr) = trunk.forward(
move_to_device=device,
token_single_trunk_initial_repr=token_single_initial_repr,
token_pair_trunk_initial_repr=token_pair_initial_repr,
token_single_trunk_repr=token_single_trunk_repr, # recycled
token_pair_trunk_repr=token_pair_trunk_repr, # recycled
msa_input_feats=msa_input_feats,
msa_mask=subsample_msa_rows(msa_mask, select_n_rows=recycle_msa_subsample),
msa_input_feats=(
subsampled_msa_input_feats
if subsampled_msa_input_feats is not None
else msa_input_feats
),
msa_mask=(
subsampled_msa_mask if subsampled_msa_mask is not None else msa_mask
),
template_input_feats=template_input_feats,
template_input_masks=template_input_masks,
token_single_mask=token_single_mask,
Expand Down
58 changes: 50 additions & 8 deletions chai_lab/data/dataset/msas/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,30 +2,33 @@
# Licensed under the Apache License, Version 2.0.
# See the LICENSE file for details.

import logging

import torch
from einops import rearrange, reduce, repeat
from einops import rearrange, reduce
from torch import Tensor
from torch.nn import functional as F

from chai_lab.utils.typing import Bool, typecheck
from chai_lab.utils.typing import Bool, Float, typecheck


@typecheck
def subsample_msa_rows(
def _subsample_msa_rows(
mask: Bool[Tensor, "1 depth tokens"],
select_n_rows: int = 4096,
generator: torch.Generator | None = None,
) -> Bool[Tensor, "1 depth tokens"]:
) -> Bool[Tensor, "depth"] | None:
"""Adjust masking to look at a random subset of msas.

Returns input mask as-is if select_n_rows <= 0 or depth < select_n_rows."""
Returns None if select_n_rows <= 0 or depth < select_n_rows."""
# Count the number of non-padding residues in each row of the MSA
msa_sizes = rearrange(
reduce(mask, "b depth tok -> b depth", reduction="sum"), "1 depth -> depth"
)
nonnull_rows_mask = msa_sizes > 0
input_depth = nonnull_rows_mask.sum().item()
if select_n_rows <= 0 or input_depth <= select_n_rows:
return mask
return None

# Bias towards bigger hit MSAs; 0 size is automatically nulled out
mask_ranking = msa_sizes * torch.rand(
Expand All @@ -36,9 +39,48 @@ def subsample_msa_rows(
)
# Ascending sort -> choose the last (highest scoring) rows
selected_row_indices = mask_ranking.argsort()[-select_n_rows:]
# We should never sample empty MSA rows
assert not (~nonnull_rows_mask[selected_row_indices]).any()

# Create a mask for selected row indices
selection_mask = torch.zeros_like(nonnull_rows_mask)
selection_mask[selected_row_indices] = True
return selection_mask


return mask & repeat(selection_mask, "d -> 1 d 1")
@typecheck
def subsample_and_reorder_msa_feats_n_mask(
feats: Float[Tensor, "1 depth tokens dim"],
mask: Bool[Tensor, "1 depth tokens"],
select_n_rows: int = 4096,
generator: torch.Generator | None = None,
) -> tuple[Float[Tensor, "1 depth tokens dim"], Bool[Tensor, "1 depth tokens"]]:
selection_mask = _subsample_msa_rows(
mask=mask,
select_n_rows=select_n_rows,
generator=generator,
)
if selection_mask is None: # No subsampling
return feats, mask

# Select the rows; where returns in order from top to bottom, preserving order
(selection_idx,) = torch.where(selection_mask)
logging.info(f"Subsampling {selection_idx.tolist()[:5]}...")
(unselected_idx,) = torch.where(~selection_mask)
combo_idx = torch.cat([selection_idx, unselected_idx])
# Features are reordered, while mask is selected + padded
feats_sampled = torch.index_select(feats, dim=1, index=combo_idx)
mask_sampled = torch.index_select(mask, dim=1, index=selection_idx)
# Every sampled row should have nonzero coverage
assert mask_sampled.any(dim=-1).all()

# Pad with zeros
_, orig_depth, _ = mask.shape
_, new_depth, _ = mask_sampled.shape
assert (n_pad := orig_depth - new_depth) > 0
# Padding is last dim, moving forward, e.g., for last two dimensions, it is:
# (left, right, top, bottom)
# [0, 0, 0, n_pad] ignores the token dim and pads out the depth dim
return (
feats_sampled,
F.pad(mask_sampled, pad=[0, 0, 0, n_pad], value=False),
)