Skip to content

Commit

Permalink
Merge pull request #498 from imperialCHEPI/time_trend_policy_fix
Browse files Browse the repository at this point in the history
Time trend policy fix
  • Loading branch information
jamesturner246 authored Sep 3, 2024
2 parents c608444 + fb26efc commit ca565d4
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 65 deletions.
16 changes: 8 additions & 8 deletions src/HealthGPS/risk_factor_adjustable_model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,12 @@ RiskFactorAdjustableModel::RiskFactorAdjustableModel(
: expected_{std::move(expected)}, expected_trend_{std::move(expected_trend)} {}

double RiskFactorAdjustableModel::get_expected(RuntimeContext &context, core::Gender sex, int age,
const core::Identifier &factor,
OptionalRange range) const noexcept {
const core::Identifier &factor, OptionalRange range,
bool apply_trend) const noexcept {
double expected = expected_->at(sex, factor).at(age);

// Apply trend to expected value.
if (expected_trend_->contains(factor)) {
// Apply optional trend to expected value.
if (apply_trend) {
int elapsed_time = context.time_now() - context.start_time();
expected *= pow(expected_trend_->at(factor), elapsed_time);
}
Expand All @@ -57,12 +57,12 @@ double RiskFactorAdjustableModel::get_expected(RuntimeContext &context, core::Ge

void RiskFactorAdjustableModel::adjust_risk_factors(RuntimeContext &context,
const std::vector<core::Identifier> &factors,
OptionalRanges ranges) const {
OptionalRanges ranges, bool apply_trend) const {
RiskFactorSexAgeTable adjustments;

// Baseline scenatio: compute adjustments.
if (context.scenario().type() == ScenarioType::baseline) {
adjustments = calculate_adjustments(context, factors, ranges);
adjustments = calculate_adjustments(context, factors, ranges, apply_trend);
}

// Intervention scenario: receive adjustments from baseline scenario.
Expand Down Expand Up @@ -115,7 +115,7 @@ void RiskFactorAdjustableModel::adjust_risk_factors(RuntimeContext &context,
RiskFactorSexAgeTable
RiskFactorAdjustableModel::calculate_adjustments(RuntimeContext &context,
const std::vector<core::Identifier> &factors,
OptionalRanges ranges) const {
OptionalRanges ranges, bool apply_trend) const {
auto age_range = context.age_range();
auto age_count = age_range.upper() + 1;

Expand All @@ -135,7 +135,7 @@ RiskFactorAdjustableModel::calculate_adjustments(RuntimeContext &context,

adjustments.emplace(sex, factor, std::vector<double>(age_count));
for (auto age = age_range.lower(); age <= age_range.upper(); age++) {
double expect = get_expected(context, sex, age, factor, range);
double expect = get_expected(context, sex, age, factor, range, apply_trend);
double sim_mean = simulated_means_by_sex.at(factor).at(age);

// Delta should remain zero if simulated mean is NaN.
Expand Down
11 changes: 7 additions & 4 deletions src/HealthGPS/risk_factor_adjustable_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,26 +41,29 @@ class RiskFactorAdjustableModel : public RiskFactorModel {
/// @param age The age key to get the expected value
/// @param factor The risk factor to get the expected value
/// @param range An optional expected value range
/// @param apply_trend Whether to apply expected value time trend
/// @returns The person's expected risk factor value
double get_expected(RuntimeContext &context, core::Gender sex, int age,
const core::Identifier &factor,
OptionalRange range = std::nullopt) const noexcept;
const core::Identifier &factor, OptionalRange range = std::nullopt,
bool apply_trend = false) const noexcept;

/// @brief Adjust risk factors such that mean sim value matches expected value
/// @param context The simulation run-time context
/// @param factors A list of risk factors to be adjusted
/// @param ranges An optional list of risk factor value boundaries
/// @param apply_trend Whether to apply expected value time trend
void adjust_risk_factors(RuntimeContext &context, const std::vector<core::Identifier> &factors,
OptionalRanges ranges = std::nullopt) const;
OptionalRanges ranges = std::nullopt, bool apply_trend = false) const;

private:
/// @brief Adjust risk factors such that mean sim value matches expected value
/// @param context The simulation run-time context
/// @param factors A list of risk factors to be adjusted
/// @param ranges An optional list of risk factor value boundaries
/// @param apply_trend Whether to apply expected value time trend
RiskFactorSexAgeTable calculate_adjustments(RuntimeContext &context,
const std::vector<core::Identifier> &factors,
OptionalRanges ranges) const;
OptionalRanges ranges, bool apply_trend) const;

static RiskFactorSexAgeTable
calculate_simulated_mean(Population &population, core::IntegerInterval age_range,
Expand Down
131 changes: 78 additions & 53 deletions src/HealthGPS/static_linear_model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,28 +40,30 @@ void StaticLinearModel::generate_risk_factors(RuntimeContext &context) {
initialise_sector(person, context.random());
initialise_income(person, context.random());
initialise_factors(context, person, context.random());
initialise_trends(context, person);
initialise_physical_activity(context, person, context.random());
}

// Adjust risk factors to match expected values.
adjust_risk_factors(context, names_, ranges_);
// Adjust such that risk factor means match expected values.
adjust_risk_factors(context, names_, ranges_, false);

// Initialise everyone.
for (auto &person : context.population()) {
initialise_policies(person, context.random(), false);
initialise_trends(context, person);
}

// Adjust such that trended risk factor means match trended expected values.
adjust_risk_factors(context, names_, ranges_, true);
}

void StaticLinearModel::update_risk_factors(RuntimeContext &context) {

// HACK: start intervening after 2 years from sim start.
// HACK: start intervening two years into the simulation.
bool intervene = (context.scenario().type() == ScenarioType::intervention &&
(context.time_now() - context.start_time()) >= 2);

// Initialise newborns and update others.
for (auto &person : context.population()) {
// Ignore if inactive.
if (!person.is_active()) {
continue;
}
Expand All @@ -70,18 +72,16 @@ void StaticLinearModel::update_risk_factors(RuntimeContext &context) {
initialise_sector(person, context.random());
initialise_income(person, context.random());
initialise_factors(context, person, context.random());
initialise_trends(context, person);
initialise_physical_activity(context, person, context.random());
} else {
update_sector(person, context.random());
update_income(person, context.random());
update_factors(context, person, context.random());
update_trends(context, person);
}
}

// Adjust risk factors to match expected values.
adjust_risk_factors(context, names_, ranges_);
// Adjust such that risk factor means match expected values.
adjust_risk_factors(context, names_, ranges_, false);

// Initialise newborns and update others.
for (auto &person : context.population()) {
Expand All @@ -91,10 +91,24 @@ void StaticLinearModel::update_risk_factors(RuntimeContext &context) {

if (person.age == 0) {
initialise_policies(person, context.random(), intervene);
initialise_trends(context, person);
} else {
update_policies(person, intervene);
update_trends(context, person);
}
}

// Adjust such that trended risk factor means match trended expected values.
adjust_risk_factors(context, names_, ranges_, true);

// Apply policies if intervening.
for (auto &person : context.population()) {
if (!person.is_active()) {
continue;
}

apply_policies(person, intervene);
}
}

double StaticLinearModel::inverse_box_cox(double factor, double lambda) {
Expand All @@ -107,7 +121,7 @@ void StaticLinearModel::initialise_factors(RuntimeContext &context, Person &pers
// Correlated residual sampling.
auto residuals = compute_residuals(random, cholesky_);

// Approximate risk factor values with linear models.
// Approximate risk factors with linear models.
auto linear = compute_linear_models(person, models_);

// Initialise residuals and risk factors (do not exist yet).
Expand All @@ -116,15 +130,18 @@ void StaticLinearModel::initialise_factors(RuntimeContext &context, Person &pers
// Initialise residual.
auto residual_name = core::Identifier{names_[i].to_string() + "_residual"};
double residual = residuals[i];

// Save residual.
person.risk_factors[residual_name] = residual;

// Initialise risk factor.
double expected = get_expected(context, person.gender, person.age, names_[i], ranges_[i]);
double factor = linear[i] + residual * stddev_[i];
factor = expected * inverse_box_cox(factor, lambda_[i]);
factor = ranges_[i].clamp(factor);

// Save clamped risk factor.
person.risk_factors[names_[i]] = ranges_[i].clamp(factor);
// Save risk factor.
person.risk_factors[names_[i]] = factor;
}
}

Expand All @@ -134,7 +151,7 @@ void StaticLinearModel::update_factors(RuntimeContext &context, Person &person,
// Correlated residual sampling.
auto residuals = compute_residuals(random, cholesky_);

// Approximate risk factor values with linear models.
// Approximate risk factors with linear models.
auto linear = compute_linear_models(person, models_);

// Update residuals and risk factors (should exist).
Expand All @@ -145,41 +162,39 @@ void StaticLinearModel::update_factors(RuntimeContext &context, Person &person,
double residual_old = person.risk_factors.at(residual_name);
double residual = residuals[i] * info_speed_;
residual += sqrt(1.0 - info_speed_ * info_speed_) * residual_old;

// Save residual.
person.risk_factors.at(residual_name) = residual;

// Update risk factor.
double expected = get_expected(context, person.gender, person.age, names_[i], ranges_[i]);
double factor = linear[i] + residual * stddev_[i];
factor = expected * inverse_box_cox(factor, lambda_[i]);
factor = ranges_[i].clamp(factor);

// Save clamped risk factor.
person.risk_factors.at(names_[i]) = ranges_[i].clamp(factor);
// Save risk factor.
person.risk_factors.at(names_[i]) = factor;
}
}

void StaticLinearModel::initialise_trends(RuntimeContext &context, Person &person) const {

// Approximate risk factor values with linear models.
// Approximate trends with linear models.
auto linear = compute_linear_models(person, *trend_models_);

// Get elapsed time (years).
int elapsed_time = context.time_now() - context.start_time();

// Initialise and apply trends (do not exist yet).
for (size_t i = 0; i < names_.size(); i++) {

// Initialise trend.
auto trend_name = core::Identifier{names_[i].to_string() + "_trend"};
double trend = (*trend_ranges_)[i].clamp(linear[i]);
person.risk_factors[trend_name] = trend;

// Apply trend to risk factor.
double factor = person.risk_factors.at(names_[i]);
factor *= pow(trend, elapsed_time);

// Save clamped risk factor.
person.risk_factors.at(names_[i]) = ranges_[i].clamp(factor);
// Save trend.
person.risk_factors[trend_name] = trend;
}

// Apply trends.
update_trends(context, person);
}

void StaticLinearModel::update_trends(RuntimeContext &context, Person &person) const {
Expand All @@ -190,21 +205,23 @@ void StaticLinearModel::update_trends(RuntimeContext &context, Person &person) c
// Apply trends (should exist).
for (size_t i = 0; i < names_.size(); i++) {

// Load time trend.
// Load trend.
auto trend_name = core::Identifier{names_[i].to_string() + "_trend"};
double trend = person.risk_factors.at(trend_name);

// Apply trend to risk factor.
double factor = person.risk_factors.at(names_[i]);
factor *= pow(trend, elapsed_time);
factor = ranges_[i].clamp(factor);

// Save clamped risk factor.
person.risk_factors.at(names_[i]) = ranges_[i].clamp(factor);
// Save risk factor.
person.risk_factors.at(names_[i]) = factor;
}
}

void StaticLinearModel::initialise_policies(Person &person, Random &random, bool intervene) const {
// NOTE: we need to keep baseline and intervention scenario RNGs in sync.
// NOTE: we need to keep baseline and intervention scenario RNGs in sync,
// so we compute residuals even though they are not used in baseline.

// Intervention policy residual components.
auto residuals = compute_residuals(random, policy_cholesky_);
Expand All @@ -215,54 +232,62 @@ void StaticLinearModel::initialise_policies(Person &person, Random &random, bool
person.risk_factors[residual_name] = residuals[i];
}

// No-op if not intervening.
// Compute policies.
update_policies(person, intervene);
}

void StaticLinearModel::update_policies(Person &person, bool intervene) const {

// Set zero policy if not intervening.
if (!intervene) {
for (const auto &name : names_) {
auto policy_name = core::Identifier{name.to_string() + "_policy"};
person.risk_factors[policy_name] = 0.0;
}
return;
}

// Intervention policy linear components.
auto linear = compute_linear_models(person, policy_models_);

// Compute and apply all intervention policies.
// Compute all intervention policies.
for (size_t i = 0; i < names_.size(); i++) {

// Compute intervention policy.
double policy = linear[i] + residuals[i];
// Load residual component.
auto residual_name = core::Identifier{names_[i].to_string() + "_policy_residual"};
double residual = person.risk_factors.at(residual_name);

// Compute policy.
auto policy_name = core::Identifier{names_[i].to_string() + "_policy"};
double policy = linear[i] + residual;
policy = policy_ranges_[i].clamp(policy);
double factor_old = person.risk_factors.at(names_[i]);
double factor = factor_old * (1.0 + policy / 100.0);

// Save clamped risk factor.
person.risk_factors.at(names_[i]) = ranges_[i].clamp(factor);
// Save policy.
person.risk_factors[policy_name] = policy;
}
}

void StaticLinearModel::update_policies(Person &person, bool intervene) const {
// NOTE: we need to keep baseline and intervention scenario RNGs in sync.
void StaticLinearModel::apply_policies(Person &person, bool intervene) const {

// No-op if not intervening.
if (!intervene) {
return;
}

// Intervention policy linear components.
auto linear = compute_linear_models(person, policy_models_);

// Compute and apply all intervention policies.
// Apply all intervention policies.
for (size_t i = 0; i < names_.size(); i++) {

// Load residual component.
auto residual_name = core::Identifier{names_[i].to_string() + "_policy_residual"};
double residual = person.risk_factors.at(residual_name);
// Load policy.
auto policy_name = core::Identifier{names_[i].to_string() + "_policy"};
double policy = person.risk_factors.at(policy_name);

// Compute intervention policy.
double policy = linear[i] + residual;
policy = policy_ranges_[i].clamp(policy);
// Apply policy to risk factor.
double factor_old = person.risk_factors.at(names_[i]);
double factor = factor_old * (1.0 + policy / 100.0);
factor = ranges_[i].clamp(factor);

// Save clamped risk factor.
person.risk_factors.at(names_[i]) = ranges_[i].clamp(factor);
// Save risk factor.
person.risk_factors.at(names_[i]) = factor;
}
}

Expand All @@ -272,7 +297,7 @@ StaticLinearModel::compute_linear_models(Person &person,
std::vector<double> linear{};
linear.reserve(names_.size());

// Approximate risk factor values for person with linear models.
// Approximate risk factors with linear models.
for (size_t i = 0; i < names_.size(); i++) {
auto name = names_[i];
auto model = models[i];
Expand Down
2 changes: 2 additions & 0 deletions src/HealthGPS/static_linear_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ class StaticLinearModel final : public RiskFactorAdjustableModel {

void update_policies(Person &person, bool intervene) const;

void apply_policies(Person &person, bool intervene) const;

std::vector<double> compute_linear_models(Person &person,
const std::vector<LinearModelParams> &models) const;

Expand Down

0 comments on commit ca565d4

Please sign in to comment.