Skip to content

Commit fa6466b

Browse files
Takaya Uchidarabernat
authored andcommitted
Added Linear detrend function for data that has more than one dimension (xgcm#15)
* Added Linear detrend function for data that has more than one dimension * Travis needs to install scipy * Updated the requirements for xrft * Added the shift argument for the test function of cross spectrum of dsar * Let's see if this works * updated appveyor.yml * I think I have something working now. Added a wrapper function for detrending. After detrending though, the array no longer becomes a dask array * Added one more test for Parseval * Cleaned up the reshape argument in `detrend` * Added test to improve codecov * Added another test to improve codecov
1 parent 73dffbf commit fa6466b

File tree

5 files changed

+304
-88
lines changed

5 files changed

+304
-88
lines changed

.travis.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ python:
2020
- 3.6
2121

2222
env:
23-
- CONDA_DEPS="pip flake8 pytest coverage numpy pandas xarray dask" PIP_DEPS="codecov pytest-cov"
23+
- CONDA_DEPS="pip flake8 pytest coverage numpy scipy pandas xarray dask" PIP_DEPS="codecov pytest-cov"
2424

2525
before_install:
2626
- if [[ "$TRAVIS_PYTHON_VERSION" == "2.7" ]]; then

appveyor.yml

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,21 +7,11 @@ environment:
77
PYTHON_ARCH: "32"
88
MINICONDA: C:\Miniconda
99

10-
- PYTHON: "C:\\Python33"
11-
PYTHON_VERSION: "3.3.5"
10+
- PYTHON: "C:\\Python36"
11+
PYTHON_VERSION: "3.6.1"
1212
PYTHON_ARCH: "32"
1313
MINICONDA: C:\Miniconda3
1414

15-
- PYTHON: "C:\\Python34"
16-
PYTHON_VERSION: "3.4.1"
17-
PYTHON_ARCH: "32"
18-
MINICONDA: C:\Miniconda3
19-
20-
- PYTHON: "C:\\Python35"
21-
PYTHON_VERSION: "3.5.1"
22-
PYTHON_ARCH: "32"
23-
MINICONDA: C:\Miniconda35
24-
2515
init:
2616
- "ECHO %PYTHON% %PYTHON_VERSION% %PYTHON_ARCH% %MINICONDA%"
2717

xrft/tests/test_xrft.py

Lines changed: 175 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
import pandas as pd
33
import xarray as xr
44
import numpy.testing as npt
5+
import scipy.signal as sps
6+
import scipy.linalg as spl
57
import pytest
68
import xrft
79

@@ -15,6 +17,96 @@ def sample_data_1d():
1517
"""Create one dimensional test DataArray."""
1618
pass
1719

20+
def numpy_detrend(da):
21+
"""
22+
Detrend a 2D field by subtracting out the least-square plane fit.
23+
24+
Parameters
25+
----------
26+
da : `numpy.array`
27+
The data to be detrended
28+
29+
Returns
30+
-------
31+
da : `numpy.array`
32+
The detrended input data
33+
"""
34+
N = da.shape
35+
36+
G = np.ones((N[0]*N[1],3))
37+
for i in range(N[0]):
38+
G[N[1]*i:N[1]*i+N[1], 1] = i+1
39+
G[N[1]*i:N[1]*i+N[1], 2] = np.arange(1, N[1]+1)
40+
41+
d_obs = np.reshape(da.copy(), (N[0]*N[1],1))
42+
m_est = np.dot(np.dot(spl.inv(np.dot(G.T, G)), G.T), d_obs)
43+
d_est = np.dot(G, m_est)
44+
45+
lin_trend = np.reshape(d_est, N)
46+
47+
return da - lin_trend
48+
49+
def test_detrend_1d():
50+
N = 16
51+
x = np.arange(N+1)
52+
y = np.arange(N-1)
53+
t = np.linspace(-int(N/2), int(N/2), N-6)
54+
z = np.arange(int(N/2))
55+
d4d = (t[:,np.newaxis,np.newaxis,np.newaxis]
56+
+ z[np.newaxis,:,np.newaxis,np.newaxis]
57+
+ y[np.newaxis,np.newaxis,:,np.newaxis]
58+
+ x[np.newaxis,np.newaxis,np.newaxis,:]
59+
)
60+
da4d = xr.DataArray(d4d, dims=['time','z','y','x'],
61+
coords={'time':range(len(t)),'z':range(len(z)),'y':range(len(y)),
62+
'x':range(len(x))}
63+
)
64+
func = xrft._detrend_wrap(xrft.detrend2)
65+
da = da4d.chunk({'time': 1})
66+
da_prime = func(da.data, axes=[2]).compute()
67+
npt.assert_allclose(da_prime[0,0], sps.detrend(d4d[0,0], axis=0))
68+
69+
def test_detrend_2d():
70+
N = 16
71+
x = np.arange(N+1)
72+
y = np.arange(N-1)
73+
t = np.linspace(-int(N/2), int(N/2), N-6)
74+
z = np.arange(int(N/2))
75+
d4d = (t[:,np.newaxis,np.newaxis,np.newaxis]
76+
+ z[np.newaxis,:,np.newaxis,np.newaxis]
77+
+ y[np.newaxis,np.newaxis,:,np.newaxis]
78+
+ x[np.newaxis,np.newaxis,np.newaxis,:]
79+
)
80+
da4d = xr.DataArray(d4d, dims=['time','z','y','x'],
81+
coords={'time':range(len(t)),'z':range(len(z)),
82+
'y':range(len(y)),'x':range(len(x))}
83+
)
84+
func = xrft._detrend_wrap(xrft.detrend2)
85+
da = da4d.chunk({'time':1})
86+
with pytest.raises(ValueError):
87+
func(da.data, axes=[0]).compute()
88+
da = da4d.chunk({'time':1, 'z':1})
89+
with pytest.raises(ValueError):
90+
func(da.data, axes=[1,2]).compute()
91+
with pytest.raises(ValueError):
92+
func(da.data, axes=[2,2]).compute()
93+
da_prime = func(da.data, axes=[2,3]).compute()
94+
npt.assert_allclose(da_prime[0,0], numpy_detrend(d4d[0,0]))
95+
96+
s = np.arange(2)
97+
d5d = d4d[np.newaxis,:,:,:,:] + s[:,np.newaxis,np.newaxis,
98+
np.newaxis,np.newaxis]
99+
da5d = xr.DataArray(d5d, dims=['s','time','z','y','x'],
100+
coords={'s':range(len(s)),'time':range(len(t)),
101+
'z':range(len(z)),'y':range(len(y)),
102+
'x':range(len(x))}
103+
)
104+
da = da5d.chunk({'time':1})
105+
with pytest.raises(ValueError):
106+
func(da.data).compute()
107+
with pytest.raises(ValueError):
108+
func(da.data, axes=[2,3,4]).compute()
109+
18110
def test_dft_1d():
19111
"""Test the discrete Fourier transform function on one-dimensional data."""
20112
Nx = 16
@@ -24,7 +116,7 @@ def test_dft_1d():
24116
da = xr.DataArray(np.random.rand(Nx), coords=[x], dims=['x'])
25117

26118
# defaults with no keyword args
27-
ft = xrft.dft(da)
119+
ft = xrft.dft(da, detrend='constant')
28120
# check that the frequency dimension was created properly
29121
assert ft.dims == ('freq_x',)
30122
# check that the coords are correct
@@ -40,10 +132,16 @@ def test_dft_1d():
40132
npt.assert_allclose(ft_data_expected, ft.values, atol=1e-14)
41133

42134
# redo without removing mean
43-
ft = xrft.dft(da, remove_mean=False)
135+
ft = xrft.dft(da)
44136
ft_data_expected = np.fft.fftshift(np.fft.fft(da))
45137
npt.assert_allclose(ft_data_expected, ft.values)
46138

139+
# redo with detrending linear least-square fit
140+
ft = xrft.dft(da, detrend='linear')
141+
da_prime = sps.detrend(da.values)
142+
ft_data_expected = np.fft.fftshift(np.fft.fft(da_prime))
143+
npt.assert_allclose(ft_data_expected, ft.values, atol=1e-14)
144+
47145
# modify data to be non-evenly spaced
48146
da2 = da.copy()
49147
da2[-1] = np.nan
@@ -73,61 +171,88 @@ def test_dft_2d():
73171
da = xr.DataArray(np.random.rand(N,N), dims=['x','y'],
74172
coords={'x':range(N),'y':range(N)}
75173
)
76-
ft = xrft.dft(da, shift=False, remove_mean=False)
174+
ft = xrft.dft(da, shift=False)
77175
npt.assert_almost_equal(ft.values, np.fft.fftn(da.values))
78176

79-
ft = xrft.dft(da, shift=False, window=True)
177+
ft = xrft.dft(da, shift=False, window=True, detrend='constant')
80178
dim = da.dims
81179
window = np.hanning(N) * np.hanning(N)[:, np.newaxis]
82180
da_prime = (da - da.mean(dim=dim)).values
83181
npt.assert_almost_equal(ft.values, np.fft.fftn(da_prime*window))
84182

183+
with pytest.raises(ValueError):
184+
xrft.dft(da, shift=False, window=True, detrend='linear')
185+
85186
def test_dft_3d_dask():
86187
"""Test the discrete Fourier transform on 3D dask array data"""
87188
N=16
88-
da = xr.DataArray(np.random.rand(2,N,N), dims=['time','x','y'],
89-
coords={'time':range(2),'x':range(N),
90-
'y':range(N)}).chunk({'time': 1}
189+
da = xr.DataArray(np.random.rand(N,N,N), dims=['time','x','y'],
190+
coords={'time':range(N),'x':range(N),
191+
'y':range(N)}
91192
)
92-
daft = xrft.dft(da, dim=['x','y'], shift=False, remove_mean=False)
93-
assert hasattr(daft.data, 'dask')
94-
npt.assert_almost_equal(daft.values, np.fft.fftn(da.values, axes=[1,2]))
193+
daft = xrft.dft(da.chunk({'time': 1}), dim=['x','y'], shift=False)
194+
# assert hasattr(daft.data, 'dask')
195+
npt.assert_almost_equal(daft.values,
196+
np.fft.fftn(da.chunk({'time': 1}).values, axes=[1,2])
197+
)
95198

96199
with pytest.raises(ValueError):
97200
xrft.dft(da.chunk({'time': 1, 'x': 1}), dim=['x'])
98201

202+
daft = xrft.dft(da.chunk({'x': 1}), dim=['time'],
203+
shift=False, detrend='linear')
204+
# assert hasattr(daft.data, 'dask')
205+
da_prime = sps.detrend(da.chunk({'x': 1}), axis=0)
206+
npt.assert_almost_equal(daft.values,
207+
np.fft.fftn(da_prime, axes=[0])
208+
)
209+
99210
def test_power_spectrum():
100211
"""Test the power spectrum function"""
101212
N = 16
102213
da = xr.DataArray(np.random.rand(N,N), dims=['x','y'],
103214
coords={'x':range(N),'y':range(N)}
104215
)
105-
ps = xrft.power_spectrum(da, window=True, density=False)
216+
ps = xrft.power_spectrum(da, window=True, density=False,
217+
detrend='constant')
106218
daft = xrft.dft(da,
107-
dim=None, shift=True, remove_mean=True,
219+
dim=None, shift=True, detrend='constant',
108220
window=True)
109221
npt.assert_almost_equal(ps.values, np.real(daft*np.conj(daft)))
110222
npt.assert_almost_equal(np.ma.masked_invalid(ps).mask.sum(), 0.)
111223

112224
### Normalized
113225
dim = da.dims
114-
ps = xrft.power_spectrum(da, window=True)
115-
daft = xrft.dft(da, window=True)
226+
ps = xrft.power_spectrum(da, window=True, detrend='constant')
227+
daft = xrft.dft(da, window=True, detrend='constant')
116228
coord = list(daft.coords)
117229
test = np.real(daft*np.conj(daft))/N**4
118230
for i in range(len(dim)):
119231
test /= daft[coord[-i-1]].values
120232
npt.assert_almost_equal(ps.values, test)
121233
npt.assert_almost_equal(np.ma.masked_invalid(ps).mask.sum(), 0.)
122234

235+
### Remove mean
123236
da = xr.DataArray(np.random.rand(5,20,30),
124237
dims=['time', 'y', 'x'],
125238
coords={'time': np.arange(5),
126239
'y': np.arange(20), 'x': np.arange(30)})
127240
ps = xrft.power_spectrum(da, dim=['y', 'x'],
128-
window=True, density=False
241+
window=True, density=False, detrend='constant'
242+
)
243+
daft = xrft.dft(da, dim=['y','x'], window=True, detrend='constant')
244+
npt.assert_almost_equal(ps.values, np.real(daft*np.conj(daft)))
245+
npt.assert_almost_equal(np.ma.masked_invalid(ps).mask.sum(), 0.)
246+
247+
### Remove least-square fit
248+
da_prime = np.zeros_like(da.values)
249+
for t in range(5):
250+
da_prime[t] = numpy_detrend(da[t].values)
251+
da_prime = xr.DataArray(da_prime, dims=da.dims, coords=da.coords)
252+
ps = xrft.power_spectrum(da_prime, dim=['y', 'x'],
253+
window=True, density=False, detrend='constant'
129254
)
130-
daft = xrft.dft(da, dim=['y','x'], window=True)
255+
daft = xrft.dft(da_prime, dim=['y','x'], window=True, detrend='constant')
131256
npt.assert_almost_equal(ps.values, np.real(daft*np.conj(daft)))
132257
npt.assert_almost_equal(np.ma.masked_invalid(ps).mask.sum(), 0.)
133258

@@ -143,8 +268,8 @@ def test_power_spectrum_dask():
143268
daft = xrft.dft(da, dim=['x','y'])
144269
npt.assert_almost_equal(ps.values, (daft * np.conj(daft)).real.values)
145270

146-
ps = xrft.power_spectrum(da, dim=dim, window=True)
147-
daft = xrft.dft(da, dim=dim, window=True)
271+
ps = xrft.power_spectrum(da, dim=dim, window=True, detrend='constant')
272+
daft = xrft.dft(da, dim=dim, window=True, detrend='constant')
148273
coord = list(daft.coords)
149274
test = (daft * np.conj(daft)).real/N**4
150275
for i in dim:
@@ -161,12 +286,13 @@ def test_cross_spectrum():
161286
da2 = xr.DataArray(np.random.rand(N,N), dims=['x','y'],
162287
coords={'x':range(N),'y':range(N)}
163288
)
164-
cs = xrft.cross_spectrum(da, da2, window=True, density=False)
289+
cs = xrft.cross_spectrum(da, da2, window=True, density=False,
290+
detrend='constant')
165291
daft = xrft.dft(da,
166-
dim=None, shift=True, remove_mean=True,
292+
dim=None, shift=True, detrend='constant',
167293
window=True)
168294
daft2 = xrft.dft(da2,
169-
dim=None, shift=True, remove_mean=True,
295+
dim=None, shift=True, detrend='constant',
170296
window=True)
171297
npt.assert_almost_equal(cs.values, np.real(daft*np.conj(daft2)))
172298
npt.assert_almost_equal(np.ma.masked_invalid(cs).mask.sum(), 0.)
@@ -188,13 +314,21 @@ def test_cross_spectrum_dask():
188314
daft2 = xrft.dft(da2, dim=dim)
189315
npt.assert_almost_equal(cs.values, (daft * np.conj(daft2)).real.values)
190316

191-
cs = xrft.cross_spectrum(da, da2, dim=dim, window=True)
192-
daft = xrft.dft(da, dim=dim, window=True)
193-
daft2 = xrft.dft(da2, dim=dim, window=True)
317+
cs = xrft.cross_spectrum(da, da2,
318+
dim=dim, shift=True, window=True,
319+
detrend='constant')
320+
daft = xrft.dft(da,
321+
dim=dim, shift=True, window=True,
322+
detrend='constant')
323+
daft2 = xrft.dft(da2,
324+
dim=dim, shift=True, window=True,
325+
detrend='constant')
194326
coord = list(daft.coords)
195-
test = (daft * np.conj(daft2)).real/N**4
196-
for i in dim:
197-
test /= daft['freq_' + i + '_spacing']
327+
test = (daft * np.conj(daft2)).real.values/N**4
328+
# for i in dim:
329+
# test /= daft['freq_' + i + '_spacing']
330+
dk = np.diff(np.fft.fftfreq(N, 1.))[0]
331+
test /= dk**2
198332
npt.assert_almost_equal(cs.values, test)
199333
npt.assert_almost_equal(np.ma.masked_invalid(cs).mask.sum(), 0.)
200334

@@ -212,36 +346,19 @@ def test_parseval():
212346
for d in dim:
213347
coord = da[d]
214348
diff = np.diff(coord)
215-
# if pd.core.common.is_timedelta64_dtype(diff):
216-
# # convert to seconds so we get hertz
217-
# diff = diff.astype('timedelta64[s]').astype('f8')
218349
delta = diff[0]
219350
delta_x.append(delta)
220351

221352
window = np.hanning(N) * np.hanning(N)[:, np.newaxis]
222-
ps = xrft.power_spectrum(da, window=True)
353+
ps = xrft.power_spectrum(da, window=True, detrend='constant')
223354
da_prime = da.values - da.mean(dim=dim).values
224355
npt.assert_almost_equal(ps.values.sum(),
225356
(np.asarray(delta_x).prod()
226357
* ((da_prime*window)**2).sum()
227358
), decimal=5
228359
)
229360

230-
# da = xr.DataArray(np.random.rand(5,2,20,30),
231-
# dims=['time', 'z', 'y', 'x'],
232-
# coords={'time': np.arange(5),'z':range(2),
233-
# 'y': np.arange(20),
234-
# 'x': np.arange(30)}).chunk({'z':1})
235-
# dim = ['y','x']
236-
# ps = xrft.power_spectrum(da, dim=dim, window=True)
237-
# da_prime = da.values - da.mean(dim=dim).values
238-
# npt.assert_almost_equal(ps.values.sum(),
239-
# (np.asarray(delta_x).prod()
240-
# * ((da_prime*window)**2).sum()
241-
# ), decimal=5
242-
# )
243-
244-
cs = xrft.cross_spectrum(da, da2, window=True)
361+
cs = xrft.cross_spectrum(da, da2, window=True, detrend='constant')
245362
da2_prime = da2.values - da2.mean(dim=dim).values
246363
npt.assert_almost_equal(cs.values.sum(),
247364
(np.asarray(delta_x).prod()
@@ -250,6 +367,17 @@ def test_parseval():
250367
), decimal=5
251368
)
252369

370+
d3d = xr.DataArray(np.random.rand(N,N,N),
371+
dims=['time','y','x'],
372+
coords={'time':range(N), 'y':range(N), 'x':range(N)}
373+
).chunk({'time':1})
374+
ps = xrft.power_spectrum(d3d, dim=['x','y'], window=True, detrend='linear')
375+
npt.assert_almost_equal(ps[0].values.sum(),
376+
(np.asarray(delta_x).prod()
377+
* ((numpy_detrend(d3d[0].values)*window)**2).sum()
378+
), decimal=5
379+
)
380+
253381
def _synthetic_field(N, dL, amp, s):
254382
"""
255383
Generate a synthetic series of size N by N
@@ -348,7 +476,7 @@ def test_isotropic_ps_slope(N=512, dL=1., amp=1e1, s=-3.):
348476
theta = xr.DataArray(_synthetic_field(N, dL, amp, s),
349477
dims=['y', 'x'],
350478
coords={'y':range(N), 'x':range(N)})
351-
iso_ps = xrft.isotropic_powerspectrum(theta, remove_mean=True,
479+
iso_ps = xrft.isotropic_powerspectrum(theta, detrend='constant',
352480
density=True)
353481
npt.assert_almost_equal(np.ma.masked_invalid(iso_ps[1:]).mask.sum(), 0.)
354482
y_fit, a, b = xrft.fit_loglog(iso_ps.freq_r.values[4:],

xrft/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,4 +61,4 @@
6161
MICRO = _version_micro
6262
VERSION = __version__
6363
PACKAGE_DATA = {'xrft': [pjoin('data', '*')]}
64-
REQUIRES = ["numpy", "xarray"]
64+
REQUIRES = ["numpy", "scipy", "xarray", "dask"]

0 commit comments

Comments
 (0)