From 0d02abb65591b18411364d404306df5a6842ff85 Mon Sep 17 00:00:00 2001 From: SylviaDu99 Date: Mon, 22 Jul 2024 12:48:37 -0700 Subject: [PATCH] WIP: add class SimulationMacroCache --- policyengine_core/simulations/simulation.py | 78 +-------------------- tests/core/test_simulations.py | 16 ++++- 2 files changed, 16 insertions(+), 78 deletions(-) diff --git a/policyengine_core/simulations/simulation.py b/policyengine_core/simulations/simulation.py index 59123e11..6faa8546 100644 --- a/policyengine_core/simulations/simulation.py +++ b/policyengine_core/simulations/simulation.py @@ -20,9 +20,6 @@ SimpleTracer, TracingParameterNodeAtInstant, ) -import h5py -from pathlib import Path -import shutil import json @@ -1348,8 +1345,8 @@ def check_macro_cache(self, variable_name: str, period: str) -> bool: Check if the variable is able to have cached value """ is_cache_available = True - if not self.is_over_dataset: - return not is_cache_available + if self.is_over_dataset: + return is_cache_available variable = self.tax_benefit_system.get_variable(variable_name) parameter_deps = variable.exhaustive_parameter_dependencies @@ -1366,77 +1363,6 @@ def check_macro_cache(self, variable_name: str, period: str) -> bool: return is_cache_available - # def _get_macro_cache( - # self, - # variable_name: str, - # period: str, - # ): - # """ - # Get the cache location of a variable for a given period, if it exists. - # """ - # if not self.is_over_dataset: - # return None - # - # variable = self.tax_benefit_system.get_variable(variable_name) - # parameter_deps = variable.exhaustive_parameter_dependencies - # - # if parameter_deps is None: - # return None - # - # for parameter in parameter_deps: - # param = get_parameter( - # self.tax_benefit_system.parameters, parameter - # ) - # if param.modified: - # return None - # - # storage_folder = ( - # self.dataset.file_path.parent - # / f"{self.dataset.name}_variable_cache" - # ) - # storage_folder.mkdir(exist_ok=True) - # - # cache_file_path = ( - # storage_folder / f"{variable_name}_{period}_{self.branch_name}.h5" - # ) - # - # return cache_file_path - # - # def clear_macro_cache(self): - # """ - # Clear the cache of all variables. - # """ - # storage_folder = ( - # self.dataset.file_path.parent - # / f"{self.dataset.name}_variable_cache" - # ) - # if storage_folder.exists(): - # shutil.rmtree(storage_folder) - # - # def _get_macro_cache_value( - # self, - # cache_file_path: Path, - # ): - # """ - # Get the value of a variable from a cache file. - # """ - # if not self.macro_cache_read or self.tax_benefit_system.data_modified: - # return None - # return SimulationMacroCache().get_cache_value(self.version, cache_file_path) - # - # def _set_macro_cache_value( - # self, - # cache_file_path: Path, - # value: ArrayLike, - # ): - # """ - # Set the value of a variable in a cache file. - # """ - # if not self.macro_cache_write or self.tax_benefit_system.data_modified: - # return None - # message = SimulationMacroCache().set_cache_value(self.version, cache_file_path, value) - # return message - class NpEncoder(json.JSONEncoder): def default(self, obj): diff --git a/tests/core/test_simulations.py b/tests/core/test_simulations.py index f1454d49..f86581c2 100644 --- a/tests/core/test_simulations.py +++ b/tests/core/test_simulations.py @@ -65,7 +65,6 @@ def test_get_memory_usage(tax_benefit_system): assert len(memory_usage["by_variable"]) == 1 -# TODO(SylviaDu99) def test_version(tax_benefit_system): simulation = SimulationBuilder().build_from_entities( tax_benefit_system, single @@ -78,9 +77,22 @@ def test_version(tax_benefit_system): assert cache.country_version == "0.0.0" -def test_macro_cache(tax_benefit_system): +def test_check_macro_cache(tax_benefit_system): + simulation = SimulationBuilder().build_from_entities( + tax_benefit_system, + single, + ) + simulation.calculate("disposable_income", "2017-01") + simulation.is_over_dataset = True + assert simulation.check_macro_cache("disposable_income", "2017-01") is True + + +# TODO(SylviaDu99) +def test_simulation_macro_cache(tax_benefit_system): simulation = SimulationBuilder().build_from_entities( tax_benefit_system, single, ) simulation.calculate("disposable_income", "2017-01") + simulation.is_over_dataset = True + simulation.calculate("disposable_income", "2018-01")