Skip to content

Commit 91c4cdb

Browse files
authored
Merge pull request #25 from NWC-CUAHSI-Summer-Institute/debug_ML
debug: reset the attributes that weren't reset. ML working
2 parents 0597924 + eb03b44 commit 91c4cdb

File tree

4 files changed

+50
-17
lines changed

4 files changed

+50
-17
lines changed

src/agents/DifferentiableCFE.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -104,9 +104,9 @@ def train(self) -> None:
104104
for epoch in range(1, self.cfg.models.hyperparameters.epochs + 1):
105105
log.info(f"Epoch #: {epoch}/{self.cfg.models.hyperparameters.epochs}")
106106
self.loss_record[epoch - 1] = self.train_one_epoch()
107-
print("Start mlp forward")
108-
self.model.mlp_forward()
109-
print("End mlp forward")
107+
# print("Start mlp forward")
108+
# self.model.mlp_forward()
109+
# print("End mlp forward")
110110
self.current_epoch += 1
111111

112112
def train_one_epoch(self):

src/data/config/test.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,4 @@ compare_results_file: ${data.data_dir}\{}-usgs-hourly.csv
99
json_params_dir: ${data.data_dir}\cat_{}_bmi_config_cfe.json
1010
partition_scheme: Schaake # Select from "Xinanjiang" or "Schaake"
1111
start_time: '1990-10-06 00:00:00'
12-
end_time: '1992-12-30 23:00:00'
12+
end_time: '1990-12-30 23:00:00'

src/models/dCFE.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,11 +63,22 @@ def __init__(self, cfg: DictConfig, Data) -> None:
6363

6464
def initialize(self):
6565
# Initialize the CFE model with the dynamic parameter
66-
self.cfe_instance.refkdt = self.refkdt[:, 0]
67-
self.cfe_instance.satdk = self.satdk[:, 0]
66+
67+
# Reset dCFE attributes
68+
self.reset_instance_attributes()
69+
70+
# Reset CFE parameters, states, fluxes, and volume tracking
71+
self.cfe_instance.load_cfe_params()
6872
self.cfe_instance.reset_flux_and_states()
6973
self.cfe_instance.reset_volume_tracking()
7074

75+
# Update parameters
76+
self.cfe_instance.update_params(self.refkdt[:, 0], self.satdk[:, 0])
77+
78+
def reset_instance_attributes(self):
79+
self.cfe_instance.refkdt = torch.zeros_like(self.cfe_instance.refkdt)
80+
self.cfe_instance.satdk = torch.zeros_like(self.cfe_instance.satdk)
81+
7182
def forward(self, x, t): # -> (Tensor, Tensor):
7283
"""
7384
The forward function to model runoff through CFE model

src/models/physics/bmi_cfe.py

Lines changed: 33 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,10 @@
1-
import time
21
import numpy as np
32
import pandas as pd
43
import sys
5-
import json
64
import matplotlib.pyplot as plt
75
from models.physics.cfe import CFE
86
import torch
97
from torch import Tensor
10-
import torch.nn as nn
118

129
torch.set_default_dtype(torch.float64)
1310

@@ -108,6 +105,16 @@ def __init__(
108105
# None
109106

110107
def load_cfe_params(self):
108+
for param in self.cfe_params.values():
109+
if torch.is_tensor(param):
110+
if param.grad is not None:
111+
param.grad = None
112+
113+
for param in self.cfe_params["soil_params"].values():
114+
if torch.is_tensor(param):
115+
if param.grad is not None:
116+
param.grad = None
117+
111118
# GET VALUES FROM Data class.
112119

113120
# Catchment area
@@ -188,12 +195,6 @@ def initialize(self, current_time_step=0):
188195
# Set these values now that we have the information from the configuration file.
189196
self.num_giuh_ordinates = self.giuh_ordinates.size(1)
190197
self.num_lateral_flow_nash_reservoirs = self.nash_storage.size(1)
191-
# ________________________________________________
192-
# ----------- The output is area normalized, this is needed to un-normalize it
193-
# mm->m km2 -> m2 hour->s
194-
self.output_factor_cms = (
195-
(1 / 1000) * (self.catchment_area_km2 * 1000 * 1000) * (1 / 3600)
196-
)
197198

198199
# ________________________________________________
199200
# The configuration should let the BMI know what mode to run in (framework vs standalone)
@@ -259,10 +260,10 @@ def reset_flux_and_states(self):
259260
self.gw_reservoir_storage_deficit_m = torch.zeros(
260261
(1, self.num_basins), dtype=torch.float64
261262
) # the available space in the conceptual groundwater reservoir
262-
self.primary_flux = torch.zeros(
263+
self.primary_flux_m = torch.zeros(
263264
(1, self.num_basins), dtype=torch.float64
264265
) # temporary vars.
265-
self.secondary_flux = torch.zeros(
266+
self.secondary_flux_m = torch.zeros(
266267
(1, self.num_basins), dtype=torch.float64
267268
) # temporary vars.
268269
self.primary_flux_from_gw_m = torch.zeros(
@@ -293,6 +294,13 @@ def reset_flux_and_states(self):
293294
(1, self.num_basins), dtype=torch.float64
294295
)
295296

297+
# ________________________________________________
298+
# ----------- The output is area normalized, this is needed to un-normalize it
299+
# mm->m km2 -> m2 hour->s
300+
self.output_factor_cms = (
301+
(1 / 1000) * (self.catchment_area_km2 * 1000 * 1000) * (1 / 3600)
302+
)
303+
296304
# ________________________________________________
297305
# ________________________________________________
298306
# SOIL RESERVOIR CONFIGURATION
@@ -366,6 +374,8 @@ def reset_flux_and_states(self):
366374
self.volstart = self.volstart.add(self.gw_reservoir["storage_m"])
367375
self.vol_in_gw_start = self.gw_reservoir["storage_m"]
368376

377+
# TODO: update soil parameter
378+
369379
self.soil_reservoir = {
370380
"is_exponential": False,
371381
"wilting_point_m": self.soil_params["wltsmc"] * self.soil_params["D"],
@@ -404,6 +414,15 @@ def reset_flux_and_states(self):
404414
self.giuh_ordinates.shape[0], self.num_giuh_ordinates + 1
405415
)
406416

417+
# __________________________________________________________
418+
self.surface_runoff_m = torch.zeros((1, self.num_basins), dtype=torch.float64)
419+
self.streamflow_cmh = torch.zeros((1, self.num_basins), dtype=torch.float64)
420+
self.flux_nash_lateral_runoff_m = torch.zeros(
421+
(1, self.num_basins), dtype=torch.float64
422+
)
423+
self.flux_giuh_runoff_m = torch.zeros((1, self.num_basins), dtype=torch.float64)
424+
self.flux_Qout_m = torch.zeros((1, self.num_basins), dtype=torch.float64)
425+
407426
def update_params(self, refkdt, satdk):
408427
"""Update dynamic parameters"""
409428
self.refkdt = refkdt.unsqueeze(dim=0)
@@ -492,6 +511,9 @@ def reset_volume_tracking(self):
492511
self.vol_et_from_soil = torch.zeros((1, self.num_basins), dtype=torch.float64)
493512
self.vol_et_from_rain = torch.zeros((1, self.num_basins), dtype=torch.float64)
494513
self.vol_PET = torch.zeros((1, self.num_basins), dtype=torch.float64)
514+
515+
self.vol_in_gw_start = torch.zeros((1, self.num_basins), dtype=torch.float64)
516+
495517
return
496518

497519
# ________________________________________________________

0 commit comments

Comments
 (0)