Skip to content

Commit

Permalink
Fixing typehints, cleaning up import structure.
Browse files Browse the repository at this point in the history
  • Loading branch information
SpacemanPaul committed Apr 19, 2024
1 parent e6c571a commit 2bb0efd
Show file tree
Hide file tree
Showing 32 changed files with 273 additions and 258 deletions.
3 changes: 2 additions & 1 deletion datacube_ows/cfg_parser_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@
from deepdiff import DeepDiff

from datacube_ows import __version__
from datacube_ows.ows_configuration import (ConfigException, OWSConfig,
from datacube_ows.ows_configuration import (OWSConfig,
OWSFolder, read_config)
from datacube_ows.config_utils import ConfigException


@click.group(invoke_without_command=True)
Expand Down
119 changes: 117 additions & 2 deletions datacube_ows/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@
import logging
import os
from importlib import import_module
from itertools import chain
from typing import (Any, Callable, Iterable, List, Mapping, MutableMapping,
Optional, Sequence, Set, Union, cast)
Optional, Sequence, Set, TypeVar, Union, cast)
from urllib.parse import urlparse

import fsspec
Expand All @@ -17,7 +18,6 @@
from xarray import DataArray

from datacube_ows.config_toolkit import deepinherit
from datacube_ows.ogc_utils import ConfigException, FunctionWrapper

_LOG = logging.getLogger(__name__)

Expand All @@ -32,6 +32,9 @@

CFG_DICT = MutableMapping[str, RAW_CFG]

F = TypeVar('F', bound=Callable[..., Any])



# inclusions defaulting to an empty list is dangerous, but note that it is never modified.
# If modification of inclusions is a required, a copy (ninclusions) is made and modified instead.
Expand Down Expand Up @@ -137,6 +140,12 @@ def import_python_obj(path: str) -> RAW_CFG:
return cast(RAW_CFG, obj)


class ConfigException(Exception):
"""
General exception for OWS Configuration issues.
"""


