Skip to content

Commit

Permalink
[README] Improve quickstart, reinstate old bootstrap video link
Browse files Browse the repository at this point in the history
  • Loading branch information
ulupo committed May 9, 2024
1 parent 876d017 commit 6baf3b5
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 74 deletions.
45 changes: 25 additions & 20 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ containing interacting biological sequences, find the optimal one-to-one
pairing between the sequences in A and B.

<figure>
<img src="https://raw.githubusercontent.com/Bitbol-Lab/DiffPaSS/main/media/MSA_pairing_problem.svg" alt="MSA pairing problem" />
<img src="https://raw.githubusercontent.com/Bitbol-Lab/DiffPaSS/main/media/MSA_pairing_problem.svg" width="640" height="201.6" alt="MSA pairing problem" />
<figcaption>
Pairing problem for two multiple sequence alignments, where pairings are
restricted to be within the same species
Expand Down Expand Up @@ -84,7 +84,7 @@ ingredients are as follows:
the DiffPaSS-Iterative Pairing Algorithm (DiffPaSS-IPA).

<figure>
<video src="https://raw.githubusercontent.com/Bitbol-Lab/DiffPaSS/main/media/DiffPaSS_bootstrap.mp4" width="432" height="243" controls>
<video src="https://github.com/Bitbol-Lab/DiffPaSS/assets/46537483/e411fe8c-2fed-4723-a25c-ff69a1abccec" width="640" height="360" controls>
</video>
<figcaption>
The DiffPaSS bootstrap technique and robust pairs
Expand Down Expand Up @@ -129,9 +129,9 @@ into a list of tuples `(header, sequence)` using
``` python
from diffpass.msa_parsing import read_msa

# Parse and one-hot encode the MSAs
msa_data_A = read_msa("path/to/msa_A.fasta")
msa_data_B = read_msa("path/to/msa_B.fasta")
# Parse the MSAs into lists of tuples (header, sequence)
msa_A = read_msa("path/to/msa_A.fasta")
msa_B = read_msa("path/to/msa_B.fasta")
```

We assume that the MSAs contain species information in the headers,
Expand All @@ -150,8 +150,8 @@ This function will be used to group the sequences by species:
``` python
from diffpass.data_utils import create_groupwise_seq_records

msa_data_A_species_by_species = create_groupwise_seq_records(msa_data_A, species_name_func)
msa_data_B_species_by_species = create_groupwise_seq_records(msa_data_B, species_name_func)
msa_A_by_sp = create_groupwise_seq_records(msa_A, species_name_func)
msa_B_by_sp = create_groupwise_seq_records(msa_B, species_name_func)
```

If one of the MSAs contains sequences from species not present in the
Expand All @@ -160,8 +160,8 @@ other MSA, we can remove these species from both MSAs:
``` python
from diffpass.data_utils import remove_groups_not_in_both

msa_data_A_species_by_species, msa_data_B_species_by_species = remove_groups_not_in_both(
msa_data_A_species_by_species, msa_data_B_species_by_species
msa_A_by_sp, msa_B_by_sp = remove_groups_not_in_both(
msa_A_by_sp, msa_B_by_sp
)
```

Expand All @@ -173,12 +173,12 @@ consisting entirely of gap symbols:
``` python
from diffpass.data_utils import pad_msas_with_dummy_sequences

msa_data_A_species_by_species_padded, msa_data_B_species_by_species_padded = pad_msas_with_dummy_sequences(
msa_data_A_species_by_species, msa_data_B_species_by_species
msa_A_by_sp_pad, msa_B_by_sp_pad = pad_msas_with_dummy_sequences(
msa_A_by_sp, msa_B_by_sp
)

species = list(msa_data_A_species_by_species_padded.keys())
species_sizes = list(map(len, msa_data_A_species_by_species_padded.values()))
species = list(msa_A_by_sp_pad.keys())
species_sizes = list(map(len, msa_A_by_sp_pad.values()))
```

Next, one-hot encode the MSAs using the
Expand All @@ -191,23 +191,28 @@ from diffpass.data_utils import one_hot_encode_msa
device = "cuda" if torch.cuda.is_available() else "cpu"

