Skip to content

Commit 8733c4c

Browse files
authored
Merge pull request #810 from DHI/generic_concat_average
support average in generic concat
2 parents 4c0d5d9 + a75371e commit 8733c4c

File tree

2 files changed

+96
-5
lines changed

2 files changed

+96
-5
lines changed

mikeio/generic.py

Lines changed: 53 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -417,10 +417,7 @@ def concat(
417417
if keep == "last":
418418
if i < (len(infilenames) - 1):
419419
dfs_n = DfsFileFactory.DfsGenericOpen(str(infilenames[i + 1]))
420-
nf = dfs_n.FileInfo.TimeAxis.StartDateTime
421-
next_start_time = datetime(
422-
nf.year, nf.month, nf.day, nf.hour, nf.minute, nf.second
423-
)
420+
next_start_time = dfs_n.FileInfo.TimeAxis.StartDateTime
424421
dfs_n.Close()
425422

426423
for timestep in range(n_time_steps):
@@ -478,6 +475,58 @@ def concat(
478475
) # get end time from current file
479476
dfs_i.Close()
480477

478+
elif keep == "average":
479+
ALPHA = 0.5 # averaging factor
480+
last_file = i == (len(infilenames) - 1)
481+
overlapping_with_next = False # lets first asume no overlap
482+
483+
# Find the start time of next file
484+
if not last_file:
485+
dfs_n = DfsFileFactory.DfsGenericOpen(str(infilenames[i + 1]))
486+
next_start_time = dfs_n.FileInfo.TimeAxis.StartDateTime
487+
else:
488+
next_start_time = datetime.max # end of time ...
489+
490+
if i == 0:
491+
timestep_n = 0 # have not read anything before
492+
493+
# lets start where we left off (if last file overlapped)
494+
timestep = timestep_n
495+
while timestep < n_time_steps:
496+
current_time = start_time + timedelta(seconds=timestep * dt)
497+
if current_time >= next_start_time: # false if last file
498+
overlapping_with_next = True
499+
break
500+
for item in range(n_items):
501+
itemdata = dfs_i.ReadItemTimeStep(item + 1, timestep)
502+
d = itemdata.Data
503+
darray = d.astype(np.float32)
504+
dfs_o.WriteItemTimeStepNext(0, darray)
505+
506+
timestep += 1
507+
508+
timestep_n = 0 # have not read anything from next file yet
509+
510+
if not overlapping_with_next:
511+
dfs_n.Close()
512+
continue # next file
513+
514+
# Otherwhise read overlapping part
515+
while timestep < n_time_steps:
516+
for item in range(n_items):
517+
itemdata_i = dfs_i.ReadItemTimeStep(item + 1, timestep)
518+
itemdata_n = dfs_n.ReadItemTimeStep(item + 1, timestep_n)
519+
d_i = itemdata_i.Data
520+
d_n = itemdata_n.Data
521+
d_av = d_i * ALPHA + d_n * (1 - ALPHA)
522+
darray = d_av.astype(np.float32)
523+
dfs_o.WriteItemTimeStepNext(0, (darray))
524+
timestep += 1
525+
timestep_n += 1
526+
527+
# Close next file before opening it again
528+
dfs_n.Close()
529+
481530
dfs_o.Close()
482531

483532

tests/test_generic.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,7 @@ def test_concat_keep(tmp_path: Path) -> None:
268268
test keep arguments of concatenation function
269269
"""
270270
# added keep arguments to test
271-
keep_args = ["first", "last"]
271+
keep_args = ["first", "last", "average"]
272272

273273
infiles = [
274274
"tests/testdata/tide1.dfs1",
@@ -317,11 +317,53 @@ def test_concat_keep(tmp_path: Path) -> None:
317317
.all()
318318
.all()
319319
)
320+
av_out = (
321+
(
322+
0.5 * (df_first.loc[overlap_idx] + df_last.loc[overlap_idx])
323+
== df_o.loc[overlap_idx]
324+
)
325+
.eq(True)
326+
.all()
327+
.all()
328+
)
320329

321330
if keep_arg == "first":
322331
assert first_out, "overlap should be with first dataset"
323332
elif keep_arg == "last":
324333
assert last_out, "overlap should be with last dataset"
334+
elif keep_arg == "average":
335+
assert av_out, "overlap should be average of datasets"
336+
337+
338+
def test_concat_average(tmp_path: Path) -> None:
339+
# Test for multiple items?
340+
g = mikeio.Grid1D(x=range(5))
341+
t = pd.date_range(start="2020-01-01", periods=5, freq="D")
342+
d = np.zeros((5, 5))
343+
# x x x o o
344+
# o o x o o
345+
# o o x x x
346+
da_1 = mikeio.DataArray(data=d, time=t, geometry=g)
347+
da_2 = mikeio.DataArray(data=d + 1, time=t + pd.DateOffset(days=3), geometry=g)
348+
da_3 = mikeio.DataArray(data=d + 2, time=t + pd.DateOffset(days=6), geometry=g)
349+
350+
files = [tmp_path / "test1.dfs1", tmp_path / "test2.dfs1", tmp_path / "test3.dfs1"]
351+
da_1.to_dfs(files[0])
352+
da_2.to_dfs(files[1])
353+
da_3.to_dfs(files[2])
354+
355+
# concat
356+
fp = tmp_path / "concat.dfs1"
357+
358+
mikeio.generic.concat(files, fp, keep="average")
359+
ds = mikeio.read(fp)
360+
da_x0 = ds[0].isel(x=0)
361+
362+
assert np.allclose(
363+
da_x0.values,
364+
np.array([0.0, 0.0, 0.0, 0.5, 0.5, 1.0, 1.5, 1.5, 2.0, 2.0, 2.0]),
365+
atol=1e-6,
366+
)
325367

326368

327369
def test_concat_non_equidistant_dfs0(tmp_path: Path) -> None:

0 commit comments

Comments
 (0)