class OWSConfigNotReady(ConfigException):
"""
Exception raised when someone tries to use an OWSConfigEntry that isn't fully initialised yet.
Expand Down Expand Up @@ -920,3 +929,109 @@ def create_mask(self, data: DataArray) -> DataArray:
if mask is not None and self.invert:
mask = ~mask # pylint: disable=invalid-unary-operand-type
return mask


# Function wrapper for configurable functional elements
class FunctionWrapper:
"""
Function wrapper for configurable functional elements
"""

def __init__(self,
product_or_style_cfg: OWSExtensibleConfigEntry,
func_cfg: F | Mapping[str, Any],
stand_alone: bool = False) -> None:
"""
:param product_or_style_cfg: An instance of either NamedLayer or Style,
the context in which the wrapper operates.
:param func_cfg: A function or a configuration dictionary representing a function.
:param stand_alone: Optional boolean.
If False (the default) then only configuration dictionaries will be accepted.
"""
self.style_or_layer_cfg = product_or_style_cfg
if callable(func_cfg):
if not stand_alone:
raise ConfigException(
"Directly including callable objects in configuration is no longer supported. Please reference callables by fully qualified name.")
self._func = func_cfg
self._args = []
self._kwargs = {}
self.band_mapper = None
self.pass_layer_cfg = False
elif isinstance(func_cfg, str):
self._func = get_function(func_cfg)
self._args = []
self._kwargs = {}
self.band_mapper = None
self.pass_layer_cfg = False
else:
if stand_alone and callable(func_cfg["function"]):
self._func = func_cfg["function"]
elif callable(func_cfg["function"]):
raise ConfigException(
"Directly including callable objects in configuration is no longer supported. Please reference callables by fully qualified name.")
else:
self._func = get_function(func_cfg["function"])
self._args = func_cfg.get("args", [])
self._kwargs = func_cfg.get("kwargs", {}).copy()
self.pass_layer_cfg = func_cfg.get("pass_layer_cfg", False)
if "pass_product_cfg" in func_cfg:
_LOG.warning("WARNING: pass_product_cfg in function wrapper definitions has been renamed "
"'mapped_bands'. Please update your config accordingly")
if func_cfg.get("mapped_bands", func_cfg.get("pass_product_cfg", False)):
if hasattr(product_or_style_cfg, "band_idx"):
# NamedLayer
from datacube_ows.ows_configuration import OWSNamedLayer
named_layer = cast(OWSNamedLayer, product_or_style_cfg)
b_idx = named_layer.band_idx
self.band_mapper = b_idx.band
else:
# Style
from datacube_ows.styles import StyleDef
style = cast(StyleDef, product_or_style_cfg)
b_idx = style.product.band_idx
delocaliser = style.local_band
self.band_mapper = lambda b: b_idx.band(delocaliser(b))
else:
self.band_mapper = None

def __call__(self, *args, **kwargs) -> Any:
if args and self._args:
calling_args: Iterable[Any] = chain(args, self._args)
elif args:
calling_args = args
else:
calling_args = self._args
if kwargs and self._kwargs:
calling_kwargs = self._kwargs.copy()
calling_kwargs.update(kwargs)
elif kwargs:
calling_kwargs = kwargs.copy()
else:
calling_kwargs = self._kwargs.copy()

if self.band_mapper:
calling_kwargs["band_mapper"] = self.band_mapper

if self.pass_layer_cfg:
calling_kwargs['layer_cfg'] = self.style_or_layer_cfg

return self._func(*calling_args, **calling_kwargs)


def get_function(func: F | str) -> F:
"""Converts a config entry to a function, if necessary
:param func: Either a Callable object or a fully qualified function name str, or None
:return: a Callable object, or None
"""
if func is not None and not callable(func):
mod_name, func_name = func.rsplit('.', 1)
try:
mod = import_module(mod_name)
func = getattr(mod, func_name)
except (ImportError, ModuleNotFoundError, ValueError, AttributeError):
raise ConfigException(f"Could not import python object: {func}")
assert callable(func)
return cast(F, func)
12 changes: 6 additions & 6 deletions datacube_ows/cube_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import logging
from contextlib import contextmanager
from threading import Lock
from typing import Generator, MutableMapping, Optional
from typing import Generator

from datacube import Datacube

Expand All @@ -28,11 +28,11 @@ class CubePool:
A Cube pool is a thread-safe resource pool for managing Datacube objects (which map to database connections).
"""
# _instances, global mapping of CubePools by app name
_instances: MutableMapping[str, "CubePool"] = {}
_instances: dict[str, "CubePool"] = {}

_cubes_lock_: bool = False

_instance: Optional[Datacube] = None
_instance: Datacube | None = None

def __new__(cls, app: str) -> "CubePool":
"""
Expand All @@ -54,7 +54,7 @@ def __init__(self, app: str) -> None:
self._cubes_lock: Lock = Lock()
self._cubes_lock_ = True

def get_cube(self) -> Optional[Datacube]:
def get_cube(self) -> Datacube | None:
"""
Return a Datacube object. Either generating a new Datacube, or recycling an unassigned one already in the pool.
Expand All @@ -77,7 +77,7 @@ def _new_cube(self) -> Datacube:


