Skip to content

Commit 674a560

Browse files
committed
GPU memory improvements for very big datasets and ensembling
1 parent 32c4189 commit 674a560

File tree

2 files changed

+50
-16
lines changed

2 files changed

+50
-16
lines changed

hyperfast/hyperfast.py

Lines changed: 49 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import numpy as np
88
import pandas as pd
99
from torch import Tensor
10+
import torch.nn as nn
1011
import torch.nn.functional as F
1112
from torch.nn.modules.container import Sequential
1213
from hyperfast.utils import TorchPCA
@@ -247,6 +248,7 @@ def _initialize_fit_attributes(self) -> None:
247248
self._rfs = []
248249
self._pcas = []
249250
self._main_networks = []
251+
self._nnbias = []
250252
self._X_preds = []
251253
self._y_preds = []
252254
if self.feature_bagging:
@@ -284,18 +286,42 @@ def _sample_data(self, X: Tensor, y: Tensor) -> Tuple[Tensor, Tensor]:
284286
X_pred = torch.repeat_interleave(X_pred, n_repeats, axis=0)
285287
y_pred = torch.repeat_interleave(y_pred, n_repeats, axis=0)
286288
return X_pred, y_pred
289+
290+
def _move_to_device(self, data, device=None):
291+
if device is None:
292+
device = self.device
293+
if isinstance(data, list):
294+
return [(mat.to(device), bi.to(device)) for mat, bi in data]
295+
elif isinstance(data, TorchPCA):
296+
data.mean_, data.components_ = data.mean_.to(device), data.components_.to(device)
297+
return data
298+
elif isinstance(data, PCA): # scikit-learn PCA
299+
return data
300+
return data.to(device)
301+
302+
def _move_to_cpu(self, data):
303+
return self._move_to_device(data, "cpu")
287304

288305
def _store_network(
289306
self,
290307
rf: Sequential,
291308
pca: PCA | TorchPCA,
292309
main_network: list,
310+
nnbias: nn.Parameter,
293311
X_pred: Tensor,
294312
y_pred: Tensor,
295313
) -> None:
314+
rf = self._move_to_cpu(rf)
315+
pca = self._move_to_cpu(pca)
316+
main_network = self._move_to_cpu(main_network)
317+
nnbias = self._move_to_cpu(nnbias)
318+
X_pred = self._move_to_cpu(X_pred)
319+
y_pred = self._move_to_cpu(y_pred)
320+
296321
self._rfs.append(rf)
297322
self._pcas.append(pca)
298323
self._main_networks.append(main_network)
324+
self._nnbias.append(nnbias)
299325
self._X_preds.append(X_pred)
300326
self._y_preds.append(y_pred)
301327

@@ -317,24 +343,23 @@ def fit(
317343
X_pred, y_pred = self._sample_data(X, y)
318344
X_pred, y_pred = X_pred.to(self.device), y_pred.to(self.device)
319345
self.n_classes_ = len(torch.unique(y_pred).cpu().numpy())
320-
321-
rf, pca, main_network = self._model(X_pred, y_pred, self.n_classes_)
322-
346+
with torch.no_grad():
347+
rf, pca, main_network, nnbias = self._model(X_pred, y_pred, self.n_classes_)
323348
if self.optimization == "ensemble_optimize":
324-
rf, pca, main_network, self._model.nn_bias = fine_tune_main_network(
349+
rf, pca, main_network, nn_bias = fine_tune_main_network(
325350
self._cfg,
326351
X_pred,
327352
y_pred,
328353
self.n_classes_,
329354
rf,
330355
pca,
331356
main_network,
332-
self._model.nn_bias,
357+
nnbias,
333358
self.device,
334359
self.optimize_steps,
335360
self.batch_size,
336361
)
337-
self._store_network(rf, pca, main_network, X_pred, y_pred)
362+
self._store_network(rf, pca, main_network, nnbias, X_pred, y_pred)
338363

339364
if self.optimization == "optimize" and self.optimize_steps > 0:
340365
assert (
@@ -344,7 +369,7 @@ def fit(
344369
self._rfs[0],
345370
self._pcas[0],
346371
self._main_networks[0],
347-
self._model.nn_bias,
372+
self._nnbias[0],
348373
) = fine_tune_main_network(
349374
self._cfg,
350375
X,
@@ -353,7 +378,7 @@ def fit(
353378
self._rfs[0],
354379
self._pcas[0],
355380
self._main_networks[0],
356-
self._model.nn_bias,
381+
self._nnbias[0],
357382
self.device,
358383
self.optimize_steps,
359384
self.batch_size,
@@ -373,11 +398,13 @@ def predict_proba(self, X: np.ndarray | pd.DataFrame) -> np.ndarray:
373398
orig_X = X_batch
374399
yhats = []
375400
for jj in range(len(self._main_networks)):
376-
main_network = self._main_networks[jj]
377-
rf = self._rfs[jj]
378-
pca = self._pcas[jj]
379-
X_pred = self._X_preds[jj]
380-
y_pred = self._y_preds[jj]
401+
main_network = self._move_to_device(self._main_networks[jj])
402+
rf = self._move_to_device(self._rfs[jj])
403+
pca = self._move_to_device(self._pcas[jj])
404+
nnbias = self._move_to_device(self._nnbias[jj])
405+
X_pred = self._move_to_device(self._X_preds[jj])
406+
y_pred = self._move_to_device(self._y_preds[jj])
407+
381408
if self.feature_bagging:
382409
X_ = X_batch[:, self.selected_features[jj]]
383410
orig_X_ = orig_X[:, self.selected_features[jj]]
@@ -399,7 +426,7 @@ def predict_proba(self, X: np.ndarray | pd.DataFrame) -> np.ndarray:
399426
outputs_pred, intermediate_activations_pred = forward_main_network(
400427
X_pred_, main_network
401428
)
402-
for bb, bias in enumerate(self._model.nn_bias):
429+
for bb, bias in enumerate(nnbias):
403430
if bb == 0:
404431
outputs = nn_bias_logits(
405432
outputs, orig_X_, X_pred, y_pred, bias, self.n_classes_, self.nn_bias_mini_batches
@@ -411,7 +438,14 @@ def predict_proba(self, X: np.ndarray | pd.DataFrame) -> np.ndarray:
411438

412439
predicted = F.softmax(outputs, dim=1)
413440
yhats.append(predicted)
414-
441+
442+
for data in [rf, pca, main_network, nnbias, X_pred, y_pred,
443+
X_transformed, outputs, intermediate_activations]:
444+
data = self._move_to_cpu(data)
445+
if self.nn_bias:
446+
for data in [X_pred_, outputs_pred, intermediate_activations_pred]:
447+
data = self._move_to_cpu(data)
448+
415449
yhats = torch.stack(yhats)
416450
yhats = torch.mean(yhats, axis=0)
417451
yhats = yhats.cpu().numpy()

hyperfast/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,4 +137,4 @@ def forward(self, X, y, n_classes):
137137
out, last_linear_layer = forward_linear_layer(out, weights, n_classes)
138138
main_network.append(last_linear_layer)
139139

140-
return rf, self.pca, main_network
140+
return rf, self.pca, main_network, self.nn_bias

0 commit comments

Comments
 (0)