Skip to content

Commit

Permalink
Merge pull request #466 from imperialCHEPI/extend_analysis
Browse files Browse the repository at this point in the history
Correct calculation of size of `calculated_stats_` vector
  • Loading branch information
TinyMarsh committed Jul 2, 2024
2 parents 37a99d4 + 7aa9470 commit 0b0ce7a
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 12 deletions.
26 changes: 15 additions & 11 deletions src/HealthGPS/analysis_module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,13 @@ AnalysisModule::AnalysisModule(AnalysisDefinition &&definition, WeightModel &&cl
residual_disability_weight_{create_age_gender_table<double>(age_range)},
comorbidities_{comorbidities} {}

// Overload constructor with additional parameter for calculated_factors_
// Overload constructor with additional parameter for calculated_stats_
AnalysisModule::AnalysisModule(AnalysisDefinition &&definition, WeightModel &&classifier,
const core::IntegerInterval age_range, unsigned int comorbidities,
std::vector<double> calculated_factors)
std::vector<double> calculated_stats)
: definition_{std::move(definition)}, weight_classifier_{std::move(classifier)},
residual_disability_weight_{create_age_gender_table<double>(age_range)},
comorbidities_{comorbidities}, calculated_factors_{std::move(calculated_factors)} {}
comorbidities_{comorbidities}, calculated_stats_{std::move(calculated_stats)} {}
SimulationModuleType AnalysisModule::type() const noexcept {
return SimulationModuleType::Analysis;
}
Expand Down Expand Up @@ -59,17 +59,21 @@ void AnalysisModule::initialise_vector(RuntimeContext &context) {
factor_bin_widths_.push_back((max_factor - min_factor) / factor_bins_.back());
}

// The number of factors to calculate is the number of factors minus the length of the `factors`
// vector.
size_t num_factors_to_calc = context.mapping().entries().size() - factors_to_calculate_.size();
// The number of factors to calculate stats for is the number of factors minus the length of the
// `factors` vector.
size_t num_stats_to_calc = context.mapping().entries().size() - factors_to_calculate_.size();

// And for each factor, we calculate the stats described in `channels_`, so we
// multiply the size of `channels_` by the number of factors to calculate stats for.
num_stats_to_calc *= channels_.size();

// The product of the number of bins for each factor can be used to calculate the size of the
// `calculated_factors_` in the next step
// `calculated_stats_` in the next step
size_t total_num_bins =
std::accumulate(factor_bins_.cbegin(), factor_bins_.cend(), size_t{1}, std::multiplies<>());

// Set the vector size and initialise all values to 0.0
calculated_factors_.resize(total_num_bins * num_factors_to_calc);
calculated_stats_.resize(total_num_bins * num_stats_to_calc);
}

const std::string &AnalysisModule::name() const noexcept { return name_; }
Expand Down Expand Up @@ -115,7 +119,7 @@ void AnalysisModule::initialise_population(RuntimeContext &context) {
void AnalysisModule::update_population(RuntimeContext &context) {

// Reset the calculated factors vector to 0.0
std::ranges::fill(calculated_factors_, 0.0);
std::ranges::fill(calculated_stats_, 0.0);

publish_result_message(context);
}
Expand Down Expand Up @@ -328,7 +332,7 @@ void AnalysisModule::calculate_population_statistics(RuntimeContext &context) {
bin_indices.push_back(bin_index);
}

// Calculate the index in the calculated_factors_ vector
// Calculate the index in the calculated_stats_ vector
size_t index = 0;
for (size_t i = 0; i < bin_indices.size() - 1; i++) {
size_t accumulated_bins =
Expand All @@ -342,7 +346,7 @@ void AnalysisModule::calculate_population_statistics(RuntimeContext &context) {
for (const auto &factor : context.mapping().entries()) {
if (std::find(factors_to_calculate_.cbegin(), factors_to_calculate_.cend(),
factor.key()) == factors_to_calculate_.cend()) {
calculated_factors_[index++] += person.get_risk_factor_value(factor.key());
calculated_stats_[index++] += person.get_risk_factor_value(factor.key());
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/HealthGPS/analysis_module.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class AnalysisModule final : public UpdatableModule {
unsigned int comorbidities_;
std::string name_{"Analysis"};
std::vector<core::Identifier> factors_to_calculate_ = {"Gender"_id, "Age"_id};
std::vector<double> calculated_factors_;
std::vector<double> calculated_stats_;
std::vector<size_t> factor_bins_;
std::vector<double> factor_bin_widths_;
std::vector<double> factor_min_values_;
Expand Down

0 comments on commit 0b0ce7a

Please sign in to comment.