Skip to content

Commit

Permalink
Refactor trend init update code.
Browse files Browse the repository at this point in the history
  • Loading branch information
jamesturner246 committed Aug 20, 2024
1 parent 228fba0 commit 4b15f14
Showing 1 changed file with 24 additions and 20 deletions.
44 changes: 24 additions & 20 deletions src/HealthGPS/static_linear_model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,15 +130,18 @@ void StaticLinearModel::initialise_factors(RuntimeContext &context, Person &pers
// Initialise residual.
auto residual_name = core::Identifier{names_[i].to_string() + "_residual"};
double residual = residuals[i];

// Save residual.
person.risk_factors[residual_name] = residual;

// Initialise risk factor.
double expected = get_expected(context, person.gender, person.age, names_[i], ranges_[i]);
double factor = linear[i] + residual * stddev_[i];
factor = expected * inverse_box_cox(factor, lambda_[i]);
factor = ranges_[i].clamp(factor);

// Save clamped risk factor.
person.risk_factors[names_[i]] = ranges_[i].clamp(factor);
// Save risk factor.
person.risk_factors[names_[i]] = factor;
}
}

Expand All @@ -159,15 +162,18 @@ void StaticLinearModel::update_factors(RuntimeContext &context, Person &person,
double residual_old = person.risk_factors.at(residual_name);
double residual = residuals[i] * info_speed_;
residual += sqrt(1.0 - info_speed_ * info_speed_) * residual_old;

// Save residual.
person.risk_factors.at(residual_name) = residual;

// Update risk factor.
double expected = get_expected(context, person.gender, person.age, names_[i], ranges_[i]);
double factor = linear[i] + residual * stddev_[i];
factor = expected * inverse_box_cox(factor, lambda_[i]);
factor = ranges_[i].clamp(factor);

// Save clamped risk factor.
person.risk_factors.at(names_[i]) = ranges_[i].clamp(factor);
// Save risk factor.
person.risk_factors.at(names_[i]) = factor;
}
}

Expand All @@ -176,24 +182,19 @@ void StaticLinearModel::initialise_trends(RuntimeContext &context, Person &perso
// Approximate trends with linear models.
auto linear = compute_linear_models(person, *trend_models_);

// Get elapsed time (years).
int elapsed_time = context.time_now() - context.start_time();

// Initialise and apply trends (do not exist yet).
for (size_t i = 0; i < names_.size(); i++) {

// Initialise trend.
auto trend_name = core::Identifier{names_[i].to_string() + "_trend"};
double trend = (*trend_ranges_)[i].clamp(linear[i]);
person.risk_factors[trend_name] = trend;

// Apply trend to risk factor.
double factor = person.risk_factors.at(names_[i]);
factor *= pow(trend, elapsed_time);

// Save clamped risk factor.
person.risk_factors.at(names_[i]) = ranges_[i].clamp(factor);
// Save trend.
person.risk_factors[trend_name] = trend;
}

// Apply trends.
update_trends(context, person);
}

void StaticLinearModel::update_trends(RuntimeContext &context, Person &person) const {
Expand All @@ -211,9 +212,10 @@ void StaticLinearModel::update_trends(RuntimeContext &context, Person &person) c
// Apply trend to risk factor.
double factor = person.risk_factors.at(names_[i]);
factor *= pow(trend, elapsed_time);
factor = ranges_[i].clamp(factor);

// Save clamped risk factor.
person.risk_factors.at(names_[i]) = ranges_[i].clamp(factor);
// Save risk factor.
person.risk_factors.at(names_[i]) = factor;
}
}

Expand Down Expand Up @@ -258,9 +260,10 @@ void StaticLinearModel::update_policies(Person &person, bool intervene) const {
// Compute policy.
auto policy_name = core::Identifier{names_[i].to_string() + "_policy"};
double policy = linear[i] + residual;
policy = policy_ranges_[i].clamp(policy);

// Save clamped policy.
person.risk_factors[policy_name] = policy_ranges_[i].clamp(policy);
// Save policy.
person.risk_factors[policy_name] = policy;
}
}

Expand All @@ -281,9 +284,10 @@ void StaticLinearModel::apply_policies(Person &person, bool intervene) const {
// Apply policy to risk factor.
double factor_old = person.risk_factors.at(names_[i]);
double factor = factor_old * (1.0 + policy / 100.0);
factor = ranges_[i].clamp(factor);

// Save clamped risk factor.
person.risk_factors.at(names_[i]) = ranges_[i].clamp(factor);
// Save risk factor.
person.risk_factors.at(names_[i]) = factor;
}
}

Expand Down

0 comments on commit 4b15f14

Please sign in to comment.