|
1 |
| -import time |
2 | 1 | import numpy as np
|
3 | 2 | import pandas as pd
|
4 | 3 | import sys
|
5 |
| -import json |
6 | 4 | import matplotlib.pyplot as plt
|
7 | 5 | from models.physics.cfe import CFE
|
8 | 6 | import torch
|
9 | 7 | from torch import Tensor
|
10 |
| -import torch.nn as nn |
11 | 8 |
|
12 | 9 | torch.set_default_dtype(torch.float64)
|
13 | 10 |
|
@@ -108,6 +105,16 @@ def __init__(
|
108 | 105 | # None
|
109 | 106 |
|
110 | 107 | def load_cfe_params(self):
|
| 108 | + for param in self.cfe_params.values(): |
| 109 | + if torch.is_tensor(param): |
| 110 | + if param.grad is not None: |
| 111 | + param.grad = None |
| 112 | + |
| 113 | + for param in self.cfe_params["soil_params"].values(): |
| 114 | + if torch.is_tensor(param): |
| 115 | + if param.grad is not None: |
| 116 | + param.grad = None |
| 117 | + |
111 | 118 | # GET VALUES FROM Data class.
|
112 | 119 |
|
113 | 120 | # Catchment area
|
@@ -188,12 +195,6 @@ def initialize(self, current_time_step=0):
|
188 | 195 | # Set these values now that we have the information from the configuration file.
|
189 | 196 | self.num_giuh_ordinates = self.giuh_ordinates.size(1)
|
190 | 197 | self.num_lateral_flow_nash_reservoirs = self.nash_storage.size(1)
|
191 |
| - # ________________________________________________ |
192 |
| - # ----------- The output is area normalized, this is needed to un-normalize it |
193 |
| - # mm->m km2 -> m2 hour->s |
194 |
| - self.output_factor_cms = ( |
195 |
| - (1 / 1000) * (self.catchment_area_km2 * 1000 * 1000) * (1 / 3600) |
196 |
| - ) |
197 | 198 |
|
198 | 199 | # ________________________________________________
|
199 | 200 | # The configuration should let the BMI know what mode to run in (framework vs standalone)
|
@@ -259,10 +260,10 @@ def reset_flux_and_states(self):
|
259 | 260 | self.gw_reservoir_storage_deficit_m = torch.zeros(
|
260 | 261 | (1, self.num_basins), dtype=torch.float64
|
261 | 262 | ) # the available space in the conceptual groundwater reservoir
|
262 |
| - self.primary_flux = torch.zeros( |
| 263 | + self.primary_flux_m = torch.zeros( |
263 | 264 | (1, self.num_basins), dtype=torch.float64
|
264 | 265 | ) # temporary vars.
|
265 |
| - self.secondary_flux = torch.zeros( |
| 266 | + self.secondary_flux_m = torch.zeros( |
266 | 267 | (1, self.num_basins), dtype=torch.float64
|
267 | 268 | ) # temporary vars.
|
268 | 269 | self.primary_flux_from_gw_m = torch.zeros(
|
@@ -293,6 +294,13 @@ def reset_flux_and_states(self):
|
293 | 294 | (1, self.num_basins), dtype=torch.float64
|
294 | 295 | )
|
295 | 296 |
|
| 297 | + # ________________________________________________ |
| 298 | + # ----------- The output is area normalized, this is needed to un-normalize it |
| 299 | + # mm->m km2 -> m2 hour->s |
| 300 | + self.output_factor_cms = ( |
| 301 | + (1 / 1000) * (self.catchment_area_km2 * 1000 * 1000) * (1 / 3600) |
| 302 | + ) |
| 303 | + |
296 | 304 | # ________________________________________________
|
297 | 305 | # ________________________________________________
|
298 | 306 | # SOIL RESERVOIR CONFIGURATION
|
@@ -366,6 +374,8 @@ def reset_flux_and_states(self):
|
366 | 374 | self.volstart = self.volstart.add(self.gw_reservoir["storage_m"])
|
367 | 375 | self.vol_in_gw_start = self.gw_reservoir["storage_m"]
|
368 | 376 |
|
| 377 | + # TODO: update soil parameter |
| 378 | + |
369 | 379 | self.soil_reservoir = {
|
370 | 380 | "is_exponential": False,
|
371 | 381 | "wilting_point_m": self.soil_params["wltsmc"] * self.soil_params["D"],
|
@@ -404,6 +414,15 @@ def reset_flux_and_states(self):
|
404 | 414 | self.giuh_ordinates.shape[0], self.num_giuh_ordinates + 1
|
405 | 415 | )
|
406 | 416 |
|
| 417 | + # __________________________________________________________ |
| 418 | + self.surface_runoff_m = torch.zeros((1, self.num_basins), dtype=torch.float64) |
| 419 | + self.streamflow_cmh = torch.zeros((1, self.num_basins), dtype=torch.float64) |
| 420 | + self.flux_nash_lateral_runoff_m = torch.zeros( |
| 421 | + (1, self.num_basins), dtype=torch.float64 |
| 422 | + ) |
| 423 | + self.flux_giuh_runoff_m = torch.zeros((1, self.num_basins), dtype=torch.float64) |
| 424 | + self.flux_Qout_m = torch.zeros((1, self.num_basins), dtype=torch.float64) |
| 425 | + |
407 | 426 | def update_params(self, refkdt, satdk):
|
408 | 427 | """Update dynamic parameters"""
|
409 | 428 | self.refkdt = refkdt.unsqueeze(dim=0)
|
@@ -492,6 +511,9 @@ def reset_volume_tracking(self):
|
492 | 511 | self.vol_et_from_soil = torch.zeros((1, self.num_basins), dtype=torch.float64)
|
493 | 512 | self.vol_et_from_rain = torch.zeros((1, self.num_basins), dtype=torch.float64)
|
494 | 513 | self.vol_PET = torch.zeros((1, self.num_basins), dtype=torch.float64)
|
| 514 | + |
| 515 | + self.vol_in_gw_start = torch.zeros((1, self.num_basins), dtype=torch.float64) |
| 516 | + |
495 | 517 | return
|
496 | 518 |
|
497 | 519 | # ________________________________________________________
|
|
0 commit comments