Skip to content

Commit

Permalink
Path management
Browse files Browse the repository at this point in the history
Errors management
Unique version management
  • Loading branch information
gtani committed Sep 4, 2023
1 parent 1bb5b7b commit 302ff93
Show file tree
Hide file tree
Showing 8 changed files with 68 additions and 69 deletions.
3 changes: 0 additions & 3 deletions configuration/environment/all.yaml

This file was deleted.

11 changes: 11 additions & 0 deletions configuration/environment/notebook_environment.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
data:
externals: ../data/externals
raw: ../data/raw
processed: ../data/processed
final: ../data/final
results: ../results


reload: true
extract: true
save_to_disk: false
3 changes: 0 additions & 3 deletions configuration/environment/specific_dataset.yaml

This file was deleted.

13 changes: 3 additions & 10 deletions configuration/main.yaml
Original file line number Diff line number Diff line change
@@ -1,22 +1,15 @@
# @package _global_
version_base: 0.01
data:
externals: ../data/externals
raw: ../data/raw
processed: ../data/processed
final: ../data/final
results: ../results
defaults:
- environment/notebook_environment

export_path: .

output_file: unit_risk_score.csv


surveys: []
survey_version: all
reload: false
extract: true
save_to_disk: false
survey_version: null

features:
answer_hour_set:
Expand Down
52 changes: 23 additions & 29 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,49 +1,43 @@
import os
from omegaconf import DictConfig, OmegaConf
from hydra.core.hydra_config import HydraConfig
from src.unit_proccessing import *
import hydra
#from memory_profiler import memory_usage
# from memory_profiler import memory_usage
import warnings
warnings.simplefilter(action='ignore', category=Warning)


def manage_relative_path(config, abosulute_path):
for name, relative_path in config.data.items():
if relative_path.startswith('../'):
config['data'][name] = os.path.join(abosulute_path, relative_path.replace('../', ''))
return config


def manage_survey_definition(config):
if config['surveys'] != 'all' and type(config['surveys']) == str:
config['surveys'] = [config['surveys']]
if config['survey_version'] != 'all' and type(config['survey_version']) == str:
config['survey_version'] = [config['survey_version']]
return config
warnings.simplefilter(action='ignore', category=Warning)


def manage_export_path(config):
def manage_path(config):
if config['export_path'] is not None:
config['data']['externals'] = os.path.dirname(config['export_path'])
if os.path.isabs(config['export_path']) is False:
root_path = HydraConfig.get().runtime.cwd
config['export_path'] = os.path.join(root_path, config['export_path'])
config['environment']['data']['externals'] = os.path.dirname(config['export_path'])
config['surveys'] = [os.path.basename(config['export_path'])]
if os.path.isabs(config['output_file']) is False:
root_path = HydraConfig.get().runtime.cwd
config['output_file'] = os.path.join(root_path, config['output_file'])
return config


@hydra.main(config_path='configuration', version_base='1.1', config_name='main.yaml')
def unit_risk_score(config: DictConfig) -> None:
#print(OmegaConf.to_yaml(config))
# print(OmegaConf.to_yaml(config))
print("*" * 12)
config = manage_export_path(config)
config = manage_relative_path(config, hydra.utils.get_original_cwd())
config = manage_survey_definition(config)
features_class = UnitDataProcessing(config)
df_item = features_class.df_item
df_unit = features_class.df_unit
features_class.make_global_score()
features_class.save()
config = manage_path(config)
try:
survey_class = UnitDataProcessing(config)
df_item = survey_class.df_item
df_unit = survey_class.df_unit
survey_class.make_global_score()
survey_class.save()
except ValueError as e:
print(f"An error occurred: {e}")


if __name__ == "__main__":
unit_risk_score()
#mem_usage = memory_usage(unit_risk_score)
#print(f"Memory usage (in MB): {max(mem_usage)}")
# mem_usage = memory_usage(unit_risk_score)
# print(f"Memory usage (in MB): {max(mem_usage)}")
14 changes: 7 additions & 7 deletions src/feature_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ def __init__(self, config):
super().__init__(config)

self.extract()
paradata, questionaire, microdata = self.get_dataframes(reload=self.config['reload'],
save_to_disk=self.config['save_to_disk'])
paradata, questionaire, microdata = self.get_dataframes(reload=self.config['environment']['reload'],
save_to_disk=self.config['environment']['save_to_disk'])
print('Data Loaded')
self._allowed_features = ['f__' + k for k, v in config['features'].items() if v['use']]
self.item_level_columns = ['interview__id', 'variable_name', 'roster_level']
Expand Down Expand Up @@ -87,14 +87,14 @@ def df_paradata(self):

@property
def df_microdata(self):
paradata, questionaire, microdata = self.get_dataframes(reload=self.config['reload'],
save_to_disk=self.config['save_to_disk'])
paradata, questionaire, microdata = self.get_dataframes(reload=self.config['environment']['reload'],
save_to_disk=self.config['environment']['save_to_disk'])
return microdata

@property
def df_questionaire(self):
paradata, questionaire, microdata = self.get_dataframes(reload=self.config['reload'],
save_to_disk=self.config['save_to_disk'])
paradata, questionaire, microdata = self.get_dataframes(reload=self.config['environment']['reload'],
save_to_disk=self.config['environment']['save_to_disk'])
return questionaire

def make_index_col(self, df):
Expand Down Expand Up @@ -274,7 +274,7 @@ def make_df_responsible(self):

def save_data(self, df, file_name):

