Skip to content

Commit 506d0ae

Browse files
bugfix output parameter names with optimization
1 parent ddda5e7 commit 506d0ae

File tree

3 files changed

+27
-9
lines changed

3 files changed

+27
-9
lines changed

src/f3dasm/_src/design/domain.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,11 @@ def names(self) -> List[str]:
8484
"""Return a list of the names of the parameters"""
8585
return list(self.keys())
8686

87+
@property
88+
def output_names(self) -> List[str]:
89+
"""Return a list of the names of the output parameters"""
90+
return list(self.output_space.keys())
91+
8792
@property
8893
def continuous(self) -> Domain:
8994
"""Returns a Domain object containing only the continuous parameters"""
@@ -748,6 +753,7 @@ def _check_output(self, names: List[str]):
748753
"""
749754
for output_name in names:
750755
if not self.is_in_output(output_name):
756+
print(f"Output {output_name} not in domain. Adding it.")
751757
self.add_output(output_name, to_disk=False)
752758

753759
def is_in_output(self, output_name: str) -> bool:
@@ -824,7 +830,7 @@ def _domain_factory(domain: Domain | DictConfig | None,
824830
input_data: pd.DataFrame,
825831
output_data: pd.DataFrame) -> Domain:
826832
if isinstance(domain, Domain):
827-
domain._check_output(output_data.columns)
833+
# domain._check_output(output_data.columns)
828834
return domain
829835

830836
elif isinstance(domain, (Path, str)):

src/f3dasm/_src/experimentdata/experimentdata.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -125,8 +125,8 @@ def __init__(self,
125125
job_value = Status.FINISHED
126126

127127
self.domain = _domain_factory(
128-
domain, self._input_data.to_dataframe(),
129-
self._output_data.to_dataframe())
128+
domain=domain, input_data=self._input_data.to_dataframe(),
129+
output_data=self._output_data.to_dataframe())
130130

131131
# Create empty input_data from domain if input_data is empty
132132
if self._input_data.is_empty():
@@ -139,6 +139,9 @@ def __init__(self,
139139
if not self._input_data.has_columnnames(self.domain.names):
140140
self._input_data.set_columnnames(self.domain.names)
141141

142+
if not self._output_data.has_columnnames(self.domain.output_names):
143+
self._output_data.set_columnnames(self.domain.output_names)
144+
142145
# For backwards compatibility; if the output_data has
143146
# only one column, rename it to 'y'
144147
if self._output_data.names == [0]:
@@ -1220,13 +1223,19 @@ def _iterate_scipy(self, optimizer: Optimizer,
12201223
# n_data_before_iterate + iterations amount of elements!
12211224
# If x_new is empty, repeat best x0 to fill up total iteration
12221225
if len(self) == n_data_before_iterate:
1223-
repeated_last_element = self.get_n_best_output(
1224-
n_samples=1).to_numpy()[0].ravel()
1226+
repeated_x, repeated_y = self.get_n_best_output(
1227+
n_samples=1).to_numpy()
1228+
# repeated_last_element = self.get_n_best_output(
1229+
# n_samples=1).to_numpy()[0].ravel()
12251230

12261231
for repetition in range(iterations):
1227-
self._add_experiments(
1228-
ExperimentSample.from_numpy(repeated_last_element,
1229-
domain=self.domain))
1232+
# self._add_experiments(
1233+
# ExperimentSample.from_numpy(repeated_last_element,
1234+
# domain=self.domain))
1235+
1236+
self.add(
1237+
domain=self.domain, input_data=repeated_x,
1238+
output_data=repeated_y)
12301239

12311240
# Repeat last iteration to fill up total iteration
12321241
if len(self) < n_data_before_iterate + iterations:

src/f3dasm/_src/optimization/adapters/scipy_implementations.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@ def run_algorithm(self, iterations: int, data_generator: DataGenerator):
4646
"""
4747

4848
def fun(x):
49-
sample: ExperimentSample = data_generator._run(x)
49+
sample: ExperimentSample = data_generator._run(
50+
x, domain=self.domain)
5051
_, y = sample.to_numpy()
5152
return float(y)
5253

@@ -63,3 +64,5 @@ def fun(x):
6364
bounds=self.domain.get_bounds(),
6465
tol=0.0,
6566
)
67+
68+
# self.data.evaluate(data_generator=data_generator)

0 commit comments

Comments
 (0)