diff --git a/src/HealthGPS/analysis_module.cpp b/src/HealthGPS/analysis_module.cpp index d6045e8c2..bc731b7a4 100644 --- a/src/HealthGPS/analysis_module.cpp +++ b/src/HealthGPS/analysis_module.cpp @@ -22,13 +22,13 @@ AnalysisModule::AnalysisModule(AnalysisDefinition &&definition, WeightModel &&cl residual_disability_weight_{create_age_gender_table(age_range)}, comorbidities_{comorbidities} {} -// Overload constructor with additional parameter for calculated_factors_ +// Overload constructor with additional parameter for calculated_stats_ AnalysisModule::AnalysisModule(AnalysisDefinition &&definition, WeightModel &&classifier, const core::IntegerInterval age_range, unsigned int comorbidities, - std::vector calculated_factors) + std::vector calculated_stats) : definition_{std::move(definition)}, weight_classifier_{std::move(classifier)}, residual_disability_weight_{create_age_gender_table(age_range)}, - comorbidities_{comorbidities}, calculated_factors_{std::move(calculated_factors)} {} + comorbidities_{comorbidities}, calculated_stats_{std::move(calculated_stats)} {} SimulationModuleType AnalysisModule::type() const noexcept { return SimulationModuleType::Analysis; } @@ -59,17 +59,21 @@ void AnalysisModule::initialise_vector(RuntimeContext &context) { factor_bin_widths_.push_back((max_factor - min_factor) / factor_bins_.back()); } - // The number of factors to calculate is the number of factors minus the length of the `factors` - // vector. - size_t num_factors_to_calc = context.mapping().entries().size() - factors_to_calculate_.size(); + // The number of factors to calculate stats for is the number of factors minus the length of the + // `factors` vector. + size_t num_stats_to_calc = context.mapping().entries().size() - factors_to_calculate_.size(); + + // And for each factor, we calculate the stats described in `channels_`, so we + // multiply the size of `channels_` by the number of factors to calculate stats for. + num_stats_to_calc *= channels_.size(); // The product of the number of bins for each factor can be used to calculate the size of the - // `calculated_factors_` in the next step + // `calculated_stats_` in the next step size_t total_num_bins = std::accumulate(factor_bins_.cbegin(), factor_bins_.cend(), size_t{1}, std::multiplies<>()); // Set the vector size and initialise all values to 0.0 - calculated_factors_.resize(total_num_bins * num_factors_to_calc); + calculated_stats_.resize(total_num_bins * num_stats_to_calc); } const std::string &AnalysisModule::name() const noexcept { return name_; } @@ -115,7 +119,7 @@ void AnalysisModule::initialise_population(RuntimeContext &context) { void AnalysisModule::update_population(RuntimeContext &context) { // Reset the calculated factors vector to 0.0 - std::ranges::fill(calculated_factors_, 0.0); + std::ranges::fill(calculated_stats_, 0.0); publish_result_message(context); } @@ -328,7 +332,7 @@ void AnalysisModule::calculate_population_statistics(RuntimeContext &context) { bin_indices.push_back(bin_index); } - // Calculate the index in the calculated_factors_ vector + // Calculate the index in the calculated_stats_ vector size_t index = 0; for (size_t i = 0; i < bin_indices.size() - 1; i++) { size_t accumulated_bins = @@ -342,7 +346,7 @@ void AnalysisModule::calculate_population_statistics(RuntimeContext &context) { for (const auto &factor : context.mapping().entries()) { if (std::find(factors_to_calculate_.cbegin(), factors_to_calculate_.cend(), factor.key()) == factors_to_calculate_.cend()) { - calculated_factors_[index++] += person.get_risk_factor_value(factor.key()); + calculated_stats_[index++] += person.get_risk_factor_value(factor.key()); } } } diff --git a/src/HealthGPS/analysis_module.h b/src/HealthGPS/analysis_module.h index f018045b1..58ea43fe7 100644 --- a/src/HealthGPS/analysis_module.h +++ b/src/HealthGPS/analysis_module.h @@ -49,7 +49,7 @@ class AnalysisModule final : public UpdatableModule { unsigned int comorbidities_; std::string name_{"Analysis"}; std::vector factors_to_calculate_ = {"Gender"_id, "Age"_id}; - std::vector calculated_factors_; + std::vector calculated_stats_; std::vector factor_bins_; std::vector factor_bin_widths_; std::vector factor_min_values_;