Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clean handling of events #151

Merged
merged 1 commit into from
Jul 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 18 additions & 5 deletions src/progpy/predictors/monte_carlo.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright © 2021 United States Government as represented by the Administrator of the National Aeronautics and Space Administration. All Rights Reserved.

from collections import abc
from copy import deepcopy
from typing import Callable
from progpy.sim_result import SimResult, LazySimResult
Expand Down Expand Up @@ -99,16 +100,28 @@ def predict(self, state: UncertainData, future_loading_eqn: Callable=None, event
params['n_samples'] = len(state) # number of samples is from provided state

if events is None:
# Predict to all events
# change to list because of limits of jsonify
if 'events' in params and params['events'] is not None:
# Set at a model level
events = list(params['events'])
# Set at a predictor construction
events = params['events']
else:
# Otherwise, all events
events = list(self.model.events)
events = self.model.events

if not isinstance(events, (abc.Iterable)) or isinstance(events, (dict, bytes)):
# must be string or list-like (list, tuple, set)
# using abc.Iterable adds support for custom data structures
# that implement that abstract base class
raise TypeError(f'`events` must be a single event string or list of events. Was unsupported type {type(events)}.')
if len(events) == 0 and 'horizon' not in params:
raise ValueError("If specifying no event (i.e., simulate to time), must specify horizon")
if isinstance(events, str):
# A single event
events = [events]
if not all([key in self.model.events for key in events]):
raise ValueError("`events` must be event names")
if not isinstance(events, list):
# Change to list because of the limits of jsonify
events = list(events)

if 'events' in params:
# Params is provided as a argument in construction
Expand Down
24 changes: 21 additions & 3 deletions src/progpy/predictors/unscented_transform.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright © 2021 United States Government as represented by the Administrator of the National Aeronautics and Space Administration. All Rights Reserved.

from collections import abc
from copy import deepcopy
from filterpy import kalman
from numpy import diag, array, transpose, isnan
Expand Down Expand Up @@ -180,11 +181,28 @@ def predict(self, state, future_loading_eqn: Callable = None, events=None, **kwa
raise ValueError(f"`event_strategy` {params['event_strategy']} not supported. Currently, only 'all' event strategy is supported")

if events is None:
# Predict to all events
# change to list because of limits of jsonify
events = list(self.model.events)
if 'events' in params and params['events'] is not None:
# Set at a predictor construction
events = params['events']
else:
# Otherwise, all events
events = self.model.events

if not isinstance(events, (abc.Iterable)) or isinstance(events, (dict, bytes)):
# must be string or list-like (list, tuple, set)
# using abc.Iterable adds support for custom data structures
# that implement that abstract base class
raise TypeError(f'`events` must be a single event string or list of events. Was unsupported type {type(events)}.')
if len(events) == 0 and 'horizon' not in params:
raise ValueError("If specifying no event (i.e., simulate to time), must specify horizon")
if isinstance(events, str):
# A single event
events = [events]
if not all([key in self.model.events for key in events]):
raise ValueError("`events` must be event names")
if not isinstance(events, list):
# Change to list because of the limits of jsonify
events = list(events)

# Optimizations
dt = params['dt']
Expand Down
13 changes: 10 additions & 3 deletions src/progpy/prognostics_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -871,13 +871,20 @@ def simulate_to_threshold(self, future_loading_eqn: abc.Callable = None, first_o
events = kwargs['threshold_keys']
else:
warn('Both `events` and `threshold_keys` were set. `events` will be used.')


if events is None:
events = self.events.copy()
if not isinstance(events, abc.Iterable):
# must be string or list-like
raise TypeError(f'`events` must be a single event string or list of events. Was unsupported type {type(events)}.')
if isinstance(events, str):
# A single threshold key
# A single event
events = [events]

if (events is not None) and not all([key in self.events for key in events]):
raise ValueError("`events` must be event names")
if not isinstance(events, list):
# Change to list because of the limits of jsonify
events = list(events)

# Configure
config = { # Defaults
Expand Down
55 changes: 55 additions & 0 deletions tests/test_predictors.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,39 @@ def test_UTP_ThrownObject(self):
self.assertAlmostEqual(results.time_of_event.mean['falling'], 4.15, 0)
# self.assertAlmostEqual(mc_results.times[-1], 9, 1) # Saving every second, last time should be around the 1s after impact event (because one of the sigma points fails afterwards)

# Setting event manually
results = pred.predict(samples, dt=0.01, save_freq=1, events=['falling'])
self.assertAlmostEqual(results.time_of_event.mean['falling'], 3.8, 5)
self.assertNotIn('impact', results.time_of_event.mean)

# Setting event in construction
pred = UnscentedTransformPredictor(m, events=['falling'])
results = pred.predict(samples, dt=0.01, save_freq=1)
self.assertAlmostEqual(results.time_of_event.mean['falling'], 3.8, 5)
self.assertNotIn('impact', results.time_of_event.mean)

# Override event set in construction
results = pred.predict(samples, dt=0.01, save_freq=1, events=['falling', 'impact'])
self.assertAlmostEqual(results.time_of_event.mean['impact'], 8.21, 0)
self.assertAlmostEqual(results.time_of_event.mean['falling'], 4.15, 0)

# String event
results = pred.predict(samples, dt=0.01, save_freq=1, events='impact')
self.assertAlmostEqual(results.time_of_event.mean['impact'], 7.785, 5)
self.assertNotIn('falling', results.time_of_event.mean)

# Invalid event
with self.assertRaises(ValueError):
results = pred.predict(samples, dt=0.01, save_freq=1, events='invalid')
with self.assertRaises(ValueError):
# Mix valid, invalid
results = pred.predict(samples, dt=0.01, save_freq=1, events=['falling', 'invalid'])
with self.assertRaises(ValueError):
# Empty
results = pred.predict(samples, dt=0.01, save_freq=1, events=[])
with self.assertRaises(TypeError):
results = pred.predict(samples, dt=0.01, save_freq=1, events=45)

def test_UTP_ThrownObject_One_Event(self):
# Test thrown object, similar to test_UKP_ThrownObject, but with only the 'falling' event
m = ThrownObject()
Expand Down Expand Up @@ -168,6 +201,28 @@ def test_MC_ThrownObject(self):
results = mc.predict(m.initialize(), dt=0.2, num_samples=3, save_freq=1, events=['falling', 'impact'])
self.assertAlmostEqual(results.time_of_event.mean['falling'], 3.8, 5)
self.assertAlmostEqual(results.time_of_event.mean['impact'], 8.0, 5)

# String event
results = mc.predict(m.initialize(), dt=0.2, num_samples=3, save_freq=1, events='impact')
self.assertAlmostEqual(results.time_of_event.mean['impact'], 8.0, 5)
self.assertNotIn('falling', results.time_of_event.mean)

# Invalid event
with self.assertRaises(ValueError):
results = mc.predict(m.initialize(), dt=0.2, num_samples=3, save_freq=1, events='invalid')
with self.assertRaises(ValueError):
# Mix valid, invalid
results = mc.predict(m.initialize(), dt=0.2, num_samples=3, save_freq=1, events=['falling', 'invalid'])
with self.assertRaises(ValueError):
# Empty
results = mc.predict(m.initialize(), dt=0.2, num_samples=3, save_freq=1, events=[])
with self.assertRaises(TypeError):
results = mc.predict(m.initialize(), dt=0.2, num_samples=3, save_freq=1, events=45)

# Empty with horizon
results = mc.predict(m.initialize(), dt=0.2, num_samples=3, save_freq=1, horizon=3, events=[])

# TODO(CT): Events in other predictor

def test_MC_ThrownObject_First(self):
# Test thrown object, similar to test_UKP_ThrownObject, but with only the first event (i.e., 'falling')
Expand Down
Loading