Skip to content

API changes: particlefile.py and other touchups #1727

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Oct 22, 2024
13 changes: 0 additions & 13 deletions parcels/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,19 +279,6 @@ class CStructuredGrid(Structure):
)
return self._cstruct

def lon_grid_to_target(self):
if self.lon_remapping:
self._lon = self.lon_remapping.to_target(self.lon)

def lon_grid_to_source(self):
if self.lon_remapping:
self._lon = self.lon_remapping.to_source(self.lon)

def lon_particle_to_target(self, lon):
if self.lon_remapping:
return self.lon_remapping.particle_to_target(lon)
return lon

@deprecated_made_private # TODO: Remove 6 months after v3.1.0
def check_zonal_periodic(self, *args, **kwargs):
return self._check_zonal_periodic(*args, **kwargs)
Expand Down
2 changes: 1 addition & 1 deletion parcels/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,7 @@ def check_fieldsets_in_kernels(self, pyfunc):
)
elif pyfunc is AdvectionAnalytical:
if self.fieldset.particlefile is not None:
self.fieldset.particlefile.analytical = True
self.fieldset.particlefile._is_analytical = True
if self._ptype.uses_jit:
raise NotImplementedError("Analytical Advection only works in Scipy mode")
if self._fieldset.U.interp_method != "cgrid_velocity":
Expand Down
6 changes: 5 additions & 1 deletion parcels/particle.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,15 @@ class Variable:
"""

def __init__(self, name, dtype=np.float32, initial=0, to_write: bool | Literal["once"] = True):
self.name = name
self._name = name
self.dtype = dtype
self.initial = initial
self.to_write = to_write

@property
def name(self):
return self._name

def __get__(self, instance, cls):
if instance is None:
return self
Expand Down
8 changes: 6 additions & 2 deletions parcels/particledata.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import numpy as np

from parcels._compat import MPI, KMeans
from parcels.tools._helpers import deprecated
from parcels.tools.statuscodes import StatusCode


Expand Down Expand Up @@ -228,12 +229,15 @@
"""Return the length, in terms of 'number of elements, of a ParticleData instance."""
return self._ncount

@deprecated(
"Use iter(...) instead, or just use the object in an iterator context (e.g. for p in particledata: ...)."
) # TODO: Remove 6 months after v3.1.0 (or 9 months; doesn't contribute to code debt)
def iterator(self):
return ParticleDataIterator(self)
return iter(self)

Check warning on line 236 in parcels/particledata.py

View check run for this annotation

Codecov / codecov/patch

parcels/particledata.py#L236

Added line #L236 was not covered by tests

def __iter__(self):
"""Return an Iterator that allows for forward iteration over the elements in the ParticleData (e.g. `for p in pset:`)."""
return self.iterator()
return ParticleDataIterator(self)

def __getitem__(self, index):
"""Get a particle object from the ParticleData instance based on its index."""
Expand Down
152 changes: 106 additions & 46 deletions parcels/particlefile.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import parcels
from parcels._compat import MPI
from parcels.tools._helpers import deprecated, deprecated_made_private
from parcels.tools.warnings import FileWarning

