Skip to content

Commit

Permalink
Feature/01 add pickle checkpoint (#1)
Browse files Browse the repository at this point in the history
* updated config

* new class file checkpoint_manager

* checkpointing working

* test for checkpoint manager

* changed pickled directory and fixed save bug

* config for load
  • Loading branch information
skoopsy authored Feb 6, 2025
1 parent e65a7d5 commit 42d518d
Show file tree
Hide file tree
Showing 5 changed files with 131 additions and 5 deletions.
59 changes: 59 additions & 0 deletions checkpoint_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import pickle
import os

class CheckpointManager:
"""
Encapsulate checkpoint functionality using pickle
"""

def __init__(self, filename: str):
"""
Initialise with given filename
Args:
filename (str): Filename to use for saving/loading checkpoint
"""

self.filename = filename

def save(self, data):
"""
Save data to checkpoint file with pickle
Args:
data: Object to be pickled
"""
# mkdir if doesnt exist
directory = os.path.dirname(self.filename)
if directory and not os.path.exists(directory):
os.makedirs(directory, exist_ok=True)

# Save data
with open(self.filename, "wb") as f:
pickle.dump(data, f)
print(f"Checkpoint saved: {self.filename}")

def load(self):
"""
Load and return data from checkpoint file .pkl
Returns:
Unpickled data object
"""
if not os.path.exists(self.filename):
raise FileNotFoundError(f"Checkpoint file {self.filename} not found.")
with open(self.filename,"rb") as f:
data = pickle.load(f)
print(f"Checkpoint loaded: {self.filename}")

return data

def exists(self):
"""
Check if the checkpoint file exists
Returns:
(bool) True if checkpoint exists, else False.
"""
return os.path.exists(self.filename)

6 changes: 4 additions & 2 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@
CONDITIONS = ["pre_heat_exposure", "intra_heat_exposure", "post_heat_exposure"]

# For checkpointing
USE_CHECKPOINT = False
CHECKPOINT_NAME = "processed_data.pkl"
LOAD_CHECKPOINT = True
SAVE_CHECKPOINT = False
CHECKPOINT_FILE = "data/pickled/ID00_loaded_data_000.pkl"
CHECKPOINT_ID = 0

# Function to get participant directories
def get_participant_dirs():
Expand Down
22 changes: 20 additions & 2 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,27 @@
import os

from config import LOAD_CHECKPOINT, SAVE_CHECKPOINT, CHECKPOINT_ID, CHECKPOINT_FILE
from loader import load_all_participants
from checkpoint_manager import CheckpointManager
from preprocessor import merge_data, compute_sample_rate_for_sensor
from visualiser import visualise_data_availability, plot_data_coverage_per_participant, plot_individual_participant_heatmap, visualise_ppg_ch0_minutes_stacked

def main():
all_data = load_all_participants()


checkpoint_mgr = CheckpointManager(CHECKPOINT_FILE)

# Data loading
if CHECKPOINT_ID == 0:
if LOAD_CHECKPOINT and checkpoint_mgr.exists():
# Load data from pickle file
all_data = checkpoint_mgr.load()
elif CHECKPOINT_ID == 0 and LOAD_CHECKPOINT == False:
# Load data from raw
all_data = load_all_participants()
if SAVE_CHECKPOINT:
checkpoint_mgr.save(all_data)

breakpoint()
merged_data = {
participant: {cat: merge_data(all_data[participant], cat) for cat in all_data[participant]}
for participant in all_data
Expand Down
7 changes: 6 additions & 1 deletion preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,12 @@
import numpy as np

def merge_data(data: dict, category: str) -> dict:
""" Merge accelerometer, PPG, HR, and gyro data on sensor clock timestamps. """
"""
Merge accelerometer, PPG, HR, and gyro data on sensor clock timestamps.
Returns:
dict( subject_id{ condition{ pd.DataFrame}})
"""
if category not in data:
print(f"Warning: Category '{category}' is missing.")
return pd.DataFrame()
Expand Down
42 changes: 42 additions & 0 deletions tests/test_checkpoint_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import os
import tempfile
import unittest
from checkpoint_manager import CheckpointManager

class TestCheckpointManager(unittest.TestCase):
def setUp(self):
temp = tempfile.NamedTemporaryFile(delete=False)
self.temp_file = temp.name
temp.close()

if os.path.exists(self.temp_file):
os.remove(self.temp_file)
self.checkpoint_mgr = CheckpointManager(self.temp_file)

def tearDown(self):
if os.path.exists(self.temp_file):
os.remove(self.temp_file)

def test_save_and_load(self):
data = {"key": "value", "numbers": [1,2,3]}
self.checkpoint_mgr.save(data)
loaded_data = self.checkpoint_mgr.load()

self.assertEqual(data, loaded_data)

def test_exists(self):
self.assertFalse(self.checkpoint_mgr.exists())
data = {"test": 123}
self.checkpoint_mgr.save(data)

self.assertTrue(self.checkpoint_mgr.exists())

def test_load_missing_file(self):
if os.path.exists(self.temp_file):
os.remove(self.temp_file)

with self.assertRaises(FileNotFoundError):
self.checkpoint_mgr.load()

if __name__ == "__main__":
unittest.main()

0 comments on commit 42d518d

Please sign in to comment.