Skip to content

Commit

Permalink
Merge pull request #72 from GeoscienceAustralia/elevation_improvements
Browse files Browse the repository at this point in the history
Elevation and tide modelling improvements
  • Loading branch information
vnewey authored Mar 8, 2024
2 parents 76bbdd6 + 1ad09bb commit 878f343
Show file tree
Hide file tree
Showing 10 changed files with 1,512 additions and 2,520 deletions.
151 changes: 90 additions & 61 deletions intertidal/elevation.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ def rolling_tide_window(
window_spacing,
window_radius,
tide_min,
min_count=5,
statistic="median",
):
"""
Expand All @@ -181,6 +182,10 @@ def rolling_tide_window(
(e.g. metres).
tide_min : float
Bottom edge of the rolling window in tide units (e.g. metres).
min_count : int, optional
The minimum number of valid datapoints required to calculate the
rolling statistic. Outputs with less observations will be set to
NaN. Defaults to 5.
statistic : str, optional
Statistic to apply on the values within each window. One of
["median", "mean", "quantile"]. Default is "median".
Expand Down Expand Up @@ -215,11 +220,21 @@ def rolling_tide_window(
elif statistic == "mean":
ds_agg = masked_ds.mean(dim="time")

# Optionally mask out observations with less than n valid datapoints.
if min_count:
clear_count = masked_ds.notnull().sum(dim="time")
ds_agg = ds_agg.where(clear_count > min_count)

return ds_agg


def pixel_rolling_median(
flat_ds, windows_n=100, window_prop_tide=0.15, window_offset=5, max_workers=None
flat_ds,
windows_n=100,
window_prop_tide=0.15,
window_offset=5,
min_count=5,
max_workers=None,
):
"""
Calculate rolling medians for each pixel in an xarray.Dataset from
Expand Down Expand Up @@ -250,6 +265,11 @@ def pixel_rolling_median(
first rolling window beneath the lowest tide, although at the
risk of introducing noisy data due to the rolling medians
containing fewer total satellite observations. Defaults to 5.
min_count : int, optional
The minimum number of cloud free observations required to
calculate the rolling statistic. Defaults to 5; higher values
will produce cleaner results but with potentially reduced
intertidal coverage.
max_workers : int, optional
Maximum number of worker processes to use for parallel
execution, by default 64
Expand Down Expand Up @@ -292,7 +312,13 @@ def pixel_rolling_median(
rolling_intervals,
*(
repeat(i, len(rolling_intervals))
for i in [flat_ds, window_spacing_tide, window_radius_tide, tide_min]
for i in [
flat_ds,
window_spacing_tide,
window_radius_tide,
tide_min,
min_count,
]
),
)

Expand All @@ -311,7 +337,14 @@ def pixel_rolling_median(
return interval_ds


def pixel_dem(interval_ds, ndwi_thresh=0.1, interp_intervals=200, smooth_radius=20):
def pixel_dem(
interval_ds,
ndwi_thresh=0.1,
interp_intervals=200,
smooth_radius=20,
min_periods=5,
debug=False,
):
"""
Calculates an estimate of intertidal elevation based on satellite
imagery and tide data. Elevation is modelled by identifying the
Expand Down Expand Up @@ -345,6 +378,10 @@ def pixel_dem(interval_ds, ndwi_thresh=0.1, interp_intervals=200, smooth_radius=
tide interval dimension. This produces smoother DEM surfaces
than using the rolling median directly. Defaults to 20; set to
None to deactivate.
min_periods : int or string, optional
Minimum number of valid datapoints required to calculate rolling
mean if `smooth_radius` is set. Defaults to 5; "auto" will use
`int(smooth_radius / 2.0)`; `None` will use the size of the window.
Returns
-------
Expand All @@ -367,7 +404,9 @@ def pixel_dem(interval_ds, ndwi_thresh=0.1, interp_intervals=200, smooth_radius=
smoothed_ds = interval_ds.rolling(
interval=smooth_radius,
center=False,
min_periods=1, # int(smooth_radius / 2.0),
min_periods=int(smooth_radius / 2.0)
if min_periods == "auto"
else min_periods,
).mean()
else:
smoothed_ds = interval_ds
Expand All @@ -388,8 +427,14 @@ def pixel_dem(interval_ds, ndwi_thresh=0.1, interp_intervals=200, smooth_radius=
always_wet = tide_thresh <= tide_min
dem_flat = tide_thresh.where(~always_wet & ~always_dry)

# Export as xr.Dataset
return dem_flat.to_dataset(name="elevation")
# Convert to xr_dataset
dem_ds = dem_flat.to_dataset(name="elevation")

# If debug is True, return smoothed data as well
if debug:
return dem_ds, smoothed_ds

return dem_ds


def pixel_dem_debug(
Expand All @@ -400,6 +445,7 @@ def pixel_dem_debug(
ndwi_thresh=0.1,
interp_intervals=200,
smooth_radius=20,
min_periods=5,
certainty_method="mad",
plot_style=None,
):
Expand All @@ -413,33 +459,21 @@ def pixel_dem_debug(
flat_pixel = flat_unstacked.sel(x=x, y=y, method="nearest")
interval_pixel = interval_unstacked.sel(x=x, y=y, method="nearest")

# Apply interval interpolation and rolling mean
interval_clean_pixel = (
interval_pixel.interp(
interval=np.linspace(0, interval_ds.interval.max(), interp_intervals),
method="linear",
)[["tide_m", "ndwi"]]
.rolling(
interval=smooth_radius,
center=False,
min_periods=int(smooth_radius / 2.0),
)
.mean()
)

if not isinstance(ndwi_thresh, float):
# Experiment with variable threshold
ndwi_thresh = xr.DataArray(
np.linspace(ndwi_thresh[0], ndwi_thresh[-1], interp_intervals),
coords={"interval": interval_clean_pixel.interval},
)
# # Experimental feature: support for variable threshold
# if not isinstance(ndwi_thresh, float):
# ndwi_thresh = xr.DataArray(
# np.linspace(ndwi_thresh[0], ndwi_thresh[-1], interp_intervals),
# coords={"interval": interval_clean_pixel.interval},
# )

# Calculate DEM
flat_dem_pixel = pixel_dem(
interval_clean_pixel,
flat_dem_pixel, interval_smoothed_pixel = pixel_dem(
interval_pixel,
ndwi_thresh=ndwi_thresh,
interp_intervals=None,
smooth_radius=None,
interp_intervals=interp_intervals,
smooth_radius=smooth_radius,
min_periods=min_periods,
debug=True,
)

# Calculate certainty
Expand All @@ -462,19 +496,22 @@ def pixel_dem_debug(
else:
sns.scatterplot(data=flat_pixel_df, x="tide_m", y="ndwi", color="black", s=10)

interval_pixel.to_dataframe().rename({"ndwi": "rolling median"}, axis=1).plot(
x="tide_m", y="rolling median", ax=plt.gca()
# Convert to dataframes and plot
interval_pixel_df = interval_pixel.to_dataframe().rename(
{"ndwi": "rolling median"}, axis=1
)
interval_clean_pixel.to_dataframe().rename({"ndwi": "smoothed"}, axis=1).plot(
x="tide_m", y="smoothed", ax=plt.gca()
interval_smoothed_pixel_df = interval_smoothed_pixel.to_dataframe().rename(
{"ndwi": "smoothed"}, axis=1
)
interval_pixel_df.plot(x="tide_m", y="rolling median", ax=plt.gca())
interval_smoothed_pixel_df.plot(x="tide_m", y="smoothed", ax=plt.gca())

if not isinstance(ndwi_thresh, float):
plt.plot(
interval_clean_pixel.tide_m.sel(
interval=~interval_clean_pixel.tide_m.isnull()
interval_smoothed_pixel.tide_m.sel(
interval=~interval_smoothed_pixel.tide_m.isnull()
),
ndwi_thresh.sel(interval=~interval_clean_pixel.tide_m.isnull()),
ndwi_thresh.sel(interval=~interval_smoothed_pixel.tide_m.isnull()),
color="black",
linestyle="--",
lw=1,
Expand All @@ -491,6 +528,8 @@ def pixel_dem_debug(
)
plt.gca().set_ylim(-1, 1)

return interval_pixel, interval_smoothed_pixel


def pixel_uncertainty(
flat_ds,
Expand Down Expand Up @@ -799,27 +838,15 @@ def elevation(
run_id = "Processing" if run_id is None else run_id

# Model tides into every pixel in the three-dimensional satellite
# dataset (x by y by time)
# dataset (x by y by time). If `model` is "ensemble" this will model
# tides by combining the best local tide models.
log.info(f"{run_id}: Modelling tide heights for each pixel")
if (tide_model[0] == "ensemble") or (tide_model == "ensemble"):
# Use ensemble model combining multiple input ocean tide models
tide_m, _ = pixel_tides_ensemble(
satellite_ds,
directory=tide_model_dir,
ancillary_points="data/raw/tide_correlations_2017-2019.geojson",
top_n=3,
reduce_method="mean",
resolution=3000,
)

else:
# Use single input ocean tide model
tide_m, _ = pixel_tides(
satellite_ds,
resample=True,
model=tide_model,
directory=tide_model_dir,
)
tide_m, _ = pixel_tides_ensemble(
ds=satellite_ds,
ancillary_points="data/raw/tide_correlations_2017-2019.geojson",
model=tide_model,
directory=tide_model_dir,
)

# Set tide array pixels to nodata if the satellite data array pixels
# contain nodata. This ensures that we ignore any tide observations
Expand Down Expand Up @@ -1162,19 +1189,21 @@ def intertidal_cli(
if exposure_offsets:
log.info(f"{run_id}: Calculating Intertidal Exposure")

# Set time range
all_timerange = pd.date_range(
# Select times used for exposure modelling
all_times = pd.date_range(
start=round_date_strings(start_date, round_type="start"),
end=round_date_strings(end_date, round_type="end"),
freq=modelled_freq,
)

# Calculate exposure (use only until exposure PR is accepted/merged)
# Calculate exposure
ds["exposure"], tide_cq = exposure(
dem=ds.elevation,
time_range=all_timerange,
times=all_times,
tide_model=tide_model,
tide_model_dir=tide_model_dir,
run_id=run_id,
log=log,
)

# Calculate spread, offsets and HAT/LAT/LOT/HOT
Expand Down
Loading

0 comments on commit 878f343

Please sign in to comment.