Skip to content

Commit

Permalink
Change signature of get_robust_pairs, refactor docs using fastcore.do…
Browse files Browse the repository at this point in the history
…cments, simplify README, add custom type aliases, fix/improve some typing annotations
  • Loading branch information
ulupo committed May 10, 2024
1 parent 6baf3b5 commit ebb160c
Show file tree
Hide file tree
Showing 14 changed files with 527 additions and 182 deletions.
45 changes: 20 additions & 25 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ containing interacting biological sequences, find the optimal one-to-one
pairing between the sequences in A and B.

<figure>
<img src="https://raw.githubusercontent.com/Bitbol-Lab/DiffPaSS/main/media/MSA_pairing_problem.svg" width="640" height="201.6" alt="MSA pairing problem" />
<img src="https://raw.githubusercontent.com/Bitbol-Lab/DiffPaSS/main/media/MSA_pairing_problem.svg" alt="MSA pairing problem" />
<figcaption>
Pairing problem for two multiple sequence alignments, where pairings are
restricted to be within the same species
Expand Down Expand Up @@ -84,7 +84,7 @@ ingredients are as follows:
the DiffPaSS-Iterative Pairing Algorithm (DiffPaSS-IPA).

<figure>
<video src="https://github.com/Bitbol-Lab/DiffPaSS/assets/46537483/e411fe8c-2fed-4723-a25c-ff69a1abccec" width="640" height="360" controls>
<video src="https://raw.githubusercontent.com/Bitbol-Lab/DiffPaSS/main/media/DiffPaSS_bootstrap.mp4" width="432" height="243" controls>
</video>
<figcaption>
The DiffPaSS bootstrap technique and robust pairs
Expand Down Expand Up @@ -129,9 +129,9 @@ into a list of tuples `(header, sequence)` using
``` python
from diffpass.msa_parsing import read_msa

# Parse the MSAs into lists of tuples (header, sequence)
msa_A = read_msa("path/to/msa_A.fasta")
msa_B = read_msa("path/to/msa_B.fasta")
# Parse and one-hot encode the MSAs
msa_data_A = read_msa("path/to/msa_A.fasta")
msa_data_B = read_msa("path/to/msa_B.fasta")
```

We assume that the MSAs contain species information in the headers,
Expand All @@ -150,8 +150,8 @@ This function will be used to group the sequences by species:
``` python
from diffpass.data_utils import create_groupwise_seq_records

msa_A_by_sp = create_groupwise_seq_records(msa_A, species_name_func)
msa_B_by_sp = create_groupwise_seq_records(msa_B, species_name_func)
msa_data_A_species_by_species = create_groupwise_seq_records(msa_data_A, species_name_func)
msa_data_B_species_by_species = create_groupwise_seq_records(msa_data_B, species_name_func)
```

If one of the MSAs contains sequences from species not present in the
Expand All @@ -160,8 +160,8 @@ other MSA, we can remove these species from both MSAs:
``` python
from diffpass.data_utils import remove_groups_not_in_both

msa_A_by_sp, msa_B_by_sp = remove_groups_not_in_both(
msa_A_by_sp, msa_B_by_sp
msa_data_A_species_by_species, msa_data_B_species_by_species = remove_groups_not_in_both(
msa_data_A_species_by_species, msa_data_B_species_by_species
)
```

Expand All @@ -173,12 +173,12 @@ consisting entirely of gap symbols:
``` python
from diffpass.data_utils import pad_msas_with_dummy_sequences

msa_A_by_sp_pad, msa_B_by_sp_pad = pad_msas_with_dummy_sequences(
msa_A_by_sp, msa_B_by_sp
msa_data_A_species_by_species_padded, msa_data_B_species_by_species_padded = pad_msas_with_dummy_sequences(
msa_data_A_species_by_species, msa_data_B_species_by_species
)

species = list(msa_A_by_sp_pad.keys())
species_sizes = list(map(len, msa_A_by_sp_pad.values()))
species = list(msa_data_A_species_by_species_padded.keys())
species_sizes = list(map(len, msa_data_A_species_by_species_padded.values()))
```

Next, one-hot encode the MSAs using the
Expand All @@ -191,28 +191,23 @@ from diffpass.data_utils import one_hot_encode_msa
device = "cuda" if torch.cuda.is_available() else "cpu"

