Skip to content

Commit 2673492

Browse files
committed
Fix multidim context merge bug (closes #93)
1 parent 4382957 commit 2673492

File tree

2 files changed

+44
-1
lines changed

2 files changed

+44
-1
lines changed

deepsensor/data/task.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -279,10 +279,10 @@ def mask_nans_numpy(self):
279279

280280
def f(arr):
281281
if isinstance(arr, deepsensor.backend.nps.Masked):
282-
# Ignore nps.Masked objects
283282
nps_mask = arr.mask == 0
284283
nan_mask = np.isnan(arr.y)
285284
mask = np.logical_or(nps_mask, nan_mask)
285+
mask = np.any(mask, axis=1, keepdims=True)
286286
data = arr.y
287287
data[nan_mask] = 0.0
288288
arr = deepsensor.backend.nps.Masked(data, mask)

tests/test_training.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,3 +113,46 @@ def test_training(self):
113113
# Check for NaNs in the loss
114114
loss = np.mean(epoch_losses)
115115
self.assertFalse(np.isnan(loss))
116+
117+
def test_training_multidim(self):
118+
"""A basic test of the training loop with multidimensional context sets"""
119+
# Load raw data
120+
ds_raw = xr.tutorial.open_dataset("air_temperature")
121+
122+
# Add extra dim
123+
ds_raw["air2"] = ds_raw["air"].copy()
124+
125+
# Normalise data
126+
dp = DataProcessor(x1_name="lat", x2_name="lon")
127+
ds = dp(ds_raw)
128+
129+
# Set up task loader
130+
tl = TaskLoader(context=ds, target=ds)
131+
132+
# Set up model
133+
model = ConvNP(dp, tl)
134+
135+
# Generate training tasks
136+
n_train_tasks = 10
137+
train_tasks = []
138+
for i in range(n_train_tasks):
139+
date = np.random.choice(self.da.time.values)
140+
task = tl(date, 10, 10)
141+
task["Y_c"][0][:, 0] = np.nan # Add NaN to context
142+
task["Y_t"][0][:, 0] = np.nan # Add NaN to target
143+
print(task)
144+
train_tasks.append(task)
145+
146+
# Train
147+
trainer = Trainer(model, lr=5e-5)
148+
# batch_size = None
149+
batch_size = 5
150+
n_epochs = 10
151+
epoch_losses = []
152+
for epoch in tqdm(range(n_epochs)):
153+
batch_losses = trainer(train_tasks, batch_size=batch_size)
154+
epoch_losses.append(np.mean(batch_losses))
155+
156+
# Check for NaNs in the loss
157+
loss = np.mean(epoch_losses)
158+
self.assertFalse(np.isnan(loss))

0 commit comments

Comments
 (0)