|
| 1 | +# Copyright 2024 The JAX SC Authors. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | +"""Adagrad with momentum optimizer for sparse dense matmul backward pass. |
| 15 | +
|
| 16 | +This implements the Jax primitive for the Adagrad with momentum optimizer for |
| 17 | +the |
| 18 | +sparse dense matmul backward pass, as a custom call to the |
| 19 | +SparseDenseMatmulGradOpWithOptimizerUpdate op. This op takes the preprocessed |
| 20 | +input tensors, embedding table, accumulator, momentum buffer, the grad, the |
| 21 | +learning rate and momentum hyperparameter as inputs and returns the updated |
| 22 | +embedding table, accumulator, and momentum buffer. |
| 23 | +""" |
| 24 | + |
| 25 | +import functools |
| 26 | +import json |
| 27 | +from typing import Tuple |
| 28 | + |
| 29 | +from jax._src.lib.mlir import ir |
| 30 | +from jax._src.lib.mlir.dialects import func as func_dialect |
| 31 | +from jax._src.lib.mlir.dialects import hlo |
| 32 | +import jax.extend as jex |
| 33 | +from jax.interpreters import mlir |
| 34 | +from jax.interpreters import xla |
| 35 | +from jax_tpu_embedding.sparsecore.lib.core import constants |
| 36 | +from jax_tpu_embedding.sparsecore.lib.core.primitives import utils |
| 37 | +import numpy as np |
| 38 | + |
| 39 | +tpu_sparse_dense_matmul_grad_with_adagrad_momentum_primitive = ( |
| 40 | + jex.core.Primitive("sparse_dense_matmul_grad_with_adagrad_momentum") |
| 41 | +) |
| 42 | +tpu_sparse_dense_matmul_grad_with_adagrad_momentum_primitive.multiple_results = ( |
| 43 | + True |
| 44 | +) |
| 45 | + |
| 46 | +tpu_sparse_dense_matmul_grad_with_adagrad_momentum_primitive.def_impl( |
| 47 | + functools.partial( |
| 48 | + xla.apply_primitive, |
| 49 | + tpu_sparse_dense_matmul_grad_with_adagrad_momentum_primitive, |
| 50 | + ) |
| 51 | +) |
| 52 | + |
| 53 | + |
| 54 | +def _annotate_sparse_compute_type(op: ir.OpView): |
| 55 | + op.attributes["mhlo.frontend_attributes"] = ir.DictAttr.get( |
| 56 | + {"_xla_compute_type": ir.StringAttr.get("sparse")} |
| 57 | + ) |
| 58 | + return op |
| 59 | + |
| 60 | + |
| 61 | +def _hlo_const(arr: np.ndarray) -> ir.Value: |
| 62 | + """Return an HLO constant from a NumPy array (any rank).""" |
| 63 | + return hlo.constant( |
| 64 | + ir.DenseElementsAttr.get(arr, type=mlir.dtype_to_ir_type(arr.dtype)) |
| 65 | + ) |
| 66 | + |
| 67 | + |
| 68 | +def _hlo_f32(x: float, emb_dim: int) -> ir.Value: |
| 69 | + """Return a <1 x emb_dim> f32 constant filled with x.""" |
| 70 | + return _hlo_const(np.full((1, emb_dim), x, dtype=np.float32)) |
| 71 | + |
| 72 | + |
| 73 | +def _tpu_sparse_dense_matmul_grad_with_adagrad_momentum_abstract_eval( |
| 74 | + lhs_row_pointers: np.ndarray, |
| 75 | + lhs_local_embedding_ids: np.ndarray, |
| 76 | + lhs_local_sample_ids: np.ndarray, |
| 77 | + lhs_gains: np.ndarray, |
| 78 | + embedding_table: np.ndarray, |
| 79 | + accumulator: np.ndarray, |
| 80 | + momentum: np.ndarray, |
| 81 | + activations_grad: np.ndarray, |
| 82 | + learning_rate: np.float32, |
| 83 | + momentum_param: np.float32, |
| 84 | + beta2: np.float32, |
| 85 | + epsilon: np.float32, |
| 86 | + exponent: np.float32, |
| 87 | + use_nesterov: np.bool = np.bool(False), |
| 88 | + *_, |
| 89 | + max_ids_per_partition: int, |
| 90 | + max_unique_ids_per_partition: int, |
| 91 | + computation_name: str = "adagrad_momentum_optimizer_update", |
| 92 | + sharding_strategy: int = 1, |
| 93 | +) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: |
| 94 | + """Abstract eval for sparse_dense_matmul_adagrad_momentum.""" |
| 95 | + utils.validate_abstract_eval_params( |
| 96 | + lhs_row_pointers, |
| 97 | + lhs_local_embedding_ids, |
| 98 | + lhs_local_sample_ids, |
| 99 | + lhs_gains, |
| 100 | + embedding_table, |
| 101 | + activations_grad, |
| 102 | + max_ids_per_partition, |
| 103 | + max_unique_ids_per_partition, |
| 104 | + computation_name, |
| 105 | + sharding_strategy, |
| 106 | + ) |
| 107 | + utils.ensure_dtype(accumulator, np.float32, "accumulator") |
| 108 | + utils.ensure_dtype(momentum, np.float32, "momentum") |
| 109 | + utils.ensure_dtype(learning_rate, np.float32, "learning_rate") |
| 110 | + utils.ensure_dtype(momentum_param, np.float32, "momentum_param") |
| 111 | + utils.ensure_dtype(beta2, np.float32, "beta2") |
| 112 | + utils.ensure_dtype(epsilon, np.float32, "epsilon") |
| 113 | + utils.ensure_dtype(exponent, np.float32, "exponent") |
| 114 | + utils.ensure_dtype(use_nesterov, np.bool, "use_nesterov") |
| 115 | + |
| 116 | + if ( |
| 117 | + embedding_table.shape != accumulator.shape |
| 118 | + or embedding_table.shape != momentum.shape |
| 119 | + ): |
| 120 | + raise ValueError( |
| 121 | + "embedding_table, accumulator and momentum must have identical shapes: " |
| 122 | + f"got {embedding_table.shape}, {accumulator.shape}, {momentum.shape}" |
| 123 | + ) |
| 124 | + return embedding_table, accumulator, momentum |
| 125 | + |
| 126 | + |
| 127 | +tpu_sparse_dense_matmul_grad_with_adagrad_momentum_primitive.def_abstract_eval( |
| 128 | + _tpu_sparse_dense_matmul_grad_with_adagrad_momentum_abstract_eval |
| 129 | +) |
| 130 | + |
| 131 | + |
| 132 | +def _tpu_sparse_dense_matmul_grad_with_adagrad_momentum_lowering( |
| 133 | + ctx: mlir.LoweringRuleContext, |
| 134 | + lhs_row_pointers: np.ndarray, |
| 135 | + lhs_local_embedding_ids: np.ndarray, |
| 136 | + lhs_local_sample_ids: np.ndarray, |
| 137 | + lhs_gains: np.ndarray, |
| 138 | + embedding_table: np.ndarray, |
| 139 | + accumulator: np.ndarray, |
| 140 | + momentum: np.ndarray, |
| 141 | + activations_grad: np.ndarray, |
| 142 | + learning_rate: np.ndarray, |
| 143 | + momentum_param: np.ndarray, |
| 144 | + beta2: np.ndarray, |
| 145 | + epsilon: np.ndarray, |
| 146 | + exponent: np.ndarray, |
| 147 | + use_nesterov: np.ndarray, |
| 148 | + *, |
| 149 | + max_ids_per_partition: int, |
| 150 | + max_unique_ids_per_partition: int, |
| 151 | + computation_name: str = "adagrad_momentum_optimizer_update", |
| 152 | + sharding_strategy: int = 1, |
| 153 | +) -> Tuple[ir.Value, ir.Value, ir.Value]: |
| 154 | + """Lowering for sparse_dense_matmul_grad_with_adagrad_momentum.""" |
| 155 | + sdmm_config = { |
| 156 | + "max_ids_per_partition": max_ids_per_partition, |
| 157 | + "max_unique_ids_per_partition": max_unique_ids_per_partition, |
| 158 | + "pad_value": constants.PADDING_VALUE, |
| 159 | + "sharding_strategy": sharding_strategy, |
| 160 | + } |
| 161 | + backend_config = json.dumps({ |
| 162 | + "sparse_dense_matmul_config": sdmm_config, |
| 163 | + "device_type": "DEVICE_TYPE_SPARSECORE", |
| 164 | + }) |
| 165 | + |
| 166 | + optimizer_update_computation_name = computation_name |
| 167 | + |
| 168 | + emb_dim_size = ir.ShapedType(embedding_table.type).get_dim_size(1) |
| 169 | + optimizer_update = func_dialect.FuncOp( |
| 170 | + computation_name, |
| 171 | + ( |
| 172 | + [ |
| 173 | + ir.RankedTensorType.get( |
| 174 | + [1, emb_dim_size], |
| 175 | + ir.F32Type.get(), |
| 176 | + ), |
| 177 | + ir.RankedTensorType.get( |
| 178 | + [1, emb_dim_size], |
| 179 | + ir.F32Type.get(), |
| 180 | + ), |
| 181 | + ir.RankedTensorType.get( |
| 182 | + [1, emb_dim_size], |
| 183 | + ir.F32Type.get(), |
| 184 | + ), |
| 185 | + ir.RankedTensorType.get( |
| 186 | + [1, emb_dim_size], |
| 187 | + ir.F32Type.get(), |
| 188 | + ), |
| 189 | + ir.RankedTensorType.get( |
| 190 | + [1, emb_dim_size], |
| 191 | + ir.F32Type.get(), |
| 192 | + ), |
| 193 | + ir.RankedTensorType.get( |
| 194 | + [1, emb_dim_size], |
| 195 | + ir.F32Type.get(), |
| 196 | + ), |
| 197 | + ir.RankedTensorType.get( |
| 198 | + [1, emb_dim_size], |
| 199 | + ir.F32Type.get(), |
| 200 | + ), |
| 201 | + ir.RankedTensorType.get( |
| 202 | + [1, emb_dim_size], |
| 203 | + ir.F32Type.get(), |
| 204 | + ), |
| 205 | + ir.RankedTensorType.get( |
| 206 | + [1, emb_dim_size], |
| 207 | + ir.F32Type.get(), |
| 208 | + ), |
| 209 | + ir.RankedTensorType.get( |
| 210 | + [1, emb_dim_size], ir.IntegerType.get_signless(1) |
| 211 | + ), |
| 212 | + ], |
| 213 | + [ |
| 214 | + ir.TupleType.get_tuple([ |
| 215 | + ir.RankedTensorType.get( |
| 216 | + [1, emb_dim_size], |
| 217 | + ir.F32Type.get(), |
| 218 | + ), |
| 219 | + ir.RankedTensorType.get( |
| 220 | + [1, emb_dim_size], |
| 221 | + ir.F32Type.get(), |
| 222 | + ), |
| 223 | + ir.RankedTensorType.get( |
| 224 | + [1, emb_dim_size], |
| 225 | + ir.F32Type.get(), |
| 226 | + ), |
| 227 | + ]), |
| 228 | + ], |
| 229 | + ), |
| 230 | + ip=ctx.module_context.ip, |
| 231 | + visibility="private", |
| 232 | + ) |
| 233 | + |
| 234 | + entry_block = optimizer_update.add_entry_block() |
| 235 | + with ir.InsertionPoint(entry_block): |
| 236 | + ( |
| 237 | + grad_, # g |
| 238 | + embedding_table_, # Ē_o |
| 239 | + accumulator_, # Ā_o |
| 240 | + momentum_, # L_o |
| 241 | + lr_param_, # λ |
| 242 | + momentum_param_, # k |
| 243 | + beta2_, # βZ |
| 244 | + epsilon_param_, # ε |
| 245 | + exponent_param_, |
| 246 | + use_nesterov_flag_, |
| 247 | + ) = entry_block.arguments |
| 248 | + |
| 249 | + one_ = _hlo_f32(1.0, emb_dim_size) |
| 250 | + neg_one_ = _hlo_f32(-1.0, emb_dim_size) |
| 251 | + |
| 252 | + # Accumulator |
| 253 | + grad_sq_ = hlo.multiply(grad_, grad_) |
| 254 | + beta2_eq_1_ = hlo.compare( |
| 255 | + beta2_, |
| 256 | + one_, |
| 257 | + comparison_direction=hlo.ComparisonDirectionAttr.get("EQ"), |
| 258 | + compare_type=hlo.ComparisonTypeAttr.get("FLOAT"), |
| 259 | + ) |
| 260 | + accum_plus_ = hlo.add(accumulator_, grad_sq_) |
| 261 | + one_minus_beta2_ = hlo.subtract(one_, beta2_) |
| 262 | + scaled_accum_ = hlo.add( |
| 263 | + hlo.multiply(beta2_, accumulator_), |
| 264 | + hlo.multiply(one_minus_beta2_, grad_sq_), |
| 265 | + ) |
| 266 | + accum_new_ = hlo.select(beta2_eq_1_, accum_plus_, scaled_accum_) |
| 267 | + |
| 268 | + # Scaled gradient |
| 269 | + neg_inv_exp_ = hlo.divide(neg_one_, exponent_param_) |
| 270 | + accum_eps_ = hlo.add(accum_new_, epsilon_param_) |
| 271 | + p_new_ = hlo.power(accum_eps_, neg_inv_exp_) |
| 272 | + scaled_grad_ = hlo.multiply(p_new_, grad_) |
| 273 | + |
| 274 | + # Momentum |
| 275 | + m_new_ = hlo.add(hlo.multiply(momentum_param_, momentum_), scaled_grad_) |
| 276 | + |
| 277 | + # Delta E |
| 278 | + nesterov_update_ = hlo.add( |
| 279 | + hlo.multiply(momentum_param_, m_new_), scaled_grad_ |
| 280 | + ) |
| 281 | + update_ = hlo.select(use_nesterov_flag_, nesterov_update_, m_new_) |
| 282 | + lr_update_ = hlo.multiply(lr_param_, update_) |
| 283 | + |
| 284 | + # Weight update |
| 285 | + w_new_ = hlo.subtract(embedding_table_, lr_update_) |
| 286 | + |
| 287 | + out_tuple = hlo.tuple([w_new_, accum_new_, m_new_]) |
| 288 | + func_dialect.ReturnOp([out_tuple]) |
| 289 | + |
| 290 | + table_tuple_op = _annotate_sparse_compute_type( |
| 291 | + hlo.TupleOp([embedding_table, accumulator, momentum]) |
| 292 | + ) |
| 293 | + |
| 294 | + hyperparams_tuple_op = _annotate_sparse_compute_type( |
| 295 | + hlo.TupleOp([ |
| 296 | + learning_rate, |
| 297 | + momentum_param, |
| 298 | + beta2, |
| 299 | + epsilon, |
| 300 | + exponent, |
| 301 | + use_nesterov, |
| 302 | + ]) |
| 303 | + ) |
| 304 | + |
| 305 | + custom_call_op = mlir.custom_call( |
| 306 | + "SparseDenseMatmulGradOpWithOptimizerUpdate", |
| 307 | + result_types=[ |
| 308 | + ir.TupleType.get_tuple([ |
| 309 | + embedding_table.type, |
| 310 | + accumulator.type, |
| 311 | + momentum.type, |
| 312 | + ]) |
| 313 | + ], |
| 314 | + operands=[ |
| 315 | + lhs_row_pointers, |
| 316 | + lhs_local_embedding_ids, |
| 317 | + lhs_local_sample_ids, |
| 318 | + lhs_gains, |
| 319 | + activations_grad, |
| 320 | + table_tuple_op.result, |
| 321 | + hyperparams_tuple_op.result, |
| 322 | + ], |
| 323 | + backend_config=backend_config, |
| 324 | + called_computations=[optimizer_update_computation_name], |
| 325 | + ) |
| 326 | + |
| 327 | + updated_table_op = _annotate_sparse_compute_type( |
| 328 | + hlo.GetTupleElementOp(custom_call_op, 0) |
| 329 | + ) |
| 330 | + updated_accumulator_op = _annotate_sparse_compute_type( |
| 331 | + hlo.GetTupleElementOp(custom_call_op, 1) |
| 332 | + ) |
| 333 | + updated_momentum_op = _annotate_sparse_compute_type( |
| 334 | + hlo.GetTupleElementOp(custom_call_op, 2) |
| 335 | + ) |
| 336 | + |
| 337 | + return ( |
| 338 | + updated_table_op.results, |
| 339 | + updated_accumulator_op.results, |
| 340 | + updated_momentum_op.results, |
| 341 | + ) |
| 342 | + |
| 343 | + |
| 344 | +mlir.register_lowering( |
| 345 | + tpu_sparse_dense_matmul_grad_with_adagrad_momentum_primitive, |
| 346 | + _tpu_sparse_dense_matmul_grad_with_adagrad_momentum_lowering, |
| 347 | +) |
0 commit comments