7
7
import numpy as np
8
8
import pandas as pd
9
9
from torch import Tensor
10
+ import torch .nn as nn
10
11
import torch .nn .functional as F
11
12
from torch .nn .modules .container import Sequential
12
13
from hyperfast .utils import TorchPCA
@@ -247,6 +248,7 @@ def _initialize_fit_attributes(self) -> None:
247
248
self ._rfs = []
248
249
self ._pcas = []
249
250
self ._main_networks = []
251
+ self ._nnbias = []
250
252
self ._X_preds = []
251
253
self ._y_preds = []
252
254
if self .feature_bagging :
@@ -284,18 +286,42 @@ def _sample_data(self, X: Tensor, y: Tensor) -> Tuple[Tensor, Tensor]:
284
286
X_pred = torch .repeat_interleave (X_pred , n_repeats , axis = 0 )
285
287
y_pred = torch .repeat_interleave (y_pred , n_repeats , axis = 0 )
286
288
return X_pred , y_pred
289
+
290
+ def _move_to_device (self , data , device = None ):
291
+ if device is None :
292
+ device = self .device
293
+ if isinstance (data , list ):
294
+ return [(mat .to (device ), bi .to (device )) for mat , bi in data ]
295
+ elif isinstance (data , TorchPCA ):
296
+ data .mean_ , data .components_ = data .mean_ .to (device ), data .components_ .to (device )
297
+ return data
298
+ elif isinstance (data , PCA ): # scikit-learn PCA
299
+ return data
300
+ return data .to (device )
301
+
302
+ def _move_to_cpu (self , data ):
303
+ return self ._move_to_device (data , "cpu" )
287
304
288
305
def _store_network (
289
306
self ,
290
307
rf : Sequential ,
291
308
pca : PCA | TorchPCA ,
292
309
main_network : list ,
310
+ nnbias : nn .Parameter ,
293
311
X_pred : Tensor ,
294
312
y_pred : Tensor ,
295
313
) -> None :
314
+ rf = self ._move_to_cpu (rf )
315
+ pca = self ._move_to_cpu (pca )
316
+ main_network = self ._move_to_cpu (main_network )
317
+ nnbias = self ._move_to_cpu (nnbias )
318
+ X_pred = self ._move_to_cpu (X_pred )
319
+ y_pred = self ._move_to_cpu (y_pred )
320
+
296
321
self ._rfs .append (rf )
297
322
self ._pcas .append (pca )
298
323
self ._main_networks .append (main_network )
324
+ self ._nnbias .append (nnbias )
299
325
self ._X_preds .append (X_pred )
300
326
self ._y_preds .append (y_pred )
301
327
@@ -317,24 +343,23 @@ def fit(
317
343
X_pred , y_pred = self ._sample_data (X , y )
318
344
X_pred , y_pred = X_pred .to (self .device ), y_pred .to (self .device )
319
345
self .n_classes_ = len (torch .unique (y_pred ).cpu ().numpy ())
320
-
321
- rf , pca , main_network = self ._model (X_pred , y_pred , self .n_classes_ )
322
-
346
+ with torch .no_grad ():
347
+ rf , pca , main_network , nnbias = self ._model (X_pred , y_pred , self .n_classes_ )
323
348
if self .optimization == "ensemble_optimize" :
324
- rf , pca , main_network , self . _model . nn_bias = fine_tune_main_network (
349
+ rf , pca , main_network , nn_bias = fine_tune_main_network (
325
350
self ._cfg ,
326
351
X_pred ,
327
352
y_pred ,
328
353
self .n_classes_ ,
329
354
rf ,
330
355
pca ,
331
356
main_network ,
332
- self . _model . nn_bias ,
357
+ nnbias ,
333
358
self .device ,
334
359
self .optimize_steps ,
335
360
self .batch_size ,
336
361
)
337
- self ._store_network (rf , pca , main_network , X_pred , y_pred )
362
+ self ._store_network (rf , pca , main_network , nnbias , X_pred , y_pred )
338
363
339
364
if self .optimization == "optimize" and self .optimize_steps > 0 :
340
365
assert (
@@ -344,7 +369,7 @@ def fit(
344
369
self ._rfs [0 ],
345
370
self ._pcas [0 ],
346
371
self ._main_networks [0 ],
347
- self ._model . nn_bias ,
372
+ self ._nnbias [ 0 ] ,
348
373
) = fine_tune_main_network (
349
374
self ._cfg ,
350
375
X ,
@@ -353,7 +378,7 @@ def fit(
353
378
self ._rfs [0 ],
354
379
self ._pcas [0 ],
355
380
self ._main_networks [0 ],
356
- self ._model . nn_bias ,
381
+ self ._nnbias [ 0 ] ,
357
382
self .device ,
358
383
self .optimize_steps ,
359
384
self .batch_size ,
@@ -373,11 +398,13 @@ def predict_proba(self, X: np.ndarray | pd.DataFrame) -> np.ndarray:
373
398
orig_X = X_batch
374
399
yhats = []
375
400
for jj in range (len (self ._main_networks )):
376
- main_network = self ._main_networks [jj ]
377
- rf = self ._rfs [jj ]
378
- pca = self ._pcas [jj ]
379
- X_pred = self ._X_preds [jj ]
380
- y_pred = self ._y_preds [jj ]
401
+ main_network = self ._move_to_device (self ._main_networks [jj ])
402
+ rf = self ._move_to_device (self ._rfs [jj ])
403
+ pca = self ._move_to_device (self ._pcas [jj ])
404
+ nnbias = self ._move_to_device (self ._nnbias [jj ])
405
+ X_pred = self ._move_to_device (self ._X_preds [jj ])
406
+ y_pred = self ._move_to_device (self ._y_preds [jj ])
407
+
381
408
if self .feature_bagging :
382
409
X_ = X_batch [:, self .selected_features [jj ]]
383
410
orig_X_ = orig_X [:, self .selected_features [jj ]]
@@ -399,7 +426,7 @@ def predict_proba(self, X: np.ndarray | pd.DataFrame) -> np.ndarray:
399
426
outputs_pred , intermediate_activations_pred = forward_main_network (
400
427
X_pred_ , main_network
401
428
)
402
- for bb , bias in enumerate (self . _model . nn_bias ):
429
+ for bb , bias in enumerate (nnbias ):
403
430
if bb == 0 :
404
431
outputs = nn_bias_logits (
405
432
outputs , orig_X_ , X_pred , y_pred , bias , self .n_classes_ , self .nn_bias_mini_batches
@@ -411,7 +438,14 @@ def predict_proba(self, X: np.ndarray | pd.DataFrame) -> np.ndarray:
411
438
412
439
predicted = F .softmax (outputs , dim = 1 )
413
440
yhats .append (predicted )
414
-
441
+
442
+ for data in [rf , pca , main_network , nnbias , X_pred , y_pred ,
443
+ X_transformed , outputs , intermediate_activations ]:
444
+ data = self ._move_to_cpu (data )
445
+ if self .nn_bias :
446
+ for data in [X_pred_ , outputs_pred , intermediate_activations_pred ]:
447
+ data = self ._move_to_cpu (data )
448
+
415
449
yhats = torch .stack (yhats )
416
450
yhats = torch .mean (yhats , axis = 0 )
417
451
yhats = yhats .cpu ().numpy ()
0 commit comments