Skip to content

Commit

Permalink
[DOCS] Run nbdev_proc_nbs
Browse files Browse the repository at this point in the history
  • Loading branch information
ulupo committed May 14, 2024
1 parent 835a570 commit 3a7d7d2
Show file tree
Hide file tree
Showing 5 changed files with 98 additions and 87 deletions.
23 changes: 11 additions & 12 deletions nbs/base.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -722,22 +722,21 @@
"outputs": [
{
"data": {
"text/markdown": "---\n\n[source](https://github.com/Bitbol-Lab/DiffPaSS/blob/main/diffpass/base.py#L57){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n\n### DiffPaSSResults\n\n> DiffPaSSResults (log_alphas:Union[list[list[numpy.ndarray]],list[list[lis\n> t[numpy.ndarray]]],NoneType], soft_perms:Union[list[list\n> [numpy.ndarray]],list[list[list[numpy.ndarray]]],NoneTyp\n> e], hard_perms:Union[list[list[numpy.ndarray]],list[list\n> [list[numpy.ndarray]]]], hard_losses:Union[list[list[num\n> py.ndarray]],list[list[list[numpy.ndarray]]]], soft_loss\n> es:Union[list[list[numpy.ndarray]],list[list[list[numpy.\n> ndarray]]],NoneType])\n\n*Container for results of DiffPaSS fits.*",
"text/markdown": "---\n\n[source](https://github.com/Bitbol-Lab/DiffPaSS/blob/main/diffpass/base.py#L55){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n\n### DiffPaSSResults\n\n> DiffPaSSResults (log_alphas:Union[list[list[numpy.ndarray]],list[list[lis\n> t[numpy.ndarray]]],NoneType], soft_perms:Union[list[list\n> [numpy.ndarray]],list[list[list[numpy.ndarray]]],NoneTyp\n> e], hard_perms:Union[list[list[numpy.ndarray]],list[list\n> [list[numpy.ndarray]]]], hard_losses:Union[list[list[flo\n> at]],list[list[list[float]]]], soft_losses:Union[list[li\n> st[float]],list[list[list[float]]],NoneType])\n\n*Container for results of DiffPaSS fits.*",
"text/plain": [
"---\n",
"\n",
"[source](https://github.com/Bitbol-Lab/DiffPaSS/blob/main/diffpass/base.py#L57){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"[source](https://github.com/Bitbol-Lab/DiffPaSS/blob/main/diffpass/base.py#L55){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### DiffPaSSResults\n",
"\n",
"> DiffPaSSResults (log_alphas:Union[list[list[numpy.ndarray]],list[list[lis\n",
"> t[numpy.ndarray]]],NoneType], soft_perms:Union[list[list\n",
"> [numpy.ndarray]],list[list[list[numpy.ndarray]]],NoneTyp\n",
"> e], hard_perms:Union[list[list[numpy.ndarray]],list[list\n",
"> [list[numpy.ndarray]]]], hard_losses:Union[list[list[num\n",
"> py.ndarray]],list[list[list[numpy.ndarray]]]], soft_loss\n",
"> es:Union[list[list[numpy.ndarray]],list[list[list[numpy.\n",
"> ndarray]]],NoneType])\n",
"> [list[numpy.ndarray]]]], hard_losses:Union[list[list[flo\n",
"> at]],list[list[list[float]]]], soft_losses:Union[list[li\n",
"> st[float]],list[list[list[float]]],NoneType])\n",
"\n",
"*Container for results of DiffPaSS fits.*"
]
Expand Down Expand Up @@ -787,11 +786,11 @@
"outputs": [
{
"data": {
"text/markdown": "---\n\n[source](https://github.com/Bitbol-Lab/DiffPaSS/blob/main/diffpass/base.py#L448){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n\n### DiffPaSSModel.fit\n\n> DiffPaSSModel.fit (x:torch.Tensor, y:torch.Tensor, epochs:int=1,\n> optimizer_name:Optional[str]='SGD',\n> optimizer_kwargs:Optional[dict[str,Any]]=None,\n> mean_centering:bool=False, show_pbar:bool=False,\n> compute_final_soft:bool=False,\n> record_log_alphas:bool=False,\n> record_soft_perms:bool=False,\n> record_soft_losses:bool=False)\n\n*Fit permutations to data using gradient descent.*\n\n| | **Type** | **Default** | **Details** |\n| -- | -------- | ----------- | ----------- |\n| x | Tensor | | The object (MSA or adjacency matrix of graphs) to be permuted |\n| y | Tensor | | The target object (MSA or adjacency matrix of graphs), that the objects represented by `x` should be paired with. Not acted upon by soft/hard permutations |\n| epochs | int | 1 | |\n| optimizer_name | Optional | SGD | |\n| optimizer_kwargs | Optional | None | |\n| mean_centering | bool | False | |\n| show_pbar | bool | False | |\n| compute_final_soft | bool | False | |\n| record_log_alphas | bool | False | |\n| record_soft_perms | bool | False | |\n| record_soft_losses | bool | False | |\n| **Returns** | **DiffPaSSResults** | | **`DiffPaSSResults` container for fit results. All attributes are lists indexed by gradient descent iteration** |",
"text/markdown": "---\n\n[source](https://github.com/Bitbol-Lab/DiffPaSS/blob/main/diffpass/base.py#L445){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n\n### DiffPaSSModel.fit\n\n> DiffPaSSModel.fit (x:torch.Tensor, y:torch.Tensor, epochs:int=1,\n> optimizer_name:Optional[str]='SGD',\n> optimizer_kwargs:Optional[dict[str,Any]]=None,\n> mean_centering:bool=False, show_pbar:bool=False,\n> compute_final_soft:bool=False,\n> record_log_alphas:bool=False,\n> record_soft_perms:bool=False,\n> record_soft_losses:bool=False)\n\n*Fit permutations to data using gradient descent.*\n\n| | **Type** | **Default** | **Details** |\n| -- | -------- | ----------- | ----------- |\n| x | Tensor | | The object (MSA or adjacency matrix of graphs) to be permuted |\n| y | Tensor | | The target object (MSA or adjacency matrix of graphs), that the objects represented by `x` should be paired with. Not acted upon by soft/hard permutations |\n| epochs | int | 1 | |\n| optimizer_name | Optional | SGD | |\n| optimizer_kwargs | Optional | None | |\n| mean_centering | bool | False | |\n| show_pbar | bool | False | |\n| compute_final_soft | bool | False | |\n| record_log_alphas | bool | False | |\n| record_soft_perms | bool | False | |\n| record_soft_losses | bool | False | |\n| **Returns** | **DiffPaSSResults** | | |",
"text/plain": [
"---\n",
"\n",
"[source](https://github.com/Bitbol-Lab/DiffPaSS/blob/main/diffpass/base.py#L448){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"[source](https://github.com/Bitbol-Lab/DiffPaSS/blob/main/diffpass/base.py#L445){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### DiffPaSSModel.fit\n",
"\n",
Expand Down Expand Up @@ -819,7 +818,7 @@
"| record_log_alphas | bool | False | |\n",
"| record_soft_perms | bool | False | |\n",
"| record_soft_losses | bool | False | |\n",
"| **Returns** | **DiffPaSSResults** | | **`DiffPaSSResults` container for fit results. All attributes are lists indexed by gradient descent iteration** |"
"| **Returns** | **DiffPaSSResults** | | |"
]
},
"execution_count": null,
Expand All @@ -838,11 +837,11 @@
"outputs": [
{
"data": {
"text/markdown": "---\n\n[source](https://github.com/Bitbol-Lab/DiffPaSS/blob/main/diffpass/base.py#L492){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n\n### DiffPaSSModel.fit_bootstrap\n\n> DiffPaSSModel.fit_bootstrap (x:torch.Tensor, y:torch.Tensor,\n> n_start:int=1, n_end:Optional[int]=None,\n> step_size:int=1, show_pbar:bool=True,\n> single_fit_cfg:Optional[dict]=None)\n\n*Fit permutations to data using the DiffPaSS bootstrap.\n\nThe DiffPaSS bootstrap consists of a sequence of short gradient descent runs (default: one epoch per run).\nAt the end of each run, a subset of the found pairings is chosen uniformly at random\nand fixed for the next run.\nThe number of pairings fixed at each iteration ranges between `n_start` (default: 1) and `n_end` (default: total number of pairs), with a step size of `step_size`.*\n\n| | **Type** | **Default** | **Details** |\n| -- | -------- | ----------- | ----------- |\n| x | Tensor | | The object (MSA or adjacency matrix of graphs) to be permuted |\n| y | Tensor | | The target object (MSA or adjacency matrix of graphs), that the objects represented by `x` should be paired with. Not acted upon by soft/hard permutations |\n| n_start | int | 1 | |\n| n_end | Optional | None | |\n| step_size | int | 1 | |\n| show_pbar | bool | True | |\n| single_fit_cfg | Optional | None | |\n| **Returns** | **DiffPaSSResults** | | **`DiffPaSSResults` container for fit results. All attributes are lists indexed by bootstrap iteration, containing lists indexed by gradient descent iteration as per `fit`** |",
"text/markdown": "---\n\n[source](https://github.com/Bitbol-Lab/DiffPaSS/blob/main/diffpass/base.py#L507){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n\n### DiffPaSSModel.fit_bootstrap\n\n> DiffPaSSModel.fit_bootstrap (x:torch.Tensor, y:torch.Tensor,\n> n_start:int=1, n_end:Optional[int]=None,\n> step_size:int=1, n_repeats:int=1,\n> show_pbar:bool=True,\n> single_fit_cfg:Optional[dict]=None)\n\n*Fit permutations to data using the DiffPaSS bootstrap.\n\nThe DiffPaSS bootstrap consists of a sequence of short gradient descent runs (default: one epoch per run).\nAt the end of each run, a subset of the found pairings is chosen uniformly at random\nand fixed for the next run.\nThe number of pairings fixed at each iteration ranges between `n_start` (default: 1) and `n_end` (default: total number of pairs), with a step size of `step_size`.*\n\n| | **Type** | **Default** | **Details** |\n| -- | -------- | ----------- | ----------- |\n| x | Tensor | | The object (MSA or adjacency matrix of graphs) to be permuted |\n| y | Tensor | | The target object (MSA or adjacency matrix of graphs), that the objects represented by `x` should be paired with. Not acted upon by soft/hard permutations |\n| n_start | int | 1 | |\n| n_end | Optional | None | |\n| step_size | int | 1 | |\n| n_repeats | int | 1 | |\n| show_pbar | bool | True | |\n| single_fit_cfg | Optional | None | |\n| **Returns** | **DiffPaSSResults** | | |",
"text/plain": [
"---\n",
"\n",
"[source](https://github.com/Bitbol-Lab/DiffPaSS/blob/main/diffpass/base.py#L492){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"[source](https://github.com/Bitbol-Lab/DiffPaSS/blob/main/diffpass/base.py#L507){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### DiffPaSSModel.fit_bootstrap\n",
"\n",
Expand All @@ -867,7 +866,7 @@
"| step_size | int | 1 | |\n",
"| show_pbar | bool | True | |\n",
"| single_fit_cfg | Optional | None | |\n",
"| **Returns** | **DiffPaSSResults** | | **`DiffPaSSResults` container for fit results. All attributes are lists indexed by bootstrap iteration, containing lists indexed by gradient descent iteration as per `fit`** |"
"| **Returns** | **DiffPaSSResults** | | |"
]
},
"execution_count": null,
Expand Down
24 changes: 13 additions & 11 deletions nbs/data_utils.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -272,10 +272,12 @@
"outputs": [
{
"data": {
"text/markdown": "---\n\n### remove_groups_not_in_both\n\n> remove_groups_not_in_both\n> (data_group_by_group_x:dict[str,list[tuple[str\n> ,str]]], data_group_by_group_y:dict[str,list[t\n> uple[str,str]]])\n\n*Remove groups that are not present in both input collections.*",
"text/markdown": "---\n\n[source](https://github.com/Bitbol-Lab/DiffPaSS/blob/main/diffpass/data_utils.py#L47){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n\n### remove_groups_not_in_both\n\n> remove_groups_not_in_both\n> (data_group_by_group_x:dict[str,list[tuple[str\n> ,str]]], data_group_by_group_y:dict[str,list[t\n> uple[str,str]]])\n\n*Remove groups that are not present in both input collections.*",
"text/plain": [
"---\n",
"\n",
"[source](https://github.com/Bitbol-Lab/DiffPaSS/blob/main/diffpass/data_utils.py#L47){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### remove_groups_not_in_both\n",
"\n",
"> remove_groups_not_in_both\n",
Expand Down Expand Up @@ -303,11 +305,11 @@
"outputs": [
{
"data": {
"text/markdown": "---\n\n[source](https://github.com/Bitbol-Lab/DiffPaSS/blob/main/diffpass/data_utils.py#L45){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n\n### pad_msas_with_dummy_sequences\n\n> pad_msas_with_dummy_sequences\n> (data_group_by_group_x:dict[str,list[tuple\n> [str,str]]], data_group_by_group_y:dict[st\n> r,list[tuple[str,str]]],\n> dummy_symbol:str='-')\n\n*Pad MSAs with dummy sequences so that all groups/species contain the same\nnumber of sequences.*",
"text/markdown": "---\n\n[source](https://github.com/Bitbol-Lab/DiffPaSS/blob/main/diffpass/data_utils.py#L64){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n\n### pad_msas_with_dummy_sequences\n\n> pad_msas_with_dummy_sequences\n> (data_group_by_group_x:dict[str,list[tuple\n> [str,str]]], data_group_by_group_y:dict[st\n> r,list[tuple[str,str]]],\n> dummy_symbol:str='-')\n\n*Pad MSAs with dummy sequences so that all groups/species contain the same\nnumber of sequences.*",
"text/plain": [
"---\n",
"\n",
"[source](https://github.com/Bitbol-Lab/DiffPaSS/blob/main/diffpass/data_utils.py#L45){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"[source](https://github.com/Bitbol-Lab/DiffPaSS/blob/main/diffpass/data_utils.py#L64){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### pad_msas_with_dummy_sequences\n",
"\n",
Expand Down Expand Up @@ -372,21 +374,21 @@
"outputs": [
{
"data": {
"text/markdown": "---\n\n[source](https://github.com/Bitbol-Lab/DiffPaSS/blob/main/diffpass/data_utils.py#L100){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n\n### one_hot_encode_msa\n\n> one_hot_encode_msa (seq_records:list[tuple[str,str]],\n> aa_to_int:Optional[dict[str,int]]=None,\n> device:Optional[torch.device]=None)\n\nGiven a list of records of the form (header, sequence), assumed to be a parsed MSA,\ntokenize each sequence and one-hot encode each token. Return a 3D tensor representing the\none-hot encoded MSA.",
"text/markdown": "---\n\n[source](https://github.com/Bitbol-Lab/DiffPaSS/blob/main/diffpass/data_utils.py#L163){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n\n### one_hot_encode_msa\n\n> one_hot_encode_msa (seq_records:list[tuple[str,str]],\n> aa_to_int:Optional[dict[str,int]]=None,\n> device:Optional[torch.device]=None)\n\n*Given a list of records of the form (header, sequence), assumed to be a parsed MSA,\ntokenize each sequence and one-hot encode each token. Return a 3D tensor representing the\none-hot encoded MSA.*",
"text/plain": [
"---\n",
"\n",
"[source](https://github.com/Bitbol-Lab/DiffPaSS/blob/main/diffpass/data_utils.py#L100){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"[source](https://github.com/Bitbol-Lab/DiffPaSS/blob/main/diffpass/data_utils.py#L163){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### one_hot_encode_msa\n",
"\n",
"> one_hot_encode_msa (seq_records:list[tuple[str,str]],\n",
"> aa_to_int:Optional[dict[str,int]]=None,\n",
"> device:Optional[torch.device]=None)\n",
"\n",
"Given a list of records of the form (header, sequence), assumed to be a parsed MSA,\n",
"*Given a list of records of the form (header, sequence), assumed to be a parsed MSA,\n",
"tokenize each sequence and one-hot encode each token. Return a 3D tensor representing the\n",
"one-hot encoded MSA."
"one-hot encoded MSA.*"
]
},
"execution_count": null,
Expand Down Expand Up @@ -463,11 +465,11 @@
"outputs": [
{
"data": {
"text/markdown": "---\n\n[source](https://github.com/Bitbol-Lab/DiffPaSS/blob/main/diffpass/data_utils.py#L125){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n\n### compute_num_correct_pairings\n\n> compute_num_correct_pairings (hard_perms_by_group:list[numpy.ndarray],\n> compare_to_identity_permutation:bool, singl\n> e_and_paired_seqs:Optional[dict[str,list]]=\n> None)\n\nCompute the total number of correct pairings.\n'Correct' means that they are present in the original paired MSAs, assumed to be the\nground truth.\n\nIf `compare_to_identity_permutation` is True, then the correct pairings are assumed\nto be given by the identity permutation, and the `x_seqs`, `y_seqs`, and `xy_seqs`\narguments are ignored.",
"text/markdown": "---\n\n[source](https://github.com/Bitbol-Lab/DiffPaSS/blob/main/diffpass/data_utils.py#L188){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n\n### compute_num_correct_pairings\n\n> compute_num_correct_pairings (hard_perms_by_group:list[numpy.ndarray],\n> compare_to_identity_permutation:bool, singl\n> e_and_paired_seqs:Optional[dict[str,list]]=\n> None)\n\n*Compute the total number of correct pairings.\n'Correct' means that they are present in the original paired MSAs, assumed to be the\nground truth.\n\nIf `compare_to_identity_permutation` is True, then the correct pairings are assumed\nto be given by the identity permutation, and the `x_seqs`, `y_seqs`, and `xy_seqs`\narguments are ignored.*",
"text/plain": [
"---\n",
"\n",
"[source](https://github.com/Bitbol-Lab/DiffPaSS/blob/main/diffpass/data_utils.py#L125){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"[source](https://github.com/Bitbol-Lab/DiffPaSS/blob/main/diffpass/data_utils.py#L188){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### compute_num_correct_pairings\n",
"\n",
Expand All @@ -476,13 +478,13 @@
"> e_and_paired_seqs:Optional[dict[str,list]]=\n",
"> None)\n",
"\n",
"Compute the total number of correct pairings.\n",
"*Compute the total number of correct pairings.\n",
"'Correct' means that they are present in the original paired MSAs, assumed to be the\n",
"ground truth.\n",
"\n",
"If `compare_to_identity_permutation` is True, then the correct pairings are assumed\n",
"to be given by the identity permutation, and the `x_seqs`, `y_seqs`, and `xy_seqs`\n",
"arguments are ignored."
"arguments are ignored.*"
]
},
"execution_count": null,
Expand Down
Loading

0 comments on commit 3a7d7d2

Please sign in to comment.