# Unpack the padded MSAs into a list of records
msa_A_for_pairing = [
rec for recs_this_sp in msa_A_by_sp_pad.values() for rec in recs_this_sp
]
msa_B_for_pairing = [
rec for recs_this_sp in msa_B_by_sp_pad.values() for rec in recs_this_sp
]
msa_data_A_for_pairing = [record for records_this_species in msa_data_A_species_by_species_padded.values() for record in records_this_species]
msa_data_B_for_pairing = [record for records_this_species in msa_data_B_species_by_species_padded.values() for record in records_this_species]

# One-hot encode the MSAs and load them to a device
msa_A_oh = one_hot_encode_msa(msa_A_for_pairing, device=device)
msa_B_oh = one_hot_encode_msa(msa_B_for_pairing, device=device)
msa_A_oh = one_hot_encode_msa(msa_data_A_for_pairing, device=device)
msa_B_oh = one_hot_encode_msa(msa_data_B_for_pairing, device=device)
```

### Pairing optimization

Finally, we can instantiate an
[`InformationPairing`](https://Bitbol-Lab.github.io/DiffPaSS/train.html#informationpairing)
object and optimize the mutual information between the paired MSAs using
the DiffPaSS bootstrapped optimization algorithm. The results are stored
in a
the DiffPaSS bootstrap algorithm. The results are stored in a
[`DiffPaSSResults`](https://Bitbol-Lab.github.io/DiffPaSS/base.html#diffpassresults)
container. The lists of (hard) losses and permutations found during the
optimization can be accessed as attributes of the container.
container. The lists of (hard) losses and permutations found can be
accessed as attributes of the container.

``` python
from diffpass.train import InformationPairing
Expand Down
96 changes: 64 additions & 32 deletions diffpass/base.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/base.ipynb.

# %% auto 0
__all__ = ['INGROUP_IDX_DTYPE', 'BootstrapList', 'GradientDescentList', 'GroupByGroupList', 'dccn', 'make_pbar',
'DiffPaSSResults', 'DiffPaSSModel']
__all__ = ['INGROUP_IDX_DTYPE', 'BootstrapList', 'GradientDescentList', 'GroupByGroupList', 'IndexPair', 'IndexPairsInGroup',
'IndexPairsInGroups', 'dccn', 'make_pbar', 'DiffPaSSResults', 'DiffPaSSModel']

# %% ../nbs/base.ipynb 4
# Stdlib imports
Expand Down Expand Up @@ -32,9 +32,15 @@
INGROUP_IDX_DTYPE = np.int16

# Type aliases
BootstrapList = list
GradientDescentList = list
GroupByGroupList = list
BootstrapList = list # List indexed by bootstrap iteration
GradientDescentList = list # List indexed by gradient descent iteration
GroupByGroupList = list # List indexed by group index

IndexPair = tuple[int, int] # Pair of indices
IndexPairsInGroup = Sequence[IndexPair] # Pairs of indices in a group of sequences
IndexPairsInGroups = Sequence[
IndexPairsInGroup
] # Pairs of indices in groups of sequences

# %% ../nbs/base.ipynb 6
def dccn(x: torch.Tensor) -> np.ndarray:
Expand All @@ -51,29 +57,31 @@ def make_pbar(epochs: int, show_pbar: bool) -> Any:
class DiffPaSSResults:
"""Container for results of DiffPaSS fits."""

