Skip to content

Commit

Permalink
Merge pull request #195 from imperialCHEPI/rf_refactor_166
Browse files Browse the repository at this point in the history
Risk factor refactor
  • Loading branch information
jamesturner246 authored Aug 7, 2023
2 parents 8cdb171 + b685faa commit 19b4519
Show file tree
Hide file tree
Showing 8 changed files with 56 additions and 137 deletions.
57 changes: 0 additions & 57 deletions example/France.EBHLM.json
Original file line number Diff line number Diff line change
Expand Up @@ -7,63 +7,6 @@
"Alpha3": "FRA"
},
"BoundaryPercentage": 0.05,
"RiskFactors": [
{
"Name": "Gender",
"Level": 0,
"Range": [0, 1]
},
{
"Name": "Age",
"Level": 0,
"Range": [1, 87]
},
{
"Name": "Age2",
"Level": 0,
"Range": [1, 7569]
},
{
"Name": "Age3",
"Level": 0,
"Range": [1, 658503]
},
{
"Name": "SES",
"Level": 0,
"Range": [-2.316299, 2.296689]
},
{
"Name": "Sodium",
"Level": 1,
"Range": [1.127384, 8.656519]
},
{
"Name": "Protein",
"Level": 1,
"Range": [43.50682, 238.4145]
},
{
"Name": "Fat",
"Level": 1,
"Range": [45.04756, 382.664098658922]
},
{
"Name": "PA",
"Level": 2,
"Range": [22.22314, 9765.512]
},
{
"Name": "Energy",
"Level": 2,
"Range": [1326.14051277588, 7522.496]
},
{
"Name": "BMI",
"Level": 3,
"Range": [13.88, 39.48983]
}
],
"Variables": [
{
"Name": "dPA",
Expand Down
8 changes: 1 addition & 7 deletions src/HealthGPS.Console/configuration.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -293,13 +293,7 @@ std::vector<core::DiseaseInfo> get_diseases_info(core::Datastore &data_api, Conf
config.diseases.size());

for (const auto &code : config.diseases) {
auto code_key = core::Identifier{code};
auto item = data_api.get_disease_info(code_key);
if (item.has_value()) {
result.emplace_back(item.value());
} else {
fmt::print(fg(fmt::color::salmon), "Unknown disease code: {}.\n", code);
}
result.emplace_back(data_api.get_disease_info(code));
}

return result;
Expand Down
1 change: 1 addition & 0 deletions src/HealthGPS.Console/csvparser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ create_fields_index_mapping(const std::vector<std::string> &column_names,

return mapping;
}

} // namespace detail

bool load_datatable_from_csv(hc::DataTable &out_table, std::string full_filename,
Expand Down
16 changes: 4 additions & 12 deletions src/HealthGPS.Console/program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,17 +66,9 @@ int main(int argc, char *argv[]) { // NOLINT(bugprone-exception-escape)
auto factory = get_default_simulation_module_factory(data_repository);

// Validate the configuration's target country for the simulation
auto countries = data_api.get_countries();
fmt::print("\nThere are {} countries in storage.\n", countries.size());
auto target = data_api.get_country(config.settings.country);
if (target.has_value()) {
fmt::print("Target country: {} - {}, population: {:0.3g}%.\n", target.value().code,
target.value().name, config.settings.size_fraction * 100.0f);
} else {
fmt::print(fg(fmt::color::red), "\nTarget country: {} not found.\n",
config.settings.country);
return exit_application(EXIT_FAILURE);
}
auto country = data_api.get_country(config.settings.country);
fmt::print("Target country: {} - {}, population: {:0.3g}%.\n", country.code, country.name,
config.settings.size_fraction * 100.0f);

// Validate the configuration diseases list, must exists in back-end data store
auto diseases = get_diseases_info(data_api, config);
Expand All @@ -95,7 +87,7 @@ int main(int argc, char *argv[]) { // NOLINT(bugprone-exception-escape)
std::cout << input_table;

// Create complete model input from configuration
auto model_input = create_model_input(input_table, target.value(), config, diseases);
auto model_input = create_model_input(input_table, country, config, diseases);

// Create event bus and event monitor with a results file writer
auto event_bus = DefaultEventBus();
Expand Down
9 changes: 4 additions & 5 deletions src/HealthGPS.Core/datastore.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

#include "interval.h"
#include "poco.h"
#include <optional>
#include <vector>

namespace hgps::core {
Expand All @@ -21,8 +20,8 @@ class Datastore {

/// @brief Gets a single country by the alpha code
/// @param alpha The country alpha 2 or 3 format code to search
/// @return The country's definition, if found, otherwise empty
virtual std::optional<Country> get_country(std::string alpha) const = 0;
/// @return The country's definition
virtual Country get_country(const std::string &alpha) const = 0;

/// @brief Gets the population growth trend for a country filtered by time
/// @param country The target country definition
Expand All @@ -46,8 +45,8 @@ class Datastore {

/// @brief Gets a single disease information by identifier
/// @param code The target disease identifier
/// @return The disease information, if found, otherwise empty
virtual std::optional<DiseaseInfo> get_disease_info(Identifier code) const = 0;
/// @return The disease information
virtual DiseaseInfo get_disease_info(const Identifier &code) const = 0;

/// @brief Gets a disease full definition by identifier for a country
/// @param info The target disease information
Expand Down
25 changes: 13 additions & 12 deletions src/HealthGPS.Datastore/datamanager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,18 +58,19 @@ std::vector<Country> DataManager::get_countries() const {
return results;
}

std::optional<Country> DataManager::get_country(std::string alpha) const {
auto v = get_countries();
auto is_target = [&alpha](const hgps::core::Country &obj) {
return core::case_insensitive::equals(obj.alpha2, alpha) ||
core::case_insensitive::equals(obj.alpha3, alpha);
Country DataManager::get_country(const std::string &alpha) const {
auto c = get_countries();
auto is_target = [&alpha](const hgps::core::Country &c) {
return core::case_insensitive::equals(c.alpha2, alpha) ||
core::case_insensitive::equals(c.alpha3, alpha);
};

if (auto it = std::find_if(v.begin(), v.end(), is_target); it != v.end()) {
return (*it);
auto country = std::find_if(c.begin(), c.end(), is_target);
if (country != c.end()) {
return *country;
}

return std::nullopt;
throw std::invalid_argument(fmt::format("Target country: '{}' not found.", alpha));
}

std::vector<PopulationItem> DataManager::get_population(Country country) const {
Expand Down Expand Up @@ -200,7 +201,7 @@ std::vector<DiseaseInfo> DataManager::get_diseases() const {
return result;
}

std::optional<DiseaseInfo> DataManager::get_disease_info(core::Identifier code) const {
DiseaseInfo DataManager::get_disease_info(const core::Identifier &code) const {
if (index_.contains("diseases")) {
auto &registry = index_["diseases"]["registry"];
auto disease_code_str = code.to_string();
Expand All @@ -221,11 +222,11 @@ std::optional<DiseaseInfo> DataManager::get_disease_info(core::Identifier code)
return info;
}
}
} else {
notify_warning("index has no 'diseases' entry.");

throw std::invalid_argument(fmt::format("Disease code: '{}' not found.", code.to_string()));
}

return std::optional<DiseaseInfo>();
throw std::runtime_error("Index has no 'diseases' entry.");
}

DiseaseEntity DataManager::get_disease(DiseaseInfo info, Country country) const {
Expand Down
7 changes: 3 additions & 4 deletions src/HealthGPS.Datastore/datamanager.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,14 @@ class DataManager : public Datastore {
/// @brief Initialises a new instance of the hgps::data::DataManager class.
/// @param root_directory The store root folder containing the index.json file.
/// @param verbosity The terminal logging verbosity mode to use.
/// @throws std::invalid_argument if the given folder does exists or contains the index.json
/// file.
/// @throws std::invalid_argument if the root directory or index.json is missing.
/// @throws std::runtime_error for invalid or unsupported index.json file schema version.
explicit DataManager(const std::filesystem::path root_directory,
VerboseMode verbosity = VerboseMode::none);

std::vector<Country> get_countries() const override;

std::optional<Country> get_country(std::string alpha) const override;
Country get_country(const std::string &alpha) const override;

std::vector<PopulationItem> get_population(Country country) const;

Expand All @@ -52,7 +51,7 @@ class DataManager : public Datastore {

std::vector<DiseaseInfo> get_diseases() const override;

std::optional<DiseaseInfo> get_disease_info(core::Identifier code) const override;
DiseaseInfo get_disease_info(const core::Identifier &code) const override;

DiseaseEntity get_disease(DiseaseInfo code, Country country) const override;

Expand Down
70 changes: 30 additions & 40 deletions src/HealthGPS.Tests/Datastore.Test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,7 @@ class DatastoreTest : public ::testing::Test {
DatastoreTest() : manager{test_datastore_path} {}

hgps::data::DataManager manager;

// We don't need to check because this is just for testing
hgps::core::Country uk =
manager.get_country("GB").value(); // NOLINT(bugprone-unchecked-optional-access)
hgps::core::Country uk = manager.get_country("GB");
};

TEST_F(DatastoreTest, CreateDataManager) {
Expand All @@ -23,24 +20,20 @@ TEST_F(DatastoreTest, CreateDataManager) {
}

TEST_F(DatastoreTest, CreateDataManagerFailWithWrongPath) {
using namespace hgps::data;
ASSERT_THROW(DataManager{"C:\\x\\y"}, std::invalid_argument);
ASSERT_THROW(DataManager{"C:/x/y"}, std::invalid_argument);
ASSERT_THROW(DataManager{"/home/x/y/z"}, std::invalid_argument);
EXPECT_THROW(hgps::data::DataManager{"C:\\x\\y"}, std::invalid_argument);
EXPECT_THROW(hgps::data::DataManager{"C:/x/y"}, std::invalid_argument);
EXPECT_THROW(hgps::data::DataManager{"/home/x/y/z"}, std::invalid_argument);
}

TEST_F(DatastoreTest, CountryIsCaseInsensitive) {
auto countries = manager.get_countries();
auto gb2_lower = manager.get_country("gb");
auto gb2_upper = manager.get_country("GB");
auto gb3_lower = manager.get_country("gbr");
auto gb3_upper = manager.get_country("GBR");
TEST_F(DatastoreTest, CountryMissingThrowsException) {
ASSERT_THROW(manager.get_country("FAKE"), std::invalid_argument);
}

ASSERT_GT(countries.size(), 0);
ASSERT_TRUE(gb2_lower.has_value());
ASSERT_TRUE(gb2_upper.has_value());
ASSERT_TRUE(gb3_lower.has_value());
ASSERT_TRUE(gb3_upper.has_value());
TEST_F(DatastoreTest, CountryIsCaseInsensitive) {
EXPECT_NO_THROW(manager.get_country("gb"));
EXPECT_NO_THROW(manager.get_country("GB"));
EXPECT_NO_THROW(manager.get_country("gbr"));
EXPECT_NO_THROW(manager.get_country("GBR"));
}

TEST_F(DatastoreTest, CountryPopulation) {
Expand Down Expand Up @@ -130,15 +123,24 @@ TEST_F(DatastoreTest, CountryMortality) {
}
}

TEST_F(DatastoreTest, RetrieveDeseasesInfo) {
TEST_F(DatastoreTest, GetDiseases) {
auto diseases = manager.get_diseases();
ASSERT_GT(diseases.size(), 0);
}

TEST_F(DatastoreTest, GetDiseaseInfoMatchesGetDisases) {
auto diseases = manager.get_diseases();
for (auto &item : diseases) {
auto info = manager.get_disease_info(item.code);
EXPECT_TRUE(info.has_value());
EXPECT_EQ(item.code, info.value().code); // NOLINT(bugprone-unchecked-optional-access)
auto call = [&] {
auto info = manager.get_disease_info(item.code);
EXPECT_EQ(item.code, info.code);
};
EXPECT_NO_THROW(call());
}
}

ASSERT_GT(diseases.size(), 0);
TEST_F(DatastoreTest, GetDiseaseInfoMissingThrowsException) {
EXPECT_THROW(manager.get_disease_info("FAKE"), std::invalid_argument);
}

TEST_F(DatastoreTest, RetrieveDeseasesTypeInInfo) {
Expand All @@ -156,12 +158,6 @@ TEST_F(DatastoreTest, RetrieveDeseasesTypeInInfo) {
ASSERT_GT(cancer_count, 0);
}

TEST_F(DatastoreTest, RetrieveDeseasesInfoHasNoValue) {
using namespace hgps::core;
auto info = manager.get_disease_info(Identifier{"ghost369"});
EXPECT_FALSE(info.has_value());
}

TEST_F(DatastoreTest, RetrieveDeseaseDefinition) {
auto diseases = manager.get_diseases();
for (auto &item : diseases) {
Expand Down Expand Up @@ -192,12 +188,8 @@ TEST_F(DatastoreTest, RetrieveDeseaseDefinitionIsEmpty) {
}

TEST_F(DatastoreTest, DiseaseRelativeRiskToDisease) {
using namespace hgps::core;

// NOLINTBEGIN(bugprone-unchecked-optional-access)
auto asthma = manager.get_disease_info(Identifier{"asthma"}).value();
auto diabetes = manager.get_disease_info(Identifier{"diabetes"}).value();
// NOLINTEND(bugprone-unchecked-optional-access)
auto asthma = manager.get_disease_info("asthma");
auto diabetes = manager.get_disease_info("diabetes");

auto table_self = manager.get_relative_risk_to_disease(diabetes, diabetes);
auto table_other = manager.get_relative_risk_to_disease(diabetes, asthma);
Expand All @@ -217,8 +209,7 @@ TEST_F(DatastoreTest, DiseaseRelativeRiskToDisease) {
TEST_F(DatastoreTest, DefaultDiseaseRelativeRiskToDisease) {
using namespace hgps::core;

// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
auto diabetes = manager.get_disease_info(Identifier{"diabetes"}).value();
auto diabetes = manager.get_disease_info("diabetes");
auto info = DiseaseInfo{.group = DiseaseGroup::other,
.code = Identifier{"ghost369"},
.name = "Look at the flowers."};
Expand All @@ -235,8 +226,7 @@ TEST_F(DatastoreTest, DiseaseRelativeRiskToRiskFactor) {
using namespace hgps::core;

auto risk_factor = Identifier{"bmi"};
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
auto diabetes = manager.get_disease_info(Identifier{"diabetes"}).value();
auto diabetes = manager.get_disease_info("diabetes");

auto col_size = 8;

Expand Down

0 comments on commit 19b4519

Please sign in to comment.