diff --git a/src/HealthGPS.Core/column_builder.h b/src/HealthGPS.Core/column_builder.h index a56547335..31b18b79e 100644 --- a/src/HealthGPS.Core/column_builder.h +++ b/src/HealthGPS.Core/column_builder.h @@ -2,6 +2,7 @@ #include #include #include +#include #include #include "column.h" diff --git a/src/HealthGPS.Core/math_util.cpp b/src/HealthGPS.Core/math_util.cpp index db5f348e5..d8e90d90f 100644 --- a/src/HealthGPS.Core/math_util.cpp +++ b/src/HealthGPS.Core/math_util.cpp @@ -1,4 +1,6 @@ #include "math_util.h" + +#include #include namespace hgps::core { diff --git a/src/HealthGPS.Tests/AnalysisModule.Test.cpp b/src/HealthGPS.Tests/AnalysisModule.Test.cpp new file mode 100644 index 000000000..fc781eaef --- /dev/null +++ b/src/HealthGPS.Tests/AnalysisModule.Test.cpp @@ -0,0 +1,85 @@ +#include "pch.h" + +#include "HealthGPS/analysis_module.h" + +#include "simulation.h" + +hgps::Person create_test_person(int age = 20, + hgps::core::Gender gender = hgps::core::Gender::male) { + auto person = hgps::Person{}; + person.age = age; + person.gender = gender; + return person; +} + +namespace hgps { +class TestAnalysisModule : public ::testing::Test { + protected: + hgps::core::DataTable data; + + hgps::input::DataManager manager = hgps::input::DataManager(test_datastore_path); + hgps::CachedRepository repository = hgps::CachedRepository(manager); + + hgps::ModelInput inputs = create_test_configuration(data); + + std::unique_ptr analysis_module = + build_analysis_module(repository, inputs); + + hgps::Person test_person_1 = create_test_person(16, hgps::core::Gender::male); + hgps::Person test_person_2 = create_test_person(19, hgps::core::Gender::male); + hgps::DefaultEventBus bus = DefaultEventBus{}; + hgps::SyncChannel channel; + std::unique_ptr rnd = std::make_unique(123456789); + std::unique_ptr scenario = std::make_unique(channel); + hgps::SimulationDefinition definition = + SimulationDefinition(inputs, std::move(scenario), std::move(rnd)); + + hgps::RuntimeContext context = RuntimeContext(bus, definition); + + TestAnalysisModule() { + create_test_datatable(data); + + auto config = create_test_configuration(data); + + context.set_current_time(2024); + + context.reset_population(12); + + // Let's set some ages for the population. + // We will create a pair of male and female persons with sequential ages + for (size_t i = 0, j = 15; i < context.population().size(); i = i + 2, j++) { + context.population()[i].age = static_cast(j); + context.population()[i + 1].age = static_cast(j); + } + + // Let's set half the population gender to male, and the other half to female + for (size_t i = 0; i < context.population().size(); i = i + 2) { + context.population()[i].gender = core::Gender::male; + context.population()[i + 1].gender = core::Gender::female; + } + + // For each person, we need to set the risk factors which we can get from channels_ + for (size_t i = 0; i < context.population().size(); i++) { + for (const auto &factor : context.mapping().entries()) { + context.population()[i].risk_factors[factor.key()] = 1.0 + i; + } + } + + auto ses_module = build_ses_noise_module(repository, config); + ses_module->initialise_population(context); + + analysis_module->initialise_population(context); + } +}; + +TEST_F(TestAnalysisModule, CalculateIndex) { + // Test that the index is calculated correctly + + size_t index_1 = analysis_module->calculate_index(test_person_1); + size_t index_2 = analysis_module->calculate_index(test_person_2); + + ASSERT_EQ(index_1, 7 * 87); + ASSERT_EQ(index_2, 10 * 87); +} + +} // namespace hgps diff --git a/src/HealthGPS.Tests/CMakeLists.txt b/src/HealthGPS.Tests/CMakeLists.txt index 891920759..f67f4eed3 100644 --- a/src/HealthGPS.Tests/CMakeLists.txt +++ b/src/HealthGPS.Tests/CMakeLists.txt @@ -41,9 +41,10 @@ target_sources( "TestMain.cpp" "WeightModel.Test.cpp" "CountryModule.h" - "RiskFactorData.h" "Interval.Test.cpp" - "LoadNutrientTable.cpp") + "LoadNutrientTable.cpp" + "AnalysisModule.Test.cpp" + "simulation.cpp") target_link_libraries( HealthGPS.Tests diff --git a/src/HealthGPS.Tests/RiskFactorData.h b/src/HealthGPS.Tests/RiskFactorData.h deleted file mode 100644 index 4ff83fcd3..000000000 --- a/src/HealthGPS.Tests/RiskFactorData.h +++ /dev/null @@ -1,273 +0,0 @@ -#pragma once -#include - -#include "HealthGPS.Input/jsonparser.h" -#include "HealthGPS/riskfactor.h" - -template -std::string join_string(const std::vector &v, std::string_view delimiter, - bool quoted = false) { - std::stringstream s; - if (!v.empty()) { - std::string_view q = quoted ? "\"" : ""; - s << q << v.front() << q; - for (size_t i = 1; i < v.size(); ++i) { - s << delimiter << " " << q << v[i] << q; - } - } - - return s.str(); -} - -std::string join_string_map(const std::vector &v, std::string_view delimiter) { - std::stringstream s; - if (!v.empty()) { - std::string_view q = "\""; - s << "{{" << q << v.front() << q << ",0}"; - for (size_t i = 1; i < v.size(); ++i) { - s << delimiter << " {" << q << v[i] << q << delimiter << i << "}"; - } - - s << "}"; - } - - return s.str(); -} - -// std::string generate_test_code(hgps::RiskFactorModelType model_type, std::string filename) { -// std::stringstream ss; -// HierarchicalModelInfo hmodel; -// std::ifstream ifs(filename, std::ifstream::in); -// if (ifs) { -// try { -// auto opt = json::parse(ifs); -// hmodel.models = opt["models"].get>(); hmodel.levels = -// opt["levels"].get>(); -// -// ss << "using namespace hgps;\n"; -// -// ss << "\n/*---- LINEAR MODELS ---- */\n"; -// ss << "std::unordered_map models;\n"; -// ss << "std::unordered_map coeffs;\n"; -// for (auto& item : hmodel.models) { -// auto& at = item.second; -// ss << std::format("\n/* {} */\n", item.first); -// for (auto& entry : at.coefficients) { -// ss << std::format("coeffs.emplace(\"{}\", Coefficient{{ -//.value={}, -//.pvalue={}, .tvalue={}, .std_error={} }});\n", -// entry.first, std::to_string(entry.second.value), std::to_string(entry.second.pvalue), -// std::to_string(entry.second.tvalue), std::to_string(entry.second.std_error)); -// } -// -// ss << std::format("models.emplace(\"{}\", LinearModel{{ -//\n\t.coefficients=coeffs, '\' -//\n\t.fitted_values={{{}}},\n\t.residuals={{{}}},\n\t.rsquared={} -//}});\n", item.first, join_string(at.fitted_values, ","), -// join_string(at.residuals, ","), std::to_string(at.rsquared)); -// -// ss << "coeffs.clear();\n"; -// } -// -// ss << "\n/*---- HIERARCHICAL LEVELS ---- */\n"; -// ss << "std::map levels;\n"; -// ss << "std::vector tmat_m;\n"; -// ss << "std::vector itmat_w;\n"; -// ss << "std::vector rmat_s;\n"; -// ss << "std::vector corr_mat;\n"; -// for (auto& item : hmodel.levels) { -// auto& at = item.second; -// ss << std::format("\n/* {} */\n", item.first); -// ss << std::format("tmat_m = {{{}}};\n", -// join_string(at.transition.data, -//",")); ss << std::format("itmat_w = {{{}}};\n", -// join_string(at.inverse_transition.data, ",")); ss -//<< std::format("rmat_s = {{{}}};\n", join_string(at.residual_distribution.data, ",")); -// ss << std::format("corr_mat = {{{}}};\n", join_string(at.correlation.data, ",")); -// -// ss << std::format("levels.emplace({0}, HierarchicalLevel{{ '\' -// \n\t.variables = {1}, '\' -// \n\t.transition = -// core::DoubleArray2D({2}, {2}, -// tmat_m), '\' \n\t.inverse_transition = -// core::DoubleArray2D({2}, {2}, itmat_w), '\' -// \n\t.residual_distribution = core::DoubleArray2D({3}, {2}, rmat_s), '\' -// \n\t.correlation = core::DoubleArray2D({2}, {2}, corr_mat),\n\t.variances = {{{4}}} }});\n", -// item.first, join_string_map(at.variables, ","), at.transition.rows, -// at.residual_distribution.rows, join_string(at.variances, ",")); -// } -// -// ss << "\nreturn HierarchicalLinearModelDefinition{std::move(models), -// std::move(levels), baseline_data);\n"; -// } -// catch (const std::exception& ex) { -// std::cout << std::format("Failed to parse model file: {}. {}\n", filename, -// ex.what()); -// } -// } -// else { -// std::cout << std::format("Model file: {} not found.\n", filename); -// } -// -// ifs.close(); -// return ss.str(); -// } - -// hgps::HierarchicalLinearModelDefinition get_static_test_model(hgps::RiskFactorSexAgeTable& -// baseline_data) { -// /* Auto-generated code, do not change **** */ -// -// using namespace hgps; -// -// /*---- LINEAR MODELS ---- */ -// std::unordered_map models; -// std::unordered_map coeffs; -// -// /* AlcoholConsumption */ -// coeffs.emplace("Age", Coefficient{ .value = -0.118861, .pvalue = 0.377382, .tvalue = -//-0.991924, .std_error = 0.119829 }); coeffs.emplace("Gender", Coefficient{ .value = 16.466127, -//.pvalue = 0.010112, .tvalue = 4.589376, .std_error = 3.587879 }); coeffs.emplace("Intercept", -// Coefficient{ .value = 6.382514, .pvalue = 0.405954, .tvalue = 0.927959, .std_error = 6.878011 }); -// models.emplace("AlcoholConsumption", LinearModel{ -// .coefficients = coeffs, -// .rsquared = 0.848093 }); -// coeffs.clear(); -// -// /* SmokingStatus */ -// coeffs.emplace("Age", Coefficient{ .value = -0.000268, .pvalue = 0.982634, .tvalue = -//-0.023157, .std_error = 0.011577 }); coeffs.emplace("Gender", Coefficient{ .value = -0.663159, -//.pvalue = 0.128275, .tvalue = -1.913180, .std_error = 0.346627 }); coeffs.emplace("Intercept", -// Coefficient{ .value = 0.681054, .pvalue = 0.363331, .tvalue = 1.024930, .std_error = 0.664488 }); -// models.emplace("SmokingStatus", LinearModel{ -// .coefficients = coeffs, -// .rsquared = 0.533396 }); -// coeffs.clear(); -// -// /* BMI */ -// coeffs.emplace("AlcoholConsumption", Coefficient{ .value = 0.319474, .pvalue = 0.131175, -//.tvalue = 2.481636, .std_error = 0.128735 }); coeffs.emplace("Age", Coefficient{ .value = -// 0.039381, .pvalue = 0.329978, .tvalue = 1.276441, .std_error = 0.030852 }); -// coeffs.emplace("Gender", Coefficient{ .value = -0.057628, .pvalue = 0.955930, .tvalue = -//-0.062385, .std_error = 0.923759 }); coeffs.emplace("Intercept", Coefficient{ .value = 22.407962, -//.pvalue = 0.006188, .tvalue = 12.653734, .std_error = 1.770858 }); -// coeffs.emplace("SmokingStatus", Coefficient{ .value = -10.048320, .pvalue = 0.017135, .tvalue = -// -7.540840, .std_error = 1.332520 -//}); models.emplace("BMI", LinearModel{ .coefficients = coeffs, -//.rsquared = 0.970226 }); coeffs.clear(); -// -// /*---- HIERARCHICAL LEVELS ---- */ -// std::map levels; -// std::vector tmat_m; -// std::vector itmat_w; -// std::vector rmat_s; -// std::vector corr_mat; -// -// /* 1 */ -// tmat_m = { -0.0035606, 0.308565, 3.19407, 0.0184036 }; -// itmat_w = { -0.0186717, 0.313059, 3.24059, 0.00361245 }; -// rmat_s = { 0.0249242, -2.15996, 0.415528, 0.00761843, 0.786364, 1.10788, -0.748982, -//-0.000606345, 1.73733, 0.00723274, -0.811288, 1.05208, -1.40388, -0.0142448 }; corr_mat = { -// 1, -0.0057769, -0.0057769, 1 }; levels.emplace(1, HierarchicalLevel{ .variables = -//{{"SmokingStatus",0}, {"AlcoholConsumption",1}}, .transition = core::DoubleArray2D(2, -// 2, tmat_m), .inverse_transition = core::DoubleArray2D(2, 2, itmat_w), -// .residual_distribution = core::DoubleArray2D(7, 2, rmat_s), .correlation = -// core::DoubleArray2D(2, 2, corr_mat), .variances = {0.990721, 0.00927894} }); -// -// /* 2 */ -// tmat_m = { 0.581511 }; -// itmat_w = { 1.71966 }; -// rmat_s = { -1.15916e-16, -2.237, 0.303002, 0.817005, 0.971242, -0.303002, 0.448757 }; -// corr_mat = { 1 }; -// levels.emplace(2, HierarchicalLevel{ -// .variables = {{"BMI",0}}, -// .transition = core::DoubleArray2D(1, 1, tmat_m), -// .inverse_transition = core::DoubleArray2D(1, 1, itmat_w), -// .residual_distribution = core::DoubleArray2D(7, 1, rmat_s), -// .correlation = core::DoubleArray2D(1, 1, corr_mat), -// .variances = {1} }); -// -// return HierarchicalLinearModelDefinition{ std::move(models), std::move(levels), -// std::move(baseline_data) }; -// } -// -// hgps::HierarchicalLinearModelDefinition get_dynamic_test_model(hgps::RiskFactorSexAgeTable& -// baseline_data) { -// /* Auto-generated code, do not change **** */ -// -// using namespace hgps; -// -// /*---- LINEAR MODELS ---- */ -// std::unordered_map models; -// std::unordered_map coeffs; -// -// /* AlcoholConsumption */ -// coeffs.emplace("Year", Coefficient{ .value = 1.493626, .pvalue = 0.458706, .tvalue = -// 0.848063, .std_error = 1.761221 }); coeffs.emplace("Age", Coefficient{ .value = -0.148091, -// .pvalue = 0.334091, .tvalue = -1.148338, .std_error = 0.128961 }); coeffs.emplace("Gender", -// Coefficient{ .value = 15.852800, .pvalue = 0.024910, .tvalue = 4.182263, .std_error = 3.790484 -// }); coeffs.emplace("Intercept", Coefficient{ .value = -2990.254569, .pvalue = 0.459571, .tvalue -//= -0.846255, .std_error = 3533.514256 }); models.emplace("AlcoholConsumption", LinearModel{ -// .coefficients = coeffs, -// .rsquared = 0.877468 }); -// coeffs.clear(); -// -// /* SmokingStatus */ -// coeffs.emplace("Year", Coefficient{ .value = 0.216439, .pvalue = 0.225850, .tvalue -//= 1.519946, .std_error = 0.142399 }); coeffs.emplace("Age", Coefficient{ .value = -//-0.004504, .pvalue = 0.694950, .tvalue = -0.431936, .std_error = 0.010427 }); -// coeffs.emplace("Gender", Coefficient{ .value = -0.752035, .pvalue = 0.091365, .tvalue = -// -2.453866, -// .std_error = 0.306470 -//}); coeffs.emplace("Intercept", Coefficient{ .value = -433.556037, .pvalue = 0.226411, .tvalue = -//-1.517559, .std_error = 285.693038 }); models.emplace("SmokingStatus", LinearModel{ -//.coefficients = coeffs, .rsquared = 0.736394 }); coeffs.clear(); -// -// /* BMI */ -// coeffs.emplace("Year", Coefficient{ .value = -1.339372, .pvalue = 0.086731, .tvalue = -//-7.294724, .std_error = 0.183608 }); coeffs.emplace("AlcoholConsumption", Coefficient{ .value = -// 0.192391, .pvalue = 0.213226, .tvalue = 2.873171, .std_error = 0.066961 }); coeffs.emplace("Age", -// Coefficient{ .value = 0.065592, .pvalue = 0.128705, .tvalue = 4.878789, .std_error = 0.013444 }); -// coeffs.emplace("Gender", Coefficient{ .value = 0.492357, .pvalue = 0.430557, .tvalue -//= 1.245968, .std_error = 0.395160 }); coeffs.emplace("Intercept", Coefficient{ .value = -// 2709.566770, .pvalue = 0.086022, .tvalue = 7.355539, .std_error = 368.370949 }); -// coeffs.emplace("SmokingStatus", Coefficient{ .value = -12.011927, .pvalue = 0.043824, -//.tvalue = -14.503793, .std_error = 0.828192 }); models.emplace("BMI", LinearModel{ -//.coefficients = coeffs, .rsquared = 0.997375 }); coeffs.clear(); -// -// /*---- HIERARCHICAL LEVELS ---- */ -// std::map levels; -// std::vector tmat_m; -// std::vector itmat_w; -// std::vector rmat_s; -// std::vector corr_mat; -// -// /* 1 */ -// tmat_m = { -0.111198, -0.203549, 2.86563, -0.132974 }; -// itmat_w = { -0.222334, 0.340336, -4.79136, -0.185925 }; -// rmat_s = { 0.778387, 1.42484, 0.94104, -1.64932, 0.21149, -0.452958, -0.731434, -// 0.194915, 1.29645, 0.681437, -0.989878, -0.971885, -1.50605, 0.772969 }; corr_mat = { 1, -//-0.438228, -0.438228, 1 }; levels.emplace(1, HierarchicalLevel{ .variables = -//{{"SmokingStatus",0}, -//{"AlcoholConsumption",1}}, .transition = core::DoubleArray2D(2, 2, tmat_m), -//.inverse_transition = core::DoubleArray2D(2, 2, itmat_w), .residual_distribution = -// core::DoubleArray2D(7, 2, rmat_s), .correlation = core::DoubleArray2D(2, 2, corr_mat), -// .variances -//= {0.992863, 0.00713654} }); -// -// /* 2 */ -// tmat_m = { 0.172665 }; -// itmat_w = { 5.79155 }; -// rmat_s = { 7.89387e-17, -0.552753, -0.300238, 2.03918, 0.0477231, 0.300238, -1.53415 }; -// corr_mat = { 1 }; -// levels.emplace(2, HierarchicalLevel{ -// .variables = {{"BMI",0}}, -// .transition = core::DoubleArray2D(1, 1, tmat_m), -// .inverse_transition = core::DoubleArray2D(1, 1, itmat_w), -// .residual_distribution = core::DoubleArray2D(7, 1, rmat_s), -// .correlation = core::DoubleArray2D(1, 1, corr_mat), -// .variances = {1} }); -// -// return HierarchicalLinearModelDefinition{ std::move(models), std::move(levels), -// std::move(baseline_data) }; -// } diff --git a/src/HealthGPS.Tests/Simulation.Test.cpp b/src/HealthGPS.Tests/Simulation.Test.cpp index eaed856c7..594c6764a 100644 --- a/src/HealthGPS.Tests/Simulation.Test.cpp +++ b/src/HealthGPS.Tests/Simulation.Test.cpp @@ -1,82 +1,6 @@ -#include "data_config.h" -#include "pch.h" - -#include "HealthGPS.Input/api.h" -#include "HealthGPS/api.h" -#include "HealthGPS/event_bus.h" -#include "HealthGPS/random_algorithm.h" - +#include "simulation.h" #include "CountryModule.h" -#include "RiskFactorData.h" - -#include -#include -#include - -namespace { -std::vector mapping_entries{ - {{"Gender", 0}, {"Age", 0}, {"SmokingStatus", 1}, {"AlcoholConsumption", 1}, {"BMI", 2}}}; -} // anonymous namespace - -void create_test_datatable(hgps::core::DataTable &data) { - using namespace hgps; - using namespace hgps::core; - - auto gender_values = std::vector{1, 0, 0, 1, 0}; - auto age_values = std::vector{4, 9, 14, 19, 25}; - auto edu_values = std::vector{6.0f, 10.0f, 2.0f, 9.0f, 12.0f}; - auto inc_values = std::vector{2.0, 10.0, 5.0, std::nan(""), 13.0}; - - auto gender_builder = IntegerDataTableColumnBuilder{"Gender"}; - auto age_builder = IntegerDataTableColumnBuilder{"Age"}; - auto edu_builder = FloatDataTableColumnBuilder{"Education"}; - auto inc_builder = DoubleDataTableColumnBuilder{"Income"}; - - for (size_t i = 0; i < gender_values.size(); i++) { - gender_builder.append(gender_values[i]); - age_builder.append(age_values[i]); - edu_builder.append(edu_values[i]); - if (std::isnan(inc_values[i])) { - inc_builder.append_null(); - } else { - inc_builder.append(inc_values[i]); - } - } - - data.add(gender_builder.build()); - data.add(age_builder.build()); - data.add(edu_builder.build()); - data.add(inc_builder.build()); -} - -hgps::ModelInput create_test_configuration(hgps::core::DataTable &data) { - using namespace hgps; - using namespace hgps::core; - - auto uk = core::Country{.code = 826, .name = "United Kingdom", .alpha2 = "GB", .alpha3 = "GBR"}; - - auto age_range = core::IntegerInterval(0, 30); - auto settings = Settings{uk, 0.1f, age_range}; - auto info = RunInfo{.start_time = 2018, .stop_time = 2025, .seed = std::nullopt}; - auto ses_mapping = std::map{ - {"gender", "Gender"}, {"age", "Age"}, {"education", "Education"}, {"income", "Income"}}; - auto ses = SESDefinition{.fuction_name = "normal", .parameters = std::vector{0.0, 1.0}}; - - auto mapping = HierarchicalMapping(mapping_entries); - - auto diseases = std::vector{ - DiseaseInfo{ - .group = DiseaseGroup::other, .code = core::Identifier{"asthma"}, .name = "Asthma"}, - DiseaseInfo{.group = DiseaseGroup::other, - .code = core::Identifier{"diabetes"}, - .name = "Diabetes Mellitus"}, - DiseaseInfo{.group = DiseaseGroup::cancer, - .code = core::Identifier{"colorectalcancer"}, - .name = "Colorectal cancer"}, - }; - - return {data, settings, info, ses, mapping, diseases}; -} +#include "pch.h" TEST(TestSimulation, RandomBitGenerator) { using namespace hgps; diff --git a/src/HealthGPS.Tests/simulation.cpp b/src/HealthGPS.Tests/simulation.cpp new file mode 100644 index 000000000..ae4ded2f6 --- /dev/null +++ b/src/HealthGPS.Tests/simulation.cpp @@ -0,0 +1,66 @@ +#include "simulation.h" + +namespace { +std::vector mapping_entries{ + {{"Gender", 0}, {"Age", 0}, {"SmokingStatus", 1}, {"AlcoholConsumption", 1}, {"BMI", 2}}}; +} // anonymous namespace + +void create_test_datatable(hgps::core::DataTable &data) { + using namespace hgps; + using namespace hgps::core; + + auto gender_values = std::vector{1, 0, 0, 1, 0}; + auto age_values = std::vector{4, 9, 14, 19, 25}; + auto edu_values = std::vector{6.0f, 10.0f, 2.0f, 9.0f, 12.0f}; + auto inc_values = std::vector{2.0, 10.0, 5.0, std::nan(""), 13.0}; + + auto gender_builder = IntegerDataTableColumnBuilder{"Gender"}; + auto age_builder = IntegerDataTableColumnBuilder{"Age"}; + auto edu_builder = FloatDataTableColumnBuilder{"Education"}; + auto inc_builder = DoubleDataTableColumnBuilder{"Income"}; + + for (size_t i = 0; i < gender_values.size(); i++) { + gender_builder.append(gender_values[i]); + age_builder.append(age_values[i]); + edu_builder.append(edu_values[i]); + if (std::isnan(inc_values[i])) { + inc_builder.append_null(); + } else { + inc_builder.append(inc_values[i]); + } + } + + data.add(gender_builder.build()); + data.add(age_builder.build()); + data.add(edu_builder.build()); + data.add(inc_builder.build()); +} + +hgps::ModelInput create_test_configuration(hgps::core::DataTable &data) { + using namespace hgps; + using namespace hgps::core; + + auto uk = core::Country{.code = 826, .name = "United Kingdom", .alpha2 = "GB", .alpha3 = "GBR"}; + + auto age_range = core::IntegerInterval(0, 30); + auto settings = Settings{uk, 0.1f, age_range}; + auto info = RunInfo{.start_time = 2018, .stop_time = 2025, .seed = std::nullopt}; + auto ses_mapping = std::map{ + {"gender", "Gender"}, {"age", "Age"}, {"education", "Education"}, {"income", "Income"}}; + auto ses = SESDefinition{.fuction_name = "normal", .parameters = std::vector{0.0, 1.0}}; + + auto mapping = HierarchicalMapping(mapping_entries); + + auto diseases = std::vector{ + DiseaseInfo{ + .group = DiseaseGroup::other, .code = core::Identifier{"asthma"}, .name = "Asthma"}, + DiseaseInfo{.group = DiseaseGroup::other, + .code = core::Identifier{"diabetes"}, + .name = "Diabetes Mellitus"}, + DiseaseInfo{.group = DiseaseGroup::cancer, + .code = core::Identifier{"colorectalcancer"}, + .name = "Colorectal cancer"}, + }; + + return {data, settings, info, ses, mapping, diseases}; +} diff --git a/src/HealthGPS.Tests/simulation.h b/src/HealthGPS.Tests/simulation.h new file mode 100644 index 000000000..e45a10dbb --- /dev/null +++ b/src/HealthGPS.Tests/simulation.h @@ -0,0 +1,18 @@ +#pragma once +#include "data_config.h" + +#include "HealthGPS.Core/column_builder.h" +#include "HealthGPS.Input/api.h" +#include "HealthGPS/analysis_module.h" +#include "HealthGPS/api.h" +#include "HealthGPS/converter.h" +#include "HealthGPS/event_bus.h" +#include "HealthGPS/random_algorithm.h" + +#include +#include +#include + +void create_test_datatable(hgps::core::DataTable &data); + +hgps::ModelInput create_test_configuration(hgps::core::DataTable &data); diff --git a/src/HealthGPS/analysis_module.cpp b/src/HealthGPS/analysis_module.cpp index bc731b7a4..637ac8303 100644 --- a/src/HealthGPS/analysis_module.cpp +++ b/src/HealthGPS/analysis_module.cpp @@ -50,17 +50,26 @@ void AnalysisModule::initialise_vector(RuntimeContext &context) { factor_min_values_.push_back(min_factor); + int factor_range = static_cast(max_factor - min_factor); + // The number of bins to use for each factor is the number of integer values of the factor, // or 100 bins of equal size, whichever is smaller (100 is an arbitrary number, it could be // any other number depending on the desired resolution of the map) - factor_bins_.push_back(std::min(100, static_cast(max_factor - min_factor))); - - // The width of each bin is the range of the factor divided by the number of bins - factor_bin_widths_.push_back((max_factor - min_factor) / factor_bins_.back()); + factor_bins_.push_back(std::min(100, factor_range + 1)); + + // The width of each bin is the factor_range divided by the number of bins. + // We need a special case for when the factor_range is 0, in which case we set the bin width + // to 1. E.g. when entire population is male. This may never happen in practice, but it's + // probably better to handle it just in case. + if (factor_range == 0) { + factor_bin_widths_.push_back(1.0); + } else { + factor_bin_widths_.push_back((max_factor + 1 - min_factor) / factor_bins_.back()); + } } // The number of factors to calculate stats for is the number of factors minus the length of the - // `factors` vector. + // `factors_to_calculate_` vector. size_t num_stats_to_calc = context.mapping().entries().size() - factors_to_calculate_.size(); // And for each factor, we calculate the stats described in `channels_`, so we @@ -74,6 +83,8 @@ void AnalysisModule::initialise_vector(RuntimeContext &context) { // Set the vector size and initialise all values to 0.0 calculated_stats_.resize(total_num_bins * num_stats_to_calc); + + num_stats_to_calc_ = num_stats_to_calc; } const std::string &AnalysisModule::name() const noexcept { return name_; } @@ -112,7 +123,7 @@ void AnalysisModule::initialise_population(RuntimeContext &context) { } initialise_output_channels(context); - + initialise_vector(context); publish_result_message(context); } @@ -151,10 +162,10 @@ void AnalysisModule::publish_result_message(RuntimeContext &context) const { auto result = ModelResult{sample_size}; auto handle = core::run_async(&AnalysisModule::calculate_historical_statistics, this, std::ref(context), std::ref(result)); - - calculate_population_statistics(context, result.series); handle.get(); + calculate_population_statistics(context); + context.publish(std::make_unique( context.identifier(), context.current_run(), context.time_now(), result)); } @@ -318,36 +329,95 @@ DALYsIndicator AnalysisModule::calculate_dalys(Population &population, unsigned .disability_adjusted_life_years = yll + yld}; } -void AnalysisModule::calculate_population_statistics(RuntimeContext &context) { - size_t num_factors_to_calculate = - context.mapping().entries().size() - factors_to_calculate_.size(); +void AnalysisModule::update_death_and_migration_stats(const Person &person, size_t index, + RuntimeContext &context) const { - for (const auto &person : context.population()) { - // Get the bin index for each factor - std::vector bin_indices; - for (size_t i = 0; i < factors_to_calculate_.size(); i++) { - double factor_value = person.get_risk_factor_value(factors_to_calculate_[i]); - auto bin_index = - static_cast((factor_value - factor_min_values_[i]) / factor_bin_widths_[i]); - bin_indices.push_back(bin_index); + auto current_time = static_cast(context.time_now()); + + if (!person.is_alive() && person.time_of_death() == current_time) { + calculated_stats_[index + channel_index_.at("deaths")]++; + float expected_life = definition_.life_expectancy().at(context.time_now(), person.gender); + double yll = std::max(expected_life - person.age, 0.0f) * DALY_UNITS; + calculated_stats_[index + channel_index_.at("mean_yll")] += yll; + calculated_stats_[index + channel_index_.at("mean_daly")] += yll; + } + + if (person.has_emigrated() && person.time_of_migration() == current_time) { + calculated_stats_[index + channel_index_.at("emigrations")]++; + } +} + +void AnalysisModule::update_calculated_stats_for_person(RuntimeContext &context, + const Person &person, size_t index) const { + calculated_stats_[index + channel_index_.at("count")]++; + + for (const auto &factor : context.mapping().entries()) { + double value = person.get_risk_factor_value(factor.key()); + calculated_stats_[index + channel_index_.at("mean_" + factor.key().to_string())] += value; + } + + for (const auto &[disease_name, disease_state] : person.diseases) { + if (disease_state.status == DiseaseStatus::active) { + calculated_stats_[index + + channel_index_.at("prevalence_" + disease_name.to_string())]++; + if (disease_state.start_time == context.time_now()) { + calculated_stats_[index + + channel_index_.at("incidence_" + disease_name.to_string())]++; + } } + } +} + +void AnalysisModule::calculate_population_statistics(RuntimeContext &context) const { + + for (const auto &person : context.population()) { + // First let's fetch the correct `calculated_stats_` bin index for this person + size_t index = calculate_index(person); - // Calculate the index in the calculated_stats_ vector - size_t index = 0; - for (size_t i = 0; i < bin_indices.size() - 1; i++) { - size_t accumulated_bins = - std::accumulate(std::next(factor_bins_.cbegin(), i + 1), factor_bins_.cend(), - size_t{1}, std::multiplies<>()); - index += bin_indices[i] * accumulated_bins * num_factors_to_calculate; + if (!person.is_active()) { + update_death_and_migration_stats(person, index, context); + continue; } - index += bin_indices.back() * num_factors_to_calculate; - // Now we can add the values of the factors that are not in factors_to_calculate_ + update_calculated_stats_for_person(context, person, index); + + double dw = calculate_disability_weight(person); + double yld = dw * DALY_UNITS; + calculated_stats_[index + channel_index_.at("mean_yld")] += yld; + calculated_stats_[index + channel_index_.at("mean_daly")] += yld; + + classify_weight(person); + } + + // For each bin in the calculated stats... + for (size_t i = 0; i < calculated_stats_.size(); i += channels_.size()) { + double count_F = calculated_stats_[i + channel_index_.at("count")]; + double count_M = calculated_stats_[i + channel_index_.at("count")]; + double deaths_F = calculated_stats_[i + channel_index_.at("deaths")]; + double deaths_M = calculated_stats_[i + channel_index_.at("deaths")]; + + // Calculate in-place factor averages. for (const auto &factor : context.mapping().entries()) { - if (std::find(factors_to_calculate_.cbegin(), factors_to_calculate_.cend(), - factor.key()) == factors_to_calculate_.cend()) { - calculated_stats_[index++] += person.get_risk_factor_value(factor.key()); - } + calculated_stats_[i + channel_index_.at("mean_" + factor.key().to_string())] /= count_F; + calculated_stats_[i + channel_index_.at("mean_" + factor.key().to_string())] /= count_M; + } + + // Calculate in-place disease prevalence and incidence rates. + for (const auto &disease : context.diseases()) { + calculated_stats_[i + channel_index_.at("prevalence_" + disease.code.to_string())] /= + count_F; + calculated_stats_[i + channel_index_.at("prevalence_" + disease.code.to_string())] /= + count_M; + calculated_stats_[i + channel_index_.at("incidence_" + disease.code.to_string())] /= + count_F; + calculated_stats_[i + channel_index_.at("incidence_" + disease.code.to_string())] /= + count_M; + } + + // Calculate in-place YLL/YLD/DALY averages. + for (const auto &column : {"mean_yll", "mean_yld", "mean_daly"}) { + calculated_stats_[i + channel_index_.at(column)] /= (count_F + deaths_F); + calculated_stats_[i + channel_index_.at(column)] /= (count_M + deaths_M); } } } @@ -531,6 +601,25 @@ void AnalysisModule::classify_weight(DataSeries &series, const Person &entity) c } } +void AnalysisModule::classify_weight(const Person &person) const { + auto weight_class = weight_classifier_.classify_weight(person); + switch (weight_class) { + case WeightCategory::normal: + calculated_stats_[channel_index_.at("normal_weight")]++; + break; + case WeightCategory::overweight: + calculated_stats_[channel_index_.at("over_weight")]++; + calculated_stats_[channel_index_.at("above_weight")]++; + break; + case WeightCategory::obese: + calculated_stats_[channel_index_.at("obese_weight")]++; + calculated_stats_[channel_index_.at("above_weight")]++; + break; + default: + throw std::logic_error("Unknown weight classification category."); + } +} + void AnalysisModule::initialise_output_channels(RuntimeContext &context) { if (!channels_.empty()) { return; @@ -560,6 +649,35 @@ void AnalysisModule::initialise_output_channels(RuntimeContext &context) { channels_.emplace_back("std_yld"); channels_.emplace_back("mean_daly"); channels_.emplace_back("std_daly"); + + // Since we will be performing frequent lookups, we will store the strings and indexes in a map + // for quick access. + for (size_t i = 0; i < channels_.size(); i++) { + channel_index_.emplace(channels_[i], i); + } +} + +size_t AnalysisModule::calculate_index(const Person &person) const { + // Get the bin index for each factor + std::vector bin_indices; + for (size_t i = 0; i < factors_to_calculate_.size(); i++) { + double factor_value = person.get_risk_factor_value(factors_to_calculate_[i]); + auto bin_index = + static_cast((factor_value - factor_min_values_[i]) / factor_bin_widths_[i]); + bin_indices.push_back(bin_index); + } + + // Calculate the index in the calculated_stats_ vector + size_t index = 0; + for (size_t i = 0; i < bin_indices.size() - 1; i++) { + size_t accumulated_bins = + std::accumulate(std::next(factor_bins_.cbegin(), i + 1), factor_bins_.cend(), size_t{1}, + std::multiplies<>()); + index += bin_indices[i] * accumulated_bins * num_stats_to_calc_; + } + index += bin_indices.back() * num_stats_to_calc_; + + return index; } std::unique_ptr build_analysis_module(Repository &repository, diff --git a/src/HealthGPS/analysis_module.h b/src/HealthGPS/analysis_module.h index 58ea43fe7..d3bf69aa3 100644 --- a/src/HealthGPS/analysis_module.h +++ b/src/HealthGPS/analysis_module.h @@ -8,11 +8,16 @@ #include "runtime_context.h" #include "weight_model.h" +#include + namespace hgps { /// @brief Implements the burden of diseases (BoD) analysis module class AnalysisModule final : public UpdatableModule { public: + friend class TestAnalysisModule; + FRIEND_TEST(TestAnalysisModule, CalculateIndex); + AnalysisModule() = delete; /// @brief Initialises a new instance of the AnalysisModule class. @@ -46,13 +51,15 @@ class AnalysisModule final : public UpdatableModule { WeightModel weight_classifier_; DoubleAgeGenderTable residual_disability_weight_; std::vector channels_; + std::unordered_map channel_index_; unsigned int comorbidities_; std::string name_{"Analysis"}; std::vector factors_to_calculate_ = {"Gender"_id, "Age"_id}; - std::vector calculated_stats_; + mutable std::vector calculated_stats_; std::vector factor_bins_; std::vector factor_bin_widths_; std::vector factor_min_values_; + size_t num_stats_to_calc_; void initialise_vector(RuntimeContext &context); @@ -65,13 +72,23 @@ class AnalysisModule final : public UpdatableModule { double calculate_disability_weight(const Person &entity) const; DALYsIndicator calculate_dalys(Population &population, unsigned int max_age, unsigned int death_year) const; + void update_death_and_migration_stats(const Person &person, size_t index, + RuntimeContext &context) const; + void update_calculated_stats_for_person(RuntimeContext &context, const Person &person, + size_t index) const; - void calculate_population_statistics(RuntimeContext &context); + void calculate_population_statistics(RuntimeContext &context) const; void calculate_population_statistics(RuntimeContext &context, DataSeries &series) const; void classify_weight(hgps::DataSeries &series, const hgps::Person &entity) const; + void classify_weight(const Person &person) const; void initialise_output_channels(RuntimeContext &context); + /// @brief Calculates the bin index in `calculated_stats_` for a given person + /// @param person The person to calculate the index for + /// @return The index in `calculated_stats_` + size_t calculate_index(const Person &person) const; + /// @brief Calculates the standard deviation of factors given data series containing means /// @param context The runtime context /// @param series The data series containing factor means