Skip to content

Commit 47a897c

Browse files
author
cgalelli
committed
Move axis outside of class
1 parent 70abfa0 commit 47a897c

File tree

1 file changed

+15
-15
lines changed

1 file changed

+15
-15
lines changed

gammapy/modeling/fit.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -901,12 +901,10 @@ def _repr_html_(self):
901901

902902

903903
class FitResults(collections.abc.Sequence):
904-
def __init__(self, results, axis=None):
905-
if axis and axis.nbin != len(results):
906-
raise ValueError("Axis length does not correspond to number of results")
907-
904+
def __init__(self, results):
905+
if np.array([not isinstance(result, FitResult) for result in results]).any():
906+
raise TypeError(f"Elements in {results!r} are not FitResult objects")
908907
self.results = results
909-
self.axis = axis
910908

911909
def __add__(self, other):
912910
if isinstance(other, FitResult):
@@ -935,19 +933,21 @@ def covariance_results(self):
935933
def optimize_results(self):
936934
return [f.optimize_result for f in self.results]
937935

938-
def select_by_interval(self, coord_min, coord_max):
939-
if self.axis is None:
940-
raise ValueError("No axis to convert coordinate to index!")
936+
def select_by_interval(self, axis, coord_min, coord_max):
937+
if axis and axis.nbin != len(self.results):
938+
raise ValueError("Axis length does not correspond to number of results")
941939

942-
keymin = self.axis.coord_to_idx(coord_min)
943-
keymax = self.axis.coord_to_idx(coord_max)
940+
keymin = axis.coord_to_idx(coord_min)
941+
keymax = axis.coord_to_idx(coord_max)
944942

945943
return self[keymin:keymax]
946944

947945
def write_models(self, path, **kwargs):
948946
self.models().write(path, **kwargs)
949947

950-
def create_model_table(self):
948+
def create_model_table(self, axis=None):
949+
if axis and axis.nbin != len(self.results):
950+
raise ValueError("Axis length does not correspond to number of results")
951951
t = QTable()
952952

953953
t["convergence"] = [result.success for result in self.results]
@@ -961,9 +961,9 @@ def create_model_table(self):
961961
for result in self.results
962962
]
963963

964-
if isinstance(self.axis, TimeMapAxis):
965-
t.add_columns(self.axis.to_gti().table.columns, indexes=[0, 0])
966-
elif isinstance(self.axis, MapAxis):
967-
t.add_columns(self.axis.to_table().columns, indexes=[0, 0])
964+
if isinstance(axis, TimeMapAxis):
965+
t.add_columns(axis.to_gti().table.columns, indexes=[0, 0])
966+
elif isinstance(axis, MapAxis):
967+
t.add_columns(axis.to_table().columns, indexes=[0, 0])
968968

969969
return t

0 commit comments

Comments
 (0)