Skip to content

Commit b0703e5

Browse files
Add the Adagrad with Momentum optimizer.
This implementation matches the one described in [Duchi et al., 2011](https://arxiv.org/abs/1103.4296) with the momentum integration discussed in [Sutskever et al., 2013](https://proceedings.mlr.press/v28/sutskever13.pdf). Features / Params in `AdagradMomentumOptimizerSpec`: - **learning_rate**: The base learning rate. - **momentum**: The momentum parameter (exponential decay for the momentum buffer). - **beta2**: The decay rate for the running average of squared gradients (accumulator). - **epsilon**: A small constant added for numerical stability. - **exponent**: The power to which the accumulator is raised (often referred to as `k_power` in some contexts, typically 0.5 for Adagrad). - **use_nesterov**: Boolean flag; if true, Nesterov momentum is used. The optimizer maintains two slot variables for each trainable embedding parameter: - **accumulator**: Stores the running average of squared gradients. - **momentum_buffer**: Stores the first-order momentum term. PiperOrigin-RevId: 765046416
1 parent c02c8ea commit b0703e5

File tree

8 files changed

+1029
-11
lines changed

8 files changed

+1029
-11
lines changed

jax_tpu_embedding/sparsecore/lib/core/primitives/BUILD

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,21 @@ pytype_strict_library(
107107
],
108108
)
109109