# Unpack the padded MSAs into a list of records
msa_data_A_for_pairing = [record for records_this_species in msa_data_A_species_by_species_padded.values() for record in records_this_species]
msa_data_B_for_pairing = [record for records_this_species in msa_data_B_species_by_species_padded.values() for record in records_this_species]
msa_A_for_pairing = [
rec for recs_this_sp in msa_A_by_sp_pad.values() for rec in recs_this_sp
]
msa_B_for_pairing = [
rec for recs_this_sp in msa_B_by_sp_pad.values() for rec in recs_this_sp
]

# One-hot encode the MSAs and load them to a device
msa_A_oh = one_hot_encode_msa(msa_data_A_for_pairing, device=device)
msa_B_oh = one_hot_encode_msa(msa_data_B_for_pairing, device=device)
msa_A_oh = one_hot_encode_msa(msa_A_for_pairing, device=device)
msa_B_oh = one_hot_encode_msa(msa_B_for_pairing, device=device)
```

### Pairing optimization

Finally, we can instantiate an
[`InformationPairing`](https://Bitbol-Lab.github.io/DiffPaSS/train.html#informationpairing)
object and optimize the mutual information between the paired MSAs using
the DiffPaSS bootstrap algorithm. The results are stored in a
the DiffPaSS bootstrapped optimization algorithm. The results are stored
in a
[`DiffPaSSResults`](https://Bitbol-Lab.github.io/DiffPaSS/base.html#diffpassresults)
container. The lists of (hard) losses and permutations found can be
accessed as attributes of the container.
container. The lists of (hard) losses and permutations found during the
optimization can be accessed as attributes of the container.

``` python
from diffpass.train import InformationPairing
Expand Down
40 changes: 22 additions & 18 deletions nbs/index.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
"A typical example of the problem DiffPaSS is designed to solve is the following: given two multiple sequence alignments (MSAs) A and B, containing interacting biological sequences, find the optimal one-to-one pairing between the sequences in A and B.\n",
"\n",
"<figure>\n",
" <img src=\"https://raw.githubusercontent.com/Bitbol-Lab/DiffPaSS/main/media/MSA_pairing_problem.svg\" alt=\"MSA pairing problem\" />\n",
" <img src=\"https://raw.githubusercontent.com/Bitbol-Lab/DiffPaSS/main/media/MSA_pairing_problem.svg\" width=\"640\" height=\"201.6\" alt=\"MSA pairing problem\" />\n",
" <figcaption>Pairing problem for two multiple sequence alignments, where pairings are restricted to be within the same species</figcaption>\n",
"</figure>\n",
"\n",
Expand Down Expand Up @@ -51,7 +51,7 @@
" 4. A notion of \"robust pairs\" that can be used to identify pairs that are consistently found throughout a DiffPaSS bootstrap. These pairs can be used as ground truths in another DiffPaSS run, giving rise to the DiffPaSS-Iterative Pairing Algorithm (DiffPaSS-IPA).\n",
" \n",
"<figure>\n",
" <video src=\"https://raw.githubusercontent.com/Bitbol-Lab/DiffPaSS/main/media/DiffPaSS_bootstrap.mp4\" width=\"432\" height=\"243\" controls></video>\n",
" <video src=\"https://github.com/Bitbol-Lab/DiffPaSS/assets/46537483/e411fe8c-2fed-4723-a25c-ff69a1abccec\" width=\"640\" height=\"360\" controls></video>\n",
" <figcaption>The DiffPaSS bootstrap technique and robust pairs</figcaption>\n",
"</figure>"
]
Expand Down Expand Up @@ -98,9 +98,9 @@
"```python\n",
"from diffpass.msa_parsing import read_msa\n",
"\n",
"# Parse and one-hot encode the MSAs\n",
"msa_data_A = read_msa(\"path/to/msa_A.fasta\")\n",
"msa_data_B = read_msa(\"path/to/msa_B.fasta\")\n",
"# Parse the MSAs into lists of tuples (header, sequence)\n",
"msa_A = read_msa(\"path/to/msa_A.fasta\")\n",
"msa_B = read_msa(\"path/to/msa_B.fasta\")\n",
"```\n",
"\n",
"We assume that the MSAs contain species information in the headers, which will be used to restrict the pairings to be within the same species (more generally, \"groups\"). We need a simple function to extract the species information from the headers. For instance, if the headers are in the format `>sequence_id|species_name|...`, we can use:\n",
Expand All @@ -115,17 +115,17 @@
"```python\n",
"from diffpass.data_utils import create_groupwise_seq_records\n",
"\n",
"msa_data_A_species_by_species = create_groupwise_seq_records(msa_data_A, species_name_func)\n",
"msa_data_B_species_by_species = create_groupwise_seq_records(msa_data_B, species_name_func)\n",
"msa_A_by_sp = create_groupwise_seq_records(msa_A, species_name_func)\n",
"msa_B_by_sp = create_groupwise_seq_records(msa_B, species_name_func)\n",
"```\n",
"\n",
"If one of the MSAs contains sequences from species not present in the other MSA, we can remove these species from both MSAs:\n",
"\n",
"```python\n",
"from diffpass.data_utils import remove_groups_not_in_both\n",
"\n",
"msa_data_A_species_by_species, msa_data_B_species_by_species = remove_groups_not_in_both(\n",
" msa_data_A_species_by_species, msa_data_B_species_by_species\n",
"msa_A_by_sp, msa_B_by_sp = remove_groups_not_in_both(\n",
" msa_A_by_sp, msa_B_by_sp\n",
")\n",
"```\n",
"\n",
Expand All @@ -134,12 +134,12 @@
"```python\n",
"from diffpass.data_utils import pad_msas_with_dummy_sequences\n",
"\n",
"msa_data_A_species_by_species_padded, msa_data_B_species_by_species_padded = pad_msas_with_dummy_sequences(\n",
" msa_data_A_species_by_species, msa_data_B_species_by_species\n",
"msa_A_by_sp_pad, msa_B_by_sp_pad = pad_msas_with_dummy_sequences(\n",
" msa_A_by_sp, msa_B_by_sp\n",
")\n",
"\n",
"species = list(msa_data_A_species_by_species_padded.keys())\n",
"species_sizes = list(map(len, msa_data_A_species_by_species_padded.values()))\n",
"species = list(msa_A_by_sp_pad.keys())\n",
"species_sizes = list(map(len, msa_A_by_sp_pad.values()))\n",
"```\n",
"\n",
"Next, one-hot encode the MSAs using the `one_hot_encode_msa` function.\n",
Expand All @@ -150,17 +150,21 @@
"device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
"\n",
"# Unpack the padded MSAs into a list of records\n",
"msa_data_A_for_pairing = [record for records_this_species in msa_data_A_species_by_species_padded.values() for record in records_this_species]\n",
"msa_data_B_for_pairing = [record for records_this_species in msa_data_B_species_by_species_padded.values() for record in records_this_species]\n",
"msa_A_for_pairing = [\n",
" rec for recs_this_sp in msa_A_by_sp_pad.values() for rec in recs_this_sp\n",
"]\n",
"msa_B_for_pairing = [\n",
" rec for recs_this_sp in msa_B_by_sp_pad.values() for rec in recs_this_sp\n",
"]\n",
"\n",
"# One-hot encode the MSAs and load them to a device\n",
"msa_A_oh = one_hot_encode_msa(msa_data_A_for_pairing, device=device)\n",
"msa_B_oh = one_hot_encode_msa(msa_data_B_for_pairing, device=device)\n",
"msa_A_oh = one_hot_encode_msa(msa_A_for_pairing, device=device)\n",
"msa_B_oh = one_hot_encode_msa(msa_B_for_pairing, device=device)\n",
"```\n",
"\n",
"### Pairing optimization\n",
"\n",
"Finally, we can instantiate an `InformationPairing` object and optimize the mutual information between the paired MSAs using the DiffPaSS bootstrap algorithm. The results are stored in a `DiffPaSSResults` container. The lists of (hard) losses and permutations found can be accessed as attributes of the container.\n",
"Finally, we can instantiate an `InformationPairing` object and optimize the mutual information between the paired MSAs using the DiffPaSS bootstrapped optimization algorithm. The results are stored in a `DiffPaSSResults` container. The lists of (hard) losses and permutations found during the optimization can be accessed as attributes of the container.\n",
"\n",
"```python\n",
"from diffpass.train import InformationPairing\n",
Expand Down
72 changes: 36 additions & 36 deletions nbs/model.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"source": [
"# model\n",
"\n",
"> DiffPaSS models for optimizing MSA pairing"
"> DiffPaSS modules for optimizing permutations and computing soft scores"
]
},
{
Expand Down Expand Up @@ -392,6 +392,41 @@
" return torch.gather(x_permuted_rows, -1, index)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/markdown": "---\n\n[source](https://github.com/Bitbol-Lab/DiffPaSS/blob/main/diffpass/model.py#L49){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n\n### GeneralizedPermutation\n\n> GeneralizedPermutation (group_sizes:collections.abc.Iterable[int], fixed_\n> pairings:Optional[collections.abc.Sequence[collec\n> tions.abc.Sequence[collections.abc.Sequence[int]]\n> ]]=None, tau:float=1.0, n_iter:int=1,\n> noise:bool=False, noise_factor:float=1.0,\n> noise_std:bool=False,\n> mode:Literal['soft','hard']='soft')\n\nGeneralized permutation layer implementing both soft and hard permutations.",
"text/plain": [
"---\n",
"\n",
"[source](https://github.com/Bitbol-Lab/DiffPaSS/blob/main/diffpass/model.py#L49){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### GeneralizedPermutation\n",
"\n",
"> GeneralizedPermutation (group_sizes:collections.abc.Iterable[int], fixed_\n",
"> pairings:Optional[collections.abc.Sequence[collec\n",
"> tions.abc.Sequence[collections.abc.Sequence[int]]\n",
"> ]]=None, tau:float=1.0, n_iter:int=1,\n",
"> noise:bool=False, noise_factor:float=1.0,\n",
"> noise_std:bool=False,\n",
"> mode:Literal['soft','hard']='soft')\n",
"\n",
"Generalized permutation layer implementing both soft and hard permutations."
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"show_doc(GeneralizedPermutation)"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down Expand Up @@ -449,41 +484,6 @@
"test_batch_perm((2, 5, 4, 4))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/markdown": "---\n\n[source](https://github.com/Bitbol-Lab/DiffPaSS/blob/main/diffpass/model.py#L49){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n\n### GeneralizedPermutation\n\n> GeneralizedPermutation (group_sizes:collections.abc.Iterable[int], fixed_\n> pairings:Optional[collections.abc.Sequence[collec\n> tions.abc.Sequence[collections.abc.Sequence[int]]\n> ]]=None, tau:float=1.0, n_iter:int=1,\n> noise:bool=False, noise_factor:float=1.0,\n> noise_std:bool=False,\n> mode:Literal['soft','hard']='soft')\n\nGeneralized permutation layer implementing both soft and hard permutations.",
"text/plain": [
"---\n",
"\n",
"[source](https://github.com/Bitbol-Lab/DiffPaSS/blob/main/diffpass/model.py#L49){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### GeneralizedPermutation\n",
"\n",
"> GeneralizedPermutation (group_sizes:collections.abc.Iterable[int], fixed_\n",
"> pairings:Optional[collections.abc.Sequence[collec\n",
"> tions.abc.Sequence[collections.abc.Sequence[int]]\n",
"> ]]=None, tau:float=1.0, n_iter:int=1,\n",
"> noise:bool=False, noise_factor:float=1.0,\n",
"> noise_std:bool=False,\n",
"> mode:Literal['soft','hard']='soft')\n",
"\n",
"Generalized permutation layer implementing both soft and hard permutations."
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"show_doc(GeneralizedPermutation)"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down

0 comments on commit 6baf3b5

Please sign in to comment.