Skip to content

Commit

Permalink
WIP: add class SimulationMacroCache
Browse files Browse the repository at this point in the history
  • Loading branch information
SylviaDu99 committed Jul 10, 2024
1 parent 699898d commit fbf018d
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 16 deletions.
27 changes: 16 additions & 11 deletions policyengine_core/simulations/sim_macro_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,30 +15,35 @@ def __call__(cls, *args, **kwargs):

class SimulationMacroCache(metaclass=Singleton):
def __init__(self):
self.cache = {}
self.version = None
self.path = None
self.core_version = None
self.country_version = None
self.cache_file_path = None

def set_version(self, version):
self.version = version
def set_version(self, core_version: str, country_version: {}):
self.core_version = core_version
self.country_version = country_version

def get_cache_value(self, version: str, cache_file_path: Path):
def set_cache_path(self, parent_path: Path, dataset_name: str):
self.cache_file_path = parent_path / f"{dataset_name}_variable_cache"

def get_cache_value(self, version: str, country_version: {}, cache_file_path: Path):
with h5py.File(cache_file_path, "r") as f:
if "metadata:version" in f:
if "metadata:core_version" in f and "metadata:country_version" in f:
# Validate version is correct, otherwise flush the cache
if f["metadata:version"][()] != version:
if f["metadata:core_version"][()] != version or f["metadata:country_version"][()] != country_version:
self.clear_cache(cache_file_path)
return None
else:
self.clear_cache(cache_file_path)
return None
return f["values"][()]

def set_cache_value(self, version: str, cache_file_path: Path, value: ArrayLike):
self.set_version(version)
def set_cache_value(self, core_version: str, country_version: {}, cache_file_path: Path, value: ArrayLike):
self.set_version(core_version, country_version)
with h5py.File(cache_file_path, "w") as f:
f.create_dataset("values", data=value)
f.create_dataset("metadata:version", data=self.version)
f.create_dataset("metadata:core_version", data=self.core_version)
f.create_dataset("metadata:country_version", data=self.country_version)
return "cache set successfully"

def clear_cache(self, cache_file_path: Path):
Expand Down
27 changes: 23 additions & 4 deletions policyengine_core/simulations/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import numpy as np
import pandas as pd
import importlib.metadata
import pkg_resources
from numpy.typing import ArrayLike

from policyengine_core import commons, periods
Expand Down Expand Up @@ -78,8 +77,11 @@ class Simulation:
macro_cache_write: bool = True
"""Whether to write to the macro cache."""

version: str = None
"""version number of the simulation."""
core_version: str = None
"""Version number of the simulation core."""

country_version: Dict[str, Any] = None
"""Version number of the country packages."""

def __init__(
self,
Expand Down Expand Up @@ -204,7 +206,24 @@ def __init__(
self.baseline = None

self.parent_branch = None
self.version = importlib.metadata.version("policyengine-core")
self.core_version = importlib.metadata.version("policyengine-core")

# TODO(SylviaDu99)
COUNTRIES = ("uk", "us", "ca", "ng", "il")
COUNTRY_PACKAGE_NAMES = (
"policyengine_uk",
"policyengine_us",
"policyengine_canada",
"policyengine_ng",
"policyengine_il",
)
try:
self.country_version = {
country: importlib.metadata.version(package_name)
for country, package_name in zip(COUNTRIES, COUNTRY_PACKAGE_NAMES)
}
except:
self.country_version = {country: "0.0.0" for country in COUNTRIES}

def apply_reform(self, reform: Union[tuple, Reform]):
if isinstance(reform, tuple):
Expand Down
3 changes: 2 additions & 1 deletion tests/core/test_simulations.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@ def test_version(tax_benefit_system):
simulation = SimulationBuilder().build_from_entities(
tax_benefit_system, single
)
print(simulation.version)
assert simulation.core_version is not None
assert simulation.country_version is not None


def test_macro_cache(tax_benefit_system):
Expand Down

0 comments on commit fbf018d

Please sign in to comment.