__all__ = ["ParticleFile"]
Expand Down Expand Up @@ -46,31 +47,24 @@
ParticleFile object that can be used to write particle data to file
"""

outputdt = None
particleset = None
parcels_mesh = None
time_origin = None
lonlatdepth_dtype = None

def __init__(self, name, particleset, outputdt=np.inf, chunks=None, create_new_zarrfile=True):
self.outputdt = outputdt.total_seconds() if isinstance(outputdt, timedelta) else outputdt
self.chunks = chunks
self.particleset = particleset
self.parcels_mesh = "spherical"
self._outputdt = outputdt.total_seconds() if isinstance(outputdt, timedelta) else outputdt
self._chunks = chunks
self._particleset = particleset
self._parcels_mesh = "spherical"
if self.particleset.fieldset is not None:
self.parcels_mesh = self.particleset.fieldset.gridset.grids[0].mesh
self.time_origin = self.particleset.time_origin
self._parcels_mesh = self.particleset.fieldset.gridset.grids[0].mesh
self.lonlatdepth_dtype = self.particleset.particledata.lonlatdepth_dtype
self.maxids = 0
self.pids_written = {}
self.create_new_zarrfile = create_new_zarrfile
self.vars_to_write = {}
self._maxids = 0
self._pids_written = {}
self._create_new_zarrfile = create_new_zarrfile
self._vars_to_write = {}
for var in self.particleset.particledata.ptype.variables:
if var.to_write:
self.vars_to_write[var.name] = var.dtype
self.mpi_rank = MPI.COMM_WORLD.Get_rank() if MPI else 0
self._mpi_rank = MPI.COMM_WORLD.Get_rank() if MPI else 0
self.particleset.fieldset._particlefile = self
self.analytical = False # Flag to indicate if ParticleFile is used for analytical trajectories
self._is_analytical = False # Flag to indicate if ParticleFile is used for analytical trajectories

# Reset obs_written of each particle, in case new ParticleFile created for a ParticleSet
particleset.particledata.setallvardata("obs_written", 0)
Expand All @@ -80,11 +74,11 @@
"Conventions": "CF-1.6/CF-1.7",
"ncei_template_version": "NCEI_NetCDF_Trajectory_Template_v2.0",
"parcels_version": parcels.__version__,
"parcels_mesh": self.parcels_mesh,
"parcels_mesh": self._parcels_mesh,

Check warning on line 77 in parcels/particlefile.py

View check run for this annotation

Codecov / codecov/patch

parcels/particlefile.py#L77

Added line #L77 was not covered by tests
}

# Create dictionary to translate datatypes and fill_values
self.fill_value_map = {
self._fill_value_map = {
np.float16: np.nan,
np.float32: np.nan,
np.float64: np.nan,
Expand All @@ -103,23 +97,82 @@
# But we need to handle incompatibility with MPI mode for now:
if MPI and MPI.COMM_WORLD.Get_size() > 1:
raise ValueError("Currently, MPI mode is not compatible with directly passing a Zarr store.")
self.fname = name
fname = name
else:
extension = os.path.splitext(str(name))[1]
if extension in [".nc", ".nc4"]:
raise RuntimeError(
"Output in NetCDF is not supported anymore. Use .zarr extension for ParticleFile name."
)
if MPI and MPI.COMM_WORLD.Get_size() > 1:
self.fname = os.path.join(name, f"proc{self.mpi_rank:02d}.zarr")
fname = os.path.join(name, f"proc{self._mpi_rank:02d}.zarr")

Check warning on line 108 in parcels/particlefile.py

View check run for this annotation

Codecov / codecov/patch

parcels/particlefile.py#L108

Added line #L108 was not covered by tests
if extension in [".zarr"]:
warnings.warn(
f"The ParticleFile name contains .zarr extension, but zarr files will be written per processor in MPI mode at {self.fname}",
f"The ParticleFile name contains .zarr extension, but zarr files will be written per processor in MPI mode at {fname}",

Check warning on line 111 in parcels/particlefile.py

View check run for this annotation

Codecov / codecov/patch

parcels/particlefile.py#L111

Added line #L111 was not covered by tests
FileWarning,
stacklevel=2,
)
else:
self.fname = name if extension in [".zarr"] else f"{name}.zarr"
fname = name if extension in [".zarr"] else f"{name}.zarr"
self._fname = fname

@property
def create_new_zarrfile(self):
return self._create_new_zarrfile

@property
def outputdt(self):
return self._outputdt

@property
def chunks(self):
return self._chunks

@property
def particleset(self):
return self._particleset

@property
def fname(self):
return self._fname

@property
def vars_to_write(self):
return self._vars_to_write

@property
def time_origin(self):
return self.particleset.time_origin

@property
@deprecated_made_private # TODO: Remove 6 months after v3.1.0
def parcels_mesh(self):
return self._parcels_mesh

Check warning on line 150 in parcels/particlefile.py

View check run for this annotation

Codecov / codecov/patch

parcels/particlefile.py#L150

Added line #L150 was not covered by tests

@property
@deprecated_made_private # TODO: Remove 6 months after v3.1.0
def maxids(self):
return self._maxids

Check warning on line 155 in parcels/particlefile.py

View check run for this annotation

Codecov / codecov/patch

parcels/particlefile.py#L155

Added line #L155 was not covered by tests

@property
@deprecated_made_private # TODO: Remove 6 months after v3.1.0
def pids_written(self):
return self._pids_written

Check warning on line 160 in parcels/particlefile.py

View check run for this annotation

Codecov / codecov/patch

parcels/particlefile.py#L160

Added line #L160 was not covered by tests

@property
@deprecated_made_private # TODO: Remove 6 months after v3.1.0
def mpi_rank(self):
return self._mpi_rank

Check warning on line 165 in parcels/particlefile.py

View check run for this annotation

Codecov / codecov/patch

parcels/particlefile.py#L165

Added line #L165 was not covered by tests

@property
@deprecated_made_private # TODO: Remove 6 months after v3.1.0
def fill_value_map(self):
return self._fill_value_map

Check warning on line 170 in parcels/particlefile.py

View check run for this annotation

Codecov / codecov/patch

parcels/particlefile.py#L170

Added line #L170 was not covered by tests

@property
@deprecated_made_private # TODO: Remove 6 months after v3.1.0
def analytical(self):
return self._is_analytical

Check warning on line 175 in parcels/particlefile.py

View check run for this annotation

Codecov / codecov/patch

parcels/particlefile.py#L175

Added line #L175 was not covered by tests

def _create_variables_attribute_dict(self):
"""Creates the dictionary with variable attributes.
Expand All @@ -133,7 +186,7 @@
"trajectory": {
"long_name": "Unique identifier for each particle",
"cf_role": "trajectory_id",
"_FillValue": self.fill_value_map[np.int64],
"_FillValue": self._fill_value_map[np.int64],

Check warning on line 189 in parcels/particlefile.py

View check run for this annotation

Codecov / codecov/patch

parcels/particlefile.py#L189

Added line #L189 was not covered by tests
},
"time": {"long_name": "", "standard_name": "time", "units": "seconds", "axis": "T"},
"lon": {"long_name": "", "standard_name": "longitude", "units": "degrees_east", "axis": "X"},
Expand All @@ -147,14 +200,17 @@
for vname in self.vars_to_write:
if vname not in ["time", "lat", "lon", "depth", "id"]:
attrs[vname] = {
"_FillValue": self.fill_value_map[self.vars_to_write[vname]],
"_FillValue": self._fill_value_map[self.vars_to_write[vname]],

Check warning on line 203 in parcels/particlefile.py

View check run for this annotation

Codecov / codecov/patch

parcels/particlefile.py#L203

Added line #L203 was not covered by tests
"long_name": "",
"standard_name": vname,
"units": "unknown",
}

