Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement stats calculations for new calculated_stats_ container #468

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
153 changes: 130 additions & 23 deletions src/HealthGPS/analysis_module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -318,36 +318,95 @@ DALYsIndicator AnalysisModule::calculate_dalys(Population &population, unsigned
.disability_adjusted_life_years = yll + yld};
}

void AnalysisModule::update_death_and_migration_stats(const Person &person, size_t index,
RuntimeContext &context) {
TinyMarsh marked this conversation as resolved.
Show resolved Hide resolved
TinyMarsh marked this conversation as resolved.
Show resolved Hide resolved

auto current_time = static_cast<unsigned int>(context.time_now());

TinyMarsh marked this conversation as resolved.
Show resolved Hide resolved
if (!person.is_alive() && person.time_of_death() == current_time) {
calculated_stats_[index + channel_index_.at("deaths")]++;
float expected_life = definition_.life_expectancy().at(context.time_now(), person.gender);
double yll = std::max(expected_life - person.age, 0.0f) * DALY_UNITS;
calculated_stats_[index + channel_index_.at("mean_yll")] += yll;
calculated_stats_[index + channel_index_.at("mean_daly")] += yll;
}

if (person.has_emigrated() && person.time_of_migration() == current_time) {
calculated_stats_[index + channel_index_.at("emigrations")]++;
}
}

void AnalysisModule::update_calculated_stats_for_person(RuntimeContext &context,
const Person &person, size_t index) {
calculated_stats_[index + channel_index_.at("count")]++;

for (const auto &factor : context.mapping().entries()) {
double value = person.get_risk_factor_value(factor.key());
calculated_stats_[index + channel_index_.at("mean_" + factor.key().to_string())] += value;
}

for (const auto &[disease_name, disease_state] : person.diseases) {
if (disease_state.status == DiseaseStatus::active) {
calculated_stats_[index +
channel_index_.at("prevalence_" + disease_name.to_string())]++;
if (disease_state.start_time == context.time_now()) {
calculated_stats_[index +
channel_index_.at("incidence_" + disease_name.to_string())]++;
}
}
}
}

void AnalysisModule::calculate_population_statistics(RuntimeContext &context) {
size_t num_factors_to_calculate =
context.mapping().entries().size() - factors_to_calculate_.size();

for (const auto &person : context.population()) {
// Get the bin index for each factor
std::vector<size_t> bin_indices;
for (size_t i = 0; i < factors_to_calculate_.size(); i++) {
double factor_value = person.get_risk_factor_value(factors_to_calculate_[i]);
auto bin_index =
static_cast<size_t>((factor_value - factor_min_values_[i]) / factor_bin_widths_[i]);
bin_indices.push_back(bin_index);
}
// First let's fetch the correct `calculated_stats_` bin index for this person
size_t index = calculate_index(person);

// Calculate the index in the calculated_stats_ vector
size_t index = 0;
for (size_t i = 0; i < bin_indices.size() - 1; i++) {
size_t accumulated_bins =
std::accumulate(std::next(factor_bins_.cbegin(), i + 1), factor_bins_.cend(),
size_t{1}, std::multiplies<>());
index += bin_indices[i] * accumulated_bins * num_factors_to_calculate;
if (!person.is_active()) {
update_death_and_migration_stats(person, index, context);
continue;
}
index += bin_indices.back() * num_factors_to_calculate;

// Now we can add the values of the factors that are not in factors_to_calculate_
update_calculated_stats_for_person(context, person, index);

double dw = calculate_disability_weight(person);
double yld = dw * DALY_UNITS;
calculated_stats_[index + channel_index_.at("mean_yld")] += yld;
calculated_stats_[index + channel_index_.at("mean_daly")] += yld;

classify_weight(person);
}

// For each bin in the calculated stats...
for (size_t i = 0; i < calculated_stats_.size(); i += channels_.size()) {
double count_F = calculated_stats_[i + channel_index_.at("count")];
double count_M = calculated_stats_[i + channel_index_.at("count")];
double deaths_F = calculated_stats_[i + channel_index_.at("deaths")];
double deaths_M = calculated_stats_[i + channel_index_.at("deaths")];

// Calculate in-place factor averages.
for (const auto &factor : context.mapping().entries()) {
if (std::find(factors_to_calculate_.cbegin(), factors_to_calculate_.cend(),
factor.key()) == factors_to_calculate_.cend()) {
calculated_stats_[index++] += person.get_risk_factor_value(factor.key());
}
calculated_stats_[i + channel_index_.at("mean_" + factor.key().to_string())] /= count_F;
calculated_stats_[i + channel_index_.at("mean_" + factor.key().to_string())] /= count_M;
}

// Calculate in-place disease prevalence and incidence rates.
for (const auto &disease : context.diseases()) {
calculated_stats_[i + channel_index_.at("prevalence_" + disease.code.to_string())] /=
count_F;
calculated_stats_[i + channel_index_.at("prevalence_" + disease.code.to_string())] /=
count_M;
calculated_stats_[i + channel_index_.at("incidence_" + disease.code.to_string())] /=
count_F;
calculated_stats_[i + channel_index_.at("incidence_" + disease.code.to_string())] /=
count_M;
}

// Calculate in-place YLL/YLD/DALY averages.
for (const auto &column : {"mean_yll", "mean_yld", "mean_daly"}) {
calculated_stats_[i + channel_index_.at(column)] /= (count_F + deaths_F);
calculated_stats_[i + channel_index_.at(column)] /= (count_M + deaths_M);
}
}
}
Expand Down Expand Up @@ -531,6 +590,25 @@ void AnalysisModule::classify_weight(DataSeries &series, const Person &entity) c
}
}

