Skip to content

Commit

Permalink
Refactor column renaming methods in experimentdata module
Browse files Browse the repository at this point in the history
  • Loading branch information
mpvanderschelling committed Jun 20, 2024
1 parent 82cf830 commit 36af2ce
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 13 deletions.
9 changes: 8 additions & 1 deletion src/f3dasm/_src/experimentdata/_columns.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from __future__ import annotations

# Standard
from typing import Dict, List, Optional
from typing import Dict, Iterable, List, Optional

# Authorship & Credits
# =============================================================================
Expand Down Expand Up @@ -123,3 +123,10 @@ def rename(self, old_name: str, new_name: str):
name of the column to replace with
"""
self.columns[new_name] = self.columns.pop(old_name)

def set_columnnames(self, names: Iterable[str]) -> None:
for old_name, new_name in zip(self.names, names):
self.rename(old_name, new_name)

def has_columnnames(self, names: Iterable[str]) -> None:
return set(names).issubset(self.names)
7 changes: 0 additions & 7 deletions src/f3dasm/_src/experimentdata/_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,13 +438,6 @@ def get_index_with_nan(self) -> pd.Index:
"""
return self.indices[self.data.isna().any(axis=1)]

def has_columnnames(self, names: Iterable[str]) -> bool:
return set(names).issubset(self.names)

def set_columnnames(self, names: Iterable[str]) -> None:
for old_name, new_name in zip(self.names, names):
self.columns.rename(old_name, new_name)


def _convert_dict_to_data(dictionary: Dict[str, Any]) -> _Data:
"""Converts a dictionary with scalar values to a data object.
Expand Down
11 changes: 6 additions & 5 deletions src/f3dasm/_src/experimentdata/experimentdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,16 +133,17 @@ def __init__(self,
jobs, self._input_data, self._output_data, job_value)

# Check if the columns of input_data are in the domain
if not self._input_data.has_columnnames(self.domain.names):
self._input_data.set_columnnames(self.domain.names)
if not self._input_data.columns.has_columnnames(self.domain.names):
self._input_data.columns.set_columnnames(self.domain.names)

if not self._output_data.has_columnnames(self.domain.output_names):
self._output_data.set_columnnames(self.domain.output_names)
if not self._output_data.columns.has_columnnames(
self.domain.output_names):
self._output_data.columns.set_columnnames(self.domain.output_names)

# For backwards compatibility; if the output_data has
# only one column, rename it to 'y'
if self._output_data.names == [0]:
self._output_data.set_columnnames(['y'])
self._output_data.columns.set_columnnames(['y'])

def __len__(self):
"""The len() method returns the number of datapoints"""
Expand Down

0 comments on commit 36af2ce

Please sign in to comment.