diff --git a/src/HealthGPS/risk_factor_adjustable_model.cpp b/src/HealthGPS/risk_factor_adjustable_model.cpp index 0e7053d31..bce9ba7a0 100644 --- a/src/HealthGPS/risk_factor_adjustable_model.cpp +++ b/src/HealthGPS/risk_factor_adjustable_model.cpp @@ -37,12 +37,12 @@ RiskFactorAdjustableModel::RiskFactorAdjustableModel( : expected_{std::move(expected)}, expected_trend_{std::move(expected_trend)} {} double RiskFactorAdjustableModel::get_expected(RuntimeContext &context, core::Gender sex, int age, - const core::Identifier &factor, - OptionalRange range) const noexcept { + const core::Identifier &factor, OptionalRange range, + bool apply_trend) const noexcept { double expected = expected_->at(sex, factor).at(age); - // Apply trend to expected value. - if (expected_trend_->contains(factor)) { + // Apply optional trend to expected value. + if (apply_trend) { int elapsed_time = context.time_now() - context.start_time(); expected *= pow(expected_trend_->at(factor), elapsed_time); } @@ -57,12 +57,12 @@ double RiskFactorAdjustableModel::get_expected(RuntimeContext &context, core::Ge void RiskFactorAdjustableModel::adjust_risk_factors(RuntimeContext &context, const std::vector &factors, - OptionalRanges ranges) const { + OptionalRanges ranges, bool apply_trend) const { RiskFactorSexAgeTable adjustments; // Baseline scenatio: compute adjustments. if (context.scenario().type() == ScenarioType::baseline) { - adjustments = calculate_adjustments(context, factors, ranges); + adjustments = calculate_adjustments(context, factors, ranges, apply_trend); } // Intervention scenario: receive adjustments from baseline scenario. @@ -115,7 +115,7 @@ void RiskFactorAdjustableModel::adjust_risk_factors(RuntimeContext &context, RiskFactorSexAgeTable RiskFactorAdjustableModel::calculate_adjustments(RuntimeContext &context, const std::vector &factors, - OptionalRanges ranges) const { + OptionalRanges ranges, bool apply_trend) const { auto age_range = context.age_range(); auto age_count = age_range.upper() + 1; @@ -135,7 +135,7 @@ RiskFactorAdjustableModel::calculate_adjustments(RuntimeContext &context, adjustments.emplace(sex, factor, std::vector(age_count)); for (auto age = age_range.lower(); age <= age_range.upper(); age++) { - double expect = get_expected(context, sex, age, factor, range); + double expect = get_expected(context, sex, age, factor, range, apply_trend); double sim_mean = simulated_means_by_sex.at(factor).at(age); // Delta should remain zero if simulated mean is NaN. diff --git a/src/HealthGPS/risk_factor_adjustable_model.h b/src/HealthGPS/risk_factor_adjustable_model.h index 8ce2e3d23..a5a6988f1 100644 --- a/src/HealthGPS/risk_factor_adjustable_model.h +++ b/src/HealthGPS/risk_factor_adjustable_model.h @@ -41,26 +41,29 @@ class RiskFactorAdjustableModel : public RiskFactorModel { /// @param age The age key to get the expected value /// @param factor The risk factor to get the expected value /// @param range An optional expected value range + /// @param apply_trend Whether to apply expected value time trend /// @returns The person's expected risk factor value double get_expected(RuntimeContext &context, core::Gender sex, int age, - const core::Identifier &factor, - OptionalRange range = std::nullopt) const noexcept; + const core::Identifier &factor, OptionalRange range = std::nullopt, + bool apply_trend = false) const noexcept; /// @brief Adjust risk factors such that mean sim value matches expected value /// @param context The simulation run-time context /// @param factors A list of risk factors to be adjusted /// @param ranges An optional list of risk factor value boundaries + /// @param apply_trend Whether to apply expected value time trend void adjust_risk_factors(RuntimeContext &context, const std::vector &factors, - OptionalRanges ranges = std::nullopt) const; + OptionalRanges ranges = std::nullopt, bool apply_trend = false) const; private: /// @brief Adjust risk factors such that mean sim value matches expected value /// @param context The simulation run-time context /// @param factors A list of risk factors to be adjusted /// @param ranges An optional list of risk factor value boundaries + /// @param apply_trend Whether to apply expected value time trend RiskFactorSexAgeTable calculate_adjustments(RuntimeContext &context, const std::vector &factors, - OptionalRanges ranges) const; + OptionalRanges ranges, bool apply_trend) const; static RiskFactorSexAgeTable calculate_simulated_mean(Population &population, core::IntegerInterval age_range, diff --git a/src/HealthGPS/static_linear_model.cpp b/src/HealthGPS/static_linear_model.cpp index 84ad9a245..449526145 100644 --- a/src/HealthGPS/static_linear_model.cpp +++ b/src/HealthGPS/static_linear_model.cpp @@ -40,28 +40,30 @@ void StaticLinearModel::generate_risk_factors(RuntimeContext &context) { initialise_sector(person, context.random()); initialise_income(person, context.random()); initialise_factors(context, person, context.random()); - initialise_trends(context, person); initialise_physical_activity(context, person, context.random()); } - // Adjust risk factors to match expected values. - adjust_risk_factors(context, names_, ranges_); + // Adjust such that risk factor means match expected values. + adjust_risk_factors(context, names_, ranges_, false); // Initialise everyone. for (auto &person : context.population()) { initialise_policies(person, context.random(), false); + initialise_trends(context, person); } + + // Adjust such that trended risk factor means match trended expected values. + adjust_risk_factors(context, names_, ranges_, true); } void StaticLinearModel::update_risk_factors(RuntimeContext &context) { - // HACK: start intervening after 2 years from sim start. + // HACK: start intervening two years into the simulation. bool intervene = (context.scenario().type() == ScenarioType::intervention && (context.time_now() - context.start_time()) >= 2); // Initialise newborns and update others. for (auto &person : context.population()) { - // Ignore if inactive. if (!person.is_active()) { continue; } @@ -70,18 +72,16 @@ void StaticLinearModel::update_risk_factors(RuntimeContext &context) { initialise_sector(person, context.random()); initialise_income(person, context.random()); initialise_factors(context, person, context.random()); - initialise_trends(context, person); initialise_physical_activity(context, person, context.random()); } else { update_sector(person, context.random()); update_income(person, context.random()); update_factors(context, person, context.random()); - update_trends(context, person); } } - // Adjust risk factors to match expected values. - adjust_risk_factors(context, names_, ranges_); + // Adjust such that risk factor means match expected values. + adjust_risk_factors(context, names_, ranges_, false); // Initialise newborns and update others. for (auto &person : context.population()) { @@ -91,10 +91,24 @@ void StaticLinearModel::update_risk_factors(RuntimeContext &context) { if (person.age == 0) { initialise_policies(person, context.random(), intervene); + initialise_trends(context, person); } else { update_policies(person, intervene); + update_trends(context, person); } } + + // Adjust such that trended risk factor means match trended expected values. + adjust_risk_factors(context, names_, ranges_, true); + + // Apply policies if intervening. + for (auto &person : context.population()) { + if (!person.is_active()) { + continue; + } + + apply_policies(person, intervene); + } } double StaticLinearModel::inverse_box_cox(double factor, double lambda) { @@ -107,7 +121,7 @@ void StaticLinearModel::initialise_factors(RuntimeContext &context, Person &pers // Correlated residual sampling. auto residuals = compute_residuals(random, cholesky_); - // Approximate risk factor values with linear models. + // Approximate risk factors with linear models. auto linear = compute_linear_models(person, models_); // Initialise residuals and risk factors (do not exist yet). @@ -116,15 +130,18 @@ void StaticLinearModel::initialise_factors(RuntimeContext &context, Person &pers // Initialise residual. auto residual_name = core::Identifier{names_[i].to_string() + "_residual"}; double residual = residuals[i]; + + // Save residual. person.risk_factors[residual_name] = residual; // Initialise risk factor. double expected = get_expected(context, person.gender, person.age, names_[i], ranges_[i]); double factor = linear[i] + residual * stddev_[i]; factor = expected * inverse_box_cox(factor, lambda_[i]); + factor = ranges_[i].clamp(factor); - // Save clamped risk factor. - person.risk_factors[names_[i]] = ranges_[i].clamp(factor); + // Save risk factor. + person.risk_factors[names_[i]] = factor; } } @@ -134,7 +151,7 @@ void StaticLinearModel::update_factors(RuntimeContext &context, Person &person, // Correlated residual sampling. auto residuals = compute_residuals(random, cholesky_); - // Approximate risk factor values with linear models. + // Approximate risk factors with linear models. auto linear = compute_linear_models(person, models_); // Update residuals and risk factors (should exist). @@ -145,41 +162,39 @@ void StaticLinearModel::update_factors(RuntimeContext &context, Person &person, double residual_old = person.risk_factors.at(residual_name); double residual = residuals[i] * info_speed_; residual += sqrt(1.0 - info_speed_ * info_speed_) * residual_old; + + // Save residual. person.risk_factors.at(residual_name) = residual; // Update risk factor. double expected = get_expected(context, person.gender, person.age, names_[i], ranges_[i]); double factor = linear[i] + residual * stddev_[i]; factor = expected * inverse_box_cox(factor, lambda_[i]); + factor = ranges_[i].clamp(factor); - // Save clamped risk factor. - person.risk_factors.at(names_[i]) = ranges_[i].clamp(factor); + // Save risk factor. + person.risk_factors.at(names_[i]) = factor; } } void StaticLinearModel::initialise_trends(RuntimeContext &context, Person &person) const { - // Approximate risk factor values with linear models. + // Approximate trends with linear models. auto linear = compute_linear_models(person, *trend_models_); - // Get elapsed time (years). - int elapsed_time = context.time_now() - context.start_time(); - // Initialise and apply trends (do not exist yet). for (size_t i = 0; i < names_.size(); i++) { // Initialise trend. auto trend_name = core::Identifier{names_[i].to_string() + "_trend"}; double trend = (*trend_ranges_)[i].clamp(linear[i]); - person.risk_factors[trend_name] = trend; - - // Apply trend to risk factor. - double factor = person.risk_factors.at(names_[i]); - factor *= pow(trend, elapsed_time); - // Save clamped risk factor. - person.risk_factors.at(names_[i]) = ranges_[i].clamp(factor); + // Save trend. + person.risk_factors[trend_name] = trend; } + + // Apply trends. + update_trends(context, person); } void StaticLinearModel::update_trends(RuntimeContext &context, Person &person) const { @@ -190,21 +205,23 @@ void StaticLinearModel::update_trends(RuntimeContext &context, Person &person) c // Apply trends (should exist). for (size_t i = 0; i < names_.size(); i++) { - // Load time trend. + // Load trend. auto trend_name = core::Identifier{names_[i].to_string() + "_trend"}; double trend = person.risk_factors.at(trend_name); // Apply trend to risk factor. double factor = person.risk_factors.at(names_[i]); factor *= pow(trend, elapsed_time); + factor = ranges_[i].clamp(factor); - // Save clamped risk factor. - person.risk_factors.at(names_[i]) = ranges_[i].clamp(factor); + // Save risk factor. + person.risk_factors.at(names_[i]) = factor; } } void StaticLinearModel::initialise_policies(Person &person, Random &random, bool intervene) const { - // NOTE: we need to keep baseline and intervention scenario RNGs in sync. + // NOTE: we need to keep baseline and intervention scenario RNGs in sync, + // so we compute residuals even though they are not used in baseline. // Intervention policy residual components. auto residuals = compute_residuals(random, policy_cholesky_); @@ -215,54 +232,62 @@ void StaticLinearModel::initialise_policies(Person &person, Random &random, bool person.risk_factors[residual_name] = residuals[i]; } - // No-op if not intervening. + // Compute policies. + update_policies(person, intervene); +} + +void StaticLinearModel::update_policies(Person &person, bool intervene) const { + + // Set zero policy if not intervening. if (!intervene) { + for (const auto &name : names_) { + auto policy_name = core::Identifier{name.to_string() + "_policy"}; + person.risk_factors[policy_name] = 0.0; + } return; } // Intervention policy linear components. auto linear = compute_linear_models(person, policy_models_); - // Compute and apply all intervention policies. + // Compute all intervention policies. for (size_t i = 0; i < names_.size(); i++) { - // Compute intervention policy. - double policy = linear[i] + residuals[i]; + // Load residual component. + auto residual_name = core::Identifier{names_[i].to_string() + "_policy_residual"}; + double residual = person.risk_factors.at(residual_name); + + // Compute policy. + auto policy_name = core::Identifier{names_[i].to_string() + "_policy"}; + double policy = linear[i] + residual; policy = policy_ranges_[i].clamp(policy); - double factor_old = person.risk_factors.at(names_[i]); - double factor = factor_old * (1.0 + policy / 100.0); - // Save clamped risk factor. - person.risk_factors.at(names_[i]) = ranges_[i].clamp(factor); + // Save policy. + person.risk_factors[policy_name] = policy; } } -void StaticLinearModel::update_policies(Person &person, bool intervene) const { - // NOTE: we need to keep baseline and intervention scenario RNGs in sync. +void StaticLinearModel::apply_policies(Person &person, bool intervene) const { // No-op if not intervening. if (!intervene) { return; } - // Intervention policy linear components. - auto linear = compute_linear_models(person, policy_models_); - - // Compute and apply all intervention policies. + // Apply all intervention policies. for (size_t i = 0; i < names_.size(); i++) { - // Load residual component. - auto residual_name = core::Identifier{names_[i].to_string() + "_policy_residual"}; - double residual = person.risk_factors.at(residual_name); + // Load policy. + auto policy_name = core::Identifier{names_[i].to_string() + "_policy"}; + double policy = person.risk_factors.at(policy_name); - // Compute intervention policy. - double policy = linear[i] + residual; - policy = policy_ranges_[i].clamp(policy); + // Apply policy to risk factor. double factor_old = person.risk_factors.at(names_[i]); double factor = factor_old * (1.0 + policy / 100.0); + factor = ranges_[i].clamp(factor); - // Save clamped risk factor. - person.risk_factors.at(names_[i]) = ranges_[i].clamp(factor); + // Save risk factor. + person.risk_factors.at(names_[i]) = factor; } } @@ -272,7 +297,7 @@ StaticLinearModel::compute_linear_models(Person &person, std::vector linear{}; linear.reserve(names_.size()); - // Approximate risk factor values for person with linear models. + // Approximate risk factors with linear models. for (size_t i = 0; i < names_.size(); i++) { auto name = names_[i]; auto model = models[i]; diff --git a/src/HealthGPS/static_linear_model.h b/src/HealthGPS/static_linear_model.h index aad17e5bb..ac7aeb6e4 100644 --- a/src/HealthGPS/static_linear_model.h +++ b/src/HealthGPS/static_linear_model.h @@ -79,6 +79,8 @@ class StaticLinearModel final : public RiskFactorAdjustableModel { void update_policies(Person &person, bool intervene) const; + void apply_policies(Person &person, bool intervene) const; + std::vector compute_linear_models(Person &person, const std::vector &models) const;