Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement stats calculations for new calculated_stats_ container #468

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
161 changes: 139 additions & 22 deletions src/HealthGPS/analysis_module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -319,36 +319,88 @@
}

void AnalysisModule::calculate_population_statistics(RuntimeContext &context) {
jamesturner246 marked this conversation as resolved.
Show resolved Hide resolved
TinyMarsh marked this conversation as resolved.
Show resolved Hide resolved
size_t num_factors_to_calculate =
context.mapping().entries().size() - factors_to_calculate_.size();

auto current_time = static_cast<unsigned int>(context.time_now());

for (const auto &person : context.population()) {
// Get the bin index for each factor
std::vector<size_t> bin_indices;
for (size_t i = 0; i < factors_to_calculate_.size(); i++) {
double factor_value = person.get_risk_factor_value(factors_to_calculate_[i]);
auto bin_index =
static_cast<size_t>((factor_value - factor_min_values_[i]) / factor_bin_widths_[i]);
bin_indices.push_back(bin_index);
}
// First let's fetch the correct `calculated_stats_` bin index for this person
size_t index = calculate_index(person);

// Calculate the index in the calculated_stats_ vector
size_t index = 0;
for (size_t i = 0; i < bin_indices.size() - 1; i++) {
size_t accumulated_bins =
std::accumulate(std::next(factor_bins_.cbegin(), i + 1), factor_bins_.cend(),
size_t{1}, std::multiplies<>());
index += bin_indices[i] * accumulated_bins * num_factors_to_calculate;
// Now we can add the calculated stats for this person to the correct index
if (!person.is_active()) {
if (!person.is_alive() && person.time_of_death() == current_time) {
calculated_stats_[index + get_channel_index("deaths")]++;
float expected_life =
definition_.life_expectancy().at(context.time_now(), person.gender);
double yll = std::max(expected_life - person.age, 0.0f) * DALY_UNITS;
calculated_stats_[index + get_channel_index("mean_yll")] += yll;
calculated_stats_[index + get_channel_index("mean_daly")] += yll;
}

if (person.has_emigrated() && person.time_of_migration() == current_time) {
calculated_stats_[index + get_channel_index("emigrations")]++;
}

continue;
}
index += bin_indices.back() * num_factors_to_calculate;

// Now we can add the values of the factors that are not in factors_to_calculate_
calculated_stats_[index + get_channel_index("count")]++;

for (const auto &factor : context.mapping().entries()) {
if (std::find(factors_to_calculate_.cbegin(), factors_to_calculate_.cend(),
factor.key()) == factors_to_calculate_.cend()) {
calculated_stats_[index++] += person.get_risk_factor_value(factor.key());
double value = person.get_risk_factor_value(factor.key());
calculated_stats_[index + get_channel_index("mean_" + factor.key().to_string())] +=
value;
}

for (const auto &[disease_name, disease_state] : person.diseases) {
if (disease_state.status == DiseaseStatus::active) {
calculated_stats_[index +
get_channel_index("prevalence_" + disease_name.to_string())]++;
if (disease_state.start_time == context.time_now()) {
calculated_stats_[index +
get_channel_index("incidence_" + disease_name.to_string())]++;
}
}
}

double dw = calculate_disability_weight(person);
double yld = dw * DALY_UNITS;
calculated_stats_[index + get_channel_index("mean_yld")] += yld;
calculated_stats_[index + get_channel_index("mean_daly")] += yld;

classify_weight(person);
}

// For each bin in the calculated stats...
for (size_t i = 0; i < calculated_stats_.size(); i += channels_.size()) {
double count_F = calculated_stats_[i + get_channel_index("count")];
double count_M = calculated_stats_[i + get_channel_index("count")];
double deaths_F = calculated_stats_[i + get_channel_index("deaths")];
double deaths_M = calculated_stats_[i + get_channel_index("deaths")];

// Calculate in-place factor averages.
for (const auto &factor : context.mapping().entries()) {
calculated_stats_[i + get_channel_index("mean_" + factor.key().to_string())] /= count_F;
calculated_stats_[i + get_channel_index("mean_" + factor.key().to_string())] /= count_M;
}

// Calculate in-place disease prevalence and incidence rates.
for (const auto &disease : context.diseases()) {
calculated_stats_[i + get_channel_index("prevalence_" + disease.code.to_string())] /=
count_F;
calculated_stats_[i + get_channel_index("prevalence_" + disease.code.to_string())] /=
count_M;
calculated_stats_[i + get_channel_index("incidence_" + disease.code.to_string())] /=
count_F;
calculated_stats_[i + get_channel_index("incidence_" + disease.code.to_string())] /=
count_M;
}

// Calculate in-place YLL/YLD/DALY averages.
for (const auto &column : {"mean_yll", "mean_yld", "mean_daly"}) {
calculated_stats_[i + get_channel_index(column)] /= (count_F + deaths_F);
calculated_stats_[i + get_channel_index(column)] /= (count_M + deaths_M);
}
}
}

Expand Down Expand Up @@ -531,6 +583,26 @@
}
}

void AnalysisModule::classify_weight(const Person &person) {
auto weight_class = weight_classifier_.classify_weight(person);
switch (weight_class) {
case WeightCategory::normal:
calculated_stats_[get_channel_index("normal_weight")]++;
break;
case WeightCategory::overweight:
calculated_stats_[get_channel_index("over_weight")]++;
calculated_stats_[get_channel_index("above_weight")]++;
TinyMarsh marked this conversation as resolved.
Show resolved Hide resolved
break;
case WeightCategory::obese:
calculated_stats_[get_channel_index("obese_weight")]++;
calculated_stats_[get_channel_index("above_weight")]++;
break;
default:
throw std::logic_error("Unknown weight classification category.");
break;
TinyMarsh marked this conversation as resolved.
Show resolved Hide resolved
}
}