# Lowlevel CubePool API
def get_cube(app: str = "ows") -> Optional[Datacube]:
def get_cube(app: str = "ows") -> Datacube | None:
"""
Obtain a Datacube object from the appropriate pool
Expand All @@ -90,7 +90,7 @@ def get_cube(app: str = "ows") -> Optional[Datacube]:

# High Level Cube Pool API
@contextmanager
def cube(app: str = "ows") -> Generator[Optional["datacube.api.core.Datacube"], None, None]:
def cube(app: str = "ows") -> Generator[Datacube | None, None, None]:
"""
Context manager for using a Datacube object from a pool.
Expand Down
3 changes: 2 additions & 1 deletion datacube_ows/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,10 @@
from datacube_ows.loading import DataStacker
from datacube_ows.mv_index import MVSelectOpts
from datacube_ows.ogc_exceptions import WMSException
from datacube_ows.ogc_utils import (ConfigException, dataset_center_time,
from datacube_ows.ogc_utils import (dataset_center_time,
solar_date, tz_for_geometry,
xarray_image_as_png)
from datacube_ows.config_utils import ConfigException
from datacube_ows.ows_configuration import get_config
from datacube_ows.query_profiler import QueryProfiler
from datacube_ows.resource_limits import ResourceLimited
Expand Down
18 changes: 11 additions & 7 deletions datacube_ows/loading.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from collections import OrderedDict

import datetime
import logging
from typing import Iterable

import datacube
Expand All @@ -11,7 +12,6 @@
from odc.geo.geom import Geometry
from odc.geo.geobox import GeoBox
from odc.geo.warp import Resampling
from datacube_ows.data import _LOG
from datacube_ows.mv_index import MVSelectOpts, mv_search
from datacube_ows.ogc_exceptions import WMSException
from datacube_ows.ows_configuration import OWSNamedLayer
Expand All @@ -20,6 +20,8 @@
from datacube_ows.utils import log_call
from datacube_ows.wms_utils import solar_correct_data

_LOG: logging.Logger = logging.getLogger(__name__)


class ProductBandQuery:
def __init__(self,
Expand Down Expand Up @@ -104,7 +106,7 @@ def full_layer_queries(cls,
def simple_layer_query(cls, layer: OWSNamedLayer,
bands: list[datacube.model.Measurement],
manual_merge: bool = False,
fuse_func: datacube.api.FuserFunction | None = None,
fuse_func: datacube.api.core.FuserFunction | None = None,
resource_limited: bool = False) -> "ProductBandQuery":
if resource_limited:
main_products = layer.low_res_products
Expand All @@ -125,7 +127,7 @@ def __init__(self,
self._product = product
self.cfg = product.global_cfg
self._geobox = geobox
self._resampling = resampling if resampling is not None else Resampling.nearest
self._resampling = resampling if resampling is not None else "nearest"
self.style = style
if style:
self._needed_bands = list(style.needed_bands)
Expand Down Expand Up @@ -335,7 +337,7 @@ def manual_data_stack(self, datasets, measurements, bands, skip_corrections, fus
# Read data for given datasets and measurements per the output_geobox
# TODO: Make skip_broken passed in via config
@log_call
def read_data(self, datasets, measurements, geobox, skip_broken = True, resampling=Resampling.nearest, fuse_func=None):
def read_data(self, datasets, measurements, geobox, skip_broken = True, resampling="nearest", fuse_func=None):
CredentialManager.check_cred()
try:
return datacube.Datacube.load_data(
Expand All @@ -344,14 +346,15 @@ def read_data(self, datasets, measurements, geobox, skip_broken = True, resampli
measurements=measurements,
fuse_func=fuse_func,
skip_broken_datasets=skip_broken,
patch_url=self._product.patch_url)
patch_url=self._product.patch_url,
resampling=resampling)
except Exception as e:
_LOG.error("Error (%s) in load_data: %s", e.__class__.__name__, str(e))
raise
# Read data for single datasets and measurements per the output_geobox
# TODO: Make skip_broken passed in via config
@log_call
def read_data_for_single_dataset(self, dataset, measurements, geobox, skip_broken = True, resampling=Resampling.nearest, fuse_func=None):
def read_data_for_single_dataset(self, dataset, measurements, geobox, skip_broken = True, resampling="nearest", fuse_func=None):
datasets = [dataset]
dc_datasets = datacube.Datacube.group_datasets(datasets, self._product.time_resolution.dataset_groupby())
CredentialManager.check_cred()
Expand All @@ -362,7 +365,8 @@ def read_data_for_single_dataset(self, dataset, measurements, geobox, skip_broke
measurements=measurements,
fuse_func=fuse_func,
skip_broken_datasets=skip_broken,
patch_url=self._product.patch_url)
patch_url=self._product.patch_url,
resampling=resampling)
except Exception as e:
_LOG.error("Error (%s) in load_data: %s", e.__class__.__name__, str(e))
raise
Loading

0 comments on commit 2bb0efd

Please sign in to comment.