Skip to content

Commit

Permalink
added code to write forcings into edges store (#27)
Browse files Browse the repository at this point in the history
  • Loading branch information
taddyb authored Aug 31, 2024
1 parent 47afeb7 commit 750cc55
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 1 deletion.
9 changes: 9 additions & 0 deletions marquette/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,15 @@ def run_extensions(cfg: DictConfig, edges: zarr.Group) -> None:
log.info("q_prime_sum statistics already exists in zarr format")
else:
calculate_q_prime_sum_stats(cfg, edges)

if "lstm_stats" in cfg.extensions:
from marquette.merit.extensions import format_lstm_forcings

log.info("Adding lstm statistics from global LSTM to your MERIT River Graph")
if "mean_precip" in edges:
log.info("q_prime_sum statistics already exists in zarr format")
else:
format_lstm_forcings(cfg, edges)


if __name__ == "__main__":
Expand Down
3 changes: 2 additions & 1 deletion marquette/conf/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ create_TMs:
TM: ${data_path}/zarr/TMs/sparse_MERIT_FLOWLINES_${zone}
shp_files: ${data_path}/raw/basins/cat_pfaf_${zone}_MERIT_Hydro_v07_Basins_v01_bugfix1.shp
create_streamflow:
version: merit_conus_v1.0
version: merit_conus_v6.14
data_store: ${data_path}/streamflow/zarr/${create_streamflow.version}/${zone}
obs_attributes: ${data_path}/gage_information/MERIT_basin_area_info
predictions: /projects/mhpi/yxs275/DM_output/water_loss_model/dPL_local_daymet_new_attr_RMSEloss_with_log_2800
Expand All @@ -38,6 +38,7 @@ extensions:
- incremental_drainage_area
- q_prime_sum
- q_prime_sum_stats
- lstm_stats
# Hydra Config ------------------------------------------------------------------------#
hydra:
help:
Expand Down
68 changes: 68 additions & 0 deletions marquette/merit/extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,3 +350,71 @@ def calculate_q_prime_sum_stats(cfg: DictConfig, edges: zarr.Group) -> None:
name="summed_q_prime_p10",
data=np.percentile(summed_q_prime, 10, axis=0),
)

def format_lstm_forcings(cfg: DictConfig, edges: zarr.Group) -> None:
forcings_store = zarr.open(Path("/projects/mhpi/data/global/zarr_sub_zone") / f"{cfg.zone}")

edge_comids = np.unique(edges.merit_basin[:]) # already sorted
log.info(msg="Reading Zarr Store")
zone_keys = [
key for key in forcings_store.keys() if str(cfg.zone) in key
]
zone_comids = []
zone_precip = []
zone_pet = []
# zone_temp = []
zone_ndvi = []
zone_aridity = []
for key in zone_keys:
zone_comids.append(forcings_store[key].COMID[:])
zone_precip.append(forcings_store[key].P[:])
zone_pet.append(forcings_store[key].PET[:])
# zone_temp.append(streamflow_predictions_root[key].Temp[:])
zone_ndvi.append(forcings_store[key]["attrs"]["NDVI"])
zone_aridity.append(forcings_store[key]["attrs"]["aridity"])

streamflow_comids = np.concatenate(zone_comids).astype(int)
file_precip = np.transpose(np.concatenate(zone_precip))
file_pet = np.transpose(np.concatenate(zone_pet))
# file_temp = np.transpose(np.concatenate(zone_temp))
file_ndvi = np.concatenate(zone_ndvi)
file_aridity = np.concatenate(zone_aridity)
del zone_comids
del zone_precip
del zone_pet
# del zone_temp
del zone_ndvi
del zone_aridity

log.info("Mapping to zone COMIDs")
precip_full_zone = np.zeros((file_precip.shape[0], edge_comids.shape[0]))
pet_full_zone = np.zeros((file_precip.shape[0], edge_comids.shape[0]))
ndvi_full_zone = np.zeros((edge_comids.shape[0]))
aridity_full_zone = np.zeros((edge_comids.shape[0]))


indices = np.searchsorted(edge_comids, streamflow_comids)
precip_full_zone[:, indices] = file_precip
pet_full_zone[:, indices] = file_pet
ndvi_full_zone[indices] = file_ndvi
aridity_full_zone[indices] = file_aridity

log.info("Writing outputs to zarr")
edges.array(
name="precip_comid",
data=np.median(precip_full_zone, axis=0),
)
edges.array(
name="pet_comid",
data=np.std(pet_full_zone, axis=0),
)
edges.array(
name="ndvi_comid",
data=np.percentile(ndvi_full_zone, 90, axis=0),
)
edges.array(
name="aridity_comid",
data=np.percentile(aridity_full_zone, 10, axis=0),
)


0 comments on commit 750cc55

Please sign in to comment.