Skip to content

Commit

Permalink
Implement calcluated stats
Browse files Browse the repository at this point in the history
  • Loading branch information
TinyMarsh committed Jul 3, 2024
1 parent 0b0ce7a commit d00133e
Show file tree
Hide file tree
Showing 2 changed files with 156 additions and 22 deletions.
161 changes: 139 additions & 22 deletions src/HealthGPS/analysis_module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -319,36 +319,88 @@ DALYsIndicator AnalysisModule::calculate_dalys(Population &population, unsigned
}

void AnalysisModule::calculate_population_statistics(RuntimeContext &context) {
size_t num_factors_to_calculate =
context.mapping().entries().size() - factors_to_calculate_.size();

auto current_time = static_cast<unsigned int>(context.time_now());

for (const auto &person : context.population()) {
// Get the bin index for each factor
std::vector<size_t> bin_indices;
for (size_t i = 0; i < factors_to_calculate_.size(); i++) {
double factor_value = person.get_risk_factor_value(factors_to_calculate_[i]);
auto bin_index =
static_cast<size_t>((factor_value - factor_min_values_[i]) / factor_bin_widths_[i]);
bin_indices.push_back(bin_index);
}
// First let's fetch the correct `calculated_stats_` bin index for this person
size_t index = calculate_index(person);

// Calculate the index in the calculated_stats_ vector
size_t index = 0;
for (size_t i = 0; i < bin_indices.size() - 1; i++) {
size_t accumulated_bins =
std::accumulate(std::next(factor_bins_.cbegin(), i + 1), factor_bins_.cend(),
size_t{1}, std::multiplies<>());
index += bin_indices[i] * accumulated_bins * num_factors_to_calculate;
// Now we can add the calculated stats for this person to the correct index
if (!person.is_active()) {
if (!person.is_alive() && person.time_of_death() == current_time) {
calculated_stats_[index + get_channel_index("deaths")]++;
float expected_life =
definition_.life_expectancy().at(context.time_now(), person.gender);
double yll = std::max(expected_life - person.age, 0.0f) * DALY_UNITS;
calculated_stats_[index + get_channel_index("mean_yll")] += yll;
calculated_stats_[index + get_channel_index("mean_daly")] += yll;
}

if (person.has_emigrated() && person.time_of_migration() == current_time) {
calculated_stats_[index + get_channel_index("emigrations")]++;
}

continue;
}
index += bin_indices.back() * num_factors_to_calculate;

// Now we can add the values of the factors that are not in factors_to_calculate_
calculated_stats_[index + get_channel_index("count")]++;

for (const auto &factor : context.mapping().entries()) {
if (std::find(factors_to_calculate_.cbegin(), factors_to_calculate_.cend(),
factor.key()) == factors_to_calculate_.cend()) {
calculated_stats_[index++] += person.get_risk_factor_value(factor.key());
double value = person.get_risk_factor_value(factor.key());
calculated_stats_[index + get_channel_index("mean_" + factor.key().to_string())] +=
value;
}

for (const auto &[disease_name, disease_state] : person.diseases) {
if (disease_state.status == DiseaseStatus::active) {
calculated_stats_[index +
get_channel_index("prevalence_" + disease_name.to_string())]++;
if (disease_state.start_time == context.time_now()) {
calculated_stats_[index +
get_channel_index("incidence_" + disease_name.to_string())]++;
}
}
}

double dw = calculate_disability_weight(person);
double yld = dw * DALY_UNITS;
calculated_stats_[index + get_channel_index("mean_yld")] += yld;
calculated_stats_[index + get_channel_index("mean_daly")] += yld;

classify_weight(person);
}

// For each bin in the calculated stats...
for (size_t i = 0; i < calculated_stats_.size(); i += channels_.size()) {
double count_F = calculated_stats_[i + get_channel_index("count")];
double count_M = calculated_stats_[i + get_channel_index("count")];
double deaths_F = calculated_stats_[i + get_channel_index("deaths")];
double deaths_M = calculated_stats_[i + get_channel_index("deaths")];

// Calculate in-place factor averages.
for (const auto &factor : context.mapping().entries()) {
calculated_stats_[i + get_channel_index("mean_" + factor.key().to_string())] /= count_F;
calculated_stats_[i + get_channel_index("mean_" + factor.key().to_string())] /= count_M;
}

// Calculate in-place disease prevalence and incidence rates.
for (const auto &disease : context.diseases()) {
calculated_stats_[i + get_channel_index("prevalence_" + disease.code.to_string())] /=
count_F;
calculated_stats_[i + get_channel_index("prevalence_" + disease.code.to_string())] /=
count_M;
calculated_stats_[i + get_channel_index("incidence_" + disease.code.to_string())] /=
count_F;
calculated_stats_[i + get_channel_index("incidence_" + disease.code.to_string())] /=
count_M;
}

// Calculate in-place YLL/YLD/DALY averages.
for (const auto &column : {"mean_yll", "mean_yld", "mean_daly"}) {
calculated_stats_[i + get_channel_index(column)] /= (count_F + deaths_F);
calculated_stats_[i + get_channel_index(column)] /= (count_M + deaths_M);
}
}
}

Expand Down Expand Up @@ -531,6 +583,26 @@ void AnalysisModule::classify_weight(DataSeries &series, const Person &entity) c
}
}

