Skip to content

Commit dc1de01

Browse files
committed
Fixes
1 parent e21715a commit dc1de01

File tree

3 files changed

+44
-16
lines changed

3 files changed

+44
-16
lines changed

pina/data/data_module.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -334,7 +334,7 @@ def __init__(
334334
# collector = Collector(problem)
335335
# collector.store_fixed_data()
336336
# collector.store_sample_domains()
337-
problem.aggregate_data()
337+
problem.collect_data()
338338

339339
# Check if the splits are correct
340340
self._check_slit_sizes(train_size, test_size, val_size)
@@ -363,7 +363,9 @@ def __init__(
363363
# raises NotImplementedError
364364
self.val_dataloader = super().val_dataloader
365365

366-
self.data_splits = self._create_splits(problem.data, splits_dict)
366+
self.data_splits = self._create_splits(
367+
problem.collected_data, splits_dict
368+
)
367369
self.transfer_batch_to_device = self._transfer_batch_to_device
368370

369371
def setup(self, stage=None):

pina/problem/abstract_problem.py

+29-5
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,30 @@ def __init__(self):
3838
self.domains[cond_name] = cond.domain
3939
cond.domain = cond_name
4040

41-
self.data = None
41+
self._collect_data = {}
42+
43+
@property
44+
def collected_data(self):
45+
"""
46+
Return the collected data from the problem's conditions.
47+
48+
:return: The collected data.
49+
:rtype: dict
50+
"""
51+
if not self._collect_data:
52+
raise RuntimeError(
53+
"You have to call collect_data() before accessing the data."
54+
)
55+
return self._collect_data
56+
57+
@collected_data.setter
58+
def collected_data(self, data):
59+
"""
60+
Set the collected data from the problem's conditions.
61+
62+
:param dict data: The collected data.
63+
"""
64+
self._collect_data = data
4265

4366
# back compatibility 0.1
4467
@property
@@ -281,11 +304,11 @@ def add_points(self, new_points_dict):
281304
[self.discretised_domains[k], v]
282305
)
283306

284-
def aggregate_data(self):
307+
def collect_data(self):
285308
"""
286309
Aggregate data from the problem's conditions into a single dictionary.
287310
"""
288-
self.data = {}
311+
data = {}
289312
if not self.are_all_domains_discretised:
290313
raise RuntimeError(
291314
"All domains must be discretised before aggregating data."
@@ -295,11 +318,12 @@ def aggregate_data(self):
295318
if hasattr(condition, "domain"):
296319
samples = self.discretised_domains[condition.domain]
297320

298-
self.data[condition_name] = {
321+
data[condition_name] = {
299322
"input": samples,
300323
"equation": condition.equation,
301324
}
302325
else:
303326
keys = condition.__slots__
304327
values = [getattr(condition, name) for name in keys]
305-
self.data[condition_name] = dict(zip(keys, values))
328+
data[condition_name] = dict(zip(keys, values))
329+
self.collected_data = data

tests/test_problem.py

+11-9
Original file line numberDiff line numberDiff line change
@@ -98,21 +98,23 @@ def test_aggregate_data():
9898
target=LabelTensor(torch.tensor([[0.0]]), labels=["u"]),
9999
)
100100
poisson_problem.discretise_domain(0, "random", domains="all")
101-
poisson_problem.aggregate_data()
102-
assert isinstance(poisson_problem.data, dict)
101+
poisson_problem.collect_data()
102+
assert isinstance(poisson_problem.collected_data, dict)
103103
for name, conditions in poisson_problem.conditions.items():
104-
assert name in poisson_problem.data.keys()
104+
assert name in poisson_problem.collected_data.keys()
105105
if isinstance(conditions, InputTargetCondition):
106-
assert "input" in poisson_problem.data[name].keys()
107-
assert "target" in poisson_problem.data[name].keys()
106+
assert "input" in poisson_problem.collected_data[name].keys()
107+
assert "target" in poisson_problem.collected_data[name].keys()
108108
elif isinstance(conditions, DomainEquationCondition):
109-
assert "input" in poisson_problem.data[name].keys()
110-
assert "target" not in poisson_problem.data[name].keys()
111-
assert "equation" in poisson_problem.data[name].keys()
109+
assert "input" in poisson_problem.collected_data[name].keys()
110+
assert "target" not in poisson_problem.collected_data[name].keys()
111+
assert "equation" in poisson_problem.collected_data[name].keys()
112112

113113

114114
def test_wrong_aggregate_data():
115115
poisson_problem = Poisson()
116116
poisson_problem.discretise_domain(0, "random", domains=["D"])
117117
with pytest.raises(RuntimeError):
118-
poisson_problem.aggregate_data()
118+
poisson_problem.collected_data()
119+
with pytest.raises(RuntimeError):
120+
poisson_problem.collect_data()

0 commit comments

Comments
 (0)