Skip to content

Commit

Permalink
Merge pull request #93 from giotto-ai/persistence_diagrams_workflow
Browse files Browse the repository at this point in the history
Persistence diagrams workflow
  • Loading branch information
matteocao authored Jun 16, 2022
2 parents affbc59 + c71b041 commit 7ffe3fd
Show file tree
Hide file tree
Showing 59 changed files with 2,706 additions and 1,215 deletions.
1 change: 1 addition & 0 deletions .github/workflows/deploy-gh-pages.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install torch
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
- name: Install sphinx
run: |
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/python-package-windows.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install flake8 pytest nbconvert jupyter
pip install flake8 pytest nbconvert jupyter torch
if(Test-Path "requirements.txt")
{
pip install -r requirements.txt
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install flake8 pytest nbconvert jupyter mypy
pip install flake8 pytest nbconvert jupyter mypy torch
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
pip install -e .
- name: Lint with flake8
Expand Down
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ examples/diagrams_5000_1000_0_1.npy
# virtualenv
giotto-deep/
venv/
myenv/

examples/model_data_specifications
runs/
Expand All @@ -91,3 +92,5 @@ state_dicts/*

# pylance files
*.pyi

giotto-env/
7 changes: 6 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,12 @@ The first step to install the developer version of the package is to `git clone`
```
git clone https://github.com/giotto-ai/giotto-deep.git
```
The change the current working directory to the Repository root folder, e.g. `cd giotto-deep`.
The change the current working directory to the Repository root folder, e.g. `cd giotto-deep`.
Make sure you have the latest version of pytorch installed.
You can do this by running the following command (if you have a GPU):
```
pip install torch --extra-index-url https://download.pytorch.org/whl/cu113
```
Once you are in the root folder, install the package dynamically with:
```
pip install -e .
Expand Down
123 changes: 0 additions & 123 deletions examples/basic_tutorial_persformer_benchmark.ipynb

This file was deleted.

38 changes: 0 additions & 38 deletions examples/hpo_space/Mutag_hyperparameter_space.json

This file was deleted.

1 change: 1 addition & 0 deletions examples/mutag_pipeline.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"cells":[{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["import os\n","from typing import Tuple\n","\n","import torch.nn as nn\n","from gdeep.data import PreprocessingPipeline\n","from gdeep.data.datasets import PersistenceDiagramFromFiles\n","from gdeep.data.datasets.base_dataloaders import (DataLoaderBuilder,\n"," DataLoaderParamsTuples)\n","from gdeep.data.datasets.persistence_diagrams_from_graphs_builder import \\\n"," PersistenceDiagramFromGraphBuilder\n","from gdeep.data.persistence_diagrams.one_hot_persistence_diagram import (\n"," OneHotEncodedPersistenceDiagram, collate_fn_persistence_diagrams)\n","from gdeep.data.preprocessors import (\n"," FilterPersistenceDiagramByHomologyDimension,\n"," FilterPersistenceDiagramByLifetime, NormalizationPersistenceDiagram)\n","from gdeep.search.hpo import GiottoSummaryWriter\n","from gdeep.topology_layers import Persformer, PersformerConfig, PersformerWrapper\n","from gdeep.topology_layers.persformer_config import PoolerType\n","from gdeep.trainer.trainer import Trainer\n","from gdeep.search import HyperParameterOptimization\n","from gdeep.utility import DEFAULT_GRAPH_DIR, PoolerType\n","from gdeep.utility.utils import autoreload_if_notebook\n","from sklearn.model_selection import train_test_split\n","from torch.optim import Adam\n","from torch.utils.data import Subset\n","from torch.utils.tensorboard.writer import SummaryWriter\n","\n","autoreload_if_notebook()\n","\n","# Parameters\n","name_graph_dataset: str = 'MUTAG'\n","diffusion_parameter: float = 0.1\n","num_homology_types: int = 4\n","\n","\n","# Create the persistence diagram dataset\n","pd_creator = PersistenceDiagramFromGraphBuilder(name_graph_dataset, diffusion_parameter)\n","pd_creator.create()"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["# Plot sample extended persistence diagram\n","file_path: str = os.path.join(DEFAULT_GRAPH_DIR,\n"," f\"MUTAG_{diffusion_parameter}_extended_persistence\", \"diagrams\")\n","graph_idx = 1\n","pd: OneHotEncodedPersistenceDiagram = \\\n"," OneHotEncodedPersistenceDiagram.load(os.path.join(file_path, \n"," f\"{graph_idx}.npy\"))\n","pd.set_homology_dimension_names([\"Ord0\", \"Ext0\", \"Rel1\", \"Ext1\"])\n","pd.plot()"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["\n","pd_mutag_ds = PersistenceDiagramFromFiles(\n"," os.path.join(\n"," DEFAULT_GRAPH_DIR, f\"MUTAG_{diffusion_parameter}_extended_persistence\"\n"," )\n",")\n","\n","pd_sample: OneHotEncodedPersistenceDiagram = pd_mutag_ds[0][0]\n","\n","fig = pd_sample.plot([\"Ord0\", \"Ext0\", \"Rel1\", \"Ext1\"])\n","# add title\n","fig.show()"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["\n","# Create the train/validation/test split\n","\n","train_indices, test_indices = train_test_split(\n"," range(len(pd_mutag_ds)),\n"," test_size=0.2,\n"," random_state=42,\n",")\n","\n","train_indices , validation_indices = train_test_split(\n"," train_indices,\n"," test_size=0.2,\n"," random_state=42,\n",")\n","\n","# Create the data loaders\n","train_dataset = Subset(pd_mutag_ds, train_indices)\n","validation_dataset = Subset(pd_mutag_ds, validation_indices)\n","test_dataset = Subset(pd_mutag_ds, test_indices)\n","\n","# Preprocess the data\n","preprocessing_pipeline = PreprocessingPipeline[Tuple[OneHotEncodedPersistenceDiagram, int]](\n"," (\n"," FilterPersistenceDiagramByHomologyDimension[int]([0, 1]),\n"," FilterPersistenceDiagramByLifetime[int](min_lifetime=-0.1, max_lifetime=1.0),\n"," NormalizationPersistenceDiagram[int](num_homology_dimensions=4),\n"," )\n",")\n","\n","preprocessing_pipeline.fit_to_dataset(train_dataset)\n",""]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["train_dataset = preprocessing_pipeline.attach_transform_to_dataset(train_dataset) # type: ignore\n","validation_dataset = preprocessing_pipeline.attach_transform_to_dataset(validation_dataset) # type: ignore\n","test_dataset = preprocessing_pipeline.attach_transform_to_dataset(test_dataset) # type: ignore\n",""]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["\n","dl_params = DataLoaderParamsTuples.default(\n"," batch_size=32,\n"," num_workers=0,\n"," collate_fn=collate_fn_persistence_diagrams,\n"," with_validation=True,\n",")\n","\n","\n","# Build the data loaders\n","dlb = DataLoaderBuilder((train_dataset, validation_dataset, test_dataset)) # type: ignore\n","dl_train, dl_val, dl_test = dlb.build(dl_params) # type: ignore"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["\n","# # Define the model\n","# model_config = PersformerConfig(\n","# num_layers=6,\n","# num_attention_heads=4,\n","# input_size= 2 + num_homology_types,\n","# ouptut_size=2,\n","# pooler_type=PoolerType.ATTENTION,\n","# )\n","\n","# model = Persformer(model_config)\n","# writer = SummaryWriter()\n","\n","# loss_function = nn.CrossEntropyLoss()\n","\n","# trainer = Trainer(model, [dl_train, dl_val, dl_test], loss_function, writer)\n","\n","# trainer.train(Adam, 3, False, \n","# {\"lr\":0.01}, \n","# {\"batch_size\":16, \"collate_fn\": collate_fn_persistence_diagrams})"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["\n","# Define the model by using a Wrapper for the Persformer model\n","wrapped_model = PersformerWrapper(\n"," num_attention_layers=3,\n"," num_attention_heads=4,\n"," input_size= 2 + num_homology_types,\n"," ouptut_size=2,\n"," pooler_type=PoolerType.ATTENTION,\n",")\n","writer = GiottoSummaryWriter()\n","\n","loss_function = nn.CrossEntropyLoss()\n","\n","trainer = Trainer(wrapped_model, [dl_train, dl_val, dl_test], loss_function, writer)\n","\n","# initialise hpo object\n","search = HyperParameterOptimization(trainer, \"accuracy\", 2, best_not_last=True)\n","\n","# if you want to store pickle files of the models instead of the state_dicts\n","search.store_pickle = True\n","\n","# dictionaries of hyperparameters\n","optimizers_params = {\"lr\": [0.001, 0.01]}\n","dataloaders_params = {\"batch_size\": [32, 64, 16], \n"," \"collate_fn\": [collate_fn_persistence_diagrams]}\n","models_hyperparams = {\n"," \"num_attention_layers\": [2, 6, 1],\n"," \"num_attention_heads\": [8, 16, 8],\n","}\n","\n","# starting the HPO\n","search.start(\n"," [Adam],\n"," 3,\n"," False,\n"," optimizers_params,\n"," dataloaders_params,\n"," models_hyperparams,\n"," n_accumulated_grads=2,\n",")"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":[]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":[""]}],"nbformat":4,"nbformat_minor":2,"metadata":{"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":3},"orig_nbformat":4}}
Loading

0 comments on commit 7ffe3fd

Please sign in to comment.