Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add class: SimulationMacroCache #224

Merged
merged 17 commits into from
Sep 2, 2024
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions changelog_entry.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
- bump: patch
changes:
added:
- Added class SimulationMacroCache for macro simulation caching purposes.
84 changes: 84 additions & 0 deletions policyengine_core/simulations/sim_macro_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import shutil
from pathlib import Path
import h5py
from numpy.typing import ArrayLike
import importlib.metadata

from policyengine_core.taxbenefitsystems import TaxBenefitSystem


class Singleton(type):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry- could you explain why we need this?

_instances = {}

def __call__(cls, *args, **kwargs):
if cls not in cls._instances:
cls._instances[cls] = super(Singleton, cls).__call__(
*args, **kwargs
)
return cls._instances[cls]


class SimulationMacroCache(metaclass=Singleton):
def __init__(self, tax_benefit_system: TaxBenefitSystem):
self.core_version = importlib.metadata.version("policyengine-core")
self.country_package_metadata = (
tax_benefit_system.get_package_metadata()
)
self.country_version = self.country_package_metadata["version"]
self.cache_folder_path = None
self.cache_file_path = None

def set_cache_path(
self,
parent_path: [Path, str],
dataset_name: str,
variable_name: str,
period: str,
branch_name: str,
):
storage_folder = Path(parent_path) / f"{dataset_name}_variable_cache"
self.cache_folder_path = storage_folder
storage_folder.mkdir(exist_ok=True)
self.cache_file_path = (
storage_folder / f"{variable_name}_{period}_{branch_name}.h5"
)

def set_cache_value(self, cache_file_path: Path, value: ArrayLike):
with h5py.File(cache_file_path, "w") as f:
f.create_dataset(
"metadata:core_version",
data=self.core_version,
)
f.create_dataset(
"metadata:country_version",
data=self.country_version,
)
f.create_dataset("values", data=value)

def get_cache_path(self):
return self.cache_file_path

def get_cache_value(self, cache_file_path: Path):
with h5py.File(cache_file_path, "r") as f:
# Validate both core version and country package metadata are uptodate, otherwise flush the cache
anth-volk marked this conversation as resolved.
Show resolved Hide resolved
if (
"metadata:core_version" in f
and "metadata:country_version" in f
):
if (
f["metadata:core_version"][()].decode("utf-8")
!= self.core_version
or f["metadata:country_version"][()].decode("utf-8")
!= self.country_version
):
f.close()
self.clear_cache(self.cache_folder_path)
return None
else:
f.close()
self.clear_cache(self.cache_folder_path)
return None
return f["values"][()]

def clear_cache(self, cache_folder_path: Path):
shutil.rmtree(cache_folder_path)
106 changes: 37 additions & 69 deletions policyengine_core/simulations/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,6 @@
SimpleTracer,
TracingParameterNodeAtInstant,
)
import h5py
from pathlib import Path
import shutil

import json

Expand All @@ -36,6 +33,7 @@
from policyengine_core.variables import Variable, QuantityType
from policyengine_core.reforms.reform import Reform
from policyengine_core.parameters import get_parameter
from policyengine_core.simulations.sim_macro_cache import SimulationMacroCache