return attrs

@deprecated(
"ParticleFile.metadata is a dictionary. Use `ParticleFile.metadata['key'] = ...` or other dictionary methods instead."
) # TODO: Remove 6 months after v3.1.0
def add_metadata(self, name, message):
"""Add metadata to :class:`parcels.particleset.ParticleSet`.

Expand All @@ -175,21 +231,25 @@
else:
return var

def write_once(self, var):
@deprecated_made_private # TODO: Remove 6 months after v3.1.0
def write_once(self, *args, **kwargs):
return self._write_once(*args, **kwargs)

Check warning on line 236 in parcels/particlefile.py

View check run for this annotation

Codecov / codecov/patch

parcels/particlefile.py#L236

Added line #L236 was not covered by tests

def _write_once(self, var):
return self.particleset.particledata.ptype[var].to_write == "once"

def _extend_zarr_dims(self, Z, store, dtype, axis):
if axis == 1:
a = np.full((Z.shape[0], self.chunks[1]), self.fill_value_map[dtype], dtype=dtype)
a = np.full((Z.shape[0], self.chunks[1]), self._fill_value_map[dtype], dtype=dtype)
obs = zarr.group(store=store, overwrite=False)["obs"]
if len(obs) == Z.shape[1]:
obs.append(np.arange(self.chunks[1]) + obs[-1] + 1)
else:
extra_trajs = self.maxids - Z.shape[0]
extra_trajs = self._maxids - Z.shape[0]
if len(Z.shape) == 2:
a = np.full((extra_trajs, Z.shape[1]), self.fill_value_map[dtype], dtype=dtype)
a = np.full((extra_trajs, Z.shape[1]), self._fill_value_map[dtype], dtype=dtype)
else:
a = np.full((extra_trajs,), self.fill_value_map[dtype], dtype=dtype)
a = np.full((extra_trajs,), self._fill_value_map[dtype], dtype=dtype)
Z.append(a, axis=axis)
zarr.consolidate_metadata(store)

Expand Down Expand Up @@ -221,11 +281,11 @@

if len(indices_to_write) > 0:
pids = pset.particledata.getvardata("id", indices_to_write)
to_add = sorted(set(pids) - set(self.pids_written.keys()))
to_add = sorted(set(pids) - set(self._pids_written.keys()))
for i, pid in enumerate(to_add):
self.pids_written[pid] = self.maxids + i
ids = np.array([self.pids_written[p] for p in pids], dtype=int)
self.maxids = len(self.pids_written)
self._pids_written[pid] = self._maxids + i
ids = np.array([self._pids_written[p] for p in pids], dtype=int)
self._maxids = len(self._pids_written)

once_ids = np.where(pset.particledata.getvardata("obs_written", indices_to_write) == 0)[0]
if len(once_ids) > 0:
Expand All @@ -234,7 +294,7 @@

if self.create_new_zarrfile:
if self.chunks is None:
self.chunks = (len(ids), 1)
self._chunks = (len(ids), 1)
if pset._repeatpclass is not None and self.chunks[0] < 1e4:
warnings.warn(
f"ParticleFile chunks are set to {self.chunks}, but this may lead to "
Expand All @@ -243,37 +303,37 @@
FileWarning,
stacklevel=2,
)
if (self.maxids > len(ids)) or (self.maxids > self.chunks[0]):
arrsize = (self.maxids, self.chunks[1])
if (self._maxids > len(ids)) or (self._maxids > self.chunks[0]):
arrsize = (self._maxids, self.chunks[1])

Check warning on line 307 in parcels/particlefile.py

View check run for this annotation

Codecov / codecov/patch

parcels/particlefile.py#L307

Added line #L307 was not covered by tests
else:
arrsize = (len(ids), self.chunks[1])
ds = xr.Dataset(
attrs=self.metadata,
coords={"trajectory": ("trajectory", pids), "obs": ("obs", np.arange(arrsize[1], dtype=np.int32))},
)
attrs = self._create_variables_attribute_dict()
obs = np.zeros((self.maxids), dtype=np.int32)
obs = np.zeros((self._maxids), dtype=np.int32)
for var in self.vars_to_write:
varout = self._convert_varout_name(var)
if varout not in ["trajectory"]: # because 'trajectory' is written as coordinate
if self.write_once(var):
if self._write_once(var):
data = np.full(
(arrsize[0],),
self.fill_value_map[self.vars_to_write[var]],
self._fill_value_map[self.vars_to_write[var]],

Check warning on line 322 in parcels/particlefile.py

View check run for this annotation

Codecov / codecov/patch

parcels/particlefile.py#L322

Added line #L322 was not covered by tests
dtype=self.vars_to_write[var],
)
data[ids_once] = pset.particledata.getvardata(var, indices_to_write_once)
dims = ["trajectory"]
else:
data = np.full(
arrsize, self.fill_value_map[self.vars_to_write[var]], dtype=self.vars_to_write[var]
arrsize, self._fill_value_map[self.vars_to_write[var]], dtype=self.vars_to_write[var]

Check warning on line 329 in parcels/particlefile.py

View check run for this annotation

Codecov / codecov/patch

parcels/particlefile.py#L329

Added line #L329 was not covered by tests
)
data[ids, 0] = pset.particledata.getvardata(var, indices_to_write)
dims = ["trajectory", "obs"]
ds[varout] = xr.DataArray(data=data, dims=dims, attrs=attrs[varout])
ds[varout].encoding["chunks"] = self.chunks[0] if self.write_once(var) else self.chunks
ds[varout].encoding["chunks"] = self.chunks[0] if self._write_once(var) else self.chunks
ds.to_zarr(self.fname, mode="w")
self.create_new_zarrfile = False
self._create_new_zarrfile = False
else:
# Either use the store that was provided directly or create a DirectoryStore:
if issubclass(type(self.fname), zarr.storage.Store):
Expand All @@ -284,9 +344,9 @@
obs = pset.particledata.getvardata("obs_written", indices_to_write)
for var in self.vars_to_write:
varout = self._convert_varout_name(var)
if self.maxids > Z[varout].shape[0]:
if self._maxids > Z[varout].shape[0]:
self._extend_zarr_dims(Z[varout], store, dtype=self.vars_to_write[var], axis=0)
if self.write_once(var):
if self._write_once(var):
if len(once_ids) > 0:
Z[varout].vindex[ids_once] = pset.particledata.getvardata(var, indices_to_write_once)
else:
Expand Down
Loading
Loading