Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add detrend function #45

Merged
merged 15 commits into from
Nov 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 65 additions & 18 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,41 +12,88 @@ Future availability of a more sophisticated python pakcage for various end-user

We welcome external contribution to the package. Please feel free to submit issue for any inputs and joining the development core team. Thank you!

## Active developement
To use the module in the package at this stage
1. Create a conda/mamba env based on the region_mom.yml
## Setting up the developement environment

1. Fork this repository using the button in the upper right of the GitHub page. This will create a copy of the repository in your own GitHub profile, giving you full control over it.

2. Clone the repository to your local machine from your forked version.

```
conda env create -f regional_mom6.yml
git clone <fork-repo-url-under-your-github-account>
```
3. Activate the conda env `regional_mom6`
This create a remote `origin` to your forked version (not the NOAA-CEFI-Portal version)


1. Create a conda/mamba env based on the environment.yml

```
conda activate regional_mom6
cd regional_mom6/
conda env create -f environment.yml
```
4. change your location to the top level of cloned repo
3. Activate the conda env `regional-mom6`

```
cd <dir_path_to_regional_mom6>/regional_mom6/
conda activate regional-mom6
```

5. pip install the package in develop mode

```
pip install -e .
```
6. setup config file (data path for local data directory)

```
cp config.json.template config.json
```
```
cp config.json.template config.json
```

open the `config.json` and input the absolute path to the top level of the regional mom6 data

```
{
"data_path": "<your-mom6data-path-here>"
}
```

current setup assuming the data directory structure is fixed (i.e. the historical run or forecast data subdirectory (ex: hist_run and forecast) need to be under this data_path )

## Syncing with the NOAA-CEFI-Portal version
1. Create a remote `upstream` to track the changes that is on NOAA-CEFI-Portal

```
git remote add upstream [email protected]:NOAA-CEFI-Portal/regional_mom6.git
```
2. Create a feature branch to make code changes

```
git branch <feature-branch-name>
git checkout <feature-branch-name>
```
This prevents making direct changes to the `main` branch in your local repository.

3. Sync your local repository with the upstream changes regularly

```
git fetch upstream
git checkout main
git merge upstream/main
```
This updates your local `main` branch with the latest changes from the upstream repository.

open the `config.json` and input the absolute path to the top level of the regional mom6 data
3. Merge updated local `main` branch into your local `<feature-branch-name>` branch to keep it up to date.

```
git checkout <feature-branch-name>
git merge main
```

4. Push your changes to your forked version on GitHub

```
git push origin <feature-branch-name>
```
Make sure you have included the `upstream/main` changes before creating a pull request on NOAA-CEFI-Portal/regional_mom6


```
{
"data_path": "<your-mom6data-path-here>"
}
```

current setup assuming the data directory structure is fixed (i.e. the historical run or forecast data subdirectory (ex: hist_run and forecast) need to be under this data_path )

93 changes: 93 additions & 0 deletions mom6/mom6_module/mom6_detrend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
"""
This is the module to implement the detrending

"""
from typing import Tuple
import xarray as xr

class ForecastDetrend:
"""Detrend class for forecast data"""
def __init__(
self,
da_data : xr.DataArray,
initialization_name : str = 'init',
member_name : str = 'member',
) -> None:
"""
Parameters
----------
da_data : xr.DataArray
The dataarray one want to use to
detrend.
initialization_name : str, optional
initialization dimension name, by default 'init'
member_name : str, optional
ensemble member dimension name, by default 'member'
"""
self.data = da_data
self.init = initialization_name
self.mem = member_name

def polyfit_coef(
self,
deg: int = 1
) -> xr.Dataset:
"""determine the polyfit coefficient based on
lead-time-dependent forecast ensemble mean anomalies

Parameters
----------
deg : int, optional
the order of polynomical fit to use for determining the
fit coefficient, by default 1

Returns
-------
xr.Dataset
coefficient of the polynomical fit
"""