void AnalysisModule::classify_weight(const Person &person) {
auto weight_class = weight_classifier_.classify_weight(person);
switch (weight_class) {
case WeightCategory::normal:
calculated_stats_[channel_index_.at("normal_weight")]++;
break;
case WeightCategory::overweight:
calculated_stats_[channel_index_.at("over_weight")]++;
calculated_stats_[channel_index_.at("above_weight")]++;
break;
case WeightCategory::obese:
calculated_stats_[channel_index_.at("obese_weight")]++;
calculated_stats_[channel_index_.at("above_weight")]++;
break;
default:
throw std::logic_error("Unknown weight classification category.");
}
}

void AnalysisModule::initialise_output_channels(RuntimeContext &context) {
if (!channels_.empty()) {
return;
Expand Down Expand Up @@ -560,6 +638,35 @@ void AnalysisModule::initialise_output_channels(RuntimeContext &context) {
channels_.emplace_back("std_yld");
channels_.emplace_back("mean_daly");
channels_.emplace_back("std_daly");

// Since we will be performing frequent lookups, we will store the strings and indexes in a map
// for quick access.
for (size_t i = 0; i < channels_.size(); i++) {
channel_index_.emplace(channels_[i], i);
}
}

size_t AnalysisModule::calculate_index(const Person &person) const {
TinyMarsh marked this conversation as resolved.
Show resolved Hide resolved
// Get the bin index for each factor
std::vector<size_t> bin_indices;
for (size_t i = 0; i < factors_to_calculate_.size(); i++) {
double factor_value = person.get_risk_factor_value(factors_to_calculate_[i]);
auto bin_index =
static_cast<size_t>((factor_value - factor_min_values_[i]) / factor_bin_widths_[i]);
bin_indices.push_back(bin_index);
}

// Calculate the index in the calculated_stats_ vector
size_t index = 0;
for (size_t i = 0; i < bin_indices.size() - 1; i++) {
size_t accumulated_bins =
std::accumulate(std::next(factor_bins_.cbegin(), i + 1), factor_bins_.cend(), size_t{1},
std::multiplies<>());
index += bin_indices[i] * accumulated_bins * channels_.size();
}
index += bin_indices.back() * channels_.size();

return index;
}

std::unique_ptr<AnalysisModule> build_analysis_module(Repository &repository,
Expand Down
11 changes: 11 additions & 0 deletions src/HealthGPS/analysis_module.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class AnalysisModule final : public UpdatableModule {
WeightModel weight_classifier_;
DoubleAgeGenderTable residual_disability_weight_;
std::vector<std::string> channels_;
std::unordered_map<std::string, size_t> channel_index_;
unsigned int comorbidities_;
std::string name_{"Analysis"};
std::vector<core::Identifier> factors_to_calculate_ = {"Gender"_id, "Age"_id};
Expand All @@ -65,13 +66,23 @@ class AnalysisModule final : public UpdatableModule {
double calculate_disability_weight(const Person &entity) const;
DALYsIndicator calculate_dalys(Population &population, unsigned int max_age,
unsigned int death_year) const;
void update_death_and_migration_stats(const Person &person, size_t index,
RuntimeContext &context);
void update_calculated_stats_for_person(RuntimeContext &context, const Person &person,
size_t index);

void calculate_population_statistics(RuntimeContext &context);
void calculate_population_statistics(RuntimeContext &context, DataSeries &series) const;

void classify_weight(hgps::DataSeries &series, const hgps::Person &entity) const;
void classify_weight(const Person &person);
void initialise_output_channels(RuntimeContext &context);

/// @brief Calculates the bin index in `calculated_stats_` for a given person
/// @param person The person to calculate the index for
/// @return The index in `calculated_stats_`
size_t calculate_index(const Person &person) const;

/// @brief Calculates the standard deviation of factors given data series containing means
/// @param context The runtime context
/// @param series The data series containing factor means
Expand Down