void AnalysisModule::initialise_output_channels(RuntimeContext &context) {
if (!channels_.empty()) {
return;
Expand Down Expand Up @@ -560,8 +632,53 @@
channels_.emplace_back("std_yld");
channels_.emplace_back("mean_daly");
channels_.emplace_back("std_daly");

// Since we will be performing frequent lookups, we will store the strings and indexes in a map
// for quick access.
for (size_t i = 0; i < channels_.size(); i++) {
channel_index_.emplace(channels_[i], i);
}
}

size_t AnalysisModule::calculate_index(const Person &person) const {
TinyMarsh marked this conversation as resolved.
Show resolved Hide resolved
// Get the bin index for each factor
std::vector<size_t> bin_indices;
for (size_t i = 0; i < factors_to_calculate_.size(); i++) {
double factor_value = person.get_risk_factor_value(factors_to_calculate_[i]);
auto bin_index =
static_cast<size_t>((factor_value - factor_min_values_[i]) / factor_bin_widths_[i]);
bin_indices.push_back(bin_index);
}

// Calculate the index in the calculated_stats_ vector
size_t index = 0;
for (size_t i = 0; i < bin_indices.size() - 1; i++) {
size_t accumulated_bins =
std::accumulate(std::next(factor_bins_.cbegin(), i + 1), factor_bins_.cend(), size_t{1},
std::multiplies<>());
index += bin_indices[i] * accumulated_bins * channels_.size();
}
index += bin_indices.back() * channels_.size();

return index;
}

size_t AnalysisModule::get_channel_index(const std::string &channel) const {
auto it = channel_index_.find(channel);
if (it == channel_index_.end()) {
throw std::out_of_range("Unknown channel: " + channel);
}

return it->second;
TinyMarsh marked this conversation as resolved.
Show resolved Hide resolved
}

void AnalysisModule::accumulate_squared_diffs(size_t bin_index, size_t channel_index) const {
for (const auto &factor : factors_to_calculate_) {

Check failure on line 676 in src/HealthGPS/analysis_module.cpp

View workflow job for this annotation

GitHub Actions / Build and test (linux, ubuntu-22.04, gcc-latest, false)

unused variable ‘factor’ [-Werror=unused-variable]

Check failure on line 676 in src/HealthGPS/analysis_module.cpp

View workflow job for this annotation

GitHub Actions / Build and test (linux, ubuntu-22.04, gcc-latest, false)

unused variable ‘factor’ [-Werror=unused-variable]
const double mean = calculated_stats_[bin_index + channel_index];
const double diff = value - mean;

Check failure on line 678 in src/HealthGPS/analysis_module.cpp

View workflow job for this annotation

GitHub Actions / Build and test (linux, ubuntu-22.04, gcc-latest, false)

‘value’ was not declared in this scope

Check failure on line 678 in src/HealthGPS/analysis_module.cpp

View workflow job for this annotation

GitHub Actions / Build and test (linux, ubuntu-22.04, gcc-latest, false)

unused variable ‘diff’ [-Werror=unused-variable]

Check failure on line 678 in src/HealthGPS/analysis_module.cpp

View workflow job for this annotation

GitHub Actions / Build and test (linux, ubuntu-22.04, gcc-latest, false)

‘value’ was not declared in this scope

Check failure on line 678 in src/HealthGPS/analysis_module.cpp

View workflow job for this annotation

GitHub Actions / Build and test (linux, ubuntu-22.04, gcc-latest, false)

unused variable ‘diff’ [-Werror=unused-variable]
jamesturner246 marked this conversation as resolved.
Show resolved Hide resolved
}
};

std::unique_ptr<AnalysisModule> build_analysis_module(Repository &repository,
const ModelInput &config) {
auto analysis_entity = repository.manager().get_disease_analysis(config.settings().country());
Expand Down
17 changes: 17 additions & 0 deletions src/HealthGPS/analysis_module.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class AnalysisModule final : public UpdatableModule {
WeightModel weight_classifier_;
DoubleAgeGenderTable residual_disability_weight_;
std::vector<std::string> channels_;
std::unordered_map<std::string, int> channel_index_;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I might not be completely understanding what these data structures do... But am I right in thinking channel_index_ contains indexes into channels_? If so, could you just make channels_ a std::map<std::string, std::string> (I'm guessing it needs to be ordered...) and drop channels_index_?

unsigned int comorbidities_;
std::string name_{"Analysis"};
std::vector<core::Identifier> factors_to_calculate_ = {"Gender"_id, "Age"_id};
Expand All @@ -70,8 +71,24 @@ class AnalysisModule final : public UpdatableModule {
void calculate_population_statistics(RuntimeContext &context, DataSeries &series) const;

void classify_weight(hgps::DataSeries &series, const hgps::Person &entity) const;
void classify_weight(const Person &person);
void initialise_output_channels(RuntimeContext &context);

/// @brief Calculates the bin index in `calculated_stats_` for a given person
/// @param person The person to calculate the index for
/// @return The index in `calculated_stats_`
size_t calculate_index(const Person &person) const;

/// @brief Gets the index of the given channel name
/// @param channel The channel name
/// @return The channel index
size_t get_channel_index(const std::string &channel) const;

/// @brief Accumulates the squared differences between the given value and the mean
/// @param bin_index The index of the bin in `calculated_stats_`
/// @param channel_index The index of the channel in `channels_`
void accumulate_squared_diffs(size_t bin_index, size_t channel_index) const;

/// @brief Calculates the standard deviation of factors given data series containing means
/// @param context The runtime context
/// @param series The data series containing factor means
Expand Down
Loading