# calculate the ensemble mean of the anomaly
da_ensmean = self.data.mean(dim=self.mem)
# use the ensemble mean anomaly to determine lead time dependent trend
ds_p = da_ensmean.polyfit(dim=self.init, deg=deg, skipna=True).compute()

return ds_p

def detrend_linear(
self,
precompute_coeff : bool = False,
ds_coeff : xr.Dataset = None,
in_place_memory_replace : bool = False
) -> Tuple[xr.DataArray,xr.Dataset]:
"""detrend the original data by using the
degree 1 ployfit coeff

Returns
-------
xr.DataArray
the data with linear trend removed
"""
if precompute_coeff:
ds_p = ds_coeff
else:
# get degree 1 polyfit coeff
ds_p = self.polyfit_coef(deg=1)

# # calculate linear trend based on polyfit coeff
# da_linear_trend = xr.polyval(self.data[self.init], ds_p.polyfit_coefficients)
# # remove the linear trend
# da_detrend = (self.data - da_linear_trend).persist()

if in_place_memory_replace:
self.data = (
self.data-
xr.polyval(self.data[self.init], ds_p.polyfit_coefficients)
).persist()
return self.data, ds_p
else:
da_detrend = (
self.data -
xr.polyval(self.data[self.init], ds_p.polyfit_coefficients)
).persist()
return da_detrend,ds_p
71 changes: 57 additions & 14 deletions mom6/mom6_module/mom6_mhw.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from mom6.mom6_module.mom6_types import (
TimeGroupByOptions
)
from mom6.mom6_module.mom6_detrend import ForecastDetrend

