Skip to content

Commit

Permalink
Merge pull request #8 from NWC-CUAHSI-Summer-Institute/add-NN
Browse files Browse the repository at this point in the history
Add nn
  • Loading branch information
RY4GIT authored Jul 20, 2023
2 parents 891647d + 75612a6 commit aaed81a
Show file tree
Hide file tree
Showing 16 changed files with 612,339 additions and 80 deletions.
259,881 changes: 259,881 additions & 0 deletions data/01137500-usgs-hourly.csv

Large diffs are not rendered by default.

352,369 changes: 352,369 additions & 0 deletions data/01137500_hourly_nldas.csv

Large diffs are not rendered by default.

Binary file modified data/synthetic_case/01022500_synthetic_classic.npy
Binary file not shown.
Binary file not shown.
Binary file added data/synthetic_case/01137500_synthetic_ode.npy
Binary file not shown.
9 changes: 5 additions & 4 deletions __main__.py → src/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from omegaconf import DictConfig
import time

from src.agents.DifferentiableCFE import DifferentiableCFE
from src.agents.SyntheticAgent import SyntheticAgent
from agents.DifferentiableCFE import DifferentiableCFE
from agents.SyntheticAgent import SyntheticAgent

log = logging.getLogger(__name__)

Expand All @@ -17,9 +17,10 @@
def main(cfg: DictConfig) -> None:
start = time.perf_counter()
print(f"Running in {cfg.run_type} mode")
if cfg.run_type == "ML":
log.info(f"{cfg.run_type}")
if (cfg.run_type == "ML") | (cfg.run_type == "ML_synthetic_test"):
agent = DifferentiableCFE(cfg) # For Running against Observed Data
elif cfg.run_type == "synthetic":
elif cfg.run_type == "generate_synthetic":
agent = SyntheticAgent(cfg)
agent.run()
agent.finalize()
Expand Down
24 changes: 13 additions & 11 deletions src/agents/DifferentiableCFE.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP

from src.agents.base import BaseAgent
from src.data.Data import Data
from src.data.metrics import calculate_nse
from src.models.dCFE import dCFE
from src.utils.ddp_setup import find_free_port, cleanup
from agents.base import BaseAgent
from data.Data import Data
from data.metrics import calculate_nse
from models.dCFE import dCFE
from utils.ddp_setup import find_free_port, cleanup

import numpy as np

Expand All @@ -29,7 +29,7 @@

import json

log = logging.getLogger("agents.DifferentiableLGAR")
log = logging.getLogger("agents.DifferentiableCFE")

# Set the RANK environment variable manually

Expand Down Expand Up @@ -64,7 +64,7 @@ def __init__(self, cfg: DictConfig) -> None:

self.criterion = torch.nn.MSELoss()
self.optimizer = torch.optim.Adam(
self.model.parameters(), lr=cfg["src\models"].hyperparameters.learning_rate
self.model.parameters(), lr=cfg.models.hyperparameters.learning_rate
)

self.current_epoch = 0
Expand All @@ -91,7 +91,7 @@ def run(self):
# :return:
# """
# self.model.train()
# for epoch in range(1, self.cfg["src\models"].hyperparameters.epochs + 1):
# for epoch in range(1, self.cfg.models.hyperparameters.epochs + 1):
# self.train_one_epoch()
# self.current_epoch += 1

Expand All @@ -114,7 +114,8 @@ def train(self) -> None:
# self.net = DDP(self.model.to(self.cfg.device), device_ids=None)

