Skip to content

Commit

Permalink
Add ops, modules, and training classes
Browse files Browse the repository at this point in the history
Co-authored-by: Umberto Lupo <[email protected]>
Co-authored-by: Damiano Sgarbossa <[email protected]>
  • Loading branch information
3 people committed Oct 13, 2023
1 parent ec31485 commit e99c2f5
Show file tree
Hide file tree
Showing 30 changed files with 210,531 additions and 108 deletions.
12 changes: 12 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v2.3.0
hooks:
- id: check-yaml
- id: end-of-file-fixer
- id: trailing-whitespace
- repo: https://github.com/fastai/nbdev
rev: 2.3.11
hooks:
- id: nbdev_clean
- id: nbdev_export
36 changes: 29 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,43 @@

<!-- WARNING: THIS FILE WAS AUTOGENERATED! DO NOT EDIT! -->

This file will become your README and also the index of your
documentation.
Description: TODO.

## Install

Clone this repository on your local machine by running and move inside
the root folder. We recommend creating and activating a dedicated conda
or virtualenv Python virtual environment.

``` sh
git clone [email protected]:Bitbol-Lab/DiffPASS.git
```

and move inside the root folder. We recommend creating and activating a
dedicated conda or virtualenv Python virtual environment. Then, make an
editable install of the package:

``` sh
pip install DiffPASS
python -m pip install -e .
```

## How to use

