From 2bb0efd0f4d2dd813e8ce919b4f590d1019512b5 Mon Sep 17 00:00:00 2001 From: Paul Haesler Date: Fri, 19 Apr 2024 11:25:07 +1000 Subject: [PATCH] Fixing typehints, cleaning up import structure. --- datacube_ows/cfg_parser_impl.py | 3 +- datacube_ows/config_utils.py | 119 ++++++++++++++++++- datacube_ows/cube_pool.py | 12 +- datacube_ows/data.py | 3 +- datacube_ows/loading.py | 18 +-- datacube_ows/mv_index.py | 70 +++++------ datacube_ows/ogc_utils.py | 185 ++++++------------------------ datacube_ows/ows_configuration.py | 5 +- datacube_ows/resource_limits.py | 4 +- datacube_ows/styles/base.py | 3 +- datacube_ows/styles/component.py | 3 +- datacube_ows/styles/expression.py | 2 +- datacube_ows/styles/hybrid.py | 3 +- datacube_ows/styles/ramp.py | 3 +- datacube_ows/tile_matrix_sets.py | 3 +- datacube_ows/utils.py | 31 ++--- datacube_ows/wcs1_utils.py | 2 +- datacube_ows/wms_utils.py | 3 +- mypy.ini | 4 + setup.py | 7 +- tests/test_cfg_bandidx.py | 3 +- tests/test_cfg_cache_ctrl.py | 2 +- tests/test_cfg_global.py | 2 +- tests/test_cfg_inclusion.py | 4 +- tests/test_cfg_layer.py | 2 +- tests/test_cfg_metadata_types.py | 2 +- tests/test_cfg_tile_matrix_set.py | 2 +- tests/test_cfg_wcs.py | 2 +- tests/test_multidate_handler.py | 2 +- tests/test_ows_configuration.py | 22 ++-- tests/test_style_api.py | 2 +- tests/test_styles.py | 3 +- 32 files changed, 273 insertions(+), 258 deletions(-) create mode 100644 mypy.ini diff --git a/datacube_ows/cfg_parser_impl.py b/datacube_ows/cfg_parser_impl.py index a2fdcc22f..353aca80c 100755 --- a/datacube_ows/cfg_parser_impl.py +++ b/datacube_ows/cfg_parser_impl.py @@ -15,8 +15,9 @@ from deepdiff import DeepDiff from datacube_ows import __version__ -from datacube_ows.ows_configuration import (ConfigException, OWSConfig, +from datacube_ows.ows_configuration import (OWSConfig, OWSFolder, read_config) +from datacube_ows.config_utils import ConfigException @click.group(invoke_without_command=True) diff --git a/datacube_ows/config_utils.py b/datacube_ows/config_utils.py index 8a331a554..db7e33fd2 100644 --- a/datacube_ows/config_utils.py +++ b/datacube_ows/config_utils.py @@ -7,8 +7,9 @@ import logging import os from importlib import import_module +from itertools import chain from typing import (Any, Callable, Iterable, List, Mapping, MutableMapping, - Optional, Sequence, Set, Union, cast) + Optional, Sequence, Set, TypeVar, Union, cast) from urllib.parse import urlparse import fsspec @@ -17,7 +18,6 @@ from xarray import DataArray from datacube_ows.config_toolkit import deepinherit -from datacube_ows.ogc_utils import ConfigException, FunctionWrapper _LOG = logging.getLogger(__name__) @@ -32,6 +32,9 @@ CFG_DICT = MutableMapping[str, RAW_CFG] +F = TypeVar('F', bound=Callable[..., Any]) + + # inclusions defaulting to an empty list is dangerous, but note that it is never modified. # If modification of inclusions is a required, a copy (ninclusions) is made and modified instead. @@ -137,6 +140,12 @@ def import_python_obj(path: str) -> RAW_CFG: return cast(RAW_CFG, obj) +class ConfigException(Exception): + """ + General exception for OWS Configuration issues. + """ + + class OWSConfigNotReady(ConfigException): """ Exception raised when someone tries to use an OWSConfigEntry that isn't fully initialised yet. @@ -920,3 +929,109 @@ def create_mask(self, data: DataArray) -> DataArray: if mask is not None and self.invert: mask = ~mask # pylint: disable=invalid-unary-operand-type return mask + + +# Function wrapper for configurable functional elements +class FunctionWrapper: + """ + Function wrapper for configurable functional elements + """ + + def __init__(self, + product_or_style_cfg: OWSExtensibleConfigEntry, + func_cfg: F | Mapping[str, Any], + stand_alone: bool = False) -> None: + """ + + :param product_or_style_cfg: An instance of either NamedLayer or Style, + the context in which the wrapper operates. + :param func_cfg: A function or a configuration dictionary representing a function. + :param stand_alone: Optional boolean. + If False (the default) then only configuration dictionaries will be accepted. + """ + self.style_or_layer_cfg = product_or_style_cfg + if callable(func_cfg): + if not stand_alone: + raise ConfigException( + "Directly including callable objects in configuration is no longer supported. Please reference callables by fully qualified name.") + self._func = func_cfg + self._args = [] + self._kwargs = {} + self.band_mapper = None + self.pass_layer_cfg = False + elif isinstance(func_cfg, str): + self._func = get_function(func_cfg) + self._args = [] + self._kwargs = {} + self.band_mapper = None + self.pass_layer_cfg = False + else: + if stand_alone and callable(func_cfg["function"]): + self._func = func_cfg["function"] + elif callable(func_cfg["function"]): + raise ConfigException( + "Directly including callable objects in configuration is no longer supported. Please reference callables by fully qualified name.") + else: + self._func = get_function(func_cfg["function"]) + self._args = func_cfg.get("args", []) + self._kwargs = func_cfg.get("kwargs", {}).copy() + self.pass_layer_cfg = func_cfg.get("pass_layer_cfg", False) + if "pass_product_cfg" in func_cfg: + _LOG.warning("WARNING: pass_product_cfg in function wrapper definitions has been renamed " + "'mapped_bands'. Please update your config accordingly") + if func_cfg.get("mapped_bands", func_cfg.get("pass_product_cfg", False)): + if hasattr(product_or_style_cfg, "band_idx"): + # NamedLayer + from datacube_ows.ows_configuration import OWSNamedLayer + named_layer = cast(OWSNamedLayer, product_or_style_cfg) + b_idx = named_layer.band_idx + self.band_mapper = b_idx.band + else: + # Style + from datacube_ows.styles import StyleDef + style = cast(StyleDef, product_or_style_cfg) + b_idx = style.product.band_idx + delocaliser = style.local_band + self.band_mapper = lambda b: b_idx.band(delocaliser(b)) + else: + self.band_mapper = None + + def __call__(self, *args, **kwargs) -> Any: + if args and self._args: + calling_args: Iterable[Any] = chain(args, self._args) + elif args: + calling_args = args + else: + calling_args = self._args + if kwargs and self._kwargs: + calling_kwargs = self._kwargs.copy() + calling_kwargs.update(kwargs) + elif kwargs: + calling_kwargs = kwargs.copy() + else: + calling_kwargs = self._kwargs.copy() + + if self.band_mapper: + calling_kwargs["band_mapper"] = self.band_mapper + + if self.pass_layer_cfg: + calling_kwargs['layer_cfg'] = self.style_or_layer_cfg + + return self._func(*calling_args, **calling_kwargs) + + +def get_function(func: F | str) -> F: + """Converts a config entry to a function, if necessary + + :param func: Either a Callable object or a fully qualified function name str, or None + :return: a Callable object, or None + """ + if func is not None and not callable(func): + mod_name, func_name = func.rsplit('.', 1) + try: + mod = import_module(mod_name) + func = getattr(mod, func_name) + except (ImportError, ModuleNotFoundError, ValueError, AttributeError): + raise ConfigException(f"Could not import python object: {func}") + assert callable(func) + return cast(F, func) diff --git a/datacube_ows/cube_pool.py b/datacube_ows/cube_pool.py index c46f6f122..2739a68aa 100644 --- a/datacube_ows/cube_pool.py +++ b/datacube_ows/cube_pool.py @@ -6,7 +6,7 @@ import logging from contextlib import contextmanager from threading import Lock -from typing import Generator, MutableMapping, Optional +from typing import Generator from datacube import Datacube @@ -28,11 +28,11 @@ class CubePool: A Cube pool is a thread-safe resource pool for managing Datacube objects (which map to database connections). """ # _instances, global mapping of CubePools by app name - _instances: MutableMapping[str, "CubePool"] = {} + _instances: dict[str, "CubePool"] = {} _cubes_lock_: bool = False - _instance: Optional[Datacube] = None + _instance: Datacube | None = None def __new__(cls, app: str) -> "CubePool": """ @@ -54,7 +54,7 @@ def __init__(self, app: str) -> None: self._cubes_lock: Lock = Lock() self._cubes_lock_ = True - def get_cube(self) -> Optional[Datacube]: + def get_cube(self) -> Datacube | None: """ Return a Datacube object. Either generating a new Datacube, or recycling an unassigned one already in the pool. @@ -77,7 +77,7 @@ def _new_cube(self) -> Datacube: # Lowlevel CubePool API -def get_cube(app: str = "ows") -> Optional[Datacube]: +def get_cube(app: str = "ows") -> Datacube | None: """ Obtain a Datacube object from the appropriate pool @@ -90,7 +90,7 @@ def get_cube(app: str = "ows") -> Optional[Datacube]: # High Level Cube Pool API @contextmanager -def cube(app: str = "ows") -> Generator[Optional["datacube.api.core.Datacube"], None, None]: +def cube(app: str = "ows") -> Generator[Datacube | None, None, None]: """ Context manager for using a Datacube object from a pool. diff --git a/datacube_ows/data.py b/datacube_ows/data.py index 0bf8d0e33..54f69920f 100644 --- a/datacube_ows/data.py +++ b/datacube_ows/data.py @@ -25,9 +25,10 @@ from datacube_ows.loading import DataStacker from datacube_ows.mv_index import MVSelectOpts from datacube_ows.ogc_exceptions import WMSException -from datacube_ows.ogc_utils import (ConfigException, dataset_center_time, +from datacube_ows.ogc_utils import (dataset_center_time, solar_date, tz_for_geometry, xarray_image_as_png) +from datacube_ows.config_utils import ConfigException from datacube_ows.ows_configuration import get_config from datacube_ows.query_profiler import QueryProfiler from datacube_ows.resource_limits import ResourceLimited diff --git a/datacube_ows/loading.py b/datacube_ows/loading.py index 966413d63..1d68200b8 100644 --- a/datacube_ows/loading.py +++ b/datacube_ows/loading.py @@ -1,6 +1,7 @@ from collections import OrderedDict import datetime +import logging from typing import Iterable import datacube @@ -11,7 +12,6 @@ from odc.geo.geom import Geometry from odc.geo.geobox import GeoBox from odc.geo.warp import Resampling -from datacube_ows.data import _LOG from datacube_ows.mv_index import MVSelectOpts, mv_search from datacube_ows.ogc_exceptions import WMSException from datacube_ows.ows_configuration import OWSNamedLayer @@ -20,6 +20,8 @@ from datacube_ows.utils import log_call from datacube_ows.wms_utils import solar_correct_data +_LOG: logging.Logger = logging.getLogger(__name__) + class ProductBandQuery: def __init__(self, @@ -104,7 +106,7 @@ def full_layer_queries(cls, def simple_layer_query(cls, layer: OWSNamedLayer, bands: list[datacube.model.Measurement], manual_merge: bool = False, - fuse_func: datacube.api.FuserFunction | None = None, + fuse_func: datacube.api.core.FuserFunction | None = None, resource_limited: bool = False) -> "ProductBandQuery": if resource_limited: main_products = layer.low_res_products @@ -125,7 +127,7 @@ def __init__(self, self._product = product self.cfg = product.global_cfg self._geobox = geobox - self._resampling = resampling if resampling is not None else Resampling.nearest + self._resampling = resampling if resampling is not None else "nearest" self.style = style if style: self._needed_bands = list(style.needed_bands) @@ -335,7 +337,7 @@ def manual_data_stack(self, datasets, measurements, bands, skip_corrections, fus # Read data for given datasets and measurements per the output_geobox # TODO: Make skip_broken passed in via config @log_call - def read_data(self, datasets, measurements, geobox, skip_broken = True, resampling=Resampling.nearest, fuse_func=None): + def read_data(self, datasets, measurements, geobox, skip_broken = True, resampling="nearest", fuse_func=None): CredentialManager.check_cred() try: return datacube.Datacube.load_data( @@ -344,14 +346,15 @@ def read_data(self, datasets, measurements, geobox, skip_broken = True, resampli measurements=measurements, fuse_func=fuse_func, skip_broken_datasets=skip_broken, - patch_url=self._product.patch_url) + patch_url=self._product.patch_url, + resampling=resampling) except Exception as e: _LOG.error("Error (%s) in load_data: %s", e.__class__.__name__, str(e)) raise # Read data for single datasets and measurements per the output_geobox # TODO: Make skip_broken passed in via config @log_call - def read_data_for_single_dataset(self, dataset, measurements, geobox, skip_broken = True, resampling=Resampling.nearest, fuse_func=None): + def read_data_for_single_dataset(self, dataset, measurements, geobox, skip_broken = True, resampling="nearest", fuse_func=None): datasets = [dataset] dc_datasets = datacube.Datacube.group_datasets(datasets, self._product.time_resolution.dataset_groupby()) CredentialManager.check_cred() @@ -362,7 +365,8 @@ def read_data_for_single_dataset(self, dataset, measurements, geobox, skip_broke measurements=measurements, fuse_func=fuse_func, skip_broken_datasets=skip_broken, - patch_url=self._product.patch_url) + patch_url=self._product.patch_url, + resampling=resampling) except Exception as e: _LOG.error("Error (%s) in load_data: %s", e.__class__.__name__, str(e)) raise diff --git a/datacube_ows/mv_index.py b/datacube_ows/mv_index.py index 482f8eaa9..5c9cfda45 100644 --- a/datacube_ows/mv_index.py +++ b/datacube_ows/mv_index.py @@ -17,12 +17,18 @@ from sqlalchemy.dialects.postgresql import TSTZRANGE, UUID from sqlalchemy.sql.functions import count, func +from datacube.index import Index +from datacube.model import Product, Dataset +from sqlalchemy.engine.base import Engine +from sqlalchemy.sql.elements import ClauseElement + + from datacube_ows.utils import default_to_utc -def get_sqlalc_engine(index: "datacube.index.Index") -> "sqlalchemy.engine.base.Engine": +def get_sqlalc_engine(index: Index) -> Engine: # pylint: disable=protected-access - return index._db._engine + return index._db._engine # type: ignore[attr-defined] def get_st_view(meta: MetaData) -> Table: @@ -32,6 +38,8 @@ def get_st_view(meta: MetaData) -> Table: Column('spatial_extent', Geometry(from_text='ST_GeomFromGeoJSON', name='geometry')), Column('temporal_extent', TSTZRANGE()) ) + + _meta = MetaData() st_view = get_st_view(_meta) @@ -53,33 +61,28 @@ class MVSelectOpts(Enum): DATASETS = 4 INVALID = 9999 - def sel(self, stv: Table) -> Iterable["sqlalchemy.sql.elements.ClauseElement"]: + def sel(self, stv: Table) -> list[ClauseElement]: if self == self.ALL: return [stv] if self == self.IDS or self == self.DATASETS: return [stv.c.id] if self == self.COUNT: - return [cast("sqlalchemy.sql.elements.ClauseElement", count(stv.c.id))] + return [cast(ClauseElement, count(stv.c.id))] if self == self.EXTENT: return [text("ST_AsGeoJSON(ST_Union(spatial_extent))")] - assert False + raise Exception("Invalid selection option") -TimeSearchTerm = Union[ - Tuple[datetime.datetime, datetime.datetime], - datetime.datetime, -] -def mv_search(index: "datacube.index.Index", +DateOrDateTime = datetime.datetime | datetime.date +TimeSearchTerm = tuple[datetime.datetime, datetime.datetime] | tuple[datetime.date, datetime.date] | DateOrDateTime + +MVSearchResult = Iterable[Iterable[Any]] | Iterable[str] | Iterable[Dataset] | int | None | ODCGeom + +def mv_search(index: Index, sel: MVSelectOpts = MVSelectOpts.IDS, - times: Optional[Iterable[TimeSearchTerm]] = None, - geom: Optional[ODCGeom] = None, - products: Optional[Iterable["datacube.model.DatasetType"]] = None) -> Union[ - Iterable[Iterable[Any]], - Iterable[str], - Iterable["datacube.model.Dataset"], - int, - None, - ODCGeom]: + times: Iterable[TimeSearchTerm] | None = None, + geom: ODCGeom | None = None, + products: Iterable[Product] | None = None) -> MVSearchResult: """ Perform a dataset query via the space_time_view @@ -98,16 +101,16 @@ def mv_search(index: "datacube.index.Index", raise Exception("Must filter by product/layer") prod_ids = [p.id for p in products] - s = select(*sel.sel(stv)).where(stv.c.dataset_type_ref.in_(prod_ids)) + s = select(*sel.sel(stv)).where(stv.c.dataset_type_ref.in_(prod_ids)) # type: ignore[call-overload] if times is not None: or_clauses = [] for t in times: if isinstance(t, datetime.datetime): - t = datetime.datetime(t.year, t.month, t.day, t.hour, t.minute, t.second) - t = default_to_utc(t) - if not t.tzinfo: - t = t.replace(tzinfo=pytz.utc) - tmax = t + datetime.timedelta(seconds=1) + st: datetime.datetime = datetime.datetime(t.year, t.month, t.day, t.hour, t.minute, t.second) + st = default_to_utc(t) + if not st.tzinfo: + st = st.replace(tzinfo=pytz.utc) + tmax = st + datetime.timedelta(seconds=1) or_clauses.append( and_( func.lower(stv.c.temporal_extent) >= t, @@ -115,11 +118,11 @@ def mv_search(index: "datacube.index.Index", ) ) elif isinstance(t, datetime.date): - t = datetime.datetime(t.year, t.month, t.day, tzinfo=pytz.utc) - tmax = t + datetime.timedelta(days=1) + st = datetime.datetime(t.year, t.month, t.day, tzinfo=pytz.utc) + tmax = st + datetime.timedelta(days=1) or_clauses.append( and_( - func.lower(stv.c.temporal_extent) >= t, + func.lower(stv.c.temporal_extent) >= st, func.lower(stv.c.temporal_extent) < tmax, ) ) @@ -139,13 +142,13 @@ def mv_search(index: "datacube.index.Index", with engine.connect() as conn: if sel == MVSelectOpts.ALL: return conn.execute(s) - if sel == MVSelectOpts.IDS: + elif sel == MVSelectOpts.IDS: return [r[0] for r in conn.execute(s)] - if sel in (MVSelectOpts.COUNT, MVSelectOpts.EXTENT): + elif sel in (MVSelectOpts.COUNT, MVSelectOpts.EXTENT): for r in conn.execute(s): if sel == MVSelectOpts.COUNT: return r[0] - if sel == MVSelectOpts.EXTENT: + else: # MVSelectOpts.EXTENT geojson = r[0] if geojson is None: return None @@ -159,6 +162,9 @@ def mv_search(index: "datacube.index.Index", else: intersect = uniongeom return intersect - if sel == MVSelectOpts.DATASETS: + elif sel == MVSelectOpts.DATASETS: ids = [r[0] for r in conn.execute(s)] return index.datasets.bulk_get(ids) + else: + raise Exception("Invalid Selection Option") + raise Exception("Unreachable code reached") diff --git a/datacube_ows/ogc_utils.py b/datacube_ows/ogc_utils.py index d15867fc1..1eb3aca00 100644 --- a/datacube_ows/ogc_utils.py +++ b/datacube_ows/ogc_utils.py @@ -5,32 +5,33 @@ # SPDX-License-Identifier: Apache-2.0 import datetime import logging -from importlib import import_module from io import BytesIO -from itertools import chain -from typing import (Any, Callable, Mapping, MutableMapping, Optional, Sequence, - Tuple, TypeVar, Union, cast) +from typing import (Any, Mapping, Optional, Sequence, cast) from urllib.parse import urlparse import numpy +import xarray from affine import Affine from dateutil.parser import parse -from flask import request +from flask import request, Request from odc.geo.geobox import GeoBox from odc.geo.geom import CRS, Geometry from PIL import Image from pytz import timezone, utc from timezonefinder import TimezoneFinder +from datacube.model import Dataset +from datacube_ows.config_utils import OWSExtensibleConfigEntry + _LOG: logging.Logger = logging.getLogger(__name__) tf = TimezoneFinder(in_memory=True) -def dataset_center_time(dataset: "datacube.model.Dataset") -> datetime.datetime: +def dataset_center_time(dataset: Dataset) -> datetime.datetime: """ Determine a center_time for the dataset - Use metadata time if possible as this is what WMS uses to calculate it's temporal extents + Use metadata time if possible as this is what WMS uses to calculate its temporal extents datacube-core center time accessed through the dataset API is calculated and may not agree with the metadata document. @@ -43,7 +44,7 @@ def dataset_center_time(dataset: "datacube.model.Dataset") -> datetime.datetime: center_time = parse(metadata_time) except KeyError: try: - metadata_time: str = dataset.metadata_doc['properties']['dtr:start_datetime'] + metadata_time = dataset.metadata_doc['properties']['dtr:start_datetime'] center_time = parse(metadata_time) except KeyError: pass @@ -65,7 +66,7 @@ def solar_date(dt: datetime.datetime, tz: datetime.tzinfo) -> datetime.date: return dt.astimezone(tz).date() -def local_date(ds: "datacube.model.Dataset", tz: Optional[datetime.tzinfo] = None) -> datetime.date: +def local_date(ds: Dataset, tz: datetime.tzinfo | None = None) -> datetime.date: """ Calculate the local (solar) date for a dataset. @@ -79,7 +80,7 @@ def local_date(ds: "datacube.model.Dataset", tz: Optional[datetime.tzinfo] = Non return solar_date(dt_utc, tz) -def tz_for_dataset(ds: "datacube.model.Dataset") -> datetime.tzinfo: +def tz_for_dataset(ds: Dataset) -> datetime.tzinfo: """ Determine the timezone for a dataset (using it's extent) @@ -89,7 +90,7 @@ def tz_for_dataset(ds: "datacube.model.Dataset") -> datetime.tzinfo: return tz_for_geometry(ds.extent) -def tz_for_coord(lon: Union[float, int], lat: Union[float, int]) -> datetime.tzinfo: +def tz_for_coord(lon: float | int, lat: float | int) -> datetime.tzinfo: """ Determine the Timezone for given lat/long coordinates @@ -109,7 +110,7 @@ def tz_for_coord(lon: Union[float, int], lat: Union[float, int]) -> datetime.tzi return timezone(tzn) -def local_solar_date_range(geobox: GeoBox, date: datetime.date) -> Tuple[datetime.datetime, datetime.datetime]: +def local_solar_date_range(geobox: GeoBox, date: datetime.date) -> tuple[datetime.datetime, datetime.datetime]: """ Converts a date to a local solar date datetime range. @@ -123,7 +124,7 @@ def local_solar_date_range(geobox: GeoBox, date: datetime.date) -> Tuple[datetim return (start.astimezone(utc), end.astimezone(utc)) -def month_date_range(date: datetime.date) -> Tuple[datetime.datetime, datetime.datetime]: +def month_date_range(date: datetime.date) -> tuple[datetime.datetime, datetime.datetime]: """ Take a month from a date and convert to a one month long UTC datetime range encompassing the month. @@ -142,7 +143,7 @@ def month_date_range(date: datetime.date) -> Tuple[datetime.datetime, datetime.d return start, end -def year_date_range(date: datetime.date) -> Tuple[datetime.datetime, datetime.datetime]: +def year_date_range(date: datetime.date) -> tuple[datetime.datetime, datetime.datetime]: """ Convert a date to a UTC datetime range encompassing the calendar year including the date. @@ -156,7 +157,7 @@ def year_date_range(date: datetime.date) -> Tuple[datetime.datetime, datetime.da return start, end -def day_summary_date_range(date: datetime.date) -> Tuple[datetime.datetime, datetime.datetime]: +def day_summary_date_range(date: datetime.date) -> tuple[datetime.datetime, datetime.datetime]: """ Convert a date to a UTC datetime range encompassing the calendar date. @@ -208,26 +209,6 @@ def resp_headers(d: Mapping[str, str]) -> Mapping[str, str]: return get_config().response_headers(d) -F = TypeVar('F', bound=Callable[..., Any]) - - -def get_function(func: Union[F, str]) -> F: - """Converts a config entry to a function, if necessary - - :param func: Either a Callable object or a fully qualified function name str, or None - :return: a Callable object, or None - """ - if func is not None and not callable(func): - mod_name, func_name = func.rsplit('.', 1) - try: - mod = import_module(mod_name) - func = getattr(mod, func_name) - except (ImportError, ModuleNotFoundError, ValueError, AttributeError): - raise ConfigException(f"Could not import python object: {func}") - assert callable(func) - return cast(F, func) - - def parse_for_base_url(url: str) -> str: """ Extract the base URL from a URL @@ -240,7 +221,7 @@ def parse_for_base_url(url: str) -> str: return parsed -def get_service_base_url(allowed_urls: Union[Sequence[str], str], request_url: str) -> str: +def get_service_base_url(allowed_urls: list[str] | str, request_url: str) -> str: """ Choose the base URL to advertise in XML. @@ -263,9 +244,8 @@ def get_service_base_url(allowed_urls: Union[Sequence[str], str], request_url: s # Collects additional headers from flask request objects -def capture_headers(req: "flask.Request", - args_dict: MutableMapping[str, Optional[str]]) \ - -> MutableMapping[str, Optional[str]]: +def capture_headers(req: Request, + args_dict: dict[str, str | None]) -> dict[str, Optional[str]]: """ Capture significant flask metadata into the args dictionary @@ -282,104 +262,7 @@ def capture_headers(req: "flask.Request", return args_dict -class ConfigException(Exception): - """ - General exception for OWS Configuration issues. - """ - - -# Function wrapper for configurable functional elements - - -class FunctionWrapper: - """ - Function wrapper for configurable functional elements - """ - - def __init__(self, - product_or_style_cfg: Union[ - "datacube_ows.ows_configuration.OWSNamedLayer", "datacube_ows.styles.StyleDef"], - func_cfg: Union[F, Mapping[str, Any]], - stand_alone: bool = False) -> None: - """ - - :param product_or_style_cfg: An instance of either NamedLayer or Style, - the context in which the wrapper operates. - :param func_cfg: A function or a configuration dictionary representing a function. - :param stand_alone: Optional boolean. - If False (the default) then only configuration dictionaries will be accepted. - """ - self.style_or_layer_cfg = product_or_style_cfg - if callable(func_cfg): - if not stand_alone: - raise ConfigException( - "Directly including callable objects in configuration is no longer supported. Please reference callables by fully qualified name.") - self._func = func_cfg - self._args = [] - self._kwargs = {} - self.band_mapper = None - self.pass_layer_cfg = False - elif isinstance(func_cfg, str): - self._func = get_function(func_cfg) - self._args = [] - self._kwargs = {} - self.band_mapper = None - self.pass_layer_cfg = False - else: - if stand_alone and callable(func_cfg["function"]): - self._func = func_cfg["function"] - elif callable(func_cfg["function"]): - raise ConfigException( - "Directly including callable objects in configuration is no longer supported. Please reference callables by fully qualified name.") - else: - self._func = get_function(func_cfg["function"]) - self._args = func_cfg.get("args", []) - self._kwargs = func_cfg.get("kwargs", {}).copy() - self.pass_layer_cfg = func_cfg.get("pass_layer_cfg", False) - if "pass_product_cfg" in func_cfg: - _LOG.warning("WARNING: pass_product_cfg in function wrapper definitions has been renamed " - "'mapped_bands'. Please update your config accordingly") - if func_cfg.get("mapped_bands", func_cfg.get("pass_product_cfg", False)): - if hasattr(product_or_style_cfg, "band_idx"): - # NamedLayer - named_layer = cast("datacube_ows.ows_configuration.OWSNamedLayer", - product_or_style_cfg) - b_idx = named_layer.band_idx - self.band_mapper = b_idx.band - else: - # Style - style = cast("datacube_ows.styles.StyleDef", product_or_style_cfg) - b_idx = style.product.band_idx - delocaliser = style.local_band - self.band_mapper = lambda b: b_idx.band(delocaliser(b)) - else: - self.band_mapper = None - - def __call__(self, *args, **kwargs) -> Any: - if args and self._args: - calling_args = chain(args, self._args) - elif args: - calling_args = args - else: - calling_args = self._args - if kwargs and self._kwargs: - calling_kwargs = self._kwargs.copy() - calling_kwargs.update(kwargs) - elif kwargs: - calling_kwargs = kwargs.copy() - else: - calling_kwargs = self._kwargs.copy() - - if self.band_mapper: - calling_kwargs["band_mapper"] = self.band_mapper - - if self.pass_layer_cfg: - calling_kwargs['layer_cfg'] = self.style_or_layer_cfg - - return self._func(*calling_args, **calling_kwargs) - - -def cache_control_headers(max_age: int) -> str: +def cache_control_headers(max_age: int) -> dict[str, str]: if max_age <= 0: return {"cache-control": "no-cache"} else: @@ -388,7 +271,7 @@ def cache_control_headers(max_age: int) -> str: # Extent Mask Functions -def mask_by_val(data: "xarray.Dataset", band: str, val: Optional[Any] = None) -> "xarray.DataArray": +def mask_by_val(data: xarray.Dataset, band: str, val: Any = None) -> xarray.DataArray: """ Mask by value. Value to mask by may be supplied, or is taken from 'nodata' metadata by default. @@ -401,7 +284,7 @@ def mask_by_val(data: "xarray.Dataset", band: str, val: Optional[Any] = None) -> return data[band] != val -def mask_by_val2(data: "xarray.Dataset", band: str) -> "xarray.DataArray": +def mask_by_val2(data: xarray.Dataset, band: str) -> xarray.DataArray: """ Mask by value, using ODC canonical nodata value @@ -410,14 +293,14 @@ def mask_by_val2(data: "xarray.Dataset", band: str) -> "xarray.DataArray": return data[band] != data[band].nodata -def mask_by_bitflag(data: "xarray.Dataset", band: str) -> "xarray.DataArray": +def mask_by_bitflag(data: xarray.Dataset, band: str) -> xarray.DataArray: """ Mask by ODC metadata nodata value, as a bitflag """ return ~data[band] & data[band].attrs['nodata'] -def mask_by_val_in_band(data: "xarray.Dataset", band: str, mask_band: str, val: Any = None) -> "xarray.DataArray": +def mask_by_val_in_band(data: xarray.Dataset, band: str, mask_band: str, val: Any = None) -> xarray.DataArray: """ Mask all bands by a value in a particular band @@ -427,7 +310,7 @@ def mask_by_val_in_band(data: "xarray.Dataset", band: str, mask_band: str, val: return mask_by_val(data, mask_band, val) -def mask_by_quality(data: "xarray.Dataset", band: str) -> "xarray.DataArray": +def mask_by_quality(data: xarray.Dataset, band: str) -> xarray.DataArray: """ Mask by a quality band. @@ -439,7 +322,7 @@ def mask_by_quality(data: "xarray.Dataset", band: str) -> "xarray.DataArray": return mask_by_val(data, "quality") -def mask_by_extent_flag(data: "xarray.Dataset", band: str) -> "xarray.DataArray": +def mask_by_extent_flag(data: xarray.Dataset, band: str) -> xarray.DataArray: """ Mask by extent. @@ -448,7 +331,7 @@ def mask_by_extent_flag(data: "xarray.Dataset", band: str) -> "xarray.DataArray" return data["extent"] == 1 -def mask_by_extent_val(data: "xarray.Dataset", band: str) -> "xarray.DataArray": +def mask_by_extent_val(data: xarray.Dataset, band: str) -> xarray.DataArray: """ Mask by extent value using metadata nodata. @@ -457,7 +340,7 @@ def mask_by_extent_val(data: "xarray.Dataset", band: str) -> "xarray.DataArray": return mask_by_val(data, "extent") -def mask_by_nan(data: "xarray.Dataset", band: str) -> "numpy.NDArray": +def mask_by_nan(data: xarray.Dataset, band: str) -> numpy.ndarray: """ Mask by nan, for bands with floating point data """ @@ -467,8 +350,8 @@ def mask_by_nan(data: "xarray.Dataset", band: str) -> "numpy.NDArray": # Example mosaic date function def rolling_window_ndays( available_dates: Sequence[datetime.datetime], - layer_cfg: "datacube_ows.ows_configuration.OWSNamedLayer", - ndays: int = 6) -> Tuple[datetime.datetime, datetime.datetime]: + layer_cfg: OWSExtensibleConfigEntry, + ndays: int = 6) -> tuple[datetime.datetime, datetime.datetime]: idx = -ndays days = available_dates[idx:] start, _ = layer_cfg.search_times(days[idx]) @@ -486,7 +369,7 @@ def rolling_window_ndays( # Method for formatting urls, e.g. for use in feature_info custom inclusions. -def lower_get_args() -> MutableMapping[str, str]: +def lower_get_args() -> dict[str, str]: """ Return Flask request arguments, with argument names converted to lower case. @@ -504,9 +387,9 @@ def lower_get_args() -> MutableMapping[str, str]: def create_geobox( crs: CRS, - minx: Union[float, int], miny: Union[float, int], - maxx: Union[float, int], maxy: Union[float, int], - width: Optional[int] = None, height: Optional[int] = None, + minx: float | int, miny: float | int, + maxx: float | int, maxy: float | int, + width: int | None = None, height: int | None = None, ) -> GeoBox: """ Create an ODC Geobox. diff --git a/datacube_ows/ows_configuration.py b/datacube_ows/ows_configuration.py index 162d0d081..ef54ed79e 100644 --- a/datacube_ows/ows_configuration.py +++ b/datacube_ows/ows_configuration.py @@ -33,10 +33,9 @@ OWSExtensibleConfigEntry, OWSFlagBand, OWSMetadataConfig, cfg_expand, get_file_loc, import_python_obj, - load_json_obj) + load_json_obj, ConfigException, FunctionWrapper) from datacube_ows.cube_pool import ODCInitException, cube, get_cube -from datacube_ows.ogc_utils import (ConfigException, FunctionWrapper, - create_geobox, local_solar_date_range) +from datacube_ows.ogc_utils import (create_geobox, local_solar_date_range) from datacube_ows.resource_limits import (OWSResourceManagementRules, parse_cache_age) from datacube_ows.styles import StyleDef diff --git a/datacube_ows/resource_limits.py b/datacube_ows/resource_limits.py index ed8a7b530..bccb83282 100644 --- a/datacube_ows/resource_limits.py +++ b/datacube_ows/resource_limits.py @@ -11,8 +11,8 @@ from odc.geo.geobox import GeoBox from odc.geo.geom import CRS, polygon -from datacube_ows.config_utils import CFG_DICT, RAW_CFG, OWSConfigEntry -from datacube_ows.ogc_utils import (ConfigException, cache_control_headers, +from datacube_ows.config_utils import CFG_DICT, RAW_CFG, OWSConfigEntry, ConfigException +from datacube_ows.ogc_utils import (cache_control_headers, create_geobox) diff --git a/datacube_ows/styles/base.py b/datacube_ows/styles/base.py index b5e713130..ca35a4912 100644 --- a/datacube_ows/styles/base.py +++ b/datacube_ows/styles/base.py @@ -21,10 +21,9 @@ OWSExtensibleConfigEntry, OWSFlagBandStandalone, OWSIndexedConfigEntry, - OWSMetadataConfig) + OWSMetadataConfig, ConfigException, FunctionWrapper) from datacube_ows.legend_utils import get_image_from_url from datacube_ows.ogc_exceptions import WMSException -from datacube_ows.ogc_utils import ConfigException, FunctionWrapper _LOG: logging.Logger = logging.getLogger(__name__) diff --git a/datacube_ows/styles/component.py b/datacube_ows/styles/component.py index 686149b28..0aa0a9964 100644 --- a/datacube_ows/styles/component.py +++ b/datacube_ows/styles/component.py @@ -9,8 +9,7 @@ import numpy as np from xarray import DataArray, Dataset -from datacube_ows.config_utils import CFG_DICT -from datacube_ows.ogc_utils import ConfigException, FunctionWrapper +from datacube_ows.config_utils import CFG_DICT, ConfigException, FunctionWrapper from datacube_ows.styles.base import StyleDefBase # pylint: disable=abstract-method diff --git a/datacube_ows/styles/expression.py b/datacube_ows/styles/expression.py index dee244226..27a86c136 100644 --- a/datacube_ows/styles/expression.py +++ b/datacube_ows/styles/expression.py @@ -8,7 +8,7 @@ import lark from datacube.virtual.expr import formula_parser -from datacube_ows.ogc_utils import ConfigException +from datacube_ows.config_utils import ConfigException # Lark stuff. diff --git a/datacube_ows/styles/hybrid.py b/datacube_ows/styles/hybrid.py index 9020dd783..f13ff53da 100644 --- a/datacube_ows/styles/hybrid.py +++ b/datacube_ows/styles/hybrid.py @@ -7,8 +7,7 @@ from xarray import DataArray, Dataset -from datacube_ows.config_utils import CFG_DICT -from datacube_ows.ogc_utils import ConfigException +from datacube_ows.config_utils import CFG_DICT, ConfigException from datacube_ows.styles.base import StyleDefBase from datacube_ows.styles.component import ComponentStyleDef from datacube_ows.styles.ramp import ColorRampDef diff --git a/datacube_ows/styles/ramp.py b/datacube_ows/styles/ramp.py index 9363d9d85..2f9e448eb 100644 --- a/datacube_ows/styles/ramp.py +++ b/datacube_ows/styles/ramp.py @@ -24,8 +24,7 @@ from numpy import ubyte from xarray import Dataset -from datacube_ows.config_utils import CFG_DICT, OWSMetadataConfig -from datacube_ows.ogc_utils import ConfigException, FunctionWrapper +from datacube_ows.config_utils import CFG_DICT, OWSMetadataConfig, ConfigException, FunctionWrapper from datacube_ows.styles.base import StyleDefBase from datacube_ows.styles.expression import Expression diff --git a/datacube_ows/tile_matrix_sets.py b/datacube_ows/tile_matrix_sets.py index f322432fc..a6707e0f9 100644 --- a/datacube_ows/tile_matrix_sets.py +++ b/datacube_ows/tile_matrix_sets.py @@ -3,8 +3,7 @@ # # Copyright (c) 2017-2023 OWS Contributors # SPDX-License-Identifier: Apache-2.0 -from datacube_ows.config_utils import OWSConfigEntry -from datacube_ows.ogc_utils import ConfigException +from datacube_ows.config_utils import OWSConfigEntry, ConfigException # Scale denominators for WebMercator QuadTree Scale Set, starting from zoom level 0. # Currently goes to zoom level 14, where the pixel size at the equator is ~10m (i.e. Sentinel2 resolution) diff --git a/datacube_ows/utils.py b/datacube_ows/utils.py index b056b73ac..d6f28d66c 100644 --- a/datacube_ows/utils.py +++ b/datacube_ows/utils.py @@ -7,12 +7,18 @@ import logging from functools import wraps from time import monotonic -from typing import Any, Callable, List, Optional, TypeVar +from typing import Any, Callable, TypeVar, cast import pytz from numpy import datetime64 from numpy import datetime64 as npdt64 +from sqlalchemy.engine.base import Connection + +from datacube import Datacube +from datacube.api.query import GroupBy, solar_day +from datacube.model import Dataset + F = TypeVar('F', bound=Callable[..., Any]) def log_call(func: F) -> F: @@ -27,7 +33,7 @@ def log_wrapper(*args, **kwargs): _LOG = logging.getLogger() _LOG.debug("%s args: %s kwargs: %s", func.__name__, args, kwargs) return func(*args, **kwargs) - return log_wrapper + return cast(F, log_wrapper) def time_call(func: F) -> F: @@ -40,23 +46,22 @@ def time_call(func: F) -> F: For debugging or optimisation research only. Should not occur in mainline code. """ @wraps(func) - def timing_wrapper(*args, **kwargs): + def timing_wrapper(*args, **kwargs) -> Any: start: float = monotonic() result: Any = func(*args, **kwargs) stop: float = monotonic() _LOG = logging.getLogger() _LOG.debug("%s took: %d ms", func.__name__, int((stop - start) * 1000)) return result - return timing_wrapper + return cast(F, timing_wrapper) -def group_by_begin_datetime(pnames: Optional[List[str]] = None, - truncate_dates: bool = True) -> "datacube.api.query.GroupBy": +def group_by_begin_datetime(pnames: list[str] | None = None, + truncate_dates: bool = True) -> GroupBy: """ Returns an ODC GroupBy object, suitable for daily/monthly/yearly/etc statistical/summary data. (Or for sub-day time resolution data) """ - from datacube.api.query import GroupBy base_sort_key = lambda ds: ds.time.begin if pnames: index = { @@ -87,8 +92,7 @@ def group_by_begin_datetime(pnames: Optional[List[str]] = None, ) -def group_by_solar(pnames: Optional[List[str]] = None) -> "datacube.api.query.GroupBy": - from datacube.api.query import GroupBy, solar_day +def group_by_solar(pnames: list[str] | None = None) -> GroupBy: base_sort_key = lambda ds: ds.time.begin if pnames: index = { @@ -106,15 +110,14 @@ def group_by_solar(pnames: Optional[List[str]] = None) -> "datacube.api.query.Gr ) -def group_by_mosaic(pnames: Optional[List[str]] = None) -> "datacube.api.query.GroupBy": - from datacube.api.query import GroupBy, solar_day +def group_by_mosaic(pnames: list[str] | None = None) -> GroupBy: base_sort_key = lambda ds: ds.time.begin if pnames: index = { pn: i for i, pn in enumerate(pnames) } - sort_key = lambda ds: (solar_day(ds), index.get(ds.type.name), base_sort_key(ds)) + sort_key: Callable[[Dataset], tuple] = lambda ds: (solar_day(ds), index.get(ds.type.name), base_sort_key(ds)) else: sort_key = lambda ds: (solar_day(ds), base_sort_key(ds)) return GroupBy( @@ -125,7 +128,7 @@ def group_by_mosaic(pnames: Optional[List[str]] = None) -> "datacube.api.query.G ) -def get_sqlconn(dc: "datacube.Datacube") -> "sqlalchemy.engine.base.Connection": +def get_sqlconn(dc: Datacube) -> Connection: """ Extracts a SQLAlchemy database connection from a Datacube object. @@ -133,7 +136,7 @@ def get_sqlconn(dc: "datacube.Datacube") -> "sqlalchemy.engine.base.Connection": :return: A SQLAlchemy database connection object. """ # pylint: disable=protected-access - return dc.index._db._engine.connect() + return dc.index._db._engine.connect() # type: ignore[attr-defined] def find_matching_date(dt, dates) -> bool: diff --git a/datacube_ows/wcs1_utils.py b/datacube_ows/wcs1_utils.py index 7783380e9..c17b562a1 100644 --- a/datacube_ows/wcs1_utils.py +++ b/datacube_ows/wcs1_utils.py @@ -17,7 +17,7 @@ from datacube_ows.loading import DataStacker from datacube_ows.mv_index import MVSelectOpts from datacube_ows.ogc_exceptions import WCS1Exception -from datacube_ows.ogc_utils import ConfigException +from datacube_ows.config_utils import ConfigException from datacube_ows.ows_configuration import get_config from datacube_ows.resource_limits import ResourceLimited from datacube_ows.wcs_utils import get_bands_from_styles diff --git a/datacube_ows/wms_utils.py b/datacube_ows/wms_utils.py index 38b016b43..168631bcd 100644 --- a/datacube_ows/wms_utils.py +++ b/datacube_ows/wms_utils.py @@ -17,7 +17,8 @@ from rasterio.warp import Resampling from datacube_ows.ogc_exceptions import WMSException -from datacube_ows.ogc_utils import ConfigException, create_geobox +from datacube_ows.ogc_utils import create_geobox +from datacube_ows.config_utils import ConfigException from datacube_ows.ows_configuration import get_config, OWSNamedLayer from datacube_ows.resource_limits import RequestScale from datacube_ows.styles import StyleDef diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 000000000..bb3e1a2d0 --- /dev/null +++ b/mypy.ini @@ -0,0 +1,4 @@ +[mypy] +python_version = 3.10 +ignore_missing_imports = True +allow_redefinition = True diff --git a/setup.py b/setup.py index 46e27a6cd..ef5cb8cee 100644 --- a/setup.py +++ b/setup.py @@ -53,7 +53,12 @@ 'pylint', 'sphinx_click', 'pre-commit', - 'pipdeptree' + 'pipdeptree', + 'mypy', + 'types-pytz', + 'types-python-dateutil', + 'types-requests', + ] operational_requirements = [ diff --git a/tests/test_cfg_bandidx.py b/tests/test_cfg_bandidx.py index be311491c..1611be831 100644 --- a/tests/test_cfg_bandidx.py +++ b/tests/test_cfg_bandidx.py @@ -7,8 +7,7 @@ import pytest -from datacube_ows.config_utils import OWSConfigNotReady -from datacube_ows.ogc_utils import ConfigException +from datacube_ows.config_utils import OWSConfigNotReady, ConfigException from datacube_ows.ows_configuration import BandIndex diff --git a/tests/test_cfg_cache_ctrl.py b/tests/test_cfg_cache_ctrl.py index 21fa6d4b0..95d0b1462 100644 --- a/tests/test_cfg_cache_ctrl.py +++ b/tests/test_cfg_cache_ctrl.py @@ -7,7 +7,7 @@ import pytest -from datacube_ows.ogc_utils import ConfigException +from datacube_ows.config_utils import ConfigException from datacube_ows.resource_limits import CacheControlRules diff --git a/tests/test_cfg_global.py b/tests/test_cfg_global.py index 4ef73f19d..ccf0c38b7 100644 --- a/tests/test_cfg_global.py +++ b/tests/test_cfg_global.py @@ -5,7 +5,7 @@ # SPDX-License-Identifier: Apache-2.0 import pytest -from datacube_ows.ogc_utils import ConfigException +from datacube_ows.config_utils import ConfigException from datacube_ows.ows_configuration import ContactInfo, OWSConfig diff --git a/tests/test_cfg_inclusion.py b/tests/test_cfg_inclusion.py index 10feeb061..422ff65a2 100644 --- a/tests/test_cfg_inclusion.py +++ b/tests/test_cfg_inclusion.py @@ -8,8 +8,8 @@ import pytest -from datacube_ows.config_utils import get_file_loc -from datacube_ows.ows_configuration import ConfigException, read_config +from datacube_ows.config_utils import get_file_loc, ConfigException +from datacube_ows.ows_configuration import read_config src_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) if src_dir not in sys.path: diff --git a/tests/test_cfg_layer.py b/tests/test_cfg_layer.py index 87e0e951d..cfd8e58d7 100644 --- a/tests/test_cfg_layer.py +++ b/tests/test_cfg_layer.py @@ -9,7 +9,7 @@ import pytest -from datacube_ows.ogc_utils import ConfigException +from datacube_ows.config_utils import ConfigException from datacube_ows.ows_configuration import OWSFolder, OWSLayer, parse_ows_layer from datacube_ows.resource_limits import ResourceLimited diff --git a/tests/test_cfg_metadata_types.py b/tests/test_cfg_metadata_types.py index 34f62d7d1..ecf46ca66 100644 --- a/tests/test_cfg_metadata_types.py +++ b/tests/test_cfg_metadata_types.py @@ -7,7 +7,7 @@ import pytest -from datacube_ows.ogc_utils import ConfigException +from datacube_ows.config_utils import ConfigException from datacube_ows.ows_configuration import AttributionCfg, SuppURL diff --git a/tests/test_cfg_tile_matrix_set.py b/tests/test_cfg_tile_matrix_set.py index 5e641dc8e..85b010010 100644 --- a/tests/test_cfg_tile_matrix_set.py +++ b/tests/test_cfg_tile_matrix_set.py @@ -7,7 +7,7 @@ import pytest -from datacube_ows.ogc_utils import ConfigException +from datacube_ows.config_utils import ConfigException from datacube_ows.tile_matrix_sets import TileMatrixSet diff --git a/tests/test_cfg_wcs.py b/tests/test_cfg_wcs.py index 396e02bc0..fe7b2aad7 100644 --- a/tests/test_cfg_wcs.py +++ b/tests/test_cfg_wcs.py @@ -7,7 +7,7 @@ import pytest -from datacube_ows.ogc_utils import ConfigException +from datacube_ows.config_utils import ConfigException from datacube_ows.ows_configuration import WCSFormat, parse_ows_layer diff --git a/tests/test_multidate_handler.py b/tests/test_multidate_handler.py index fc6b6b3e6..ae584371e 100644 --- a/tests/test_multidate_handler.py +++ b/tests/test_multidate_handler.py @@ -8,7 +8,7 @@ import pytest import xarray as xr -from datacube_ows.ogc_utils import ConfigException +from datacube_ows.config_utils import ConfigException from datacube_ows.styles.base import StyleDefBase diff --git a/tests/test_ows_configuration.py b/tests/test_ows_configuration.py index 24373a743..5ede2d99c 100644 --- a/tests/test_ows_configuration.py +++ b/tests/test_ows_configuration.py @@ -16,14 +16,14 @@ def test_function_wrapper_lyr(): lyr = MagicMock() func_cfg = "tests.utils.a_function" - f = datacube_ows.ogc_utils.FunctionWrapper(lyr, func_cfg) + f = datacube_ows.config_utils.FunctionWrapper(lyr, func_cfg) assert f(7)[0] == "a7 b2 c3" assert f(5, c=4)[0] == "a5 b2 c4" assert f.band_mapper is None func_cfg = { "function": "tests.utils.a_function", } - f = datacube_ows.ogc_utils.FunctionWrapper(lyr, func_cfg) + f = datacube_ows.config_utils.FunctionWrapper(lyr, func_cfg) assert f(7, 8)[0] == "a7 b8 c3" func_cfg = { "function": "tests.utils.a_function", @@ -32,12 +32,12 @@ def test_function_wrapper_lyr(): "c": "ouple" } } - f = datacube_ows.ogc_utils.FunctionWrapper(lyr, func_cfg) + f = datacube_ows.config_utils.FunctionWrapper(lyr, func_cfg) result = f("pple", "eagle") assert result[0] == "apple beagle couple" assert result[1]["foo"] == "bar" assert f.band_mapper is None - f = datacube_ows.ogc_utils.FunctionWrapper(lyr, func_cfg) + f = datacube_ows.config_utils.FunctionWrapper(lyr, func_cfg) result = f(a="pple", b="eagle") assert result[0] == "apple beagle couple" assert result[1]["foo"] == "bar" @@ -46,11 +46,11 @@ def test_function_wrapper_lyr(): "function": "tests.utils.a_function", "args": ["bar", "ouple"] } - f = datacube_ows.ogc_utils.FunctionWrapper(lyr, func_cfg) + f = datacube_ows.config_utils.FunctionWrapper(lyr, func_cfg) result = f("pple") assert result[0] == "apple bbar couple" assert f.band_mapper is None - f = datacube_ows.ogc_utils.FunctionWrapper(lyr, func_cfg) + f = datacube_ows.config_utils.FunctionWrapper(lyr, func_cfg) result = f() assert result[0] == "abar bouple c3" assert f.band_mapper is None @@ -59,25 +59,25 @@ def test_function_wrapper_lyr(): "args": ["bar", "ouple"] } with pytest.raises(datacube_ows.config_utils.ConfigException) as e: - f = datacube_ows.ogc_utils.FunctionWrapper(lyr, func_cfg) + f = datacube_ows.config_utils.FunctionWrapper(lyr, func_cfg) assert "Could not import python object" in str(e.value) assert "so_fake.not_real.not_a_function" in str(e.value) def test_func_naked(): lyr = MagicMock() with pytest.raises(datacube_ows.config_utils.ConfigException) as e: - f = datacube_ows.ogc_utils.FunctionWrapper(lyr, { + f = datacube_ows.config_utils.FunctionWrapper(lyr, { "function": a_function, }) assert str("Directly including callable objects in configuration is no longer supported.") with pytest.raises(datacube_ows.config_utils.ConfigException) as e: - f = datacube_ows.ogc_utils.FunctionWrapper(lyr, a_function) + f = datacube_ows.config_utils.FunctionWrapper(lyr, a_function) assert str("Directly including callable objects in configuration is no longer supported.") - f = datacube_ows.ogc_utils.FunctionWrapper(lyr, { + f = datacube_ows.config_utils.FunctionWrapper(lyr, { "function": a_function, }, stand_alone=True) assert f("ardvark", "bllbbll")[0] == "aardvark bbllbbll c3" - f = datacube_ows.ogc_utils.FunctionWrapper(lyr, a_function, stand_alone=True) + f = datacube_ows.config_utils.FunctionWrapper(lyr, a_function, stand_alone=True) assert f("ardvark", "bllbbll")[0] == "aardvark bbllbbll c3" diff --git a/tests/test_style_api.py b/tests/test_style_api.py index 461eb9384..c950d453c 100644 --- a/tests/test_style_api.py +++ b/tests/test_style_api.py @@ -7,7 +7,7 @@ import pytest -from datacube_ows.ogc_utils import ConfigException +from datacube_ows.config_utils import ConfigException from datacube_ows.styles.api import ( # noqa: F401 isort:skip StandaloneStyle, apply_ows_style, diff --git a/tests/test_styles.py b/tests/test_styles.py index 409523d71..136772284 100644 --- a/tests/test_styles.py +++ b/tests/test_styles.py @@ -11,8 +11,7 @@ from xarray import DataArray, Dataset, concat import datacube_ows.styles -from datacube_ows.config_utils import OWSEntryNotFound -from datacube_ows.ogc_utils import ConfigException +from datacube_ows.config_utils import OWSEntryNotFound, ConfigException from datacube_ows.ows_configuration import BandIndex, OWSProductLayer