110+
pytype_strict_library(
111+
name = "sparse_dense_matmul_grad_with_adagrad_momentum",
112+
srcs = [
113+
"sparse_dense_matmul_grad_with_adagrad_momentum.py",
114+
],
115+
deps = [
116+
":utils",
117+
"//jax_tpu_embedding/sparsecore/lib/core:constants",
118+
pypi_requirement("jax"),
119+
pypi_requirement("jax/_src/lib"),
120+
pypi_requirement("jax/extend"),
121+
pypi_requirement("numpy"),
122+
],
123+
)
124+
110125
pytype_strict_library(
111126
name = "sparse_dense_matmul_grad_with_laprop",
112127
srcs = ["sparse_dense_matmul_grad_with_laprop.py"],
@@ -185,6 +200,7 @@ pytype_strict_library(
185200
":optimizers_computation", # buildcleaner: keep
186201
":sparse_dense_matmul_csr", # buildcleaner: keep
187202
":sparse_dense_matmul_grad_with_adagrad", # buildcleaner: keep
203+
":sparse_dense_matmul_grad_with_adagrad_momentum", # buildcleaner: keep
188204
":sparse_dense_matmul_grad_with_ftrl", # buildcleaner: keep
189205
":sparse_dense_matmul_grad_with_laprop", # buildcleaner: keep
190206
":sparse_dense_matmul_grad_with_sgd", # buildcleaner: keep
Lines changed: 347 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,347 @@
1+
# Copyright 2024 The JAX SC Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Adagrad with momentum optimizer for sparse dense matmul backward pass.
15+
16+
This implements the Jax primitive for the Adagrad with momentum optimizer for
17+
the
18+
sparse dense matmul backward pass, as a custom call to the
19+
SparseDenseMatmulGradOpWithOptimizerUpdate op. This op takes the preprocessed
20+
input tensors, embedding table, accumulator, momentum buffer, the grad, the
21+
learning rate and momentum hyperparameter as inputs and returns the updated
22+
embedding table, accumulator, and momentum buffer.
23+
"""
24+
25+
import functools
26+
import json
27+
from typing import Tuple
28+
29+
from jax._src.lib.mlir import ir
30+
from jax._src.lib.mlir.dialects import func as func_dialect
31+
from jax._src.lib.mlir.dialects import hlo
32+
import jax.extend as jex
33+
from jax.interpreters import mlir
34+
from jax.interpreters import xla
35+
from jax_tpu_embedding.sparsecore.lib.core import constants
36+
from jax_tpu_embedding.sparsecore.lib.core.primitives import utils
37+
import numpy as np
38+
39+
tpu_sparse_dense_matmul_grad_with_adagrad_momentum_primitive = (
40+
jex.core.Primitive("sparse_dense_matmul_grad_with_adagrad_momentum")
41+
)
42+
tpu_sparse_dense_matmul_grad_with_adagrad_momentum_primitive.multiple_results = (
43+
True
44+
)
45+
46+
tpu_sparse_dense_matmul_grad_with_adagrad_momentum_primitive.def_impl(
47+
functools.partial(
48+
xla.apply_primitive,
49+
tpu_sparse_dense_matmul_grad_with_adagrad_momentum_primitive,
50+
)
51+
)
52+
53+
54+
def _annotate_sparse_compute_type(op: ir.OpView):
55+
op.attributes["mhlo.frontend_attributes"] = ir.DictAttr.get(
56+
{"_xla_compute_type": ir.StringAttr.get("sparse")}
57+
)
58+
return op
59+
60+
61+
def _hlo_const(arr: np.ndarray) -> ir.Value:
62+
"""Return an HLO constant from a NumPy array (any rank)."""
63+
return hlo.constant(
64+
ir.DenseElementsAttr.get(arr, type=mlir.dtype_to_ir_type(arr.dtype))
65+
)
66+
67+
68+
def _hlo_f32(x: float, emb_dim: int) -> ir.Value:
69+
"""Return a <1 x emb_dim> f32 constant filled with x."""
70+
return _hlo_const(np.full((1, emb_dim), x, dtype=np.float32))
71+
72+
73+
def _tpu_sparse_dense_matmul_grad_with_adagrad_momentum_abstract_eval(
74+
lhs_row_pointers: np.ndarray,
75+
lhs_local_embedding_ids: np.ndarray,
76+
lhs_local_sample_ids: np.ndarray,
77+
lhs_gains: np.ndarray,
78+
embedding_table: np.ndarray,
79+
accumulator: np.ndarray,
80+
momentum: np.ndarray,
81+
activations_grad: np.ndarray,
82+
learning_rate: np.float32,
83+
momentum_param: np.float32,
84+
beta2: np.float32,
85+
epsilon: np.float32,
86+
exponent: np.float32,
87+
use_nesterov: np.bool = np.bool(False),
88+
*_,
89+
max_ids_per_partition: int,
90+
max_unique_ids_per_partition: int,
91+
computation_name: str = "adagrad_momentum_optimizer_update",
92+
sharding_strategy: int = 1,
93+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
94+
"""Abstract eval for sparse_dense_matmul_adagrad_momentum."""
95+
utils.validate_abstract_eval_params(
96+
lhs_row_pointers,
97+
lhs_local_embedding_ids,
98+
lhs_local_sample_ids,
99+
lhs_gains,
100+
embedding_table,
101+
activations_grad,
102+
max_ids_per_partition,
103+
max_unique_ids_per_partition,
104+
computation_name,
105+
sharding_strategy,
106+
)
107+
utils.ensure_dtype(accumulator, np.float32, "accumulator")
108+
utils.ensure_dtype(momentum, np.float32, "momentum")
109+
utils.ensure_dtype(learning_rate, np.float32, "learning_rate")
110+
utils.ensure_dtype(momentum_param, np.float32, "momentum_param")
111+
utils.ensure_dtype(beta2, np.float32, "beta2")
112+
utils.ensure_dtype(epsilon, np.float32, "epsilon")
113+
utils.ensure_dtype(exponent, np.float32, "exponent")
114+
utils.ensure_dtype(use_nesterov, np.bool, "use_nesterov")
115+
116+
if (
117+
embedding_table.shape != accumulator.shape
118+
or embedding_table.shape != momentum.shape
119+
):
120+
raise ValueError(
121+
"embedding_table, accumulator and momentum must have identical shapes: "
122+
f"got {embedding_table.shape}, {accumulator.shape}, {momentum.shape}"
123+
)
124+
return embedding_table, accumulator, momentum
125+
126+
127+
tpu_sparse_dense_matmul_grad_with_adagrad_momentum_primitive.def_abstract_eval(
128+
_tpu_sparse_dense_matmul_grad_with_adagrad_momentum_abstract_eval
129+
)
130+
131+
132+
def _tpu_sparse_dense_matmul_grad_with_adagrad_momentum_lowering(
133+
ctx: mlir.LoweringRuleContext,
134+
lhs_row_pointers: np.ndarray,
135+
lhs_local_embedding_ids: np.ndarray,
136+
lhs_local_sample_ids: np.ndarray,
137+
lhs_gains: np.ndarray,
138+
embedding_table: np.ndarray,
139+
accumulator: np.ndarray,
140+
momentum: np.ndarray,
141+
activations_grad: np.ndarray,
142+
learning_rate: np.ndarray,
143+
momentum_param: np.ndarray,
144+
beta2: np.ndarray,
145+
epsilon: np.ndarray,
146+
exponent: np.ndarray,
147+
use_nesterov: np.ndarray,
148+
*,
149+
max_ids_per_partition: int,
150+
max_unique_ids_per_partition: int,
151+
computation_name: str = "adagrad_momentum_optimizer_update",
152+
sharding_strategy: int = 1,
153+
) -> Tuple[ir.Value, ir.Value, ir.Value]:
154+
"""Lowering for sparse_dense_matmul_grad_with_adagrad_momentum."""
155+
sdmm_config = {
156+
"max_ids_per_partition": max_ids_per_partition,
157+
"max_unique_ids_per_partition": max_unique_ids_per_partition,
158+
"pad_value": constants.PADDING_VALUE,
159+
"sharding_strategy": sharding_strategy,
160+
}
161+
backend_config = json.dumps({
162+
"sparse_dense_matmul_config": sdmm_config,
163+
"device_type": "DEVICE_TYPE_SPARSECORE",
164+
})
165+
166+
optimizer_update_computation_name = computation_name
167+
168+
emb_dim_size = ir.ShapedType(embedding_table.type).get_dim_size(1)
169+
optimizer_update = func_dialect.FuncOp(
170+
computation_name,
171+
(
172+
[
173+
ir.RankedTensorType.get(
174+
[1, emb_dim_size],
175+
ir.F32Type.get(),
176+
),
177+
ir.RankedTensorType.get(
178+
[1, emb_dim_size],
179+
ir.F32Type.get(),
180+
),
181+
ir.RankedTensorType.get(
182+
[1, emb_dim_size],
183+
ir.F32Type.get(),
184+
),
185+
ir.RankedTensorType.get(
186+
[1, emb_dim_size],
187+
ir.F32Type.get(),
188+
),
189+
ir.RankedTensorType.get(
190+
[1, emb_dim_size],
191+
ir.F32Type.get(),
192+
),
193+
ir.RankedTensorType.get(
194+
[1, emb_dim_size],
195+
ir.F32Type.get(),
196+
),
197+
ir.RankedTensorType.get(
198+
[1, emb_dim_size],
199+
ir.F32Type.get(),
200+
),
201+
ir.RankedTensorType.get(
202+
[1, emb_dim_size],
203+
ir.F32Type.get(),
204+
),
205+
ir.RankedTensorType.get(
206+
[1, emb_dim_size],
207+
ir.F32Type.get(),
208+
),
209+
ir.RankedTensorType.get(
210+
[1, emb_dim_size], ir.IntegerType.get_signless(1)
211+
),
212+
],
213+
[
214+
ir.TupleType.get_tuple([
215+
ir.RankedTensorType.get(
216+
[1, emb_dim_size],
217+
ir.F32Type.get(),
218+
),
219+
ir.RankedTensorType.get(
220+
[1, emb_dim_size],
221+
ir.F32Type.get(),
222+
),
223+
ir.RankedTensorType.get(
224+
[1, emb_dim_size],
225+
ir.F32Type.get(),
226+
),
227+
]),
228+
],
229+
),
230+
ip=ctx.module_context.ip,
231+
visibility="private",
232+
)
233+
234+
entry_block = optimizer_update.add_entry_block()
235+
with ir.InsertionPoint(entry_block):
236+
(
237+
grad_, # g
238+
embedding_table_, # Ē_o
239+
accumulator_, # Ā_o
240+
momentum_, # L_o
241+
lr_param_, # λ
242+
momentum_param_, # k
243+
beta2_, # βZ
244+
epsilon_param_, # ε
245+
exponent_param_,
246+
use_nesterov_flag_,
247+
) = entry_block.arguments
248+
249+
one_ = _hlo_f32(1.0, emb_dim_size)
250+
neg_one_ = _hlo_f32(-1.0, emb_dim_size)
251+
252+
# Accumulator
253+
grad_sq_ = hlo.multiply(grad_, grad_)
254+
beta2_eq_1_ = hlo.compare(
255+
beta2_,
256+
one_,
257+
comparison_direction=hlo.ComparisonDirectionAttr.get("EQ"),
258+
compare_type=hlo.ComparisonTypeAttr.get("FLOAT"),
259+
)
260+
accum_plus_ = hlo.add(accumulator_, grad_sq_)
261+
one_minus_beta2_ = hlo.subtract(one_, beta2_)
262+
scaled_accum_ = hlo.add(
263+
hlo.multiply(beta2_, accumulator_),
264+
hlo.multiply(one_minus_beta2_, grad_sq_),
265+
)
266+
accum_new_ = hlo.select(beta2_eq_1_, accum_plus_, scaled_accum_)
267+
268+
# Scaled gradient
269+
neg_inv_exp_ = hlo.divide(neg_one_, exponent_param_)
270+
accum_eps_ = hlo.add(accum_new_, epsilon_param_)
271+
p_new_ = hlo.power(accum_eps_, neg_inv_exp_)
272+
scaled_grad_ = hlo.multiply(p_new_, grad_)
273+
274+
# Momentum
275+
m_new_ = hlo.add(hlo.multiply(momentum_param_, momentum_), scaled_grad_)
276+
277+
# Delta E
278+
nesterov_update_ = hlo.add(
279+
hlo.multiply(momentum_param_, m_new_), scaled_grad_
280+
)
281+
update_ = hlo.select(use_nesterov_flag_, nesterov_update_, m_new_)
282+
lr_update_ = hlo.multiply(lr_param_, update_)
283+
284+
# Weight update
285+
w_new_ = hlo.subtract(embedding_table_, lr_update_)
286+
287+
out_tuple = hlo.tuple([w_new_, accum_new_, m_new_])
288+
func_dialect.ReturnOp([out_tuple])
289+
290+
table_tuple_op = _annotate_sparse_compute_type(
291+
hlo.TupleOp([embedding_table, accumulator, momentum])
292+
)
293+
294+
hyperparams_tuple_op = _annotate_sparse_compute_type(
295+
hlo.TupleOp([
296+
learning_rate,
297+
momentum_param,
298+
beta2,
299+
epsilon,
300+
exponent,
301+
use_nesterov,
302+
])
303+
)
304+
305+
custom_call_op = mlir.custom_call(
306+
"SparseDenseMatmulGradOpWithOptimizerUpdate",
307+
result_types=[
308+
ir.TupleType.get_tuple([
309+
embedding_table.type,
310+
accumulator.type,
311+
momentum.type,
312+
])
313+
],
314+
operands=[
315+
lhs_row_pointers,
316+
lhs_local_embedding_ids,
317+
lhs_local_sample_ids,
318+
lhs_gains,
319+
activations_grad,
320+
table_tuple_op.result,
321+
hyperparams_tuple_op.result,
322+
],
323+
backend_config=backend_config,
324+
called_computations=[optimizer_update_computation_name],
325+
)
326+
327+
updated_table_op = _annotate_sparse_compute_type(
328+
hlo.GetTupleElementOp(custom_call_op, 0)
329+
)
330+
updated_accumulator_op = _annotate_sparse_compute_type(
331+
hlo.GetTupleElementOp(custom_call_op, 1)
332+
)
333+
updated_momentum_op = _annotate_sparse_compute_type(
334+
hlo.GetTupleElementOp(custom_call_op, 2)
335+
)
336+
337+
return (
338+
updated_table_op.results,
339+
updated_accumulator_op.results,
340+
updated_momentum_op.results,
341+
)
342+
343+
344+
mlir.register_lowering(
345+
tpu_sparse_dense_matmul_grad_with_adagrad_momentum_primitive,
346+
_tpu_sparse_dense_matmul_grad_with_adagrad_momentum_lowering,
347+
)

0 commit comments

Comments
 (0)