Skip to content

Commit

Permalink
Improvements to MSA subsampling (#305)
Browse files Browse the repository at this point in the history
Bias towards MSA hits with larger coverage to query.
  • Loading branch information
wukevin authored Feb 14, 2025
1 parent d56a8af commit 169dd4d
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 13 deletions.
4 changes: 1 addition & 3 deletions chai_lab/chai1.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,9 +402,7 @@ def make_all_atom_feature_context(

# Load templates
if templates_path is None:
assert (
not use_templates_server
), "Templates path should never be none when querying server for templates"
assert not use_templates_server, "Server should have written a path"
template_context = TemplateContext.empty(
n_tokens=n_actual_tokens,
n_templates=MAX_NUM_TEMPLATES,
Expand Down
2 changes: 1 addition & 1 deletion chai_lab/data/dataset/msas/msa_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def __getitem__(self, subscript: tuple) -> "MSAContext":
mask=self.mask[subscript],
)

def take_rows_with_padding(self, row_indices_with_nones: list):
def take_rows_with_padding(self, row_indices_with_nones: list[int | None]):
"""
allows specifying index=None, which will be filled with empty sequence,
helpful to align multiple sequences
Expand Down
25 changes: 16 additions & 9 deletions chai_lab/data/dataset/msas/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# See the LICENSE file for details.

import torch
from einops import rearrange, repeat
from einops import rearrange, reduce, repeat
from torch import Tensor

from chai_lab.utils.typing import Bool, typecheck
Expand All @@ -18,20 +18,27 @@ def subsample_msa_rows(
"""Adjust masking to look at a random subset of msas.
Returns input mask as-is if select_n_rows <= 0 or depth < select_n_rows."""
nonnull_rows_mask = rearrange(mask.any(dim=-1), "1 d -> d")
# Count the number of non-padding residues in each row of the MSA
msa_sizes = rearrange(
reduce(mask, "b depth tok -> b depth", reduction="sum"), "1 depth -> depth"
)
nonnull_rows_mask = msa_sizes > 0
input_depth = nonnull_rows_mask.sum().item()
if select_n_rows <= 0 or input_depth <= select_n_rows:
return mask

# Select from rows of the MSA that are not fully masked out
(nonnull_row_indices,) = torch.where(nonnull_rows_mask)
assert (n := nonnull_row_indices.numel()) > select_n_rows
permuted = torch.randperm(n, device=mask.device, generator=generator)
selected_row_indices = nonnull_row_indices[permuted[:select_n_rows]]
# Bias towards bigger hit MSAs; 0 size is automatically nulled out
mask_ranking = msa_sizes * torch.rand(
size=msa_sizes.shape,
dtype=torch.float16,
device=msa_sizes.device,
generator=generator,
)
# Ascending sort -> choose the last (highest scoring) rows
selected_row_indices = mask_ranking.argsort()[-select_n_rows:]

# Create a mask for selected row indices
selection_mask = torch.zeros_like(nonnull_rows_mask)
selection_mask[selected_row_indices] = True
selection_mask = repeat(selection_mask, "d -> 1 d 1")

return mask & selection_mask
return mask & repeat(selection_mask, "d -> 1 d 1")

0 comments on commit 169dd4d

Please sign in to comment.