Skip to content

Commit d6565c8

Browse files
cantoniosGoogle-ML-Automation
authored andcommitted
Add the Adam optimizer from [Kingma et al., 2014](http://arxiv.org/abs/1412.6980).
Some specific design decisions were made that differ from Keras/Optax. - Keras ignores the step-dependent bias correction for epsilon (google-deepmind/optax#571), which differs from the original paper. We _do_ correct for the bias, consistent with optax/pytorch. - Keras/pytorch support `amsgrad: bool` as an option, which changes how the variable is updated, keeping track of the maximum velocity encountered. However, this would lead to an additional state parameter (`v_max`), and conditionally changes the number of slot variables. Slot variables are particularly expensive in large embedding lookups (each is the size of the entire sharded table), and would require a different underlying primitive anyways. If we need the option, we can create a new optimizer. This is consistent with optax, which has a separate `optax.amsgrad` optimizer. - Optax supports a `nesterov: bool` option. Similar to `amsgrad`, this modifies the update rule. Technically the Nesterov modification also adds a step-dependent `beta_1` parameter, and requires an additional state variable to keep track of the accumulated product - something Optax currently ignores. Keras handles this with a different optimizer, `keras.optimizer.Nadam`, which does add the additional state variable. PyTorch also has a separate `torch.optim.NAdam` specifically for this. PiperOrigin-RevId: 769287009
1 parent 15da518 commit d6565c8

File tree

7 files changed

+1114
-1
lines changed

7 files changed

+1114
-1
lines changed

jax_tpu_embedding/sparsecore/lib/core/primitives/BUILD

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,21 @@ pytype_strict_library(
120120
],
121121
)
122122

123+
pytype_strict_library(
124+
name = "sparse_dense_matmul_grad_with_adam",
125+
srcs = [
126+
"sparse_dense_matmul_grad_with_adam.py",
127+
],
128+
deps = [
129+
":utils",
130+
"//jax_tpu_embedding/sparsecore/lib/core:constants",
131+
pypi_requirement("jax"),
132+
pypi_requirement("jax/_src/lib"),
133+
pypi_requirement("jax/extend"),
134+
pypi_requirement("numpy"),
135+
],
136+
)
137+
123138
pytype_strict_library(
124139
name = "optimizers_computation",
125140
srcs = [
Lines changed: 354 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,354 @@
1+
# Copyright 2024 The JAX SC Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Adam optimizer for sparse dense matmul backward pass.
15+
16+
This implements the Jax primitive for the Adam optimizer for the sparse dense
17+
matmul backward pass, as a custom call to the
18+
SparseDenseMatmulGradOpWithOptimizerUpdate op. This op takes the preprocessed
19+
input tensors, embedding table, Adam hyperparameters
20+
(alpha_t, beta_1, beta_2, epsilon_hat), Adam states (momentum, velocity), and
21+
the grad
22+
as inputs and returns the updated embedding table and momentum, velocity values.
23+
"""
24+
25+
import functools
26+
import json
27+
from typing import Tuple
28+
29+
from jax._src.lib.mlir import ir
30+
from jax._src.lib.mlir.dialects import func as func_dialect
31+
from jax._src.lib.mlir.dialects import hlo
32+
import jax.extend as jex
33+
from jax.interpreters import mlir
34+
from jax.interpreters import xla
35+
from jax_tpu_embedding.sparsecore.lib.core import constants
36+
from jax_tpu_embedding.sparsecore.lib.core.primitives import utils
37+
import numpy as np
38+
39+
tpu_sparse_dense_matmul_grad_with_adam_primitive = jex.core.Primitive(
40+
"sparse_dense_matmul_grad_with_adam_primitive",
41+
)
42+
43+
tpu_sparse_dense_matmul_grad_with_adam_primitive.multiple_results = True
44+
45+
46+
tpu_sparse_dense_matmul_grad_with_adam_primitive.def_impl(
47+
functools.partial(
48+
xla.apply_primitive,
49+
tpu_sparse_dense_matmul_grad_with_adam_primitive,
50+
)
51+
)
52+
53+
54+
def _annotate_sparse_compute_type(op: ir.OpView):
55+
op.attributes["mhlo.frontend_attributes"] = ir.DictAttr.get(
56+
{"_xla_compute_type": ir.StringAttr.get("sparse")}
57+
)
58+
return op
59+
60+
61+
def _hlo_const(x: np.ndarray) -> ir.Value:
62+
return hlo.constant(
63+
ir.DenseElementsAttr.get(x, type=mlir.dtype_to_ir_type(x.dtype))
64+
)
65+
66+
67+
def _hlo_f32(x: float, emb_dim: int):
68+
return _hlo_const(
69+
np.array(emb_dim * [x], dtype=np.float32).reshape((1, emb_dim))
70+
)
71+
72+
73+
def _tpu_sparse_dense_matmul_grad_with_adam_abstract_eval(
74+
lhs_row_pointers: np.ndarray,
75+
lhs_local_embedding_ids: np.ndarray,
76+
lhs_local_sample_ids: np.ndarray,
77+
lhs_gains: np.ndarray,
78+
embedding_table: np.ndarray,
79+
velocity: np.ndarray,
80+
momentum: np.ndarray,
81+
activations_grad: np.ndarray,
82+
alpha_t: np.float32,
83+
beta_1: np.float32,
84+
beta_2: np.float32,
85+
epsilon_hat: np.float32,
86+
*_,
87+
max_ids_per_partition: int,
88+
max_unique_ids_per_partition: int,
89+
computation_name: str = "adam_optimizer_update",
90+
sharding_strategy: int = 1,
91+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
92+
"""Abstract eval for sparse_dense_matmul_adam."""
93+
94+
utils.validate_abstract_eval_params(
95+
lhs_row_pointers,
96+
lhs_local_embedding_ids,
97+
lhs_local_sample_ids,
98+
lhs_gains,
99+
embedding_table,
100+
activations_grad,
101+
max_ids_per_partition,
102+
max_unique_ids_per_partition,
103+
computation_name,
104+
sharding_strategy,
105+
)
106+
107+
utils.ensure_dtype(alpha_t, np.float32, "alpha_t")
108+
utils.ensure_dtype(beta_1, np.float32, "beta_1")
109+
utils.ensure_dtype(beta_2, np.float32, "beta_2")
110+
utils.ensure_dtype(epsilon_hat, np.float32, "epsilon_hat")
111+
utils.ensure_dtype(velocity, np.float32, "momentum")
112+
utils.ensure_dtype(momentum, np.float32, "velocity")
113+
114+
if embedding_table.shape != velocity.shape:
115+
raise ValueError(
116+
"embedding_table and velocity must have equal shapes, got"
117+
f" {embedding_table.shape} and {velocity.shape}"
118+
)
119+
elif embedding_table.shape != momentum.shape:
120+
raise ValueError(
121+
"embedding_table and momentum must have equal shapes, got"
122+
f" {embedding_table.shape} and {momentum.shape}"
123+
)
124+
125+
return embedding_table, velocity, momentum
126+
127+
128+
tpu_sparse_dense_matmul_grad_with_adam_primitive.def_abstract_eval(
129+
_tpu_sparse_dense_matmul_grad_with_adam_abstract_eval
130+
)
131+
132+
133+
def _tpu_sparse_dense_matmul_grad_with_adam_lowering(
134+
ctx: mlir.LoweringRuleContext,
135+
lhs_row_pointers: np.ndarray,
136+
lhs_local_embedding_ids: np.ndarray,
137+
lhs_local_sample_ids: np.ndarray,
138+
lhs_gains: np.ndarray,
139+
embedding_table: np.ndarray,
140+
momentum: np.ndarray,
141+
velocity: np.ndarray,
142+
activations_grad: np.ndarray,
143+
alpha_t: np.ndarray,
144+
beta_1: np.ndarray,
145+
beta_2: np.ndarray,
146+
epsilon_hat: np.ndarray,
147+
*_,
148+
max_ids_per_partition: int,
149+
max_unique_ids_per_partition: int,
150+
computation_name: str = "adam_optimizer_update",
151+
sharding_strategy: int = 1,
152+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
153+
"""Lowering for sparse_dense_matmul_grad_with_adam."""
154+
155+
sdmm_adam_config = {
156+
"max_ids_per_partition": max_ids_per_partition,
157+
"max_unique_ids_per_partition": max_unique_ids_per_partition,
158+
"pad_value": constants.PADDING_VALUE,
159+
"sharding_strategy": sharding_strategy,
160+
}
161+
backend_config = json.dumps({
162+
"sparse_dense_matmul_config": sdmm_adam_config,
163+
"device_type": "DEVICE_TYPE_SPARSECORE",
164+
})
165+
166+
optimizer_update_computation_name = computation_name
167+
168+
# Define the optimizer update function mlir.
169+
# The expected signature is:
170+
# func @adam_optimizer_update(
171+
# %arg0: tensor<1xNxf32>,
172+
# %arg1: tuple<tensor<1xNxf32>>,
173+
# %arg2: tuple<tensor<1xNxf32>>,
174+
# %arg3: tuple<tensor<1xNxf32>>,
175+
# %arg4: tuple<tensor<1xNxf32>>,
176+
# %arg5: tuple<tensor<1xNxf32>>,
177+
# %arg6: tuple<tensor<1xNxf32>>,
178+
# %arg7: tuple<tensor<1xNxf32>>,
179+
# )
180+
# -> tuple<tensor<1xNxf32>, tensor<1xNxf32>, tensor<1xNxf32>>
181+
# where N is the embedding dimension size.
182+
# The input arguments are:
183+
# %arg0: the gradient vector.
184+
# %arg1: the embedding tables before the update.
185+
# %arg2: the optimizer states (momentum).
186+
# %arg3: the optimizer states (velocity).
187+
# %arg4: the hyperparameters (alpha_t).
188+
# %arg5: the hyperparameters (beta_1).
189+
# %arg6: the hyperparameters (beta_2).
190+
# %arg7: the hyperparameters (epsilon_hat).
191+
# The output is a tuple containing the updated embedding tables and optimizer
192+
# states.
193+
194+
embedding_table_dim_size = embedding_table.type.get_dim_size(1)
195+
hlo_f32 = functools.partial(_hlo_f32, emb_dim=embedding_table_dim_size)
196+
197+
optimizer_update = func_dialect.FuncOp(
198+
optimizer_update_computation_name,
199+
(
200+
[
201+
ir.RankedTensorType.get( # grad
202+
[1, embedding_table.type.get_dim_size(1)],
203+
ir.F32Type.get(),
204+
),
205+
ir.RankedTensorType.get( # embedding_table
206+
[1, embedding_table.type.get_dim_size(1)],
207+
ir.F32Type.get(),
208+
),
209+
ir.RankedTensorType.get( # momentum
210+
[1, embedding_table.type.get_dim_size(1)],
211+
ir.F32Type.get(),
212+
),
213+
ir.RankedTensorType.get( # velocity
214+
[1, embedding_table.type.get_dim_size(1)],
215+
ir.F32Type.get(),
216+
),
217+
ir.RankedTensorType.get( # alpha_t
218+
[1, embedding_table.type.get_dim_size(1)],
219+
ir.F32Type.get(),
220+
),
221+
ir.RankedTensorType.get( # beta_1
222+
[1, embedding_table.type.get_dim_size(1)],
223+
ir.F32Type.get(),
224+
),
225+
ir.RankedTensorType.get( # beta_2
226+
[1, embedding_table.type.get_dim_size(1)],
227+
ir.F32Type.get(),
228+
),
229+
ir.RankedTensorType.get( # epsilon_hat
230+
[1, embedding_table.type.get_dim_size(1)],
231+
ir.F32Type.get(),
232+
),
233+
],
234+
[
235+
ir.TupleType.get_tuple([
236+
ir.RankedTensorType.get( # embedding_table
237+
[1, embedding_table_dim_size],
238+
ir.F32Type.get(),
239+
),
240+
ir.RankedTensorType.get( # momentum
241+
[1, embedding_table_dim_size],
242+
ir.F32Type.get(),
243+
),
244+
ir.RankedTensorType.get( # velocity
245+
[1, embedding_table_dim_size],
246+
ir.F32Type.get(),
247+
),
248+
]),
249+
],
250+
),
251+
ip=ctx.module_context.ip,
252+
visibility="private",
253+
)
254+
255+
# This is the row-wise implementation of the optimizer.
256+
entry_block = optimizer_update.add_entry_block()
257+
with ir.InsertionPoint(entry_block):
258+
# Get parameters.
259+
grad_ = entry_block.arguments[0]
260+
embedding_table_ = entry_block.arguments[1]
261+
momentum_ = entry_block.arguments[2]
262+
velocity_ = entry_block.arguments[3]
263+
alpha_t_ = entry_block.arguments[4]
264+
beta_1_ = entry_block.arguments[5]
265+
beta_2_ = entry_block.arguments[6]
266+
epsilon_hat_ = entry_block.arguments[7]
267+
268+
grad_square = hlo.multiply(
269+
grad_,
270+
grad_,
271+
)
272+
273+
# momentum: m = beta_1 * m + (1 - beta_1) * grad
274+
# = m + (1 - beta_1) * (grad - m)
275+
momentum_new = hlo.add(
276+
momentum_,
277+
hlo.multiply(
278+
hlo.subtract(hlo_f32(1.0), beta_1_),
279+
hlo.subtract(grad_, momentum_),
280+
),
281+
)
282+
283+
# velocity: v = beta_2 * v + (1 - beta_2) * grad^2
284+
# = v + (1 - beta_2) * (grad^2 - v)
285+
velocity_new = hlo.add(
286+
velocity_,
287+
hlo.multiply(
288+
hlo.subtract(hlo_f32(1.0), beta_2_),
289+
hlo.subtract(grad_square, velocity_),
290+
),
291+
)
292+
293+
# theta = theta - alpha_t * m / (sqrt(v) + epsilon_hat)
294+
update = hlo.divide(
295+
hlo.multiply(
296+
alpha_t_,
297+
momentum_new,
298+
),
299+
hlo.add(
300+
hlo.sqrt(velocity_new),
301+
epsilon_hat_,
302+
),
303+
)
304+
305+
theta = hlo.subtract(embedding_table_, update)
306+
307+
updated_tables = hlo.tuple([theta, momentum_new, velocity_new])
308+
309+
# return the updated embedding table, mu, nu
310+
func_dialect.ReturnOp([updated_tables])
311+
312+
table_tuple_op = hlo.TupleOp([embedding_table, momentum, velocity])
313+
table_tuple_op = _annotate_sparse_compute_type(table_tuple_op)
314+
hyperparams_tuple_op = hlo.TupleOp([alpha_t, beta_1, beta_2, epsilon_hat])
315+
hyperparams_tuple_op = _annotate_sparse_compute_type(hyperparams_tuple_op)
316+
317+
op = mlir.custom_call(
318+
"SparseDenseMatmulGradOpWithOptimizerUpdate",
319+
result_types=[
320+
ir.TupleType.get_tuple(
321+
[embedding_table.type, momentum.type, velocity.type]
322+
)
323+
],
324+
operands=[
325+
lhs_row_pointers,
326+
lhs_local_embedding_ids,
327+
lhs_local_sample_ids,
328+
lhs_gains,
329+
activations_grad,
330+
table_tuple_op.result,
331+
hyperparams_tuple_op.result,
332+
],
333+
backend_config=backend_config,
334+
called_computations=[optimizer_update_computation_name],
335+
)
336+
337+
table_tuple_op = hlo.GetTupleElementOp(op, 0)
338+
table_tuple_op = _annotate_sparse_compute_type(table_tuple_op)
339+
momentum_tuple_op = hlo.GetTupleElementOp(op, 1)
340+
momentum_tuple_op = _annotate_sparse_compute_type(momentum_tuple_op)
341+
velocity_tuple_op = hlo.GetTupleElementOp(op, 2)
342+
velocity_tuple_op = _annotate_sparse_compute_type(velocity_tuple_op)
343+
344+
return (
345+
table_tuple_op.results,
346+
momentum_tuple_op.results,
347+
velocity_tuple_op.results,
348+
)
349+
350+
351+
mlir.register_lowering(
352+
tpu_sparse_dense_matmul_grad_with_adam_primitive,
353+
_tpu_sparse_dense_matmul_grad_with_adam_lowering,
354+
)

0 commit comments

Comments
 (0)