Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Persistence diagrams workflow #93

Merged
merged 68 commits into from
Jun 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
68 commits
Select commit Hold shift + click to select a range
4dffe32
added modules from different branches
raphaelreinauer May 23, 2022
82968a4
refactored construction of persistence diagrams
raphaelreinauer May 23, 2022
ba53b0d
refactored mutag pipeline
raphaelreinauer May 24, 2022
c73bb99
added functionality to OneHotEncodedPersistenceDiagram
raphaelreinauer May 24, 2022
6ab4b31
added test for OneHotEncodedPersistenceDiagram
raphaelreinauer May 24, 2022
c223410
fixed OneHotEncodedPersistenceDiagram.from_numpy
raphaelreinauer May 24, 2022
93d5dda
added plot functionality to OneHotEncodedPersistenceDiagram
raphaelreinauer May 24, 2022
dda8155
added plotting histogram
raphaelreinauer May 25, 2022
8bb1a88
fixed normalizing pd
raphaelreinauer May 25, 2022
660c0cc
made data a member variable of OneHotEncodedPersistenceDiagram
raphaelreinauer May 25, 2022
7e989db
added kwargs dataloader
raphaelreinauer May 27, 2022
439eb1e
refactored persistence diagram preprocessing components
raphaelreinauer Jun 1, 2022
8fa0a13
added various pooling layers
raphaelreinauer Jun 1, 2022
9c7bb39
refactored dataloader builder
raphaelreinauer Jun 2, 2022
065d36b
extended DataLoaderParams dataclass
raphaelreinauer Jun 6, 2022
c4d4deb
changed model hyperparams
raphaelreinauer Jun 6, 2022
1d08ec0
Merge branch 'master' of https://github.com/giotto-ai/giotto-deep
raphaelreinauer Jun 6, 2022
cea2209
Merge branch 'master' of https://github.com/giotto-ai/giotto-deep
raphaelreinauer Jun 6, 2022
06c31f9
changed public api persformer
raphaelreinauer Jun 14, 2022
4d6792d
Merge branch 'master' into persistence_diagrams_workflow
raphaelreinauer Jun 14, 2022
fe98955
fixed format collate_fn
raphaelreinauer Jun 14, 2022
9cb037c
Merge branch 'master' of https://github.com/giotto-ai/giotto-deep
raphaelreinauer Jun 14, 2022
fb4b9e3
Merge branch 'master' into persistence_diagrams_workflow
raphaelreinauer Jun 14, 2022
822deb3
fixed masked error in attentions
raphaelreinauer Jun 14, 2022
d97d930
removed print statements
raphaelreinauer Jun 14, 2022
5101e52
Merge branch 'master' of https://github.com/giotto-ai/giotto-deep
raphaelreinauer Jun 14, 2022
99f633c
Merge branch 'master' into persistence_diagrams_workflow
raphaelreinauer Jun 14, 2022
3b26054
added test for persformer
raphaelreinauer Jun 15, 2022
648f033
added notbook for mutag_pipeline
raphaelreinauer Jun 15, 2022
2700e49
removed graph_dataloaders.py
raphaelreinauer Jun 15, 2022
412320c
removed graph_dataloaders.py
raphaelreinauer Jun 15, 2022
320f45f
added torch_geometric to requirements
raphaelreinauer Jun 15, 2022
9f8fbc4
added torch_sparse
raphaelreinauer Jun 15, 2022
774adee
added dependencies of torch-geomtric
raphaelreinauer Jun 15, 2022
06845a6
torch version error
raphaelreinauer Jun 15, 2022
e6874d1
updated to newest version
raphaelreinauer Jun 15, 2022
ebfa1b5
put url at beginning
raphaelreinauer Jun 15, 2022
b1204b4
without -f
raphaelreinauer Jun 15, 2022
e352ecc
removed test data
raphaelreinauer Jun 15, 2022
12135b4
new try
raphaelreinauer Jun 15, 2022
51f719a
new
raphaelreinauer Jun 15, 2022
dfcea7a
new
raphaelreinauer Jun 15, 2022
3e00e41
ew
raphaelreinauer Jun 15, 2022
7ee4d7d
--find-links
raphaelreinauer Jun 15, 2022
e562326
new try
raphaelreinauer Jun 15, 2022
4909f6a
removed modules
raphaelreinauer Jun 15, 2022
88b6351
small changes
raphaelreinauer Jun 15, 2022
3d36baf
changed order
raphaelreinauer Jun 15, 2022
642a2cc
changes suggested by Matteo
raphaelreinauer Jun 15, 2022
2b71460
added torch-sparse
raphaelreinauer Jun 15, 2022
af99367
added torch-scatter
raphaelreinauer Jun 15, 2022
95d1529
added gudhi
raphaelreinauer Jun 15, 2022
156ef74
added hpo initialization from model builder
raphaelreinauer Jun 15, 2022
48b187a
added persformer_wrapper
raphaelreinauer Jun 15, 2022
669973f
added additional parameters to the static from_model_builder method
raphaelreinauer Jun 15, 2022
6ff5095
added mutag hpo
raphaelreinauer Jun 16, 2022
2de76b1
fixed change in persformer_config
raphaelreinauer Jun 16, 2022
f11f65f
commented out old api in orbit_5k_train.py
raphaelreinauer Jun 16, 2022
19fd6eb
removed all persformer benchmark notebook
raphaelreinauer Jun 16, 2022
f3db492
updated jupyter notebooks
raphaelreinauer Jun 16, 2022
8aa688a
fixed issue with the number of heads
raphaelreinauer Jun 16, 2022
0275a29
fixed error in persformer wrapper
raphaelreinauer Jun 16, 2022
26fb45a
added
raphaelreinauer Jun 16, 2022
a0602d1
fixed error with a too big batch size
raphaelreinauer Jun 16, 2022
d292c37
fixed typo in PersformerConfig
raphaelreinauer Jun 16, 2022
8862400
updated orbit_5k_pipeline notebook
raphaelreinauer Jun 16, 2022
01ac322
added changes requested by Matteo
raphaelreinauer Jun 16, 2022
c71b041
added new line
raphaelreinauer Jun 16, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/deploy-gh-pages.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install torch
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
- name: Install sphinx
run: |
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/python-package-windows.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install flake8 pytest nbconvert jupyter
pip install flake8 pytest nbconvert jupyter torch
if(Test-Path "requirements.txt")
{
pip install -r requirements.txt
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install flake8 pytest nbconvert jupyter mypy
pip install flake8 pytest nbconvert jupyter mypy torch
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
pip install -e .
- name: Lint with flake8
Expand Down
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ examples/diagrams_5000_1000_0_1.npy
# virtualenv
giotto-deep/
venv/
myenv/

examples/model_data_specifications
runs/
Expand All @@ -91,3 +92,5 @@ state_dicts/*

# pylance files
*.pyi

giotto-env/
7 changes: 6 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,12 @@ The first step to install the developer version of the package is to `git clone`
```
git clone https://github.com/giotto-ai/giotto-deep.git
```
The change the current working directory to the Repository root folder, e.g. `cd giotto-deep`.
The change the current working directory to the Repository root folder, e.g. `cd giotto-deep`.
Make sure you have the latest version of pytorch installed.
You can do this by running the following command (if you have a GPU):
```
pip install torch --extra-index-url https://download.pytorch.org/whl/cu113
```
Once you are in the root folder, install the package dynamically with:
```
pip install -e .
Expand Down
123 changes: 0 additions & 123 deletions examples/basic_tutorial_persformer_benchmark.ipynb

This file was deleted.

38 changes: 0 additions & 38 deletions examples/hpo_space/Mutag_hyperparameter_space.json

This file was deleted.

1 change: 1 addition & 0 deletions examples/mutag_pipeline.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"cells":[{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["import os\n","from typing import Tuple\n","\n","import torch.nn as nn\n","from gdeep.data import PreprocessingPipeline\n","from gdeep.data.datasets import PersistenceDiagramFromFiles\n","from gdeep.data.datasets.base_dataloaders import (DataLoaderBuilder,\n"," DataLoaderParamsTuples)\n","from gdeep.data.datasets.persistence_diagrams_from_graphs_builder import \\\n"," PersistenceDiagramFromGraphBuilder\n","from gdeep.data.persistence_diagrams.one_hot_persistence_diagram import (\n"," OneHotEncodedPersistenceDiagram, collate_fn_persistence_diagrams)\n","from gdeep.data.preprocessors import (\n"," FilterPersistenceDiagramByHomologyDimension,\n"," FilterPersistenceDiagramByLifetime, NormalizationPersistenceDiagram)\n","from gdeep.search.hpo import GiottoSummaryWriter\n","from gdeep.topology_layers import Persformer, PersformerConfig, PersformerWrapper\n","from gdeep.topology_layers.persformer_config import PoolerType\n","from gdeep.trainer.trainer import Trainer\n","from gdeep.search import HyperParameterOptimization\n","from gdeep.utility import DEFAULT_GRAPH_DIR, PoolerType\n","from gdeep.utility.utils import autoreload_if_notebook\n","from sklearn.model_selection import train_test_split\n","from torch.optim import Adam\n","from torch.utils.data import Subset\n","from torch.utils.tensorboard.writer import SummaryWriter\n","\n","autoreload_if_notebook()\n","\n","# Parameters\n","name_graph_dataset: str = 'MUTAG'\n","diffusion_parameter: float = 0.1\n","num_homology_types: int = 4\n","\n","\n","# Create the persistence diagram dataset\n","pd_creator = PersistenceDiagramFromGraphBuilder(name_graph_dataset, diffusion_parameter)\n","pd_creator.create()"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["# Plot sample extended persistence diagram\n","file_path: str = os.path.join(DEFAULT_GRAPH_DIR,\n"," f\"MUTAG_{diffusion_parameter}_extended_persistence\", \"diagrams\")\n","graph_idx = 1\n","pd: OneHotEncodedPersistenceDiagram = \\\n"," OneHotEncodedPersistenceDiagram.load(os.path.join(file_path, \n"," f\"{graph_idx}.npy\"))\n","pd.set_homology_dimension_names([\"Ord0\", \"Ext0\", \"Rel1\", \"Ext1\"])\n","pd.plot()"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["\n","pd_mutag_ds = PersistenceDiagramFromFiles(\n"," os.path.join(\n"," DEFAULT_GRAPH_DIR, f\"MUTAG_{diffusion_parameter}_extended_persistence\"\n"," )\n",")\n","\n","pd_sample: OneHotEncodedPersistenceDiagram = pd_mutag_ds[0][0]\n","\n","fig = pd_sample.plot([\"Ord0\", \"Ext0\", \"Rel1\", \"Ext1\"])\n","# add title\n","fig.show()"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["\n","# Create the train/validation/test split\n","\n","train_indices, test_indices = train_test_split(\n"," range(len(pd_mutag_ds)),\n"," test_size=0.2,\n"," random_state=42,\n",")\n","\n","train_indices , validation_indices = train_test_split(\n"," train_indices,\n"," test_size=0.2,\n"," random_state=42,\n",")\n","\n","# Create the data loaders\n","train_dataset = Subset(pd_mutag_ds, train_indices)\n","validation_dataset = Subset(pd_mutag_ds, validation_indices)\n","test_dataset = Subset(pd_mutag_ds, test_indices)\n","\n","# Preprocess the data\n","preprocessing_pipeline = PreprocessingPipeline[Tuple[OneHotEncodedPersistenceDiagram, int]](\n"," (\n"," FilterPersistenceDiagramByHomologyDimension[int]([0, 1]),\n"," FilterPersistenceDiagramByLifetime[int](min_lifetime=-0.1, max_lifetime=1.0),\n"," NormalizationPersistenceDiagram[int](num_homology_dimensions=4),\n"," )\n",")\n","\n","preprocessing_pipeline.fit_to_dataset(train_dataset)\n",""]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["train_dataset = preprocessing_pipeline.attach_transform_to_dataset(train_dataset) # type: ignore\n","validation_dataset = preprocessing_pipeline.attach_transform_to_dataset(validation_dataset) # type: ignore\n","test_dataset = preprocessing_pipeline.attach_transform_to_dataset(test_dataset) # type: ignore\n",""]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["\n","dl_params = DataLoaderParamsTuples.default(\n"," batch_size=32,\n"," num_workers=0,\n"," collate_fn=collate_fn_persistence_diagrams,\n"," with_validation=True,\n",")\n","\n","\n","# Build the data loaders\n","dlb = DataLoaderBuilder((train_dataset, validation_dataset, test_dataset)) # type: ignore\n","dl_train, dl_val, dl_test = dlb.build(dl_params) # type: ignore"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["\n","# # Define the model\n","# model_config = PersformerConfig(\n","# num_layers=6,\n","# num_attention_heads=4,\n","# input_size= 2 + num_homology_types,\n","# ouptut_size=2,\n","# pooler_type=PoolerType.ATTENTION,\n","# )\n","\n","# model = Persformer(model_config)\n","# writer = SummaryWriter()\n","\n","# loss_function = nn.CrossEntropyLoss()\n","\n","# trainer = Trainer(model, [dl_train, dl_val, dl_test], loss_function, writer)\n","\n","# trainer.train(Adam, 3, False, \n","# {\"lr\":0.01}, \n","# {\"batch_size\":16, \"collate_fn\": collate_fn_persistence_diagrams})"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["\n","# Define the model by using a Wrapper for the Persformer model\n","wrapped_model = PersformerWrapper(\n"," num_attention_layers=3,\n"," num_attention_heads=4,\n"," input_size= 2 + num_homology_types,\n"," ouptut_size=2,\n"," pooler_type=PoolerType.ATTENTION,\n",")\n","writer = GiottoSummaryWriter()\n","\n","loss_function = nn.CrossEntropyLoss()\n","\n","trainer = Trainer(wrapped_model, [dl_train, dl_val, dl_test], loss_function, writer)\n","\n","# initialise hpo object\n","search = HyperParameterOptimization(trainer, \"accuracy\", 2, best_not_last=True)\n","\n","# if you want to store pickle files of the models instead of the state_dicts\n","search.store_pickle = True\n","\n","# dictionaries of hyperparameters\n","optimizers_params = {\"lr\": [0.001, 0.01]}\n","dataloaders_params = {\"batch_size\": [32, 64, 16], \n"," \"collate_fn\": [collate_fn_persistence_diagrams]}\n","models_hyperparams = {\n"," \"num_attention_layers\": [2, 6, 1],\n"," \"num_attention_heads\": [8, 16, 8],\n","}\n","\n","# starting the HPO\n","search.start(\n"," [Adam],\n"," 3,\n"," False,\n"," optimizers_params,\n"," dataloaders_params,\n"," models_hyperparams,\n"," n_accumulated_grads=2,\n",")"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":[]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":[""]}],"nbformat":4,"nbformat_minor":2,"metadata":{"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":3},"orig_nbformat":4}}
Loading