@@ -113,3 +113,46 @@ def test_training(self):
113
113
# Check for NaNs in the loss
114
114
loss = np .mean (epoch_losses )
115
115
self .assertFalse (np .isnan (loss ))
116
+
117
+ def test_training_multidim (self ):
118
+ """A basic test of the training loop with multidimensional context sets"""
119
+ # Load raw data
120
+ ds_raw = xr .tutorial .open_dataset ("air_temperature" )
121
+
122
+ # Add extra dim
123
+ ds_raw ["air2" ] = ds_raw ["air" ].copy ()
124
+
125
+ # Normalise data
126
+ dp = DataProcessor (x1_name = "lat" , x2_name = "lon" )
127
+ ds = dp (ds_raw )
128
+
129
+ # Set up task loader
130
+ tl = TaskLoader (context = ds , target = ds )
131
+
132
+ # Set up model
133
+ model = ConvNP (dp , tl )
134
+
135
+ # Generate training tasks
136
+ n_train_tasks = 10
137
+ train_tasks = []
138
+ for i in range (n_train_tasks ):
139
+ date = np .random .choice (self .da .time .values )
140
+ task = tl (date , 10 , 10 )
141
+ task ["Y_c" ][0 ][:, 0 ] = np .nan # Add NaN to context
142
+ task ["Y_t" ][0 ][:, 0 ] = np .nan # Add NaN to target
143
+ print (task )
144
+ train_tasks .append (task )
145
+
146
+ # Train
147
+ trainer = Trainer (model , lr = 5e-5 )
148
+ # batch_size = None
149
+ batch_size = 5
150
+ n_epochs = 10
151
+ epoch_losses = []
152
+ for epoch in tqdm (range (n_epochs )):
153
+ batch_losses = trainer (train_tasks , batch_size = batch_size )
154
+ epoch_losses .append (np .mean (batch_losses ))
155
+
156
+ # Check for NaNs in the loss
157
+ loss = np .mean (epoch_losses )
158
+ self .assertFalse (np .isnan (loss ))
0 commit comments