target_dir = os.path.join(self.config.data.raw, self.config.surveys)
target_dir = os.path.join(self.config['environment']['data']['raw'], self.config.surveys)
survey_path = os.path.join(target_dir, self.config.survey_version)
processed_data_path = os.path.join(survey_path, 'processed_data')
df.to_pickle(os.path.join(processed_data_path, f'{file_name}.pkl'))
Expand Down
33 changes: 20 additions & 13 deletions src/import_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,23 +478,24 @@ def __init__(self, config):
self.file_dict = {}
self.get_survey_version()


def get_files(self):
"""
Get a dictionary with all zip files from the surveys defined in the config.
"""
# code omitted for brevity
# Get a dictionary with all zip files from the surveys defined in config
if self.config.surveys == 'all':
import_path = os.listdir(self.config.data.externals)
import_path = os.listdir(self.config['environment']['data']['externals'])
else:
# Get surveys defined in the config file that are present in the path
import_path = [survey for survey in self.config.surveys if survey in os.listdir(self.config.data.externals)]

import_path = [survey for survey in self.config.surveys if survey in os.listdir(self.config['environment']['data']['externals'])]
if len(import_path) == 0:
raise ValueError(f"ERROR: survey path {self.config['export_path']} does not exists")
for survey_name in import_path:
if os.path.isdir(os.path.join(self.config.data.externals, survey_name)):
if os.path.isdir(os.path.join(self.config['environment']['data']['externals'], survey_name)):
self.file_dict[survey_name] = self.file_dict.get(survey_name, {})

survey_path = os.path.join(self.config.data.externals, survey_name)
survey_path = os.path.join(self.config['environment']['data']['externals'], survey_name)
for filename in os.listdir(survey_path):
if filename.endswith('.zip'):

Expand All @@ -515,12 +516,18 @@ def get_survey_version(self):
"""
self.get_files()
if self.config.surveys != 'all':
if self.config.survey_version != 'all':
self.file_dict = {k: {nk: v for nk, v in nested_dict.items() if nk in self.config.survey_version} for
k, nested_dict in self.file_dict.items() if k in self.config.surveys}
else:
if self.config.survey_version is None:
if len(self.file_dict[self.config.surveys[0]]) > 1:
raise ValueError(f"There are multiple versions in {self.config['export_path']}. "
f"Either specify survey_version=all in python main.py i.e. \n"
f"python main.py export_path={self.config['export_path']} output_file={self.config['output_file']} survey_version=all "
f"\n OR provide a path with only one version.")
elif self.config.survey_version == 'all':
self.file_dict = {survey: survey_data for survey, survey_data in self.file_dict.items() if
survey in self.config.surveys}
else:
self.file_dict = {k: {nk: v for nk, v in nested_dict.items() if nk in self.config.survey_version} for
k, nested_dict in self.file_dict.items() if k in self.config.surveys}

def extract(self, overwrite_dir=False):
"""
Expand All @@ -529,9 +536,9 @@ def extract(self, overwrite_dir=False):
Parameters:
overwrite_dir: A boolean indicating whether to overwrite the existing directory.
"""
if self.config['extract']:
if self.config['environment']['extract']:
for survey_name, survey in self.file_dict.items():
target_dir = os.path.join(self.config.data.raw, survey_name)
target_dir = os.path.join(self.config['environment']['data']['raw'], survey_name)
if overwrite_dir and os.path.exists(target_dir):
shutil.rmtree(target_dir)
# Create a new target directory if it does not yet exist
Expand Down Expand Up @@ -589,7 +596,7 @@ def get_dataframes(self, save_to_disk=True, reload=False):
dfs_questionnaires = []
dfs_microdata = []
for survey_name, survey in self.file_dict.items():
target_dir = os.path.join(self.config.data.raw, survey_name)
target_dir = os.path.join(self.config['environment']['data']['raw'], survey_name)

for survey_version, files in survey.items():
print(f"IMPORTING: {survey_name} with version {survey_version}. ")
Expand Down
8 changes: 4 additions & 4 deletions src/unit_proccessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,14 @@ def __init__(self, config):
def df_unit_score(self):
for method_name in self.get_make_methods(method_type='score', level='unit'):
feature_name = method_name.replace('make_score_unit__', 'f__')

score_name = self.rename_feature(feature_name)
if feature_name in self._allowed_features and self._score_columns is None:
try:
print('Processing Score {}...'.format(feature_name))
print('Processing Score {}...'.format(score_name))
getattr(self, method_name)(feature_name)
# print('Score{} Processed'.format(feature_name))
except Exception as e:
print("WARNING: SCORE: {} won't be used in further calculation".format(feature_name))
print("WARNING: SCORE: {} won't be used in further calculation".format(score_name))

score_columns = [col for col in self._df_unit if
col.startswith('s__')] # and col.replace('s__','f__') in self._allowed_features]
Expand Down Expand Up @@ -100,7 +100,7 @@ def save(self):
df['unit_risk_score'] = df['unit_risk_score'].round(2)
df.sort_values('unit_risk_score', inplace=True)
file_name = "_".join([self.config.surveys[0], self.config.survey_version[0], 'unit_risk_score']) + ".csv"
output_path = self.config.output_file.split('.')[0] + '.csv'
output_path = self.config['output_file'].split('.')[0] + '.csv'
df.to_csv(output_path, index=False)
print(f'SUCCESS! you can find the unit_risk_score output file in {output_path}')

Expand Down

0 comments on commit 302ff93

Please sign in to comment.