diff --git a/checkpoint_manager.py b/checkpoint_manager.py new file mode 100644 index 0000000..61ef7eb --- /dev/null +++ b/checkpoint_manager.py @@ -0,0 +1,59 @@ +import pickle +import os + +class CheckpointManager: + """ + Encapsulate checkpoint functionality using pickle + """ + + def __init__(self, filename: str): + """ + Initialise with given filename + + Args: + filename (str): Filename to use for saving/loading checkpoint + """ + + self.filename = filename + + def save(self, data): + """ + Save data to checkpoint file with pickle + + Args: + data: Object to be pickled + """ + # mkdir if doesnt exist + directory = os.path.dirname(self.filename) + if directory and not os.path.exists(directory): + os.makedirs(directory, exist_ok=True) + + # Save data + with open(self.filename, "wb") as f: + pickle.dump(data, f) + print(f"Checkpoint saved: {self.filename}") + + def load(self): + """ + Load and return data from checkpoint file .pkl + + Returns: + Unpickled data object + """ + if not os.path.exists(self.filename): + raise FileNotFoundError(f"Checkpoint file {self.filename} not found.") + with open(self.filename,"rb") as f: + data = pickle.load(f) + print(f"Checkpoint loaded: {self.filename}") + + return data + + def exists(self): + """ + Check if the checkpoint file exists + + Returns: + (bool) True if checkpoint exists, else False. + """ + return os.path.exists(self.filename) + diff --git a/config.py b/config.py index d2f5377..2cc550c 100644 --- a/config.py +++ b/config.py @@ -14,8 +14,10 @@ CONDITIONS = ["pre_heat_exposure", "intra_heat_exposure", "post_heat_exposure"] # For checkpointing -USE_CHECKPOINT = False -CHECKPOINT_NAME = "processed_data.pkl" +LOAD_CHECKPOINT = True +SAVE_CHECKPOINT = False +CHECKPOINT_FILE = "data/pickled/ID00_loaded_data_000.pkl" +CHECKPOINT_ID = 0 # Function to get participant directories def get_participant_dirs(): diff --git a/main.py b/main.py index 20a0226..b6ef398 100644 --- a/main.py +++ b/main.py @@ -1,9 +1,27 @@ +import os + +from config import LOAD_CHECKPOINT, SAVE_CHECKPOINT, CHECKPOINT_ID, CHECKPOINT_FILE from loader import load_all_participants +from checkpoint_manager import CheckpointManager from preprocessor import merge_data, compute_sample_rate_for_sensor from visualiser import visualise_data_availability, plot_data_coverage_per_participant, plot_individual_participant_heatmap, visualise_ppg_ch0_minutes_stacked + def main(): - all_data = load_all_participants() - + + checkpoint_mgr = CheckpointManager(CHECKPOINT_FILE) + + # Data loading + if CHECKPOINT_ID == 0: + if LOAD_CHECKPOINT and checkpoint_mgr.exists(): + # Load data from pickle file + all_data = checkpoint_mgr.load() + elif CHECKPOINT_ID == 0 and LOAD_CHECKPOINT == False: + # Load data from raw + all_data = load_all_participants() + if SAVE_CHECKPOINT: + checkpoint_mgr.save(all_data) + + breakpoint() merged_data = { participant: {cat: merge_data(all_data[participant], cat) for cat in all_data[participant]} for participant in all_data diff --git a/preprocessor.py b/preprocessor.py index 4bbb33f..4d08573 100644 --- a/preprocessor.py +++ b/preprocessor.py @@ -2,7 +2,12 @@ import numpy as np def merge_data(data: dict, category: str) -> dict: - """ Merge accelerometer, PPG, HR, and gyro data on sensor clock timestamps. """ + """ + Merge accelerometer, PPG, HR, and gyro data on sensor clock timestamps. + + Returns: + dict( subject_id{ condition{ pd.DataFrame}}) + """ if category not in data: print(f"Warning: Category '{category}' is missing.") return pd.DataFrame() diff --git a/tests/test_checkpoint_manager.py b/tests/test_checkpoint_manager.py new file mode 100644 index 0000000..2856f38 --- /dev/null +++ b/tests/test_checkpoint_manager.py @@ -0,0 +1,42 @@ +import os +import tempfile +import unittest +from checkpoint_manager import CheckpointManager + +class TestCheckpointManager(unittest.TestCase): + def setUp(self): + temp = tempfile.NamedTemporaryFile(delete=False) + self.temp_file = temp.name + temp.close() + + if os.path.exists(self.temp_file): + os.remove(self.temp_file) + self.checkpoint_mgr = CheckpointManager(self.temp_file) + + def tearDown(self): + if os.path.exists(self.temp_file): + os.remove(self.temp_file) + + def test_save_and_load(self): + data = {"key": "value", "numbers": [1,2,3]} + self.checkpoint_mgr.save(data) + loaded_data = self.checkpoint_mgr.load() + + self.assertEqual(data, loaded_data) + + def test_exists(self): + self.assertFalse(self.checkpoint_mgr.exists()) + data = {"test": 123} + self.checkpoint_mgr.save(data) + + self.assertTrue(self.checkpoint_mgr.exists()) + + def test_load_missing_file(self): + if os.path.exists(self.temp_file): + os.remove(self.temp_file) + + with self.assertRaises(FileNotFoundError): + self.checkpoint_mgr.load() + +if __name__ == "__main__": + unittest.main()