Skip to content

Commit

Permalink
Merge pull request #25 from NWC-CUAHSI-Summer-Institute/debug_ML
Browse files Browse the repository at this point in the history
debug: reset the attributes that weren't reset. ML working
  • Loading branch information
RY4GIT committed Nov 2, 2023
2 parents 0597924 + eb03b44 commit 91c4cdb
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 17 deletions.
6 changes: 3 additions & 3 deletions src/agents/DifferentiableCFE.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,9 @@ def train(self) -> None:
for epoch in range(1, self.cfg.models.hyperparameters.epochs + 1):
log.info(f"Epoch #: {epoch}/{self.cfg.models.hyperparameters.epochs}")
self.loss_record[epoch - 1] = self.train_one_epoch()
print("Start mlp forward")
self.model.mlp_forward()
print("End mlp forward")
# print("Start mlp forward")
# self.model.mlp_forward()
# print("End mlp forward")
self.current_epoch += 1

def train_one_epoch(self):
Expand Down
2 changes: 1 addition & 1 deletion src/data/config/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@ compare_results_file: ${data.data_dir}\{}-usgs-hourly.csv
json_params_dir: ${data.data_dir}\cat_{}_bmi_config_cfe.json
partition_scheme: Schaake # Select from "Xinanjiang" or "Schaake"
start_time: '1990-10-06 00:00:00'
end_time: '1992-12-30 23:00:00'
end_time: '1990-12-30 23:00:00'
15 changes: 13 additions & 2 deletions src/models/dCFE.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,22 @@ def __init__(self, cfg: DictConfig, Data) -> None:

def initialize(self):
# Initialize the CFE model with the dynamic parameter
self.cfe_instance.refkdt = self.refkdt[:, 0]
self.cfe_instance.satdk = self.satdk[:, 0]

# Reset dCFE attributes
self.reset_instance_attributes()

# Reset CFE parameters, states, fluxes, and volume tracking
self.cfe_instance.load_cfe_params()
self.cfe_instance.reset_flux_and_states()
self.cfe_instance.reset_volume_tracking()

# Update parameters
self.cfe_instance.update_params(self.refkdt[:, 0], self.satdk[:, 0])

def reset_instance_attributes(self):
self.cfe_instance.refkdt = torch.zeros_like(self.cfe_instance.refkdt)
self.cfe_instance.satdk = torch.zeros_like(self.cfe_instance.satdk)

def forward(self, x, t): # -> (Tensor, Tensor):
"""
The forward function to model runoff through CFE model
Expand Down
44 changes: 33 additions & 11 deletions src/models/physics/bmi_cfe.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
import time
import numpy as np
import pandas as pd
import sys
import json
import matplotlib.pyplot as plt
from models.physics.cfe import CFE
import torch
from torch import Tensor
import torch.nn as nn

torch.set_default_dtype(torch.float64)

Expand Down Expand Up @@ -108,6 +105,16 @@ def __init__(
# None

def load_cfe_params(self):
for param in self.cfe_params.values():
if torch.is_tensor(param):
if param.grad is not None:
param.grad = None

for param in self.cfe_params["soil_params"].values():
if torch.is_tensor(param):
if param.grad is not None:
param.grad = None

# GET VALUES FROM Data class.

# Catchment area
Expand Down Expand Up @@ -188,12 +195,6 @@ def initialize(self, current_time_step=0):
# Set these values now that we have the information from the configuration file.
self.num_giuh_ordinates = self.giuh_ordinates.size(1)
self.num_lateral_flow_nash_reservoirs = self.nash_storage.size(1)
# ________________________________________________
# ----------- The output is area normalized, this is needed to un-normalize it
# mm->m km2 -> m2 hour->s
self.output_factor_cms = (
(1 / 1000) * (self.catchment_area_km2 * 1000 * 1000) * (1 / 3600)
)

# ________________________________________________
# The configuration should let the BMI know what mode to run in (framework vs standalone)
Expand Down Expand Up @@ -259,10 +260,10 @@ def reset_flux_and_states(self):
self.gw_reservoir_storage_deficit_m = torch.zeros(
(1, self.num_basins), dtype=torch.float64
) # the available space in the conceptual groundwater reservoir
self.primary_flux = torch.zeros(
self.primary_flux_m = torch.zeros(
(1, self.num_basins), dtype=torch.float64
) # temporary vars.
self.secondary_flux = torch.zeros(
self.secondary_flux_m = torch.zeros(
(1, self.num_basins), dtype=torch.float64
) # temporary vars.
self.primary_flux_from_gw_m = torch.zeros(
Expand Down Expand Up @@ -293,6 +294,13 @@ def reset_flux_and_states(self):
(1, self.num_basins), dtype=torch.float64
)

# ________________________________________________
# ----------- The output is area normalized, this is needed to un-normalize it
# mm->m km2 -> m2 hour->s
self.output_factor_cms = (
(1 / 1000) * (self.catchment_area_km2 * 1000 * 1000) * (1 / 3600)
)

# ________________________________________________
# ________________________________________________
# SOIL RESERVOIR CONFIGURATION
Expand Down Expand Up @@ -366,6 +374,8 @@ def reset_flux_and_states(self):
self.volstart = self.volstart.add(self.gw_reservoir["storage_m"])
self.vol_in_gw_start = self.gw_reservoir["storage_m"]

# TODO: update soil parameter

self.soil_reservoir = {
"is_exponential": False,
"wilting_point_m": self.soil_params["wltsmc"] * self.soil_params["D"],
Expand Down Expand Up @@ -404,6 +414,15 @@ def reset_flux_and_states(self):
self.giuh_ordinates.shape[0], self.num_giuh_ordinates + 1
)

# __________________________________________________________
self.surface_runoff_m = torch.zeros((1, self.num_basins), dtype=torch.float64)
self.streamflow_cmh = torch.zeros((1, self.num_basins), dtype=torch.float64)
self.flux_nash_lateral_runoff_m = torch.zeros(
(1, self.num_basins), dtype=torch.float64
)
self.flux_giuh_runoff_m = torch.zeros((1, self.num_basins), dtype=torch.float64)
self.flux_Qout_m = torch.zeros((1, self.num_basins), dtype=torch.float64)

def update_params(self, refkdt, satdk):
"""Update dynamic parameters"""
self.refkdt = refkdt.unsqueeze(dim=0)
Expand Down Expand Up @@ -492,6 +511,9 @@ def reset_volume_tracking(self):
self.vol_et_from_soil = torch.zeros((1, self.num_basins), dtype=torch.float64)
self.vol_et_from_rain = torch.zeros((1, self.num_basins), dtype=torch.float64)
self.vol_PET = torch.zeros((1, self.num_basins), dtype=torch.float64)

self.vol_in_gw_start = torch.zeros((1, self.num_basins), dtype=torch.float64)

return

# ________________________________________________________
Expand Down

0 comments on commit 91c4cdb

Please sign in to comment.