# Optionally, log alphas for fine-grained information
# Optionally, log-alphas for fine-grained information
log_alphas: Optional[
Union[
GradientDescentList[GroupByGroupList[np.ndarray]],
BootstrapList[GradientDescentList[GroupByGroupList[np.ndarray]]],
]
]
# Soft and hard permutations
# Soft permutations
soft_perms: Optional[
Union[
GradientDescentList[GroupByGroupList[np.ndarray]],
BootstrapList[GradientDescentList[GroupByGroupList[np.ndarray]]],
]
]
# Hard permutations
hard_perms: Union[
GradientDescentList[GroupByGroupList[np.ndarray]],
BootstrapList[GradientDescentList[GroupByGroupList[np.ndarray]]],
]
# Losses
# Hard losses
hard_losses: Union[
GradientDescentList[GroupByGroupList[np.ndarray]],
BootstrapList[GradientDescentList[GroupByGroupList[np.ndarray]]],
]
# Soft losses
soft_losses: Optional[
Union[
GradientDescentList[GroupByGroupList[np.ndarray]],
Expand Down Expand Up @@ -101,7 +109,7 @@ class DiffPaSSModel(Module):
allowed_best_hits_cfg_keys = {"tau", "reciprocal"}

group_sizes: Sequence[int]
fixed_pairings: Optional[Sequence[Sequence[Sequence[int]]]]
fixed_pairings: Optional[IndexPairsInGroups]
permutation_cfg: Optional[dict[str, Any]]
effective_permutation_cfg_: dict[str, Any]
information_measure: str
Expand Down Expand Up @@ -182,7 +190,7 @@ def validate_best_hits_cfg(self, best_hits_cfg: Optional[dict]) -> None:
def init_permutation(
self,
group_sizes: Sequence[int],
fixed_pairings: Optional[Sequence[Sequence[Sequence[int]]]] = None,
fixed_pairings: Optional[IndexPairsInGroups] = None,
permutation_cfg: Optional[dict[str, Any]] = None,
) -> None:
self.group_sizes = tuple(s for s in group_sizes)
Expand Down Expand Up @@ -441,21 +449,39 @@ def _fit(

def fit(
self,
x: torch.Tensor,
y: torch.Tensor,
x: torch.Tensor, # The object (MSA or adjacency matrix of graphs) to be permuted
y: torch.Tensor, # The target object (MSA or adjacency matrix of graphs), that the objects represented by `x` should be paired with. Not acted upon by soft/hard permutations
*,
epochs: int = single_fit_default_cfg["epochs"],
optimizer_name: Optional[str] = single_fit_default_cfg["optimizer_name"],
epochs: int = single_fit_default_cfg[
"epochs"
], # Number of gradient descent steps
optimizer_name: Optional[str] = single_fit_default_cfg[
"optimizer_name"
], # If not ``None``, name of the optimizer. Default: ``"SGD"``
optimizer_kwargs: Optional[dict[str, Any]] = single_fit_default_cfg[
"optimizer_kwargs"
],
mean_centering: bool = single_fit_default_cfg["mean_centering"],
show_pbar: bool = single_fit_default_cfg["show_pbar"],
compute_final_soft: bool = single_fit_default_cfg["compute_final_soft"],
record_log_alphas: bool = single_fit_default_cfg["record_log_alphas"],
record_soft_perms: bool = single_fit_default_cfg["record_soft_perms"],
record_soft_losses: bool = single_fit_default_cfg["record_soft_losses"],
) -> DiffPaSSResults:
], # If not ``None``, keyword arguments for the optimizer. Default: ``None``
mean_centering: bool = single_fit_default_cfg[
"mean_centering"
], # If ``True``, mean-center log-alphas (stopping gradients) after each gradient descent step. Default: ``False``
show_pbar: bool = single_fit_default_cfg[
"show_pbar"
], # If ``True``, show progress bar. Default: ``False``
compute_final_soft: bool = single_fit_default_cfg[
"compute_final_soft"
], # If ``True``, compute soft permutations and losses after the last gradient descent step. Default: ``False``
record_log_alphas: bool = single_fit_default_cfg[
"record_log_alphas"
], # If ``True``, record log-alphas at each gradient descent step. Default: ``False``
record_soft_perms: bool = single_fit_default_cfg[
"record_soft_perms"
], # If ``True``, record soft permutations at each gradient descent step. Default: ``False``
record_soft_losses: bool = single_fit_default_cfg[
"record_soft_losses"
], # If ``True``, record soft losses at each gradient descent step. Default: ``False``
) -> (
DiffPaSSResults
): # `DiffPaSSResults` container for fit results. All attributes are lists indexed by gradient descent iteration
"""Fit permutations to data using gradient descent."""
self.prepare_fit(x, y)

Expand Down Expand Up @@ -485,21 +511,27 @@ def fit(

def fit_bootstrap(
self,
x: torch.Tensor,
y: torch.Tensor,
x: torch.Tensor, # The object (MSA or adjacency matrix of graphs) to be permuted
y: torch.Tensor, # The target object (MSA or adjacency matrix of graphs), that the objects represented by `x` should be paired with. Not acted upon by soft/hard permutations
*,
n_start: int = 1,
n_end: Optional[int] = None,
step_size: int = 1,
show_pbar: bool = True,
single_fit_cfg: Optional[dict] = None,
) -> DiffPaSSResults:
n_start: int = 1, # Number of fixed pairings to choose among the pairs not already fixed by `self.fixed_pairings`, using the results of the first call to `fit`
n_end: Optional[
int
] = None, # If ``None``, the bootstrap will end when all pairs are fixed. Otherwise, the bootstrap will end when `n_end` pairs are fixed
step_size: int = 1, # Difference between the number of fixed pairings chosen at consecutive bootstrap iterations
show_pbar: bool = True, # If ``True``, show progress bar. Default: ``True``
single_fit_cfg: Optional[
dict
] = None, # If not ``None``, configuration dictionary for gradient optimization in each bootstrap iteration (call to `fit`). See `fit` for details
) -> (
DiffPaSSResults
): # `DiffPaSSResults` container for fit results. All attributes are lists indexed by bootstrap iteration, containing lists indexed by gradient descent iteration as per `fit`
"""Fit permutations to data using the DiffPaSS bootstrap.
The DiffPaSS bootstrap consists of a sequence of short gradient descent runs (default: one epoch per run).
At the end of each run, a subset of the found pairings is chosen uniformly at random
and fixed for the next run. The number of pairings fixed at each iteration ranges
between `n_start` (default: 1) and `n_end` (default: total number of pairs), with a step size of `step_size`.
and fixed for the next run.
The number of pairings fixed at each iteration ranges between `n_start` (default: 1) and `n_end` (default: total number of pairs), with a step size of `step_size`.
"""
self.prepare_fit(x, y)
if self.fixed_pairings is None:
Expand Down
2 changes: 1 addition & 1 deletion diffpass/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from .constants import DEFAULT_AA_TO_INT


# Type aliases
SeqRecord = tuple[str, str]
SeqRecords = list[SeqRecord]
GroupwiseSeqRecords = dict[str, SeqRecords]
Expand Down
25 changes: 19 additions & 6 deletions diffpass/ipa_utils.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,30 @@
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/ipa_utils.ipynb.

# %% auto 0
__all__ = ['get_robust_pairs']
__all__ = ['BootstrapList', 'GradientDescentList', 'GroupByGroupList', 'IndexPair', 'IndexPairsInGroup', 'IndexPairsInGroups',
'get_robust_pairs']

# %% ../nbs/ipa_utils.ipynb 2
# %% ../nbs/ipa_utils.ipynb 4
from collections import defaultdict

import numpy as np
from .base import DiffPaSSResults

# %% ../nbs/ipa_utils.ipynb 3
# Type aliases
BootstrapList = list # List indexed by bootstrap iteration
GradientDescentList = list # List indexed by gradient descent iteration
GroupByGroupList = list # List indexed by group index

IndexPair = tuple[int, int] # Pair of indices
IndexPairsInGroup = list[IndexPair] # Pairs of indices in a group of sequences
IndexPairsInGroups = list[IndexPairsInGroup] # Pairs of indices in groups of sequences

# %% ../nbs/ipa_utils.ipynb 6
def get_robust_pairs(
all_hard_perms: list[list[list[np.ndarray]]], cutoff: float = 1.0
) -> list[list[tuple[int, int]]]:
bootstrap_results: DiffPaSSResults, # E.g. results of a run of `DiffPaSSModel.fit_bootstrap`
cutoff: float = 1.0, # Fraction of iterations a pair must be present in to be considered robust
) -> IndexPairsInGroups: # Robust pairs of indices in each group of sequences
"""Get robust pairs of indices from a `DiffPaSSResults` object."""
all_hard_perms = bootstrap_results.hard_perms
group_sizes = [len(hp) for hp in all_hard_perms[0][0]]
run_length = len(all_hard_perms)
absolute_cutoff = cutoff * run_length
Expand Down
Loading

0 comments on commit ebb160c

Please sign in to comment.