@@ -56,7 +56,7 @@ class Subframe:
5656 """
5757
5858 def __init__ (self , time : sc .Variable , wavelength : sc .Variable ):
59- if time .sizes != wavelength .sizes :
59+ if { dim : time .sizes . get ( dim ) for dim in wavelength . sizes } != wavelength .sizes :
6060 raise sc .DimensionError (
6161 f'Inconsistent dims or shape: { time .sizes } vs { wavelength .sizes } '
6262 )
@@ -106,22 +106,33 @@ def propagate_by(self, distance: sc.Variable) -> Subframe:
106106
107107 @property
108108 def start_time (self ) -> sc .Variable :
109- """The start time of the subframe."""
110- return self .time .min ()
109+ """The start time of the subframe for each of the distances in self.time."""
110+ # The `self.time` may have an additional dimension for distance, compared to
111+ # `self.wavelength`, and we need to keep that dimension in the output.
112+ out = self .time
113+ for dim in self .wavelength .dims :
114+ out = out .min (dim )
115+ return out
111116
112117 @property
113118 def end_time (self ) -> sc .Variable :
114- """The end time of the subframe."""
115- return self .time .max ()
119+ """The end time of the subframe for each of the distances in self.time."""
120+ # The `self.time` may have an additional dimension for distance, compared to
121+ # `self.wavelength`, and we need to keep that dimension in the output.
122+ out = self .time
123+ for dim in self .wavelength .dims :
124+ out = out .max (dim )
125+ return out
116126
117127 @property
118128 def start_wavelength (self ) -> sc .Variable :
119- """The start wavelength of the subframe."""
129+ """The start wavelength of the subframe for each of the distances in
130+ self.time"""
120131 return self .wavelength .min ()
121132
122133 @property
123134 def end_wavelength (self ) -> sc .Variable :
124- """The end wavelength of the subframe."""
135+ """The end wavelength of the subframe for each of the distances in self.time ."""
125136 return self .wavelength .max ()
126137
127138
@@ -206,19 +217,14 @@ def bounds(self) -> sc.DataGroup:
206217
207218 def subbounds (self ) -> sc .DataGroup :
208219 """
209- The bounds of the subframes, defined as the union over subframes.
210-
211- This is not the same as the bounds of the individual subframes, but defined as
212- the union of all subframes. Subframes that overlap in time are "merged" into a
213- single subframe.
214-
215- This function is to some extent experimental: It is not clear if taking the
216- union of overlapping subframes has any utility in practice, since this may
217- simply indicate a problem with the chopper cascade. Attempts to handle this
218- automatically may be misguided.
220+ The bounds of the individual subframes, stored as a DataGroup.
219221 """
220- starts = [subframe .start_time for subframe in self .subframes ]
221- ends = [subframe .end_time for subframe in self .subframes ]
222+ starts = sc .concat (
223+ [subframe .start_time for subframe in self .subframes ], dim = 'subframe'
224+ )
225+ ends = sc .concat (
226+ [subframe .end_time for subframe in self .subframes ], dim = 'subframe'
227+ )
222228 # Given how time-propagation and chopping works, the min wavelength is always
223229 # given by the same vertex as the min time, and the max wavelength by the same
224230 # vertex as the max time. Thus, this check should generally always pass.
@@ -228,48 +234,21 @@ def subbounds(self) -> sc.DataGroup:
228234 'Subframes must be regular, i.e., min/max time and wavelength must '
229235 'coincide.'
230236 )
231- wav_starts = [subframe .start_wavelength for subframe in self .subframes ]
232- wav_ends = [subframe .end_wavelength for subframe in self .subframes ]
233-
234- @dataclass
235- class Bound :
236- start : sc .Variable
237- end : sc .Variable
238- wav_start : sc .Variable
239- wav_end : sc .Variable
240-
241- bounds = [
242- Bound (start , end , wav_start , wav_end )
243- for start , end , wav_start , wav_end in zip (
244- starts , ends , wav_starts , wav_ends , strict = True
245- )
246- ]
247- bounds = sorted (bounds , key = lambda x : x .start )
248- current = bounds [0 ]
249- merged_bounds = []
250- for bound in bounds [1 :]:
251- # If start is before current end, merge
252- if bound .start <= current .end :
253- current = Bound (
254- current .start ,
255- max (current .end , bound .end ),
256- current .wav_start ,
257- max (current .wav_end , bound .wav_end ),
258- )
259- else :
260- merged_bounds .append (current )
261- current = bound
262- merged_bounds .append (current )
263- time_bounds = [
264- sc .concat ([bound .start , bound .end ], dim = 'bound' ) for bound in merged_bounds
265- ]
266- wav_bounds = [
267- sc .concat ([bound .wav_start , bound .wav_end ], dim = 'bound' )
268- for bound in merged_bounds
269- ]
237+ wav_starts = sc .concat (
238+ [subframe .start_wavelength for subframe in self .subframes ], dim = 'subframe'
239+ )
240+ wav_ends = sc .concat (
241+ [subframe .end_wavelength for subframe in self .subframes ], dim = 'subframe'
242+ )
243+
244+ time = sc .concat ([starts , ends ], dim = 'bound' )
245+ wavelength = sc .concat ([wav_starts , wav_ends ], dim = 'bound' )
246+ time_dims = list (set (time .dims ) - {'subframe' , 'bound' })
247+ wavelength_dims = list (set (wavelength .dims ) - {'subframe' , 'bound' })
248+
270249 return sc .DataGroup (
271- time = sc . concat ( time_bounds , dim = 'subframe' ),
272- wavelength = sc . concat ( wav_bounds , dim = 'subframe' ),
250+ time = time . transpose ([ * time_dims , 'subframe' , 'bound' ] ),
251+ wavelength = wavelength . transpose ([ * wavelength_dims , 'subframe' , 'bound' ] ),
273252 )
274253
275254
0 commit comments