warnings.simplefilter("ignore")
xr.set_options(keep_attrs=True)
Expand Down Expand Up @@ -73,7 +74,8 @@ def generate_forecast_batch(
climo_end_year : int = 2020,
anom_start_year : int = 1993,
anom_end_year : int = 2020,
quantile_threshold : float = 90.
quantile_threshold : float = 90.,
detrend : bool = False
) -> xr.Dataset:
"""generate the MHW statistics and identify MHW

Expand All @@ -89,6 +91,8 @@ def generate_forecast_batch(
end year of anomaly that need to identify MHW, by default 2020
quantile_threshold : float, optional
quantile value that define the threshold, by default 90.
detrend : bool, optional
flag for whether the MHW is based on detrended ssta or not.

Returns
-------
Expand All @@ -98,27 +102,49 @@ def generate_forecast_batch(

# calculate anomaly based on climatology
class_forecast_climo = ForecastClimatology(self.dataset,self.varname)
dict_anom = class_forecast_climo.generate_anom_batch(
dict_anom_thres = class_forecast_climo.generate_anom_batch(
climo_start_year,
climo_end_year,
climo_start_year, # force the anom start year for threshold be the same as climo period
climo_end_year, # force the anom end year for threshold be the same as climo period
'persist'
)

# detrend or not
if detrend:
class_detrend_thres = ForecastDetrend(dict_anom_thres['anomaly'])
dict_anom_thres['anomaly'], ds_p = class_detrend_thres.detrend_linear(
precompute_coeff=False,
in_place_memory_replace=True
)

# anomaly used for the threshold
ds_anom = xr.Dataset()
ds_anom[f'{self.varname}_anom'] = dict_anom['anomaly']
ds_anom['lon'] = self.dataset['lon']
ds_anom['lat'] = self.dataset['lat']
ds_anom_thres = xr.Dataset()
ds_anom_thres[f'{self.varname}_anom'] = dict_anom_thres['anomaly']
ds_anom_thres['lon'] = self.dataset['lon']
ds_anom_thres['lat'] = self.dataset['lat']

# calculate threshold
class_forecast_quantile = ForecastQuantile(ds_anom,f'{self.varname}_anom')
# if detrend:
# ### in memery result when creating the class
# class_forecast_quantile = ForecastQuantile(
# ds_anom_thres.compute(),
# f'{self.varname}_anom'
# )
# da_threshold = class_forecast_quantile.generate_quantile(
# climo_start_year,
# climo_end_year,
# quantile_threshold,
# dask_obj=False
# )
# else:
class_forecast_quantile = ForecastQuantile(ds_anom_thres,f'{self.varname}_anom')
### in memery result not lazy-loaded (same as climo period)
da_threshold = class_forecast_quantile.generate_quantile(
climo_start_year,
climo_end_year,
quantile_threshold
quantile_threshold,
dask_obj=True
)

# anomaly that need to find MHW
Expand All @@ -129,10 +155,18 @@ def generate_forecast_batch(
anom_end_year,
'persist',
precompute_climo = True,
da_climo = dict_anom['climatology']
da_climo = dict_anom_thres['climatology']
)
da_anom = dict_anom['anomaly']

if detrend:
class_detrend = ForecastDetrend(da_anom)
da_anom,_ = class_detrend.detrend_linear(
precompute_coeff=True,
ds_coeff=ds_p,
in_place_memory_replace=True
)

# calculate average mhw magnitude
da_mhw_mag = da_anom.where(da_anom.groupby(f'{self.init}.{self.tfreq}')>=da_threshold)
da_mhw_mag_ave = da_anom.mean(dim=f'{self.mem}').compute()
Expand All @@ -152,12 +186,21 @@ def generate_forecast_batch(

# output dataset
ds_mhw = xr.Dataset()
if detrend :
ds_mhw['polyfit_coefficients'] = ds_p['polyfit_coefficients']

ds_mhw[f'{self.varname}_threshold{quantile_threshold:02d}'] = da_threshold
ds_mhw[f'{self.varname}_threshold{quantile_threshold:02d}'].attrs['long_name'] = (
f'{self.varname} threshold{quantile_threshold:02d})'
f'{self.varname} threshold{quantile_threshold:02d}'
)
ds_mhw[f'{self.varname}_threshold{quantile_threshold:02d}'].attrs['units'] = 'degC'

ds_mhw[f'{self.varname}_climo'] = dict_anom_thres['climatology']
ds_mhw[f'{self.varname}_climo'].attrs['long_name'] = (
f'{self.varname} climatology'
)
ds_mhw[f'{self.varname}_climo'].attrs['units'] = 'degC'

ds_mhw[f'mhw_prob{quantile_threshold:02d}'] = da_prob
ds_mhw[f'mhw_prob{quantile_threshold:02d}'].attrs['long_name'] = (
f'marine heatwave probability (threshold{quantile_threshold:02d})'
Expand All @@ -170,11 +213,11 @@ def generate_forecast_batch(
)
ds_mhw['ssta_avg'].attrs['units'] = 'degC'

ds_mhw['mhw_mag_indentified_ens'] = da_mhw_mag
ds_mhw['mhw_mag_indentified_ens'].attrs['long_name'] = (
'marine heatwave magnitude in each ensemble'
ds_mhw['ssta'] = da_anom
ds_mhw['ssta'].attrs['long_name'] = (
'anomalous sea surface temperature'
)
ds_mhw['mhw_mag_indentified_ens'].attrs['units'] = 'degC'
ds_mhw['ssta'].attrs['units'] = 'degC'

ds_mhw.attrs['period_of_quantile'] = da_threshold.attrs['period_of_quantile']
ds_mhw.attrs['period_of_climatology'] = da_threshold.attrs['period_of_climatology']
Expand Down
22 changes: 0 additions & 22 deletions mom6/mom6_module/mom6_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,28 +213,6 @@ def generate_anom_batch(
"""Generate the anomaly based on the input
dataset covered period

Parameters
----------
climo_start_year : int, optional
start year to calculation the climatology, by default 1993
climo_end_year : int, optional
end year to calculation the climatology, by default 2020
dask_option : DaskOptions, optional
flag to determine one want the return result
to be 'compute', 'persist' or keep 'lazy' in anomaly, by default 'lazy'

Returns
-------
dict
anomaly: dataarray which represent the anomaly,
climatology: dataarray which represent the climatology

Raises
------
ValueError
when the kwarg anom_start_year & anom_end_year result in
empty array crop

Parameters
----------
climo_start_year : int, optional
Expand Down
Loading