self.model.mlp_forward()
for epoch in range(1, self.cfg["src\models"].hyperparameters.epochs + 1):
for epoch in range(1, self.cfg.models.hyperparameters.epochs + 1):
log.info(f"Epoch #: {epoch}")
# self.data_loader.sampler.set_epoch(epoch)
self.train_one_epoch()
self.model.mlp_forward()
Expand Down Expand Up @@ -167,7 +168,7 @@ def validate(self, y_hat_: Tensor, y_t_: Tensor) -> None:
- y_t_ : The tensor containing actual values.
"""
y_t_ = y_t_.squeeze()
warmup = self.cfg["src\models"].hyperparameters.warmup
warmup = self.cfg.models.hyperparameters.warmup
y_hat = y_hat_[warmup:]
y_t = y_t_[warmup:]

Expand Down Expand Up @@ -249,7 +250,7 @@ def finalize(self):

print(self.model.finalize())

cleanup()
# cleanup()

except:
raise NotImplementedError
Expand Down Expand Up @@ -290,6 +291,7 @@ def save_result(self, y_hat, y_t, eval_metrics, out_filename):
axes.set_title(f"Classic (KGE={float(eval_metrics):.4})")
plt.legend()
plt.savefig(os.path.join(matching_folder[0], f"{out_filename}.png"))
plt.close()

# # Best param
# array_dict = {
Expand Down
15 changes: 8 additions & 7 deletions src/agents/SyntheticAgent.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP

from src.agents.base import BaseAgent
from src.data.Data import Data
from src.data.metrics import calculate_nse
from src.models.dCFE import dCFE
from src.models.SyntheticCFE import SyntheticCFE
from src.utils.ddp_setup import find_free_port, cleanup
from agents.base import BaseAgent
from data.Data import Data
from data.metrics import calculate_nse
from models.dCFE import dCFE
from models.SyntheticCFE import SyntheticCFE
from utils.ddp_setup import find_free_port, cleanup

import numpy as np

Expand Down Expand Up @@ -71,7 +71,7 @@ def run(self):
for i, (x, y_t) in enumerate(
tqdm(self.data_loader, desc="Processing data")
):
runoff = self.model(x) #
runoff = self.model(x)
y_hat[i] = runoff

self.save_data(y_hat)
Expand Down Expand Up @@ -138,6 +138,7 @@ def finalize(self, interrupt=False):
:return:
"""
try:
self.model.cfe_instance.finalize(print_mass_balance=True)
print(f"Agend finished the job")
except:
raise NotImplementedError
14 changes: 8 additions & 6 deletions config.yaml → src/config.yaml
Original file line number Diff line number Diff line change
@@ -1,22 +1,24 @@
defaults:
- _self_
- src\data: \config\test
- src\models: \config\base
- data: \config\test
- models: \config\base

cwd: G:\Shared drives\SI_NextGen_Aridity\dCFE\
save_name: debugger
output_dir: ${cwd}\output\
device: cpu
num_processes: 1

run_type: synthetic # Choose between "synthetic" & "ML"
soil_scheme: ode # ode # classic
run_type: generate_synthetic # Choose between "generate_synthetic" & "ML" & "ML_synthetic_test"
soil_scheme: classic # ode # classic
basin_id: '01137500'

synthetic:
output_dir: ${cwd}\data\synthetic_case\
nams: 01022500_synthetic_${soil_scheme}
nams: ${basin_id}_synthetic_${soil_scheme}
param_nams: ${basin_id}_synthetic_params_${soil_scheme}
refkdt: 3
satdk: 0.0001
satdk: 0.00001

conversions:
cm_to_mm: 10.0
Expand Down
35 changes: 23 additions & 12 deletions src/data/Data.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,19 @@ def __init__(self, cfg: DictConfig) -> None:

# Read in start and end datetime
self.start_time = datetime.strptime(
cfg["src\data"]["start_time"], r"%Y-%m-%d %H:%M:%S"
cfg.data["start_time"], r"%Y-%m-%d %H:%M:%S"
)

self.end_time = datetime.strptime(
cfg["src\data"]["end_time"], r"%Y-%m-%d %H:%M:%S"
)
self.end_time = datetime.strptime(cfg.data["end_time"], r"%Y-%m-%d %H:%M:%S")

self.x = self.get_forcings(cfg)

self.basin_attributes = self.get_attributes(cfg)

self.y = self.get_observations(cfg)
if (cfg.run_type == "ML") | (cfg.run_type == "generate_synthetic"):
self.y = self.get_observations(cfg)
elif cfg.run_type == "ML_synthetic_test":
self.y = self.get_synthetic(cfg)

self.cfe_params = self.get_cfe_params(cfg)

Expand All @@ -55,7 +56,7 @@ def __len__(self):

def get_forcings(self, cfg: DictConfig):
# Read forcing data into pandas dataframe
forcing_df_ = pd.read_csv(cfg["src\data"]["forcing_file"])
forcing_df_ = pd.read_csv(cfg.data["forcing_file"])
forcing_df_.set_index(pd.to_datetime(forcing_df_["date"]), inplace=True)
forcing_df = forcing_df_[self.start_time : self.end_time].copy()

