Skip to content

Commit 214e140

Browse files
Add the FTRL (Follow The Regularized Leader) optimizer.
This implementation is based on the FTRL algorithm, [McMahan et al., 2013](https://research.google.com/pubs/archive/41159.pdf). Features / Params in `FTRLOptimizerSpec` (as used in the primitive): - **learning_rate**: The base learning rate. - **learning_rate_power**: Controls the per-coordinate learning rate decay (typically -0.5). - **l1_regularization_strength**: Applies L1 regularization, which can lead to sparsity in the model weights. - **l2_regularization_strength**: Applies L2 regularization. - **beta**: An additional smoothing term.. - **clip_weight_min**, **clip_weight_max**: Optional bounds for clipping the updated embedding weights. - **weight_decay_factor**: Factor for applying weight decay to the gradients. - **multiply_weight_decay_factor_by_learning_rate**: Boolean flag; if true, the `weight_decay_factor` is multiplied by the `learning_rate` before applying decay. - **multiply_linear_by_learning_rate**: Boolean flag; if true, the linear term update incorporates the `learning_rate` differently. - **allow_zero_accumulator**: Boolean flag; if true, allows the accumulator to be exactly zero. Otherwise, a small epsilon is added for numerical stability when `accumulator` is zero. The optimizer maintains two slot variables for each trainable embedding parameter: - **accumulator**: Stores the sum of squared gradients, used to adapt the learning rate on a per-coordinate basis. - **linear**: Stores a linear combination related to the gradients, which is central to the FTRL weight update rule. PiperOrigin-RevId: 764794731
1 parent d6565c8 commit 214e140

File tree

8 files changed

+1179
-10
lines changed

8 files changed

+1179
-10
lines changed

jax_tpu_embedding/sparsecore/lib/core/primitives/BUILD

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,21 @@ pytype_strict_library(
135135
],
136136
)
137137

138+
pytype_strict_library(
139+
name = "sparse_dense_matmul_grad_with_ftrl",
140+
srcs = [
141+
"sparse_dense_matmul_grad_with_ftrl.py",
142+
],
143+
deps = [
144+
":utils",
145+
"//jax_tpu_embedding/sparsecore/lib/core:constants",
146+
pypi_requirement("jax"),
147+
pypi_requirement("jax/_src/lib"),
148+
pypi_requirement("jax/extend"),
149+
pypi_requirement("numpy"),
150+
],
151+
)
152+
138153
pytype_strict_library(
139154
name = "optimizers_computation",
140155
srcs = [
@@ -170,6 +185,7 @@ pytype_strict_library(
170185
":optimizers_computation", # buildcleaner: keep
171186
":sparse_dense_matmul_csr", # buildcleaner: keep
172187
":sparse_dense_matmul_grad_with_adagrad", # buildcleaner: keep
188+
":sparse_dense_matmul_grad_with_ftrl", # buildcleaner: keep
173189
":sparse_dense_matmul_grad_with_laprop", # buildcleaner: keep
174190
":sparse_dense_matmul_grad_with_sgd", # buildcleaner: keep
175191
":sparse_dense_matmul_optimizer_grad", # buildcleaner: keep

0 commit comments

Comments
 (0)