Skip to content

Remove collector #576

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 0 additions & 129 deletions pina/collector.py

This file was deleted.

26 changes: 15 additions & 11 deletions pina/data/data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
from torch.utils.data.distributed import DistributedSampler
from ..label_tensor import LabelTensor
from .dataset import PinaDatasetFactory, PinaTensorDataset
from ..collector import Collector

# from ..collector import Collector


class DummyDataloader:
Expand Down Expand Up @@ -330,9 +331,10 @@ def __init__(
self.pin_memory = pin_memory

# Collect data
collector = Collector(problem)
collector.store_fixed_data()
collector.store_sample_domains()
# collector = Collector(problem)
# collector.store_fixed_data()
# collector.store_sample_domains()
problem.collect_data()

# Check if the splits are correct
self._check_slit_sizes(train_size, test_size, val_size)
Expand Down Expand Up @@ -361,7 +363,9 @@ def __init__(
# raises NotImplementedError
self.val_dataloader = super().val_dataloader

self.collector_splits = self._create_splits(collector, splits_dict)
self.data_splits = self._create_splits(
problem.collected_data, splits_dict
)
self.transfer_batch_to_device = self._transfer_batch_to_device

def setup(self, stage=None):
Expand All @@ -376,23 +380,23 @@ def setup(self, stage=None):
"""
if stage == "fit" or stage is None:
self.train_dataset = PinaDatasetFactory(
self.collector_splits["train"],
self.data_splits["train"],
max_conditions_lengths=self.find_max_conditions_lengths(
"train"
),
automatic_batching=self.automatic_batching,
)
if "val" in self.collector_splits.keys():
if "val" in self.data_splits.keys():
self.val_dataset = PinaDatasetFactory(
self.collector_splits["val"],
self.data_splits["val"],
max_conditions_lengths=self.find_max_conditions_lengths(
"val"
),
automatic_batching=self.automatic_batching,
)
elif stage == "test":
self.test_dataset = PinaDatasetFactory(
self.collector_splits["test"],
self.data_splits["test"],
max_conditions_lengths=self.find_max_conditions_lengths("test"),
automatic_batching=self.automatic_batching,
)
Expand Down Expand Up @@ -473,7 +477,7 @@ def _apply_shuffle(condition_dict, len_data):
for (
condition_name,
condition_dict,
) in collector.data_collections.items():
) in collector.items():
len_data = len(condition_dict["input"])
if self.shuffle:
_apply_shuffle(condition_dict, len_data)
Expand Down Expand Up @@ -540,7 +544,7 @@ def find_max_conditions_lengths(self, split):
"""

max_conditions_lengths = {}
for k, v in self.collector_splits[split].items():
for k, v in self.data_splits[split].items():
if self.batch_size is None:
max_conditions_lengths[k] = len(v["input"])
elif self.repeat:
Expand Down
59 changes: 46 additions & 13 deletions pina/problem/abstract_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,11 @@ def __init__(self):
Initialization of the :class:`AbstractProblem` class.
"""
self._discretised_domains = {}
# create collector to manage problem data

# create hook conditions <-> problems
for condition_name in self.conditions:
self.conditions[condition_name].problem = self

self._batching_dimension = 0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are we removing this?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is not used!


# Store in domains dict all the domains object directly passed to
# ConditionInterface. Done for back compatibility with PINA <0.2
if not hasattr(self, "domains"):
Expand All @@ -41,26 +38,33 @@ def __init__(self):
self.domains[cond_name] = cond.domain
cond.domain = cond_name

self._collect_data = {}

@property
def batching_dimension(self):
def collected_data(self):
"""
Get batching dimension.
Return the collected data from the problem's conditions.

:return: The batching dimension.
:rtype: int
:return: The collected data.
:rtype: dict
"""
return self._batching_dimension
if not self._collect_data:
raise RuntimeError(
"You have to call collect_data() before accessing the data."
)
return self._collect_data

@batching_dimension.setter
def batching_dimension(self, value):
@collected_data.setter
def collected_data(self, data):
"""
Set the batching dimension.
Set the collected data from the problem's conditions.

:param int value: The batching dimension.
:param dict data: The collected data.
"""
self._batching_dimension = value
self._collect_data = data

# back compatibility 0.1

@property
def input_pts(self):
"""
Expand Down Expand Up @@ -300,3 +304,32 @@ def add_points(self, new_points_dict):
self.discretised_domains[k] = LabelTensor.vstack(
[self.discretised_domains[k], v]
)

def collect_data(self):
"""
Aggregate data from the problem's conditions into a single dictionary.
"""
data = {}
# check if all domains are discretised
if not self.are_all_domains_discretised:
raise RuntimeError(
"All domains must be discretised before aggregating data."
)
# Iterate over the conditions and collect data
for condition_name in self.conditions:
condition = self.conditions[condition_name]
# Check if the condition has an domain attribute
if hasattr(condition, "domain"):
# Store the discretisation points
samples = self.discretised_domains[condition.domain]
data[condition_name] = {
"input": samples,
"equation": condition.equation,
}
else:
# If the condition does not have a domain attribute, store
# the input and target points
keys = condition.__slots__
values = [getattr(condition, name) for name in keys]
data[condition_name] = dict(zip(keys, values))
self.collected_data = data
Loading