Skip to content

Commit 7191569

Browse files
FilippoOlivodario-coscia
authored andcommitted
Fix bugs (#387)
1 parent f8ba016 commit 7191569

File tree

9 files changed

+28
-29
lines changed

9 files changed

+28
-29
lines changed

pina/data/data_module.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def __init__(self,
9595
logging.debug('Start initialization of Pina DataModule')
9696
logging.info('Start initialization of Pina DataModule')
9797
super().__init__()
98-
self.default_batching = automatic_batching
98+
self.automatic_batching = automatic_batching
9999
self.batch_size = batch_size
100100
self.shuffle = shuffle
101101
self.repeat = repeat
@@ -133,24 +133,24 @@ def setup(self, stage=None):
133133
self.train_dataset = PinaDatasetFactory(
134134
self.collector_splits['train'],
135135
max_conditions_lengths=self.find_max_conditions_lengths(
136-
'train'))
136+
'train'), automatic_batching=self.automatic_batching)
137137
if 'val' in self.collector_splits.keys():
138138
self.val_dataset = PinaDatasetFactory(
139139
self.collector_splits['val'],
140140
max_conditions_lengths=self.find_max_conditions_lengths(
141-
'val')
141+
'val'), automatic_batching=self.automatic_batching
142142
)
143143
elif stage == 'test':
144144
self.test_dataset = PinaDatasetFactory(
145145
self.collector_splits['test'],
146146
max_conditions_lengths=self.find_max_conditions_lengths(
147-
'test')
147+
'test'), automatic_batching=self.automatic_batching
148148
)
149149
elif stage == 'predict':
150150
self.predict_dataset = PinaDatasetFactory(
151151
self.collector_splits['predict'],
152152
max_conditions_lengths=self.find_max_conditions_lengths(
153-
'predict')
153+
'predict'), automatic_batching=self.automatic_batching
154154
)
155155
else:
156156
raise ValueError(
@@ -237,9 +237,9 @@ def val_dataloader(self):
237237
self.val_dataset)
238238

239239
# Use default batching in torch DataLoader (good is batch size is small)
240-
if self.default_batching:
240+
if self.automatic_batching:
241241
collate = Collator(self.find_max_conditions_lengths('val'))
242-
return DataLoader(self.val_dataset, self.batch_size,
242+
return DataLoader(self.val_dataset, batch_size,
243243
collate_fn=collate)
244244
collate = Collator(None)
245245
# Use custom batching (good if batch size is large)
@@ -252,14 +252,16 @@ def train_dataloader(self):
252252
Create the training dataloader
253253
"""
254254
# Use default batching in torch DataLoader (good is batch size is small)
255-
if self.default_batching:
255+
batch_size = self.batch_size if self.batch_size is not None else len(
256+
self.train_dataset)
257+
258+
if self.automatic_batching:
256259
collate = Collator(self.find_max_conditions_lengths('train'))
257-
return DataLoader(self.train_dataset, self.batch_size,
260+
return DataLoader(self.train_dataset, batch_size,
258261
collate_fn=collate)
259262
collate = Collator(None)
260263
# Use custom batching (good if batch size is large)
261-
batch_size = self.batch_size if self.batch_size is not None else len(
262-
self.train_dataset)
264+
263265
sampler = PinaBatchSampler(self.train_dataset, batch_size,
264266
shuffle=False)
265267
return DataLoader(self.train_dataset, sampler=sampler,

pina/data/dataset.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,12 @@ def __getitem__(self, item):
5151

5252
class PinaTensorDataset(PinaDataset):
5353
def __init__(self, conditions_dict, max_conditions_lengths,
54-
):
54+
automatic_batching):
5555
super().__init__(conditions_dict, max_conditions_lengths)
56+
if automatic_batching:
57+
self._getitem_func = self._getitem_int
58+
else:
59+
self._getitem_func = self._getitem_list
5660

5761
def _getitem_int(self, idx):
5862
return {
@@ -72,9 +76,7 @@ def _getitem_list(self, idx):
7276
return to_return_dict
7377

7478
def __getitem__(self, idx):
75-
if isinstance(idx, int):
76-
return self._getitem_int(idx)
77-
return self._getitem_list(idx)
79+
return self._getitem_func(idx)
7880

7981
class PinaGraphDataset(PinaDataset):
8082
pass

pina/label_tensor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from torch import Tensor
55

66

7-
full_labels = True
7+
full_labels = False
88
MATH_FUNCTIONS = {torch.sin, torch.cos}
99

1010
class LabelTensor(torch.Tensor):

pina/solvers/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,3 @@
1717
from .supervised import SupervisedSolver
1818
from .rom import ReducedOrderModelSolver
1919
from .garom import GAROM
20-
from .graph import GraphSupervisedSolver

pina/solvers/pinns/basepinn.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,10 @@
33
from abc import ABCMeta, abstractmethod
44
import torch
55
from torch.nn.modules.loss import _Loss
6-
from ...condition import InputOutputPointsCondition
76
from ...solvers.solver import SolverInterface
87
from ...utils import check_consistency
98
from ...loss.loss_interface import LossInterface
109
from ...problem import InverseProblem
11-
from ...condition import DomainEquationCondition
1210
from ...optim import TorchOptimizer, TorchScheduler
1311

1412
torch.pi = torch.acos(torch.zeros(1)).item() * 2 # which is 3.1415927410125732
@@ -26,8 +24,7 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
2624
to the user to choose which problem the implemented solver inheriting from
2725
this class is suitable for.
2826
"""
29-
accepted_condition_types = [DomainEquationCondition.condition_type[0],
30-
InputOutputPointsCondition.condition_type[0]]
27+
3128
def __init__(
3229
self,
3330
models,

pina/solvers/pinns/pinn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212

1313
from .basepinn import PINNInterface
14-
from pina.problem import InverseProblem
14+
from ...problem import InverseProblem
1515

1616

1717
class PINN(PINNInterface):

pina/solvers/solver.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -134,16 +134,18 @@ def on_train_start(self):
134134
return super().on_train_start()
135135

136136
def _check_solver_consistency(self, problem):
137-
"""
138-
TODO
139-
"""
137+
pass
138+
#TODO : Implement this method for the conditions
139+
'''
140+
141+
140142
for _, condition in problem.conditions.items():
141143
if not set(condition.condition_type).issubset(
142144
set(self.accepted_condition_types)):
143145
raise ValueError(
144146
f'{self.__name__} dose not support condition '
145147
f'{condition.condition_type}')
146-
148+
'''
147149
@staticmethod
148150
def get_batch_size(batch):
149151
# Assuming batch is your custom Batch object

pina/solvers/supervised.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from ..label_tensor import LabelTensor
88
from ..utils import check_consistency
99
from ..loss.loss_interface import LossInterface
10-
from ..condition import InputOutputPointsCondition
1110

1211

1312
class SupervisedSolver(SolverInterface):
@@ -38,7 +37,6 @@ class SupervisedSolver(SolverInterface):
3837
we are seeking to approximate multiple (discretised) functions given
3938
multiple (discretised) input functions.
4039
"""
41-
accepted_condition_types = [InputOutputPointsCondition.condition_type[0]]
4240
__name__ = 'SupervisedSolver'
4341

4442
def __init__(self,

pina/trainer.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
""" Trainer module. """
2-
import warnings
32
import torch
43
import lightning
54
from .utils import check_consistency

0 commit comments

Comments
 (0)