Skip to content

Commit

Permalink
Merge pull request #104 from raphaelreinauer/master
Browse files Browse the repository at this point in the history
Add detailed docstring to the OneHotEncodedPersistenceDiagram class Fixes #98
  • Loading branch information
matteocao authored Sep 28, 2022
2 parents 480c2da + c27a4ca commit a0608c9
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 16 deletions.
16 changes: 11 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,16 @@ The first step to install the developer version of the package is to `git clone`
git clone https://github.com/giotto-ai/giotto-deep.git
```
The change the current working directory to the Repository root folder, e.g. `cd giotto-deep`.
It is best practice to create a virtual environment for the project, e.g. using `virtualenv`:
```
virtualenv -p python3.9 venv
```
Activate the virtual environment (e.g. `source venv/bin/activate` on Linux or `venv\Scripts\activate` on Windows).

First make sure you have upgraded to the last version of `pip` with
```
python -m pip install --upgrade pip
```
Make sure you have the latest version of pytorch installed.
You can do this by running the following command (if you have a GPU):
```
Expand All @@ -54,10 +64,6 @@ Once you are in the root folder, install the package dynamically with:
```
pip install -e .
```
Make sure you have upgraded to the last version of `pip` with
```
python -m pip install --upgrade pip
```


## Contributing
Expand Down Expand Up @@ -87,4 +93,4 @@ I order to run your analysis on TPU cores, you ca use the following lines:
```
Once you have run the lines above, please make sure to restart the runtime.

The code will automatically detect the TPU core and use it as deffault to run the experiments. GPUs are also automatically supported.
The code will automatically detect the TPU core and use it as default to run the experiments. GPUs are also automatically supported.
Original file line number Diff line number Diff line change
Expand Up @@ -57,5 +57,5 @@ def test_mutag_load_and_save():
os.path.join(file_path, f"{graph_idx}.npy")
)
assert pd_gudhi.all_close(
pd, atol=1e-6
pd, atol=1e-2
), "Generated persistence diagram is not equal to the one from GUDHI"
35 changes: 26 additions & 9 deletions gdeep/data/persistence_diagrams/one_hot_persistence_diagram.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,32 @@ class OneHotEncodedPersistenceDiagram:
Args:
data:
The data of the persistence diagram.
The data of the persistence diagram. The data must be a tensor of shape
(num_points, 2 + num_homology_dimensions) and the last dimension must be
the concatenation of the birth-death-coordinates and the one-hot encoded homology
dimension.
The invariants of the persistence diagram are checked in the constructor.
homology_dimension_names:
The names of the homology dimensions. If None, the names are set to H_0, H_1, ...
Example::
pd = torch.tensor\
([[0.0928, 0.0995, 0.0000, 0.0000, 1.0000, 0.0000],
[0.0916, 0.1025, 1.0000, 0.0000, 0.0000, 0.0000],
[0.0978, 0.1147, 1.0000, 0.0000, 0.0000, 0.0000],
[0.0978, 0.1147, 0.0000, 0.0000, 1.0000, 0.0000],
[0.0916, 0.1162, 0.0000, 0.0000, 0.0000, 1.0000],
[0.0740, 0.0995, 1.0000, 0.0000, 0.0000, 0.0000],
[0.0728, 0.0995, 1.0000, 0.0000, 0.0000, 0.0000],
[0.0740, 0.1162, 0.0000, 0.0000, 0.0000, 1.0000],
[0.0728, 0.1162, 0.0000, 0.0000, 1.0000, 0.0000],
[0.0719, 0.1343, 0.0000, 0.0000, 0.0000, 1.0000],
[0.0830, 0.2194, 1.0000, 0.0000, 0.0000, 0.0000],
[0.0830, 0.2194, 1.0000, 0.0000, 0.0000, 0.0000],
[0.0719, 0.2194, 0.0000, 1.0000, 0.0000, 0.0000]])
names = ["Ord0", "Ext0", "Rel1", "Ext1"]
pd = OneHotEncodedPersistenceDiagram(pd, names)
"""

_data: Tensor
Expand Down Expand Up @@ -235,14 +260,6 @@ def _check_if_valid(data) -> None:
), "The homology dimension should be one-hot encoded."


# def _sort_by_lifetime(data: Tensor) -> Tensor:
# """This function sorts the points by their lifetime.
# """
# return data[(
# data[:, 1] - data[:, 0]
# ).argsort()]


def get_one_hot_encoded_persistence_diagram_from_gtda(
persistence_diagram: Array,
) -> OneHotEncodedPersistenceDiagram:
Expand Down
2 changes: 1 addition & 1 deletion gdeep/search/hpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import plotly.express as px
from optuna.pruners import MedianPruner, BasePruner
from optuna.trial._base import BaseTrial # noqa
from optuna.study import BaseStudy as Study
from optuna.study import Study
from torch.utils.data import DataLoader
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler # noqa
Expand Down

0 comments on commit a0608c9

Please sign in to comment.