Expand Down Expand Up @@ -95,7 +96,7 @@ def get_forcings(self, cfg: DictConfig):

def get_observations(self, cfg: DictConfig):
# # TODO FIND OBSERVATION DATA TO TRAIN AGAINST
obs_q_ = pd.read_csv(cfg["src\data"]["compare_results_file"])
obs_q_ = pd.read_csv(cfg.data["compare_results_file"])
obs_q_.set_index(pd.to_datetime(obs_q_["date"]), inplace=True)
self.obs_q = obs_q_[self.start_time : self.end_time].copy()

Expand All @@ -106,15 +107,25 @@ def get_observations(self, cfg: DictConfig):
self.n_timesteps = len(self.obs_q)
return torch.tensor(self.obs_q["QObs(mm/h)"].values, device=cfg.device)

def get_synthetic(self, cfg: DictConfig):
# Define the file path
dir_path = Path(cfg.synthetic.output_dir)
file_path = dir_path / (cfg.synthetic.nams + ".npy")
synthetic_q = np.load(file_path)
self.obs_q = synthetic_q
self.n_timesteps = len(self.obs_q)

return torch.tensor(synthetic_q, device=cfg.device)

def get_attributes(self, cfg: DictConfig):
"""
Reading attributes from the soil params file
"""
file_name = cfg["src\data"].attributes_file
basin_id = cfg["src\data"].basin_id
file_name = cfg.data.attributes_file
basin_id = cfg.data.basin_id
# Load the txt data into a DataFrame
data = pd.read_csv(file_name, sep=",")
data["gauge_id"] = data["gauge_id"].str.replace("Gage-", "")
data["gauge_id"] = data["gauge_id"].str.replace("Gage-", "").str.zfill(8)
# # Filter the DataFrame for the specified basin id
filtered_data = data[data["gauge_id"] == basin_id]
slope = filtered_data["slope_mean"].item()
Expand All @@ -134,7 +145,7 @@ def get_cfe_params(self, cfg: DictConfig):
"""
cfe_params = dict()

cfe_cfg = cfg["src\data"]
cfe_cfg = cfg.data

# GET VALUES FROM CONFIGURATION FILE.
cfe_params = {
Expand All @@ -149,7 +160,7 @@ def get_cfe_params(self, cfg: DictConfig):
"D": torch.tensor([cfe_cfg.D], dtype=torch.float),
"satpsi": torch.tensor([cfe_cfg.satpsi], dtype=torch.float),
"wltsmc": torch.tensor([cfe_cfg.wltsmc], dtype=torch.float),
"scheme": cfe_cfg.soil_scheme,
"scheme": cfg.soil_scheme,
},
"max_gw_storage": torch.tensor([cfe_cfg.max_gw_storage], dtype=torch.float),
"expon": torch.tensor([cfe_cfg.expon], dtype=torch.float),
Expand Down
24 changes: 11 additions & 13 deletions src/data/config/test.yaml
Original file line number Diff line number Diff line change
@@ -1,21 +1,20 @@
attributes_file: ${cwd}\data\lumped_soil_attributes_mean_soilindex_included.csv
basin_id: '01022500'
forcing_file: ${cwd}\data\01022500_hourly_nldas.csv
compare_results_file: ${cwd}\data\01022500-usgs-hourly.csv
basin_id: '01137500'
forcing_file: ${cwd}\data\01137500_hourly_nldas.csv
compare_results_file: ${cwd}\data\01137500-usgs-hourly.csv
catchment_area_km2: 50
alpha_fc : 0.33
bb : 5
bb : 2
D : 2
satdk : 0.00001
satpsi: 0.33
slop: 1
slop: 0.5
smcmax: 0.5
wltsmc: 0.1
K_lf: 0.5
max_gw_storage : 0.5
Cgw : 1
expon : 7
K_nash : 0.3
wltsmc: 0.0
K_lf: 0.23
max_gw_storage : 0.05
Cgw : 4.0e-3
expon : 4.8
K_nash : 0.5
nash_storage :
- 0
- 0
Expand All @@ -26,6 +25,5 @@ giuh_ordinates :
- 0.1200
- 0.0300
partition_scheme: Schaake # Select from "Xinanjiang" or "Schaake"
soil_scheme: ode # Select from "ode" or "classic"
start_time: '2008-01-01 00:00:00'
end_time: '2008-06-30 23:00:00'
14 changes: 6 additions & 8 deletions src/models/MLP.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

# import utils.logger as logger
# from utils.read_yaml import config
from src.utils.transform import to_physical
from utils.transform import to_physical

# log = logger.get_logger("graphs.MLP")

Expand All @@ -31,9 +31,9 @@ def __init__(self, cfg: DictConfig) -> None:
# The size of out1 from MLP correponds to output_size (so increase this when increasing parameters)

torch.manual_seed(0)
input_size = self.cfg["src\models"].mlp.input_size
hidden_size = self.cfg["src\models"].mlp.hidden_size
output_size = self.cfg["src\models"].mlp.output_size
input_size = self.cfg.models.mlp.input_size
hidden_size = self.cfg.models.mlp.hidden_size
output_size = self.cfg.models.mlp.output_size
self.lin1 = Linear(input_size, hidden_size)
self.lin2 = Linear(hidden_size, hidden_size)
self.lin3 = Linear(hidden_size, hidden_size)
Expand All @@ -55,8 +55,6 @@ def forward(self, x: Tensor) -> Tensor:
out1 = self.sigmoid(x4)
# Possibly, HardTanh? https://paperswithcode.com/method/hardtanh-activation
x_transpose = out1.transpose(0, 1)
refkdt = to_physical(
x=x_transpose[0], param="refkdt", cfg=self.cfg["src\models"]
)
satdk = to_physical(x=x_transpose[1], param="satdk", cfg=self.cfg["src\models"])
refkdt = to_physical(x=x_transpose[0], param="refkdt", cfg=self.cfg.models)
satdk = to_physical(x=x_transpose[1], param="satdk", cfg=self.cfg.models)
return refkdt, satdk
8 changes: 4 additions & 4 deletions src/models/SyntheticCFE.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@
import torch
from torch import Tensor
import torch.nn as nn
from src.models.physics.bmi_cfe import BMI_CFE
from models.physics.bmi_cfe import BMI_CFE
import pandas as pd
import numpy as np
from src.utils.transform import normalization, to_physical
from src.models.MLP import MLP
from src.data.Data import Data
from utils.transform import normalization, to_physical
from models.MLP import MLP
from data.Data import Data

log = logging.getLogger("models.dCFE")

Expand Down
8 changes: 4 additions & 4 deletions src/models/config/base.yaml
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
hyperparameters:
epochs: 1 #50
learning_rate: 0.1 # 0.001
warmup: 100
epochs: 1000 #50
learning_rate: 0.001 # 0.001
warmup: 500

mlp:
hidden_size: 6
input_size: 4
output_size: 2

transformation:
refkdt: # https://www.sciencedirect.com/science/article/pii/S0022169420303620
refkdt: # refkdt range taking from https://www.sciencedirect.com/science/article/pii/S0022169420303620
- 0.5
- 5
satdk:
Expand Down
12 changes: 6 additions & 6 deletions src/models/dCFE.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,12 @@
import torch
from torch import Tensor
import torch.nn as nn
from src.models.physics.bmi_cfe import BMI_CFE
from models.physics.bmi_cfe import BMI_CFE
import pandas as pd
import numpy as np
from src.utils.transform import normalization, to_physical
from src.models.MLP import MLP
from src.data.Data import Data
from utils.transform import normalization, to_physical
from models.MLP import MLP
from data.Data import Data

log = logging.getLogger("models.dCFE")

Expand Down Expand Up @@ -109,8 +109,8 @@ def finalize(self):
self.cfe_instance.finalize(print_mass_balance=True)

def print(self):
print(f"refkdt: {self.refkdt}")
print(f"satdk: {self.satdk}")
log.info(f"refkdt: {self.refkdt.tolist()[0]:.6f}")
log.info(f"satdk: {self.satdk.tolist()[0]:.6f}")
# for key, value in self.c.items():
# print(f"{key}: {value.item():.8f}")
# log.info(f"{key}: {value.item():.8f}")
Expand Down
Loading

0 comments on commit aaed81a

Please sign in to comment.