Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Persistence diagrams workflow #93

Merged
merged 68 commits into from
Jun 16, 2022
Merged
Show file tree
Hide file tree
Changes from 49 commits
Commits
Show all changes
68 commits
Select commit Hold shift + click to select a range
4dffe32
added modules from different branches
raphaelreinauer May 23, 2022
82968a4
refactored construction of persistence diagrams
raphaelreinauer May 23, 2022
ba53b0d
refactored mutag pipeline
raphaelreinauer May 24, 2022
c73bb99
added functionality to OneHotEncodedPersistenceDiagram
raphaelreinauer May 24, 2022
6ab4b31
added test for OneHotEncodedPersistenceDiagram
raphaelreinauer May 24, 2022
c223410
fixed OneHotEncodedPersistenceDiagram.from_numpy
raphaelreinauer May 24, 2022
93d5dda
added plot functionality to OneHotEncodedPersistenceDiagram
raphaelreinauer May 24, 2022
dda8155
added plotting histogram
raphaelreinauer May 25, 2022
8bb1a88
fixed normalizing pd
raphaelreinauer May 25, 2022
660c0cc
made data a member variable of OneHotEncodedPersistenceDiagram
raphaelreinauer May 25, 2022
7e989db
added kwargs dataloader
raphaelreinauer May 27, 2022
439eb1e
refactored persistence diagram preprocessing components
raphaelreinauer Jun 1, 2022
8fa0a13
added various pooling layers
raphaelreinauer Jun 1, 2022
9c7bb39
refactored dataloader builder
raphaelreinauer Jun 2, 2022
065d36b
extended DataLoaderParams dataclass
raphaelreinauer Jun 6, 2022
c4d4deb
changed model hyperparams
raphaelreinauer Jun 6, 2022
1d08ec0
Merge branch 'master' of https://github.com/giotto-ai/giotto-deep
raphaelreinauer Jun 6, 2022
cea2209
Merge branch 'master' of https://github.com/giotto-ai/giotto-deep
raphaelreinauer Jun 6, 2022
06c31f9
changed public api persformer
raphaelreinauer Jun 14, 2022
4d6792d
Merge branch 'master' into persistence_diagrams_workflow
raphaelreinauer Jun 14, 2022
fe98955
fixed format collate_fn
raphaelreinauer Jun 14, 2022
9cb037c
Merge branch 'master' of https://github.com/giotto-ai/giotto-deep
raphaelreinauer Jun 14, 2022
fb4b9e3
Merge branch 'master' into persistence_diagrams_workflow
raphaelreinauer Jun 14, 2022
822deb3
fixed masked error in attentions
raphaelreinauer Jun 14, 2022
d97d930
removed print statements
raphaelreinauer Jun 14, 2022
5101e52
Merge branch 'master' of https://github.com/giotto-ai/giotto-deep
raphaelreinauer Jun 14, 2022
99f633c
Merge branch 'master' into persistence_diagrams_workflow
raphaelreinauer Jun 14, 2022
3b26054
added test for persformer
raphaelreinauer Jun 15, 2022
648f033
added notbook for mutag_pipeline
raphaelreinauer Jun 15, 2022
2700e49
removed graph_dataloaders.py
raphaelreinauer Jun 15, 2022
412320c
removed graph_dataloaders.py
raphaelreinauer Jun 15, 2022
320f45f
added torch_geometric to requirements
raphaelreinauer Jun 15, 2022
9f8fbc4
added torch_sparse
raphaelreinauer Jun 15, 2022
774adee
added dependencies of torch-geomtric
raphaelreinauer Jun 15, 2022
06845a6
torch version error
raphaelreinauer Jun 15, 2022
e6874d1
updated to newest version
raphaelreinauer Jun 15, 2022
ebfa1b5
put url at beginning
raphaelreinauer Jun 15, 2022
b1204b4
without -f
raphaelreinauer Jun 15, 2022
e352ecc
removed test data
raphaelreinauer Jun 15, 2022
12135b4
new try
raphaelreinauer Jun 15, 2022
51f719a
new
raphaelreinauer Jun 15, 2022
dfcea7a
new
raphaelreinauer Jun 15, 2022
3e00e41
ew
raphaelreinauer Jun 15, 2022
7ee4d7d
--find-links
raphaelreinauer Jun 15, 2022
e562326
new try
raphaelreinauer Jun 15, 2022
4909f6a
removed modules
raphaelreinauer Jun 15, 2022
88b6351
small changes
raphaelreinauer Jun 15, 2022
3d36baf
changed order
raphaelreinauer Jun 15, 2022
642a2cc
changes suggested by Matteo
raphaelreinauer Jun 15, 2022
2b71460
added torch-sparse
raphaelreinauer Jun 15, 2022
af99367
added torch-scatter
raphaelreinauer Jun 15, 2022
95d1529
added gudhi
raphaelreinauer Jun 15, 2022
156ef74
added hpo initialization from model builder
raphaelreinauer Jun 15, 2022
48b187a
added persformer_wrapper
raphaelreinauer Jun 15, 2022
669973f
added additional parameters to the static from_model_builder method
raphaelreinauer Jun 15, 2022
6ff5095
added mutag hpo
raphaelreinauer Jun 16, 2022
2de76b1
fixed change in persformer_config
raphaelreinauer Jun 16, 2022
f11f65f
commented out old api in orbit_5k_train.py
raphaelreinauer Jun 16, 2022
19fd6eb
removed all persformer benchmark notebook
raphaelreinauer Jun 16, 2022
f3db492
updated jupyter notebooks
raphaelreinauer Jun 16, 2022
8aa688a
fixed issue with the number of heads
raphaelreinauer Jun 16, 2022
0275a29
fixed error in persformer wrapper
raphaelreinauer Jun 16, 2022
26fb45a
added
raphaelreinauer Jun 16, 2022
a0602d1
fixed error with a too big batch size
raphaelreinauer Jun 16, 2022
d292c37
fixed typo in PersformerConfig
raphaelreinauer Jun 16, 2022
8862400
updated orbit_5k_pipeline notebook
raphaelreinauer Jun 16, 2022
01ac322
added changes requested by Matteo
raphaelreinauer Jun 16, 2022
c71b041
added new line
raphaelreinauer Jun 16, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/deploy-gh-pages.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install torch
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
- name: Install sphinx
run: |
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/python-package-windows.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install flake8 pytest nbconvert jupyter
pip install flake8 pytest nbconvert jupyter torch
if(Test-Path "requirements.txt")
{
pip install -r requirements.txt
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install flake8 pytest nbconvert jupyter mypy
pip install flake8 pytest nbconvert jupyter mypy torch
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
pip install -e .
- name: Lint with flake8
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ examples/diagrams_5000_1000_0_1.npy
# virtualenv
giotto-deep/
venv/
myenv/

examples/model_data_specifications
runs/
Expand Down
1 change: 1 addition & 0 deletions examples/mutag_pipeline.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"cells":[{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["import os\n","from typing import Tuple\n","\n","import torch.nn as nn\n","from gdeep.data import PreprocessingPipeline\n","from gdeep.data.datasets import PersistenceDiagramFromFiles\n","from gdeep.data.datasets.base_dataloaders import (DataLoaderBuilder,\n"," DataLoaderParamsTuples)\n","from gdeep.data.datasets.persistence_diagrams_from_graphs_builder import \\\n"," PersistenceDiagramFromGraphBuilder\n","from gdeep.data.persistence_diagrams.one_hot_persistence_diagram import (\n"," OneHotEncodedPersistenceDiagram, collate_fn_persistence_diagrams)\n","from gdeep.data.preprocessors import (\n"," FilterPersistenceDiagramByHomologyDimension,\n"," FilterPersistenceDiagramByLifetime, NormalizationPersistenceDiagram)\n","from gdeep.topology_layers import Persformer, PersformerConfig\n","from gdeep.topology_layers.persformer_config import PoolerType\n","from gdeep.trainer.trainer import Trainer\n","from gdeep.utility import DEFAULT_GRAPH_DIR, PoolerType\n","from gdeep.utility.utils import autoreload_if_notebook\n","from sklearn.model_selection import train_test_split\n","from torch.optim import Adam\n","from torch.utils.data import Subset\n","from torch.utils.tensorboard.writer import SummaryWriter\n","\n","autoreload_if_notebook()\n","\n","# Parameters\n","name_graph_dataset: str = 'MUTAG'\n","diffusion_parameter: float = 0.1\n","num_homology_types: int = 4\n","\n","\n","# Create the persistence diagram dataset\n","pd_creator = PersistenceDiagramFromGraphBuilder(name_graph_dataset, diffusion_parameter)\n","pd_creator.create()"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["# Plot sample extended persistence diagram\n","file_path: str = os.path.join(DEFAULT_GRAPH_DIR,\n"," f\"MUTAG_{diffusion_parameter}_extended_persistence\", \"diagrams\")\n","graph_idx = 1\n","pd: OneHotEncodedPersistenceDiagram = \\\n"," OneHotEncodedPersistenceDiagram.load(os.path.join(file_path, \n"," f\"{graph_idx}.npy\"))\n","pd.set_homology_dimension_names([\"Ord0\", \"Ext0\", \"Rel1\", \"Ext1\"])\n","pd.plot()"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["\n","pd_mutag_ds = PersistenceDiagramFromFiles(\n"," os.path.join(\n"," DEFAULT_GRAPH_DIR, f\"MUTAG_{diffusion_parameter}_extended_persistence\"\n"," )\n",")\n","\n","pd: OneHotEncodedPersistenceDiagram = pd_mutag_ds[0][0]\n","\n","fig = pd.plot([\"Ord0\", \"Ext0\", \"Rel1\", \"Ext1\"])\n","# add title\n","fig.show()"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["\n","# Create the train/validation/test split\n","\n","train_indices, test_indices = train_test_split(\n"," range(len(pd_mutag_ds)),\n"," test_size=0.2,\n"," random_state=42,\n",")\n","\n","train_indices , validation_indices = train_test_split(\n"," train_indices,\n"," test_size=0.2,\n"," random_state=42,\n",")\n","\n","# Create the data loaders\n","train_dataset = Subset(pd_mutag_ds, train_indices)\n","validation_dataset = Subset(pd_mutag_ds, validation_indices)\n","test_dataset = Subset(pd_mutag_ds, test_indices)\n","\n","# Preprocess the data\n","preprocessing_pipeline = PreprocessingPipeline[Tuple[OneHotEncodedPersistenceDiagram, int]](\n"," (\n"," FilterPersistenceDiagramByHomologyDimension[int]([0, 1]),\n"," FilterPersistenceDiagramByLifetime[int](min_lifetime=-0.1, max_lifetime=1.0),\n"," NormalizationPersistenceDiagram[int](num_homology_dimensions=4),\n"," )\n",")\n","\n","preprocessing_pipeline.fit_to_dataset(train_dataset)\n",""]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["train_dataset = preprocessing_pipeline.attach_transform_to_dataset(train_dataset)\n","validation_dataset = preprocessing_pipeline.attach_transform_to_dataset(validation_dataset)\n","test_dataset = preprocessing_pipeline.attach_transform_to_dataset(test_dataset)\n",""]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["\n","dl_params = DataLoaderParamsTuples.default(\n"," batch_size=32,\n"," num_workers=0,\n"," collate_fn=collate_fn_persistence_diagrams,\n"," with_validation=True,\n",")\n","\n","\n","# Build the data loaders\n","dlb = DataLoaderBuilder((train_dataset, validation_dataset, test_dataset)) # type: ignore\n","dl_train, dl_val, dl_test = dlb.build(dl_params) # type: ignore"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["\n","# Define the model\n","model_config = PersformerConfig(\n"," num_layers=6,\n"," num_attention_heads=4,\n"," input_size= 2 + num_homology_types,\n"," ouptut_size=2,\n"," pooler_type=PoolerType.ATTENTION,\n",")\n","\n","model = Persformer(model_config)\n","writer = SummaryWriter()\n","\n","loss_function = nn.CrossEntropyLoss()\n","\n","trainer = Trainer(model, [dl_train, dl_val, dl_test], loss_function, writer)\n","\n","trainer.train(Adam, 3, False, \n"," {\"lr\":0.01}, \n"," {\"batch_size\":16, \"collate_fn\": collate_fn_persistence_diagrams})"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":[""]}],"nbformat":4,"nbformat_minor":2,"metadata":{"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":3},"orig_nbformat":4}}
131 changes: 131 additions & 0 deletions examples/mutag_pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
# %%
import os
from typing import Tuple

import torch.nn as nn
from gdeep.data import PreprocessingPipeline
from gdeep.data.datasets import PersistenceDiagramFromFiles
from gdeep.data.datasets.base_dataloaders import (DataLoaderBuilder,
DataLoaderParamsTuples)
from gdeep.data.datasets.persistence_diagrams_from_graphs_builder import \
PersistenceDiagramFromGraphBuilder
from gdeep.data.persistence_diagrams.one_hot_persistence_diagram import (
OneHotEncodedPersistenceDiagram, collate_fn_persistence_diagrams)
from gdeep.data.preprocessors import (
FilterPersistenceDiagramByHomologyDimension,
FilterPersistenceDiagramByLifetime, NormalizationPersistenceDiagram)
from gdeep.topology_layers import Persformer, PersformerConfig
from gdeep.topology_layers.persformer_config import PoolerType
from gdeep.trainer.trainer import Trainer
from gdeep.utility import DEFAULT_GRAPH_DIR, PoolerType
from gdeep.utility.utils import autoreload_if_notebook
from sklearn.model_selection import train_test_split
from torch.optim import Adam
from torch.utils.data import Subset
from torch.utils.tensorboard.writer import SummaryWriter

autoreload_if_notebook()

# Parameters
name_graph_dataset: str = 'MUTAG'
diffusion_parameter: float = 0.1
num_homology_types: int = 4


# Create the persistence diagram dataset
pd_creator = PersistenceDiagramFromGraphBuilder(name_graph_dataset, diffusion_parameter)
pd_creator.create()
# %%
# Plot sample extended persistence diagram
file_path: str = os.path.join(DEFAULT_GRAPH_DIR,
f"MUTAG_{diffusion_parameter}_extended_persistence", "diagrams")
graph_idx = 1
pd: OneHotEncodedPersistenceDiagram = \
OneHotEncodedPersistenceDiagram.load(os.path.join(file_path,
f"{graph_idx}.npy"))
pd.set_homology_dimension_names(["Ord0", "Ext0", "Rel1", "Ext1"])
pd.plot()
# %%

pd_mutag_ds = PersistenceDiagramFromFiles(
os.path.join(
DEFAULT_GRAPH_DIR, f"MUTAG_{diffusion_parameter}_extended_persistence"
)
)

pd: OneHotEncodedPersistenceDiagram = pd_mutag_ds[0][0]

fig = pd.plot(["Ord0", "Ext0", "Rel1", "Ext1"])
# add title
fig.show()
# %%

# Create the train/validation/test split

train_indices, test_indices = train_test_split(
range(len(pd_mutag_ds)),
test_size=0.2,
random_state=42,
)

train_indices , validation_indices = train_test_split(
train_indices,
test_size=0.2,
random_state=42,
)

# Create the data loaders
train_dataset = Subset(pd_mutag_ds, train_indices)
validation_dataset = Subset(pd_mutag_ds, validation_indices)
test_dataset = Subset(pd_mutag_ds, test_indices)

# Preprocess the data
preprocessing_pipeline = PreprocessingPipeline[Tuple[OneHotEncodedPersistenceDiagram, int]](
(
FilterPersistenceDiagramByHomologyDimension[int]([0, 1]),
FilterPersistenceDiagramByLifetime[int](min_lifetime=-0.1, max_lifetime=1.0),
NormalizationPersistenceDiagram[int](num_homology_dimensions=4),
)
)

preprocessing_pipeline.fit_to_dataset(train_dataset)

# %%
train_dataset = preprocessing_pipeline.attach_transform_to_dataset(train_dataset)
validation_dataset = preprocessing_pipeline.attach_transform_to_dataset(validation_dataset)
test_dataset = preprocessing_pipeline.attach_transform_to_dataset(test_dataset)

# %%

dl_params = DataLoaderParamsTuples.default(
batch_size=32,
num_workers=0,
collate_fn=collate_fn_persistence_diagrams,
with_validation=True,
)


# Build the data loaders
dlb = DataLoaderBuilder((train_dataset, validation_dataset, test_dataset)) # type: ignore
dl_train, dl_val, dl_test = dlb.build(dl_params) # type: ignore
#%%

# Define the model
model_config = PersformerConfig(
num_layers=6,
num_attention_heads=4,
input_size= 2 + num_homology_types,
ouptut_size=2,
pooler_type=PoolerType.ATTENTION,
)

model = Persformer(model_config)
writer = SummaryWriter()

loss_function = nn.CrossEntropyLoss()

trainer = Trainer(model, [dl_train, dl_val, dl_test], loss_function, writer)

trainer.train(Adam, 3, False,
{"lr":0.01},
{"batch_size":16, "collate_fn": collate_fn_persistence_diagrams})
6 changes: 5 additions & 1 deletion gdeep/data/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from .dataloader_cloud import DlBuilderFromDataCloud
from .parallel_orbit import generate_orbit_parallel, create_pd_orbits,\
OrbitsGenerator, DataLoaderKwargs
from .persistence_diagrams_from_files import PersistenceDiagramFromFiles,\
collate_fn_persistence_diagrams


__all__ = [
Expand All @@ -27,5 +29,7 @@
'ImageClassificationFromFiles',
'FromArray',
'DatasetCloud',
'DlBuilderFromDataCloud'
'DlBuilderFromDataCloud',
'PersistenceDiagramsFromFiles',
'collate_fn_persistence_diagrams',
]
180 changes: 166 additions & 14 deletions gdeep/data/datasets/base_dataloaders.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,23 @@

from dataclasses import dataclass
import json
import os
import shutil
import warnings
from abc import ABC, abstractmethod
from typing import Any, Dict, Optional, TypeVar, List
from os.path import join
from collections.abc import Iterable
from typing import Any, Callable, Dict, Generic, Optional, Sequence, Tuple, \
TypeVar, Union, List

import torch
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import Resize, ToTensor
from tqdm import tqdm
from torch.utils.data import Sampler

from .build_datasets import get_dataset
from .dataset_cloud import DatasetCloud
from ..transforming_dataset import TransformingDataset


Tensor = torch.Tensor
Expand All @@ -17,6 +31,138 @@ class AbstractDataLoaderBuilder(ABC):
def build(self, tuple_of_kwargs: List[Dict[str, Any]]):
pass

@dataclass
class DataLoaderParams:
batch_size: Optional[int] = 1
shuffle: bool = False
num_workers: int = 0
collate_fn: Optional[Callable[[Any], Any]] = None
batch_sampler: Optional[Sampler[Sequence]] = None
pin_memory: bool = False
drop_last: bool = False
timeout: float = 0.0
persistent_workers:bool = False

def copy(self):
return DataLoaderParams(**self.to_dict())

def update_batch_size(self, batch_size: int):
self.batch_size = batch_size
return self

def update_shuffle(self, shuffle: bool):
self.shuffle = shuffle
return self

def to_dict(self):
return {
k: v for k, v in self.__dict__.items() if not k.startswith("_")
}

@dataclass
class DataLoaderParamsTuples:
train: DataLoaderParams
test: Optional[DataLoaderParams] = None
validation: Optional[DataLoaderParams] = None

@staticmethod
def default(
collate_fn: Callable[[Any], Any],
batch_size: int = 32,
num_workers: int = 0,
pin_memory: bool = False,
drop_last: bool = False,
timeout: float = 0.0,
persistent_workers:bool = False,
with_validation: bool = False,
) -> "DataLoaderParamsTuples":
dlp = DataLoaderParamsTuples(
train=DataLoaderParams(
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
collate_fn=collate_fn,
pin_memory=pin_memory,
drop_last=drop_last,
timeout=timeout,
persistent_workers=persistent_workers
),
test=DataLoaderParams(
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
collate_fn=collate_fn,
pin_memory=pin_memory,
drop_last=drop_last,
timeout=timeout,
persistent_workers=persistent_workers
)
)
if with_validation:
dlp.validation = DataLoaderParams(
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
collate_fn=collate_fn,
pin_memory=pin_memory,
drop_last=drop_last,
timeout=timeout,
persistent_workers=persistent_workers
)
return dlp

def to_tuple_of_dicts(self) -> Tuple[Dict, ...]:
if self.validation is not None and self.test is not None:
return (
self.train.to_dict(),
self.validation.to_dict(),
self.test.to_dict()
)
elif self.test is not None:
return (
self.train.to_dict(),
self.test.to_dict()
)
else:
return (
self.train.to_dict(),
)

@staticmethod
def from_list_of_dicts(
list_of_kwargs: List[Dict[str, Any]]
) -> 'DataLoaderParamsTuples':
"""This method accepts the arguments of the torch
Dataloader and applies them when creating the
tuple

Args:
list_of_kwargs:
List of dictionaries, each one being the
kwargs for the corresponding DataLoader
"""
if len(list_of_kwargs) == 3:
return DataLoaderParamsTuples(
DataLoaderParams(**list_of_kwargs[0]),
DataLoaderParams(**list_of_kwargs[1]),
DataLoaderParams(**list_of_kwargs[2])
)
elif len(list_of_kwargs) == 2:
return DataLoaderParamsTuples(
DataLoaderParams(**list_of_kwargs[0]),
DataLoaderParams(**list_of_kwargs[1]),
None
)
elif len(list_of_kwargs) == 1:
return DataLoaderParamsTuples(
DataLoaderParams(**list_of_kwargs[0]),
None,
None
)
else:
raise ValueError(
"The list of dictionaries must have 2 or 3 elements"
)

class DataLoaderBuilder(AbstractDataLoaderBuilder):
"""This class builds, out of a tuple of datasets, the
Expand All @@ -35,7 +181,7 @@ def __init__(self, tuple_of_datasets: List[Dataset[Any]]) -> None:
assert len(tuple_of_datasets) <= 3, "Too many Dataset inserted: maximum 3."

def build(self,
tuple_of_kwargs: Optional[List[Dict[str, Any]]] = None
tuple_of_kwargs:Union[List[Dict[str, Any]], DataLoaderParamsTuples, None]=None
) -> List[DataLoader[Any]]:
"""This method accepts the arguments of the torch
Dataloader and applies them when creating the
Expand All @@ -45,19 +191,25 @@ def build(self,
tuple_of_kwargs:
List of dictionaries, each one being the
kwargs for the corresponding DataLoader
"""
out: List = []
if tuple_of_kwargs:
assert len(tuple_of_kwargs) == len(self.tuple_of_datasets), \
"Cannot match the dataloaders and the parameters. "

for dataset, kwargs in zip(self.tuple_of_datasets, tuple_of_kwargs):
out.append(DataLoader(dataset, **kwargs))
out += [None] * (3 - len(out))
return out
else:
"""
if tuple_of_kwargs is None:
out: List = []
for i, dataset in enumerate(self.tuple_of_datasets):
out.append(DataLoader(dataset))
out += [None] * (3 - len(out))
return out

if isinstance(tuple_of_kwargs, DataLoaderParamsTuples):
tuple_of_kwargs = tuple_of_kwargs.to_tuple_of_dicts() # type: ignore

assert isinstance(tuple_of_kwargs, (list, tuple)), ("The kwargs must be a list or a tuple at"
"this point")
assert len(tuple_of_kwargs) == len(self.tuple_of_datasets), \
"Cannot match the dataloaders and the parameters. "
out: List = []
for dataset, kwargs in zip(self.tuple_of_datasets, tuple_of_kwargs):
out.append(DataLoader(dataset, **kwargs))
out += [None] * (3 - len(out))
return out


Loading