-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add ops, modules, and training classes
Co-authored-by: Umberto Lupo <[email protected]> Co-authored-by: Damiano Sgarbossa <[email protected]>
- Loading branch information
1 parent
ec31485
commit e99c2f5
Showing
30 changed files
with
210,531 additions
and
108 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
repos: | ||
- repo: https://github.com/pre-commit/pre-commit-hooks | ||
rev: v2.3.0 | ||
hooks: | ||
- id: check-yaml | ||
- id: end-of-file-fixer | ||
- id: trailing-whitespace | ||
- repo: https://github.com/fastai/nbdev | ||
rev: 2.3.11 | ||
hooks: | ||
- id: nbdev_clean | ||
- id: nbdev_export |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,21 +2,43 @@ | |
|
||
<!-- WARNING: THIS FILE WAS AUTOGENERATED! DO NOT EDIT! --> | ||
|
||
This file will become your README and also the index of your | ||
documentation. | ||
Description: TODO. | ||
|
||
## Install | ||
|
||
Clone this repository on your local machine by running and move inside | ||
the root folder. We recommend creating and activating a dedicated conda | ||
or virtualenv Python virtual environment. | ||
|
||
``` sh | ||
git clone [email protected]:Bitbol-Lab/DiffPASS.git | ||
``` | ||
|
||
and move inside the root folder. We recommend creating and activating a | ||
dedicated conda or virtualenv Python virtual environment. Then, make an | ||
editable install of the package: | ||
|
||
``` sh | ||
pip install DiffPASS | ||
python -m pip install -e . | ||
``` | ||
|
||
## How to use | ||
|
||
Fill me in please! Don’t forget code examples: | ||
See the | ||
[`_example_prokaryotic.ipynb`](https://github.com/Bitbol-Lab/DiffPALM/blob/main/nbs/_example_prokaryotic.ipynb) | ||
notebook for an example of paired MSA optimization in the case of | ||
well-known prokaryotic datasets, for which ground truth matchings are | ||
given by genome proximity. | ||
|
||
## Citation | ||
|
||
``` python | ||
1+1 | ||
TODO | ||
|
||
``` bibtex | ||
@article{ | ||
} | ||
``` | ||
|
||
2 | ||
## nbdev | ||
|
||
This project has been developed using [nbdev](https://nbdev.fast.ai/). |
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/base.ipynb. | ||
|
||
# %% auto 0 | ||
__all__ = ['DiffPASSMixin', 'scalar_or_1d_tensor', 'EnsembleMixin'] | ||
|
||
# %% ../nbs/base.ipynb 2 | ||
# Stdlib imports | ||
from collections.abc import Iterable, Sequence | ||
from typing import Optional, Union, Any | ||
|
||
# NumPy | ||
# import numpy as np | ||
|
||
# PyTorch | ||
import torch | ||
|
||
# PLOTTING | ||
# from matplotlib import colormaps as cm | ||
# import matplotlib.pyplot as plt | ||
# from matplotlib.colors import CenteredNorm | ||
|
||
# %% ../nbs/base.ipynb 3 | ||
class DiffPASSMixin: | ||
group_sizes: Iterable[int] | ||
|
||
@staticmethod | ||
def reduce_num_tokens(x: torch.Tensor) -> torch.Tensor: | ||
"""Reduce the number of tokens in a one-hot encoded tensor.""" | ||
used_tokens = x.clone() | ||
for _ in range(x.ndim - 1): | ||
used_tokens = used_tokens.any(-2) | ||
|
||
return x[..., used_tokens] | ||
|
||
def validate_inputs( | ||
self, x: torch.Tensor, y: torch.Tensor, check_same_alphabet_size: bool = False | ||
) -> None: | ||
"""Validate input tensors representing aligned objects.""" | ||
size_x, length_x, alphabet_size_x = x.shape | ||
size_y, length_y, alphabet_size_y = y.shape | ||
if size_x != size_x: | ||
raise ValueError(f"Size mismatch between x ({size_x}) and y ({size_y}).") | ||
if check_same_alphabet_size and (alphabet_size_x != alphabet_size_y): | ||
raise ValueError("Inputs must have the same alphabet size.") | ||
|
||
# Validate size attribute | ||
total_size = sum(self.group_sizes) | ||
if size_x != total_size: | ||
raise ValueError( | ||
f"Inputs have size {total_size} but `group_sizes` implies a total " | ||
f"size of {total_size}." | ||
) | ||
|
||
# %% ../nbs/base.ipynb 4 | ||
def scalar_or_1d_tensor( | ||
*, param: Any, param_name: str, dtype: torch.dtype = torch.float32 | ||
) -> torch.Tensor: | ||
if not isinstance(param, (float, torch.Tensor)): | ||
raise TypeError(f"`{param_name}` must be a float or a torch.Tensor.") | ||
if isinstance(param, float): | ||
param = torch.tensor(param, dtype=dtype) | ||
elif param.ndim > 1: | ||
raise ValueError( | ||
f"`{param_name}` must be a scalar or a tensor of dimension <= 1." | ||
) | ||
|
||
return param | ||
|
||
|
||
class EnsembleMixin: | ||
def _validate_ensemble_param( | ||
self, | ||
*, | ||
param: Union[float, torch.Tensor], | ||
param_name: str, | ||
ensemble_shape: Sequence[int], | ||
dim_in_ensemble: Optional[int] = None, | ||
n_dims_per_instance: Optional[int] = None, | ||
) -> torch.Tensor: | ||
param = scalar_or_1d_tensor(param=param, param_name=param_name) | ||
|
||
param = self._reshape_ensemble_param( | ||
param=param, | ||
ensemble_shape=ensemble_shape, | ||
dim_in_ensemble=dim_in_ensemble, | ||
n_dims_per_instance=n_dims_per_instance, | ||
param_name=param_name, | ||
) | ||
|
||
return param | ||
|
||
@staticmethod | ||
def _reshape_ensemble_param( | ||
*, | ||
param: torch.Tensor, | ||
ensemble_shape: Sequence[int], | ||
dim_in_ensemble: Optional[int], | ||
n_dims_per_instance: int, | ||
param_name: str, | ||
) -> torch.Tensor: | ||
n_ensemble_dims = len(ensemble_shape) | ||
if param.ndim == 1: | ||
if dim_in_ensemble is None: | ||
raise ValueError( | ||
f"`dim_in_ensemble` cannot be None if {param_name} is 1D." | ||
) | ||
param = param.to(torch.float32) | ||
# If param is not a scalar, broadcast it along the `ensemble_dim`-th ensemble dimension | ||
if dim_in_ensemble >= n_ensemble_dims or dim_in_ensemble < -n_ensemble_dims: | ||
raise ValueError( | ||
f"Ensemble dimension for {param_name} must be an available index " | ||
f"in `ensemble_shape`." | ||
) | ||
elif len(param) != ensemble_shape[dim_in_ensemble]: | ||
raise ValueError( | ||
f"Parameter `{param_name}` must have the same length as " | ||
f"``ensemble_shape[dim_in_ensemble]`` = " | ||
f"{ensemble_shape[dim_in_ensemble]}." | ||
) | ||
new_shape = ( | ||
(1,) * dim_in_ensemble | ||
+ param.shape | ||
+ (1,) * (n_ensemble_dims - dim_in_ensemble - 1) | ||
+ (1,) * n_dims_per_instance | ||
) | ||
param = param.view(*new_shape) | ||
|
||
return param |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/constants.ipynb. | ||
|
||
# %% auto 0 | ||
__all__ = ['BLOSUM62', 'DEFAULT_TOKENS', 'DEFAULT_AA_TO_INT', 'SubstitutionMatrix', 'TokenizedSubstitutionMatrix', | ||
'get_blosum62_data'] | ||
|
||
# %% ../nbs/constants.ipynb 2 | ||
from dataclasses import dataclass | ||
from typing import Optional | ||
from io import StringIO | ||
|
||
import pandas as pd | ||
|
||
import torch | ||
|
||
|
||
@dataclass | ||
class SubstitutionMatrix: | ||
name: str | ||
mat: pd.DataFrame | ||
expected_value: float | ||
|
||
|
||
@dataclass | ||
class TokenizedSubstitutionMatrix: | ||
name: str | ||
mat: torch.Tensor | ||
expected_value: float | ||
|
||
|
||
BLOSUM62 = SubstitutionMatrix( | ||
name="BLOSUM62", | ||
mat=pd.read_csv( | ||
StringIO( | ||
""" A R N D C Q E G H I L K M F P S T W Y V B Z X * | ||
A 4 -1 -2 -2 0 -1 -1 0 -2 -1 -1 -1 -1 -2 -1 1 0 -3 -2 0 -2 -1 0 -4 | ||
R -1 5 0 -2 -3 1 0 -2 0 -3 -2 2 -1 -3 -2 -1 -1 -3 -2 -3 -1 0 -1 -4 | ||
N -2 0 6 1 -3 0 0 0 1 -3 -3 0 -2 -3 -2 1 0 -4 -2 -3 3 0 -1 -4 | ||
D -2 -2 1 6 -3 0 2 -1 -1 -3 -4 -1 -3 -3 -1 0 -1 -4 -3 -3 4 1 -1 -4 | ||
C 0 -3 -3 -3 9 -3 -4 -3 -3 -1 -1 -3 -1 -2 -3 -1 -1 -2 -2 -1 -3 -3 -2 -4 | ||
Q -1 1 0 0 -3 5 2 -2 0 -3 -2 1 0 -3 -1 0 -1 -2 -1 -2 0 3 -1 -4 | ||
E -1 0 0 2 -4 2 5 -2 0 -3 -3 1 -2 -3 -1 0 -1 -3 -2 -2 1 4 -1 -4 | ||
G 0 -2 0 -1 -3 -2 -2 6 -2 -4 -4 -2 -3 -3 -2 0 -2 -2 -3 -3 -1 -2 -1 -4 | ||
H -2 0 1 -1 -3 0 0 -2 8 -3 -3 -1 -2 -1 -2 -1 -2 -2 2 -3 0 0 -1 -4 | ||
I -1 -3 -3 -3 -1 -3 -3 -4 -3 4 2 -3 1 0 -3 -2 -1 -3 -1 3 -3 -3 -1 -4 | ||
L -1 -2 -3 -4 -1 -2 -3 -4 -3 2 4 -2 2 0 -3 -2 -1 -2 -1 1 -4 -3 -1 -4 | ||
K -1 2 0 -1 -3 1 1 -2 -1 -3 -2 5 -1 -3 -1 0 -1 -3 -2 -2 0 1 -1 -4 | ||
M -1 -1 -2 -3 -1 0 -2 -3 -2 1 2 -1 5 0 -2 -1 -1 -1 -1 1 -3 -1 -1 -4 | ||
F -2 -3 -3 -3 -2 -3 -3 -3 -1 0 0 -3 0 6 -4 -2 -2 1 3 -1 -3 -3 -1 -4 | ||
P -1 -2 -2 -1 -3 -1 -1 -2 -2 -3 -3 -1 -2 -4 7 -1 -1 -4 -3 -2 -2 -1 -2 -4 | ||
S 1 -1 1 0 -1 0 0 0 -1 -2 -2 0 -1 -2 -1 4 1 -3 -2 -2 0 0 0 -4 | ||
T 0 -1 0 -1 -1 -1 -1 -2 -2 -1 -1 -1 -1 -2 -1 1 5 -2 -2 0 -1 -1 0 -4 | ||
W -3 -3 -4 -4 -2 -2 -3 -2 -2 -3 -2 -3 -1 1 -4 -3 -2 11 2 -3 -4 -3 -2 -4 | ||
Y -2 -2 -2 -3 -2 -1 -2 -3 2 -1 -1 -2 -1 3 -3 -2 -2 2 7 -1 -3 -2 -1 -4 | ||
V 0 -3 -3 -3 -1 -2 -2 -3 -3 3 1 -2 1 -1 -2 -2 0 -3 -1 4 -3 -2 -1 -4 | ||
B -2 -1 3 4 -3 0 1 -1 0 -3 -4 0 -3 -3 -2 0 -1 -4 -3 -3 4 1 -1 -4 | ||
Z -1 0 0 1 -3 3 4 -2 0 -3 -3 1 -1 -3 -1 0 -1 -3 -2 -2 1 4 -1 -4 | ||
X 0 -1 -1 -1 -2 -1 -1 -1 -1 -1 -1 -1 -1 -1 -2 0 0 -2 -1 -1 -1 -1 -1 -4 | ||
* -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 1 | ||
""" | ||
), | ||
index_col=0, | ||
header=0, | ||
sep="\s+", | ||
), | ||
expected_value=-0.5209, | ||
) | ||
|
||
|
||
DEFAULT_TOKENS = "-ACDEFGHIKLMNPQRSTVWY" | ||
DEFAULT_AA_TO_INT = dict(zip(DEFAULT_TOKENS, range(len(DEFAULT_TOKENS)))) | ||
|
||
|
||
def get_blosum62_data( | ||
aa_to_int: Optional[dict[str, int]] = None, | ||
gaps_as_stars: bool = False, | ||
) -> TokenizedSubstitutionMatrix: | ||
aa_to_int = (DEFAULT_AA_TO_INT if aa_to_int is None else aa_to_int).copy() | ||
|
||
mat = BLOSUM62.mat.copy() | ||
if gaps_as_stars: | ||
if "-" in aa_to_int and "*" in aa_to_int: | ||
raise ValueError( | ||
"Cannot have both gaps and stars in `aa_to_int` if `gaps_as_stars` is True" | ||
) | ||
aa_to_int["*"] = aa_to_int.pop("-") | ||
else: | ||
mat.loc["-"] = 0 | ||
mat.loc[:, "-"] = 0 | ||
aa_to_int = dict(sorted(aa_to_int.items(), key=lambda x: x[1])) | ||
|
||
mat = mat.loc[list(aa_to_int), list(aa_to_int)] | ||
|
||
return TokenizedSubstitutionMatrix( | ||
name=BLOSUM62.name, | ||
mat=torch.tensor(mat.to_numpy(), dtype=torch.float32), | ||
expected_value=BLOSUM62.expected_value, | ||
) |
This file was deleted.
Oops, something went wrong.
Oops, something went wrong.