Skip to content

Commit

Permalink
sgd nan fix
Browse files Browse the repository at this point in the history
  • Loading branch information
nirjacoby committed Jan 8, 2024
1 parent a3109e5 commit a8e8f9c
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 0 deletions.
13 changes: 13 additions & 0 deletions neighbors/_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ def sgd(
error_history = np.zeros((n_iterations))
converged = False
last_e = 0
e = 0
error_is_nan = False
norm_rmse = np.inf
delta = np.inf
np.random.seed(seed)
Expand Down Expand Up @@ -69,6 +71,11 @@ def sgd(

# Use changes in e to determine tolerance
e = data[u, i] - prediction # error
# Check if predictions have exploded resulting in NaN errors
# and prevent propagation but breaking early
if np.isnan(e):
error_is_nan = True
break

# Update biases
user_bias[u] += learning_rate * (e - user_bias_reg * user_bias[u])
Expand All @@ -85,6 +92,11 @@ def sgd(
# Keep track of total squared error
total_error += np.power(e, 2)

# Check if error was nan
if error_is_nan:
converged = False
break

# Force non-negativity. Surprise does this per-epoch via re-initialization. We do this per sweep over all training data, e.g. see: https://github.com/NicolasHug/Surprise/blob/master/surprise/prediction_algorithms/matrix_factorization.pyx#L671
user_vecs = np.maximum(user_vecs, 0)
item_vecs = np.maximum(item_vecs, 0)
Expand All @@ -107,6 +119,7 @@ def sgd(
return (
error_history,
converged,
error_is_nan,
this_iter,
delta,
norm_rmse,
Expand Down
5 changes: 5 additions & 0 deletions neighbors/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,6 +523,7 @@ def fit(
(
error_history,
converged,
error_is_nan,
n_iter,
delta,
norm_rmse,
Expand Down Expand Up @@ -569,11 +570,15 @@ def fit(
self._delta = delta
self._norm_rmse = norm_rmse
self.converged = converged
self.error_is_nan = error_is_nan
if verbose:
if self.converged:
print("\n\tCONVERGED!")
print(f"\n\tFinal Iteration: {self._n_iter}")
print(f"\tFinal Delta: {np.round(self._delta)}")
elif self.error_is_nan:
print("\tFAILED TO CONVERGE (predictions are NaN)")
print(f"\n\tFinal Iteration: {self._n_iter}")
else:
print("\tFAILED TO CONVERGE (n_iter reached)")
print(f"\n\tFinal Iteration: {self._n_iter}")
Expand Down

0 comments on commit a8e8f9c

Please sign in to comment.