void AnalysisModule::classify_weight(const Person &person) {
auto weight_class = weight_classifier_.classify_weight(person);
switch (weight_class) {
case WeightCategory::normal:
calculated_stats_[get_channel_index("normal_weight")]++;
break;
case WeightCategory::overweight:
calculated_stats_[get_channel_index("over_weight")]++;
calculated_stats_[get_channel_index("above_weight")]++;
break;
case WeightCategory::obese:
calculated_stats_[get_channel_index("obese_weight")]++;
calculated_stats_[get_channel_index("above_weight")]++;
break;
default:
throw std::logic_error("Unknown weight classification category.");
break;
}
}

void AnalysisModule::initialise_output_channels(RuntimeContext &context) {
if (!channels_.empty()) {
return;
Expand Down Expand Up @@ -560,8 +632,53 @@ void AnalysisModule::initialise_output_channels(RuntimeContext &context) {
channels_.emplace_back("std_yld");
channels_.emplace_back("mean_daly");
channels_.emplace_back("std_daly");

// Since we will be performing frequent lookups, we will store the strings and indexes in a map
// for quick access.
for (size_t i = 0; i < channels_.size(); i++) {
channel_index_.emplace(channels_[i], i);
}
}

size_t AnalysisModule::calculate_index(const Person &person) const {
// Get the bin index for each factor
std::vector<size_t> bin_indices;
for (size_t i = 0; i < factors_to_calculate_.size(); i++) {
double factor_value = person.get_risk_factor_value(factors_to_calculate_[i]);
auto bin_index =
static_cast<size_t>((factor_value - factor_min_values_[i]) / factor_bin_widths_[i]);
bin_indices.push_back(bin_index);
}

// Calculate the index in the calculated_stats_ vector
size_t index = 0;
for (size_t i = 0; i < bin_indices.size() - 1; i++) {
size_t accumulated_bins =
std::accumulate(std::next(factor_bins_.cbegin(), i + 1), factor_bins_.cend(), size_t{1},
std::multiplies<>());
index += bin_indices[i] * accumulated_bins * channels_.size();
}
index += bin_indices.back() * channels_.size();

return index;
}

size_t AnalysisModule::get_channel_index(const std::string &channel) const {
auto it = channel_index_.find(channel);
if (it == channel_index_.end()) {
throw std::out_of_range("Unknown channel: " + channel);
}

return it->second;
}

void AnalysisModule::accumulate_squared_diffs(size_t bin_index, size_t channel_index) const {
for (const auto &factor : factors_to_calculate_) {

Check failure on line 676 in src/HealthGPS/analysis_module.cpp

View workflow job for this annotation

GitHub Actions / Build and test (linux, ubuntu-22.04, gcc-latest, false)

unused variable ‘factor’ [-Werror=unused-variable]

Check failure on line 676 in src/HealthGPS/analysis_module.cpp

View workflow job for this annotation

GitHub Actions / Build and test (linux, ubuntu-22.04, gcc-latest, false)

unused variable ‘factor’ [-Werror=unused-variable]
const double mean = calculated_stats_[bin_index + channel_index];
const double diff = value - mean;

Check failure on line 678 in src/HealthGPS/analysis_module.cpp

View workflow job for this annotation

GitHub Actions / Build and test (linux, ubuntu-22.04, gcc-latest, false)

‘value’ was not declared in this scope

Check failure on line 678 in src/HealthGPS/analysis_module.cpp

View workflow job for this annotation

GitHub Actions / Build and test (linux, ubuntu-22.04, gcc-latest, false)

unused variable ‘diff’ [-Werror=unused-variable]

Check failure on line 678 in src/HealthGPS/analysis_module.cpp

View workflow job for this annotation

GitHub Actions / Build and test (linux, ubuntu-22.04, gcc-latest, false)

‘value’ was not declared in this scope

Check failure on line 678 in src/HealthGPS/analysis_module.cpp

View workflow job for this annotation

GitHub Actions / Build and test (linux, ubuntu-22.04, gcc-latest, false)

unused variable ‘diff’ [-Werror=unused-variable]
}
};

std::unique_ptr<AnalysisModule> build_analysis_module(Repository &repository,
const ModelInput &config) {
auto analysis_entity = repository.manager().get_disease_analysis(config.settings().country());
Expand Down
17 changes: 17 additions & 0 deletions src/HealthGPS/analysis_module.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class AnalysisModule final : public UpdatableModule {
WeightModel weight_classifier_;
DoubleAgeGenderTable residual_disability_weight_;
std::vector<std::string> channels_;
std::unordered_map<std::string, int> channel_index_;
unsigned int comorbidities_;
std::string name_{"Analysis"};
std::vector<core::Identifier> factors_to_calculate_ = {"Gender"_id, "Age"_id};
Expand All @@ -70,8 +71,24 @@ class AnalysisModule final : public UpdatableModule {
void calculate_population_statistics(RuntimeContext &context, DataSeries &series) const;

void classify_weight(hgps::DataSeries &series, const hgps::Person &entity) const;
void classify_weight(const Person &person);
void initialise_output_channels(RuntimeContext &context);

/// @brief Calculates the bin index in `calculated_stats_` for a given person
/// @param person The person to calculate the index for
/// @return The index in `calculated_stats_`
size_t calculate_index(const Person &person) const;

/// @brief Gets the index of the given channel name
/// @param channel The channel name
/// @return The channel index
size_t get_channel_index(const std::string &channel) const;

/// @brief Accumulates the squared differences between the given value and the mean
/// @param bin_index The index of the bin in `calculated_stats_`
/// @param channel_index The index of the channel in `channels_`
void accumulate_squared_diffs(size_t bin_index, size_t channel_index) const;

/// @brief Calculates the standard deviation of factors given data series containing means
/// @param context The runtime context
/// @param series The data series containing factor means
Expand Down

0 comments on commit d00133e

Please sign in to comment.