class Simulation:
Expand Down Expand Up @@ -605,11 +603,30 @@ def _calculate(
if cached_array is not None:
return cached_array

cache_path = self._get_macro_cache(variable_name, str(period))
if cache_path and cache_path.exists():
value = self._get_macro_cache_value(cache_path)
if value is not None:
return self._get_macro_cache_value(cache_path)
smc = SimulationMacroCache(self.tax_benefit_system)

# Check if cache could be used, if available, check if path exists
is_cache_available = self.check_macro_cache(variable_name, str(period))
if is_cache_available:
smc.set_cache_path(
self.dataset.file_path.parent,
self.dataset.name,
variable_name,
str(period),
self.branch_name,
)
cache_path = smc.get_cache_path()
if cache_path.exists():
if (
not self.macro_cache_read
or self.tax_benefit_system.data_modified
):
value = None
else:
value = smc.get_cache_value(cache_path)

if value is not None:
return value

if variable.requires_computation_after is not None:
if variable.requires_computation_after not in [
Expand Down Expand Up @@ -638,8 +655,8 @@ def _calculate(
values = self.calculate_divide(variable_name, period)

if alternate_period_handling:
if cache_path is not None:
self._set_macro_cache_value(cache_path, values)
if is_cache_available:
smc.set_cache_value(cache_path, values)
return values

self._check_period_consistency(period, variable)
Expand Down Expand Up @@ -737,8 +754,8 @@ def _calculate(
f"RecursionError while calculating {variable_name} for period {period}. The full computation stack is:\n{stack_formatted}"
)

if cache_path is not None:
self._set_macro_cache_value(cache_path, array)
if is_cache_available:
smc.set_cache_value(cache_path, array)

return array

Expand Down Expand Up @@ -1395,77 +1412,28 @@ def extract_person(

return json.loads(json.dumps(situation, cls=NpEncoder))

def _get_macro_cache(
self,
variable_name: str,
period: str,
):
def check_macro_cache(self, variable_name: str, period: str) -> bool:
"""
Get the cache location of a variable for a given period, if it exists.
Check if the variable is able to have cached value
"""
if not self.is_over_dataset:
return None
is_cache_available = True
anth-volk marked this conversation as resolved.
Show resolved Hide resolved
if self.is_over_dataset:
anth-volk marked this conversation as resolved.
Show resolved Hide resolved
return is_cache_available
anth-volk marked this conversation as resolved.
Show resolved Hide resolved

variable = self.tax_benefit_system.get_variable(variable_name)
parameter_deps = variable.exhaustive_parameter_dependencies

if parameter_deps is None:
return None
return not is_cache_available
anth-volk marked this conversation as resolved.
Show resolved Hide resolved
anth-volk marked this conversation as resolved.
Show resolved Hide resolved

for parameter in parameter_deps:
param = get_parameter(
self.tax_benefit_system.parameters, parameter
)
if param.modified:
return None

storage_folder = (
self.dataset.file_path.parent
/ f"{self.dataset.name}_variable_cache"
)
storage_folder.mkdir(exist_ok=True)

cache_file_path = (
storage_folder / f"{variable_name}_{period}_{self.branch_name}.h5"
)

return cache_file_path
return not is_cache_available
anth-volk marked this conversation as resolved.
Show resolved Hide resolved

def clear_macro_cache(self):
"""
Clear the cache of all variables.
"""
storage_folder = (
self.dataset.file_path.parent
/ f"{self.dataset.name}_variable_cache"
)
if storage_folder.exists():
shutil.rmtree(storage_folder)

def _get_macro_cache_value(
self,
cache_file_path: Path,
):
"""
Get the value of a variable from a cache file.
"""
if not self.macro_cache_read or self.tax_benefit_system.data_modified:
return None
with h5py.File(cache_file_path, "r") as f:
return f["values"][()]

def _set_macro_cache_value(
self,
cache_file_path: Path,
value: ArrayLike,
):
"""
Set the value of a variable in a cache file.
"""
if not self.macro_cache_write or self.tax_benefit_system.data_modified:
return None
with h5py.File(cache_file_path, "w") as f:
f.create_dataset("values", data=value)
return is_cache_available
anth-volk marked this conversation as resolved.
Show resolved Hide resolved

def to_input_dataframe(
self,
Expand Down
3 changes: 2 additions & 1 deletion policyengine_core/taxbenefitsystems/tax_benefit_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,7 +598,8 @@ def get_package_metadata(self) -> dict:

fallback_metadata = {
"name": self.__class__.__name__,
"version": "",
# For testing purposes
"version": "0.0.0",
"repository_url": "",
"location": "",
}
Expand Down
37 changes: 37 additions & 0 deletions tests/core/test_simulations.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from policyengine_core.country_template.situation_examples import single
from policyengine_core.simulations import SimulationBuilder
from policyengine_core.simulations.sim_macro_cache import SimulationMacroCache
import importlib.metadata
import numpy as np
from pathlib import Path


def test_calculate_full_tracer(tax_benefit_system):
Expand Down Expand Up @@ -61,3 +65,36 @@ def test_get_memory_usage(tax_benefit_system):
memory_usage = simulation.get_memory_usage(variables=["salary"])
assert memory_usage["total_nb_bytes"] > 0
assert len(memory_usage["by_variable"]) == 1


def test_macro_cache(tax_benefit_system):
nikhilwoodruff marked this conversation as resolved.
Show resolved Hide resolved
simulation = SimulationBuilder().build_from_entities(
tax_benefit_system, single
)

cache = SimulationMacroCache(tax_benefit_system)
assert cache.core_version == importlib.metadata.version(
"policyengine-core"
)
assert cache.country_version == "0.0.0"

cache.set_cache_path(
parent_path="tests/core",
dataset_name="test_dataset",
variable_name="test_variable",
period="2020",
branch_name="test_branch",
)
cache.set_cache_value(
cache_file_path=cache.cache_file_path,
value=np.array([1, 2, 3], dtype=np.float32),
)
assert cache.get_cache_path() == Path(
"tests/core/test_dataset_variable_cache/test_variable_2020_test_branch.h5"
)
assert np.array_equal(
cache.get_cache_value(cache.cache_file_path),
np.array([1, 2, 3], dtype=np.float32),
)
cache.clear_cache(cache.cache_folder_path)
assert not cache.cache_folder_path.exists()
Loading