1
1
from abc import ABCMeta as _ABCMeta , abstractmethod as _abstractmethod
2
- from typing import Any as _Any , override as _override , Generator as _Generator
2
+ from typing import Any as _Any , override as _override , Generator as _Generator , Literal as _Literal
3
3
4
4
from leads .data_persistence .analyzer .utils import time_invalid , speed_invalid , acceleration_invalid , \
5
5
mileage_invalid , latitude_invalid , longitude_invalid , distance_between
6
- from leads .data_persistence .core import CSVDataset , DEFAULT_HEADER
6
+ from leads .data_persistence .core import CSVDataset , DEFAULT_HEADER , VISUAL_HEADER_ONLY
7
7
8
8
9
9
class Inference (object , metaclass = _ABCMeta ):
10
- def __init__ (self , required_depth : tuple [int , int ] = (0 , 0 )) -> None :
10
+ def __init__ (self , required_depth : tuple [int , int ] = (0 , 0 ),
11
+ required_header : tuple [str , ...] = DEFAULT_HEADER ) -> None :
11
12
"""
12
13
Declare the scale of data this inference requires.
13
14
:param required_depth: (-depth backward, depth forward)
15
+ :param required_header: the necessary header that the dataset must contain for this inference to work
14
16
"""
15
17
self ._required_depth : tuple [int , int ] = required_depth
18
+ self ._required_header : tuple [str , ...] = required_header
16
19
17
20
def depth (self ) -> tuple [int , int ]:
18
21
"""
19
22
:return: (-depth backward, depth forward)
20
23
"""
21
24
return self ._required_depth
22
25
26
+ def header (self ) -> tuple [str , ...]:
27
+ return self ._required_header
28
+
23
29
@_abstractmethod
24
30
def complete (self , * rows : dict [str , _Any ], backward : bool = False ) -> dict [str , _Any ] | None :
25
31
"""
@@ -45,7 +51,7 @@ class SafeSpeedInference(SpeedInferenceBase):
45
51
"""
46
52
47
53
def __init__ (self ) -> None :
48
- super ().__init__ (( 0 , 0 ) )
54
+ super ().__init__ ()
49
55
50
56
@_override
51
57
def complete (self , * rows : dict [str , _Any ], backward : bool = False ) -> dict [str , _Any ] | None :
@@ -111,7 +117,7 @@ class SpeedInferenceByGPSGroundSpeed(SpeedInferenceBase):
111
117
"""
112
118
113
119
def __init__ (self ) -> None :
114
- super ().__init__ (( 0 , 0 ) )
120
+ super ().__init__ ()
115
121
116
122
@_override
117
123
def complete (self , * rows : dict [str , _Any ], backward : bool = False ) -> dict [str , _Any ] | None :
@@ -225,6 +231,27 @@ def complete(self, *rows: dict[str, _Any], backward: bool = False) -> dict[str,
225
231
}
226
232
227
233
234
+ class VisualDataRealignmentByLatency (Inference ):
235
+ def __init__ (self , * channels : _Literal ["front" , "left" , "right" , "rear" ]) -> None :
236
+ super ().__init__ ((0 , 1 ), VISUAL_HEADER_ONLY )
237
+ self ._channels : tuple [_Literal ["front" , "left" , "right" , "rear" ], ...] = channels if channels else (
238
+ "front" , "left" , "right" , "rear" )
239
+
240
+ @_override
241
+ def complete (self , * rows : dict [str , _Any ], backward : bool = False ) -> dict [str , _Any ] | None :
242
+ if backward :
243
+ return None
244
+ target , base = rows
245
+ original_target = target .copy ()
246
+ t_0 , t = target ["t" ], base ["t" ]
247
+ for channel in self ._channels :
248
+ if (new_latency := t_0 - t + base [f"{ channel } _view_latency" ]) > 0 :
249
+ continue
250
+ target [f"{ channel } _view_base64" ] = base [f"{ channel } _view_base64" ]
251
+ target [f"{ channel } _view_latency" ] = new_latency
252
+ return None if target == original_target else target
253
+
254
+
228
255
class InferredDataset (CSVDataset ):
229
256
def __init__ (self , file : str , chunk_size : int = 100 ) -> None :
230
257
super ().__init__ (file , chunk_size )
@@ -291,8 +318,9 @@ def complete(self, *inferences: Inference, enhanced: bool = False, assume_initia
291
318
:param enhanced: True: use inferred data to infer other data; False: use only raw data to infer other data
292
319
:param assume_initial_zeros: True: reasonably set any missing data in the first row to zero; False: no change
293
320
"""
294
- if DEFAULT_HEADER in self .read_header ():
295
- raise KeyError ("Your dataset must include the default header" )
321
+ for inference in inferences :
322
+ if not set (rh := inference .header ()).issubset (ah := self .read_header ()):
323
+ raise KeyError (f"Inference { inference } requires header { rh } but the dataset only contains { ah } " )
296
324
if assume_initial_zeros :
297
325
self .assume_initial_zeros ()
298
326
self ._complete (inferences , enhanced , False )
0 commit comments