Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cache197 #223

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions policyengine_core/simulations/sim_macro_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import shutil
from pathlib import Path
import h5py
from numpy.typing import ArrayLike


class Singleton(type):
_instances = {}

def __call__(cls, *args, **kwargs):
if cls not in cls._instances:
cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
return cls._instances[cls]


class SimulationMacroCache(metaclass=Singleton):
def __init__(self):
self.core_version = None
self.country_version = None
self.cache_file_path = None

def set_version(self, core_version: str, country_version: {}):
self.core_version = core_version
self.country_version = country_version

def set_cache_path(self, parent_path: Path, dataset_name: str):
self.cache_file_path = parent_path / f"{dataset_name}_variable_cache"

def get_cache_value(self, version: str, country_version: {}, cache_file_path: Path):
with h5py.File(cache_file_path, "r") as f:
if "metadata:core_version" in f and "metadata:country_version" in f:
# Validate version is correct, otherwise flush the cache
if f["metadata:core_version"][()] != version or f["metadata:country_version"][()] != country_version:
self.clear_cache(cache_file_path)
return None
else:
self.clear_cache(cache_file_path)
return None
return f["values"][()]

def set_cache_value(self, core_version: str, country_version: {}, cache_file_path: Path, value: ArrayLike):
self.set_version(core_version, country_version)
with h5py.File(cache_file_path, "w") as f:
f.create_dataset("values", data=value)
f.create_dataset("metadata:core_version", data=self.core_version)
f.create_dataset("metadata:country_version", data=self.country_version)
return "cache set successfully"

def clear_cache(self, cache_file_path: Path):
shutil.rmtree(cache_file_path)
36 changes: 31 additions & 5 deletions policyengine_core/simulations/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import numpy
import numpy as np
import pandas as pd
import importlib.metadata
from numpy.typing import ArrayLike

from policyengine_core import commons, periods
Expand Down Expand Up @@ -35,6 +36,7 @@
from policyengine_core.variables import Variable, QuantityType
from policyengine_core.reforms.reform import Reform
from policyengine_core.parameters import get_parameter
from policyengine_core.simulations.sim_macro_cache import SimulationMacroCache


class Simulation:
Expand Down Expand Up @@ -75,6 +77,12 @@ class Simulation:
macro_cache_write: bool = True
"""Whether to write to the macro cache."""

core_version: str = None
"""Version number of the simulation core."""

country_version: Dict[str, Any] = None
"""Version number of the country packages."""

def __init__(
self,
tax_benefit_system: "TaxBenefitSystem" = None,
Expand Down Expand Up @@ -198,6 +206,24 @@ def __init__(
self.baseline = None

self.parent_branch = None
self.core_version = importlib.metadata.version("policyengine-core")

# TODO(SylviaDu99)
COUNTRIES = ("uk", "us", "ca", "ng", "il")
COUNTRY_PACKAGE_NAMES = (
"policyengine_uk",
"policyengine_us",
"policyengine_canada",
"policyengine_ng",
"policyengine_il",
)
try:
self.country_version = {
country: importlib.metadata.version(package_name)
for country, package_name in zip(COUNTRIES, COUNTRY_PACKAGE_NAMES)
}
except:
self.country_version = {country: "0.0.0" for country in COUNTRIES}

def apply_reform(self, reform: Union[tuple, Reform]):
if isinstance(reform, tuple):
Expand Down Expand Up @@ -568,7 +594,8 @@ def _calculate(

if alternate_period_handling:
if cache_path is not None:
self._set_macro_cache_value(cache_path, values)
message = self._set_macro_cache_value(cache_path, values)
print(message is not None)
return values

self._check_period_consistency(period, variable)
Expand Down Expand Up @@ -1379,8 +1406,7 @@ def _get_macro_cache_value(
"""
if not self.macro_cache_read or self.tax_benefit_system.data_modified:
return None
with h5py.File(cache_file_path, "r") as f:
return f["values"][()]
return SimulationMacroCache().get_cache_value(self.version, cache_file_path)

def _set_macro_cache_value(
self,
Expand All @@ -1392,8 +1418,8 @@ def _set_macro_cache_value(
"""
if not self.macro_cache_write or self.tax_benefit_system.data_modified:
return None
with h5py.File(cache_file_path, "w") as f:
f.create_dataset("values", data=value)
message = SimulationMacroCache().set_cache_value(self.version, cache_file_path, value)
return message


class NpEncoder(json.JSONEncoder):
Expand Down
17 changes: 16 additions & 1 deletion tests/core/test_simulations.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from policyengine_core.country_template.situation_examples import single
from policyengine_core.simulations import SimulationBuilder

from policyengine_core.country_template.data.datasets import country_template_dataset

def test_calculate_full_tracer(tax_benefit_system):
simulation = SimulationBuilder().build_default_simulation(
Expand Down Expand Up @@ -61,3 +61,18 @@ def test_get_memory_usage(tax_benefit_system):
memory_usage = simulation.get_memory_usage(variables=["salary"])
assert memory_usage["total_nb_bytes"] > 0
assert len(memory_usage["by_variable"]) == 1


def test_version(tax_benefit_system):
simulation = SimulationBuilder().build_from_entities(
tax_benefit_system, single
)
assert simulation.core_version is not None
assert simulation.country_version is not None


def test_macro_cache(tax_benefit_system):
simulation = SimulationBuilder().build_from_entities(
tax_benefit_system, single,
)
simulation.calculate("disposable_income", "2017-01")
Loading