Fill me in please! Don’t forget code examples:
See the
[`_example_prokaryotic.ipynb`](https://github.com/Bitbol-Lab/DiffPALM/blob/main/nbs/_example_prokaryotic.ipynb)
notebook for an example of paired MSA optimization in the case of
well-known prokaryotic datasets, for which ground truth matchings are
given by genome proximity.

## Citation

``` python
1+1
TODO

``` bibtex
@article{
}
```

2
## nbdev

This project has been developed using [nbdev](https://nbdev.fast.ai/).
152 changes: 151 additions & 1 deletion diffpass/_modidx.py

Large diffs are not rendered by default.

128 changes: 128 additions & 0 deletions diffpass/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/base.ipynb.

# %% auto 0
__all__ = ['DiffPASSMixin', 'scalar_or_1d_tensor', 'EnsembleMixin']

# %% ../nbs/base.ipynb 2
# Stdlib imports
from collections.abc import Iterable, Sequence
from typing import Optional, Union, Any

# NumPy
# import numpy as np

# PyTorch
import torch

# PLOTTING
# from matplotlib import colormaps as cm
# import matplotlib.pyplot as plt
# from matplotlib.colors import CenteredNorm

# %% ../nbs/base.ipynb 3
class DiffPASSMixin:
group_sizes: Iterable[int]

@staticmethod
def reduce_num_tokens(x: torch.Tensor) -> torch.Tensor:
"""Reduce the number of tokens in a one-hot encoded tensor."""
used_tokens = x.clone()
for _ in range(x.ndim - 1):
used_tokens = used_tokens.any(-2)

return x[..., used_tokens]

def validate_inputs(
self, x: torch.Tensor, y: torch.Tensor, check_same_alphabet_size: bool = False
) -> None:
"""Validate input tensors representing aligned objects."""
size_x, length_x, alphabet_size_x = x.shape
size_y, length_y, alphabet_size_y = y.shape
if size_x != size_x:
raise ValueError(f"Size mismatch between x ({size_x}) and y ({size_y}).")
if check_same_alphabet_size and (alphabet_size_x != alphabet_size_y):
raise ValueError("Inputs must have the same alphabet size.")

# Validate size attribute
total_size = sum(self.group_sizes)
if size_x != total_size:
raise ValueError(
f"Inputs have size {total_size} but `group_sizes` implies a total "
f"size of {total_size}."
)

# %% ../nbs/base.ipynb 4
def scalar_or_1d_tensor(
*, param: Any, param_name: str, dtype: torch.dtype = torch.float32
) -> torch.Tensor:
if not isinstance(param, (float, torch.Tensor)):
raise TypeError(f"`{param_name}` must be a float or a torch.Tensor.")
if isinstance(param, float):
param = torch.tensor(param, dtype=dtype)
elif param.ndim > 1:
raise ValueError(
f"`{param_name}` must be a scalar or a tensor of dimension <= 1."
)

return param


class EnsembleMixin:
def _validate_ensemble_param(
self,
*,
param: Union[float, torch.Tensor],
param_name: str,
ensemble_shape: Sequence[int],
dim_in_ensemble: Optional[int] = None,
n_dims_per_instance: Optional[int] = None,
) -> torch.Tensor:
param = scalar_or_1d_tensor(param=param, param_name=param_name)

param = self._reshape_ensemble_param(
param=param,
ensemble_shape=ensemble_shape,
dim_in_ensemble=dim_in_ensemble,
n_dims_per_instance=n_dims_per_instance,
param_name=param_name,
)

return param

@staticmethod
def _reshape_ensemble_param(
*,
param: torch.Tensor,
ensemble_shape: Sequence[int],
dim_in_ensemble: Optional[int],
n_dims_per_instance: int,
param_name: str,
) -> torch.Tensor:
n_ensemble_dims = len(ensemble_shape)
if param.ndim == 1:
if dim_in_ensemble is None:
raise ValueError(
f"`dim_in_ensemble` cannot be None if {param_name} is 1D."
)
param = param.to(torch.float32)
# If param is not a scalar, broadcast it along the `ensemble_dim`-th ensemble dimension
if dim_in_ensemble >= n_ensemble_dims or dim_in_ensemble < -n_ensemble_dims:
raise ValueError(
f"Ensemble dimension for {param_name} must be an available index "
f"in `ensemble_shape`."
)
elif len(param) != ensemble_shape[dim_in_ensemble]:
raise ValueError(
f"Parameter `{param_name}` must have the same length as "
f"``ensemble_shape[dim_in_ensemble]`` = "
f"{ensemble_shape[dim_in_ensemble]}."
)
new_shape = (
(1,) * dim_in_ensemble
+ param.shape
+ (1,) * (n_ensemble_dims - dim_in_ensemble - 1)
+ (1,) * n_dims_per_instance
)
param = param.view(*new_shape)

return param
98 changes: 98 additions & 0 deletions diffpass/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/constants.ipynb.

# %% auto 0
__all__ = ['BLOSUM62', 'DEFAULT_TOKENS', 'DEFAULT_AA_TO_INT', 'SubstitutionMatrix', 'TokenizedSubstitutionMatrix',
'get_blosum62_data']

# %% ../nbs/constants.ipynb 2
from dataclasses import dataclass
from typing import Optional
from io import StringIO

import pandas as pd

import torch


@dataclass
class SubstitutionMatrix:
name: str
mat: pd.DataFrame
expected_value: float


@dataclass
class TokenizedSubstitutionMatrix:
name: str
mat: torch.Tensor
expected_value: float


BLOSUM62 = SubstitutionMatrix(
name="BLOSUM62",
mat=pd.read_csv(
StringIO(
""" A R N D C Q E G H I L K M F P S T W Y V B Z X *
A 4 -1 -2 -2 0 -1 -1 0 -2 -1 -1 -1 -1 -2 -1 1 0 -3 -2 0 -2 -1 0 -4
R -1 5 0 -2 -3 1 0 -2 0 -3 -2 2 -1 -3 -2 -1 -1 -3 -2 -3 -1 0 -1 -4
N -2 0 6 1 -3 0 0 0 1 -3 -3 0 -2 -3 -2 1 0 -4 -2 -3 3 0 -1 -4
D -2 -2 1 6 -3 0 2 -1 -1 -3 -4 -1 -3 -3 -1 0 -1 -4 -3 -3 4 1 -1 -4
C 0 -3 -3 -3 9 -3 -4 -3 -3 -1 -1 -3 -1 -2 -3 -1 -1 -2 -2 -1 -3 -3 -2 -4
Q -1 1 0 0 -3 5 2 -2 0 -3 -2 1 0 -3 -1 0 -1 -2 -1 -2 0 3 -1 -4
E -1 0 0 2 -4 2 5 -2 0 -3 -3 1 -2 -3 -1 0 -1 -3 -2 -2 1 4 -1 -4
G 0 -2 0 -1 -3 -2 -2 6 -2 -4 -4 -2 -3 -3 -2 0 -2 -2 -3 -3 -1 -2 -1 -4
H -2 0 1 -1 -3 0 0 -2 8 -3 -3 -1 -2 -1 -2 -1 -2 -2 2 -3 0 0 -1 -4
I -1 -3 -3 -3 -1 -3 -3 -4 -3 4 2 -3 1 0 -3 -2 -1 -3 -1 3 -3 -3 -1 -4
L -1 -2 -3 -4 -1 -2 -3 -4 -3 2 4 -2 2 0 -3 -2 -1 -2 -1 1 -4 -3 -1 -4
K -1 2 0 -1 -3 1 1 -2 -1 -3 -2 5 -1 -3 -1 0 -1 -3 -2 -2 0 1 -1 -4
M -1 -1 -2 -3 -1 0 -2 -3 -2 1 2 -1 5 0 -2 -1 -1 -1 -1 1 -3 -1 -1 -4
F -2 -3 -3 -3 -2 -3 -3 -3 -1 0 0 -3 0 6 -4 -2 -2 1 3 -1 -3 -3 -1 -4
P -1 -2 -2 -1 -3 -1 -1 -2 -2 -3 -3 -1 -2 -4 7 -1 -1 -4 -3 -2 -2 -1 -2 -4
S 1 -1 1 0 -1 0 0 0 -1 -2 -2 0 -1 -2 -1 4 1 -3 -2 -2 0 0 0 -4
T 0 -1 0 -1 -1 -1 -1 -2 -2 -1 -1 -1 -1 -2 -1 1 5 -2 -2 0 -1 -1 0 -4
W -3 -3 -4 -4 -2 -2 -3 -2 -2 -3 -2 -3 -1 1 -4 -3 -2 11 2 -3 -4 -3 -2 -4
Y -2 -2 -2 -3 -2 -1 -2 -3 2 -1 -1 -2 -1 3 -3 -2 -2 2 7 -1 -3 -2 -1 -4
V 0 -3 -3 -3 -1 -2 -2 -3 -3 3 1 -2 1 -1 -2 -2 0 -3 -1 4 -3 -2 -1 -4
B -2 -1 3 4 -3 0 1 -1 0 -3 -4 0 -3 -3 -2 0 -1 -4 -3 -3 4 1 -1 -4
Z -1 0 0 1 -3 3 4 -2 0 -3 -3 1 -1 -3 -1 0 -1 -3 -2 -2 1 4 -1 -4
X 0 -1 -1 -1 -2 -1 -1 -1 -1 -1 -1 -1 -1 -1 -2 0 0 -2 -1 -1 -1 -1 -1 -4
* -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 -4 1
"""
),
index_col=0,
header=0,
sep="\s+",
),
expected_value=-0.5209,
)


DEFAULT_TOKENS = "-ACDEFGHIKLMNPQRSTVWY"
DEFAULT_AA_TO_INT = dict(zip(DEFAULT_TOKENS, range(len(DEFAULT_TOKENS))))


def get_blosum62_data(
aa_to_int: Optional[dict[str, int]] = None,
gaps_as_stars: bool = False,
) -> TokenizedSubstitutionMatrix:
aa_to_int = (DEFAULT_AA_TO_INT if aa_to_int is None else aa_to_int).copy()

mat = BLOSUM62.mat.copy()
if gaps_as_stars:
if "-" in aa_to_int and "*" in aa_to_int:
raise ValueError(
"Cannot have both gaps and stars in `aa_to_int` if `gaps_as_stars` is True"
)
aa_to_int["*"] = aa_to_int.pop("-")
else:
mat.loc["-"] = 0
mat.loc[:, "-"] = 0
aa_to_int = dict(sorted(aa_to_int.items(), key=lambda x: x[1]))

mat = mat.loc[list(aa_to_int), list(aa_to_int)]

return TokenizedSubstitutionMatrix(
name=BLOSUM62.name,
mat=torch.tensor(mat.to_numpy(), dtype=torch.float32),
expected_value=BLOSUM62.expected_value,
)
8 changes: 0 additions & 8 deletions diffpass/core.py

This file was deleted.

Loading

0 comments on commit e99c2f5

Please sign in to comment.