-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Feature/01 add pickle checkpoint (#1)
* updated config * new class file checkpoint_manager * checkpointing working * test for checkpoint manager * changed pickled directory and fixed save bug * config for load
- Loading branch information
Showing
5 changed files
with
131 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
import pickle | ||
import os | ||
|
||
class CheckpointManager: | ||
""" | ||
Encapsulate checkpoint functionality using pickle | ||
""" | ||
|
||
def __init__(self, filename: str): | ||
""" | ||
Initialise with given filename | ||
Args: | ||
filename (str): Filename to use for saving/loading checkpoint | ||
""" | ||
|
||
self.filename = filename | ||
|
||
def save(self, data): | ||
""" | ||
Save data to checkpoint file with pickle | ||
Args: | ||
data: Object to be pickled | ||
""" | ||
# mkdir if doesnt exist | ||
directory = os.path.dirname(self.filename) | ||
if directory and not os.path.exists(directory): | ||
os.makedirs(directory, exist_ok=True) | ||
|
||
# Save data | ||
with open(self.filename, "wb") as f: | ||
pickle.dump(data, f) | ||
print(f"Checkpoint saved: {self.filename}") | ||
|
||
def load(self): | ||
""" | ||
Load and return data from checkpoint file .pkl | ||
Returns: | ||
Unpickled data object | ||
""" | ||
if not os.path.exists(self.filename): | ||
raise FileNotFoundError(f"Checkpoint file {self.filename} not found.") | ||
with open(self.filename,"rb") as f: | ||
data = pickle.load(f) | ||
print(f"Checkpoint loaded: {self.filename}") | ||
|
||
return data | ||
|
||
def exists(self): | ||
""" | ||
Check if the checkpoint file exists | ||
Returns: | ||
(bool) True if checkpoint exists, else False. | ||
""" | ||
return os.path.exists(self.filename) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
import os | ||
import tempfile | ||
import unittest | ||
from checkpoint_manager import CheckpointManager | ||
|
||
class TestCheckpointManager(unittest.TestCase): | ||
def setUp(self): | ||
temp = tempfile.NamedTemporaryFile(delete=False) | ||
self.temp_file = temp.name | ||
temp.close() | ||
|
||
if os.path.exists(self.temp_file): | ||
os.remove(self.temp_file) | ||
self.checkpoint_mgr = CheckpointManager(self.temp_file) | ||
|
||
def tearDown(self): | ||
if os.path.exists(self.temp_file): | ||
os.remove(self.temp_file) | ||
|
||
def test_save_and_load(self): | ||
data = {"key": "value", "numbers": [1,2,3]} | ||
self.checkpoint_mgr.save(data) | ||
loaded_data = self.checkpoint_mgr.load() | ||
|
||
self.assertEqual(data, loaded_data) | ||
|
||
def test_exists(self): | ||
self.assertFalse(self.checkpoint_mgr.exists()) | ||
data = {"test": 123} | ||
self.checkpoint_mgr.save(data) | ||
|
||
self.assertTrue(self.checkpoint_mgr.exists()) | ||
|
||
def test_load_missing_file(self): | ||
if os.path.exists(self.temp_file): | ||
os.remove(self.temp_file) | ||
|
||
with self.assertRaises(FileNotFoundError): | ||
self.checkpoint_mgr.load() | ||
|
||
if __name__ == "__main__": | ||
unittest.main() |