Skip to content

Commit

Permalink
Merge pull request #491 from imperialCHEPI/time_trend
Browse files Browse the repository at this point in the history
Time trend for expected values
  • Loading branch information
jamesturner246 committed Aug 16, 2024
2 parents 0e2229f + 845d846 commit e766691
Show file tree
Hide file tree
Showing 10 changed files with 232 additions and 149 deletions.
57 changes: 33 additions & 24 deletions src/HealthGPS.Input/model_parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,29 +24,30 @@

namespace hgps::input {

hgps::RiskFactorSexAgeTable load_risk_factor_expected(const Configuration &config) {
std::unique_ptr<hgps::RiskFactorSexAgeTable>
load_risk_factor_expected(const Configuration &config) {
MEASURE_FUNCTION();

const auto &info = config.modelling.baseline_adjustment;
if (!hgps::core::case_insensitive::equals(info.format, "CSV")) {
throw hgps::core::HgpsException{"Unsupported file format: " + info.format};
}

auto table = hgps::RiskFactorSexAgeTable{};
auto table = std::make_unique<hgps::RiskFactorSexAgeTable>();
const auto male_filename = info.file_names.at("factorsmean_male").string();
const auto female_filename = info.file_names.at("factorsmean_female").string();
try {
table.emplace_row(hgps::core::Gender::male,
load_baseline_from_csv(male_filename, info.delimiter));
table.emplace_row(hgps::core::Gender::female,
load_baseline_from_csv(female_filename, info.delimiter));
table->emplace_row(hgps::core::Gender::male,
load_baseline_from_csv(male_filename, info.delimiter));
table->emplace_row(hgps::core::Gender::female,
load_baseline_from_csv(female_filename, info.delimiter));
} catch (const std::runtime_error &ex) {
throw hgps::core::HgpsException{fmt::format("Failed to parse adjustment file: {} or {}. {}",
male_filename, female_filename, ex.what())};
}

const auto max_age = static_cast<std::size_t>(config.settings.age_range.upper());
for (const auto &sex : table) {
for (const auto &sex : *table) {
for (const auto &factor : sex.second) {
if (factor.second.size() <= max_age) {
throw hgps::core::HgpsException{
Expand Down Expand Up @@ -177,13 +178,14 @@ load_staticlinear_risk_model_definition(const nlohmann::json &opt, const Configu
policy_covariance_table.num_columns()};

// Risk factor and intervention policy: names, models, parameters and correlation/covariance.
std::vector<hgps::core::Identifier> names;
std::vector<hgps::LinearModelParams> models;
std::vector<hgps::core::DoubleInterval> ranges;
std::vector<core::Identifier> names;
std::vector<LinearModelParams> models;
std::vector<core::DoubleInterval> ranges;
std::vector<double> lambda;
std::vector<double> stddev;
std::vector<hgps::LinearModelParams> policy_models;
std::vector<hgps::core::DoubleInterval> policy_ranges;
std::vector<LinearModelParams> policy_models;
std::vector<core::DoubleInterval> policy_ranges;
auto expected_trend = std::make_unique<std::unordered_map<core::Identifier, double>>();

size_t i = 0;
for (const auto &[key, json_params] : opt["RiskFactorModels"].items()) {
Expand Down Expand Up @@ -240,6 +242,9 @@ load_staticlinear_risk_model_definition(const nlohmann::json &opt, const Configu
std::any_cast<double>(policy_covariance_table.column(i).value(j));
}

// Load expected value trends.
(*expected_trend)[key] = json_params["ExpectedTrend"].get<double>();

// Increment table column index.
i++;
}
Expand Down Expand Up @@ -268,15 +273,15 @@ load_staticlinear_risk_model_definition(const nlohmann::json &opt, const Configu
Eigen::MatrixXd{Eigen::LLT<Eigen::MatrixXd>{policy_covariance}.matrixL()};

// Risk factor expected values by sex and age.
hgps::RiskFactorSexAgeTable expected = load_risk_factor_expected(config);
auto expected = load_risk_factor_expected(config);

// Check expected values are defined for all risk factors.
for (const auto &name : names) {
if (!expected.at(hgps::core::Gender::male).contains(name)) {
if (!expected->at(hgps::core::Gender::male).contains(name)) {
throw hgps::core::HgpsException{fmt::format(
"'{}' is not defined in male risk factor expected values.", name.to_string())};
}
if (!expected.at(hgps::core::Gender::female).contains(name)) {
if (!expected->at(hgps::core::Gender::female).contains(name)) {
throw hgps::core::HgpsException{fmt::format(
"'{}' is not defined in female risk factor expected values.", name.to_string())};
}
Expand Down Expand Up @@ -327,9 +332,9 @@ load_staticlinear_risk_model_definition(const nlohmann::json &opt, const Configu
const double physical_activity_stddev = opt["PhysicalActivityStdDev"].get<double>();

return std::make_unique<hgps::StaticLinearModelDefinition>(
std::move(expected), std::move(names), std::move(models), std::move(ranges),
std::move(lambda), std::move(stddev), std::move(cholesky), std::move(policy_models),
std::move(policy_ranges), std::move(policy_cholesky), info_speed,
std::move(expected), std::move(expected_trend), std::move(names), std::move(models),
std::move(ranges), std::move(lambda), std::move(stddev), std::move(cholesky),
std::move(policy_models), std::move(policy_ranges), std::move(policy_cholesky), info_speed,
std::move(rural_prevalence), std::move(income_models), physical_activity_stddev);
}

Expand Down Expand Up @@ -425,10 +430,12 @@ load_ebhlm_risk_model_definition(const nlohmann::json &opt, const Configuration
}

// Risk factor expected values by sex and age.
hgps::RiskFactorSexAgeTable expected = load_risk_factor_expected(config);
auto expected = load_risk_factor_expected(config);
auto expected_trend = std::make_unique<std::unordered_map<core::Identifier, double>>();

return std::make_unique<hgps::DynamicHierarchicalLinearModelDefinition>(
std::move(expected), std::move(equations), std::move(variables), percentage);
std::move(expected), std::move(expected_trend), std::move(equations), std::move(variables),
percentage);
}
// NOLINTEND(readability-function-cognitive-complexity)

Expand All @@ -437,7 +444,8 @@ load_kevinhall_risk_model_definition(const nlohmann::json &opt, const Configurat
MEASURE_FUNCTION();

// Risk factor expected values by sex and age.
hgps::RiskFactorSexAgeTable expected = load_risk_factor_expected(config);
auto expected = load_risk_factor_expected(config);
auto expected_trend = std::make_unique<std::unordered_map<core::Identifier, double>>();

// Nutrient groups.
std::unordered_map<hgps::core::Identifier, double> energy_equation;
Expand Down Expand Up @@ -507,9 +515,10 @@ load_kevinhall_risk_model_definition(const nlohmann::json &opt, const Configurat
{hgps::core::Gender::male, opt["HeightSlope"]["Male"].get<double>()}};

return std::make_unique<hgps::KevinHallModelDefinition>(
std::move(expected), std::move(energy_equation), std::move(nutrient_ranges),
std::move(nutrient_equations), std::move(food_prices), std::move(weight_quantiles),
std::move(epa_quantiles), std::move(height_stddev), std::move(height_slope));
std::move(expected), std::move(expected_trend), std::move(energy_equation),
std::move(nutrient_ranges), std::move(nutrient_equations), std::move(food_prices),
std::move(weight_quantiles), std::move(epa_quantiles), std::move(height_stddev),
std::move(height_slope));
}

std::pair<hgps::RiskFactorModelType, std::unique_ptr<hgps::RiskFactorModelDefinition>>
Expand Down
2 changes: 1 addition & 1 deletion src/HealthGPS.Input/model_parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ namespace hgps::input {
/// @brief Loads risk factor expected values from a file
/// @param config The model configuration
/// @return An instance of the hgps::RiskFactorSexAgeTable type
hgps::RiskFactorSexAgeTable load_risk_factor_expected(const Configuration &config);
std::unique_ptr<hgps::RiskFactorSexAgeTable> load_risk_factor_expected(const Configuration &config);

/// @brief Loads either a static or dynamic dummy risk factor model from a JSON file
/// @param type Model type (static or dynamic)
Expand Down
21 changes: 12 additions & 9 deletions src/HealthGPS/dynamic_hierarchical_linear_model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,18 @@

#include "HealthGPS.Core/exception.h"

#include <utility>
#include <vector>

namespace hgps {

DynamicHierarchicalLinearModel::DynamicHierarchicalLinearModel(
const RiskFactorSexAgeTable &expected,
std::shared_ptr<RiskFactorSexAgeTable> expected,
std::shared_ptr<std::unordered_map<core::Identifier, double>> expected_trend,
const std::map<core::IntegerInterval, AgeGroupGenderEquation> &equations,
const std::map<core::Identifier, core::Identifier> &variables, double boundary_percentage)
: RiskFactorAdjustableModel{expected}, equations_{equations}, variables_{variables},
boundary_percentage_{boundary_percentage} {}
: RiskFactorAdjustableModel{std::move(expected), std::move(expected_trend)},
equations_{equations}, variables_{variables}, boundary_percentage_{boundary_percentage} {}

RiskFactorModelType DynamicHierarchicalLinearModel::type() const noexcept {
return RiskFactorModelType::Dynamic;
Expand Down Expand Up @@ -137,11 +139,13 @@ double DynamicHierarchicalLinearModel::sample_normal_with_boundary(Random &rando
}

DynamicHierarchicalLinearModelDefinition::DynamicHierarchicalLinearModelDefinition(
RiskFactorSexAgeTable expected,
std::unique_ptr<RiskFactorSexAgeTable> expected,
std::unique_ptr<std::unordered_map<core::Identifier, double>> expected_trend,
std::map<core::IntegerInterval, AgeGroupGenderEquation> equations,
std::map<core::Identifier, core::Identifier> variables, const double boundary_percentage)
: RiskFactorAdjustableModelDefinition{std::move(expected)}, equations_{std::move(equations)},
variables_{std::move(variables)}, boundary_percentage_{boundary_percentage} {
: RiskFactorAdjustableModelDefinition{std::move(expected), std::move(expected_trend)},
equations_{std::move(equations)}, variables_{std::move(variables)},
boundary_percentage_{boundary_percentage} {

if (equations_.empty()) {
throw core::HgpsException("The model equations definition must not be empty");
Expand All @@ -152,9 +156,8 @@ DynamicHierarchicalLinearModelDefinition::DynamicHierarchicalLinearModelDefiniti
}

std::unique_ptr<RiskFactorModel> DynamicHierarchicalLinearModelDefinition::create_model() const {
const auto &expected = get_risk_factor_expected();
return std::make_unique<DynamicHierarchicalLinearModel>(expected, equations_, variables_,
boundary_percentage_);
return std::make_unique<DynamicHierarchicalLinearModel>(expected_, expected_trend_, equations_,
variables_, boundary_percentage_);
}

} // namespace hgps
8 changes: 6 additions & 2 deletions src/HealthGPS/dynamic_hierarchical_linear_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,13 @@ class DynamicHierarchicalLinearModel final : public RiskFactorAdjustableModel {
public:
/// @brief Initialises a new instance of the DynamicHierarchicalLinearModel class
/// @param expected The expected values
/// @param expected_trend The expected trend of risk factor values
/// @param equations The linear regression equations
/// @param variables The factors delta variables mapping
/// @param boundary_percentage The boundary percentage to sample
DynamicHierarchicalLinearModel(
const RiskFactorSexAgeTable &expected,
std::shared_ptr<RiskFactorSexAgeTable> expected,
std::shared_ptr<std::unordered_map<core::Identifier, double>> expected_trend,
const std::map<core::IntegerInterval, AgeGroupGenderEquation> &equations,
const std::map<core::Identifier, core::Identifier> &variables,
const double boundary_percentage);
Expand Down Expand Up @@ -80,12 +82,14 @@ class DynamicHierarchicalLinearModelDefinition : public RiskFactorAdjustableMode
public:
/// @brief Initialises a new instance of the DynamicHierarchicalLinearModelDefinition class
/// @param expected The expected values
/// @param expected_trend The expected trend of risk factor values
/// @param equations The linear regression equations
/// @param variables The factors delta variables mapping
/// @param boundary_percentage The boundary percentage to sample
/// @throws std::invalid_argument for empty model equations definition
DynamicHierarchicalLinearModelDefinition(
RiskFactorSexAgeTable expected,
std::unique_ptr<RiskFactorSexAgeTable> expected,
std::unique_ptr<std::unordered_map<core::Identifier, double>> expected_trend,
std::map<core::IntegerInterval, AgeGroupGenderEquation> equations,
std::map<core::Identifier, core::Identifier> variables,
const double boundary_percentage = 0.05);
Expand Down
Loading

0 comments on commit e766691

Please sign in to comment.