|
| 1 | +# Copyright 2024 The JAX SC Authors. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | +"""Adam optimizer for sparse dense matmul backward pass. |
| 15 | +
|
| 16 | +This implements the Jax primitive for the Adam optimizer for the sparse dense |
| 17 | +matmul backward pass, as a custom call to the |
| 18 | +SparseDenseMatmulGradOpWithOptimizerUpdate op. This op takes the preprocessed |
| 19 | +input tensors, embedding table, Adam hyperparameters |
| 20 | +(alpha_t, beta_1, beta_2, epsilon_hat), Adam states (momentum, velocity), and |
| 21 | +the grad |
| 22 | +as inputs and returns the updated embedding table and momentum, velocity values. |
| 23 | +""" |
| 24 | + |
| 25 | +import functools |
| 26 | +import json |
| 27 | +from typing import Tuple |
| 28 | + |
| 29 | +from jax._src.lib.mlir import ir |
| 30 | +from jax._src.lib.mlir.dialects import func as func_dialect |
| 31 | +from jax._src.lib.mlir.dialects import hlo |
| 32 | +import jax.extend as jex |
| 33 | +from jax.interpreters import mlir |
| 34 | +from jax.interpreters import xla |
| 35 | +from jax_tpu_embedding.sparsecore.lib.core import constants |
| 36 | +from jax_tpu_embedding.sparsecore.lib.core.primitives import utils |
| 37 | +import numpy as np |
| 38 | + |
| 39 | +tpu_sparse_dense_matmul_grad_with_adam_primitive = jex.core.Primitive( |
| 40 | + "sparse_dense_matmul_grad_with_adam_primitive", |
| 41 | +) |
| 42 | + |
| 43 | +tpu_sparse_dense_matmul_grad_with_adam_primitive.multiple_results = True |
| 44 | + |
| 45 | + |
| 46 | +tpu_sparse_dense_matmul_grad_with_adam_primitive.def_impl( |
| 47 | + functools.partial( |
| 48 | + xla.apply_primitive, |
| 49 | + tpu_sparse_dense_matmul_grad_with_adam_primitive, |
| 50 | + ) |
| 51 | +) |
| 52 | + |
| 53 | + |
| 54 | +def _annotate_sparse_compute_type(op: ir.OpView): |
| 55 | + op.attributes["mhlo.frontend_attributes"] = ir.DictAttr.get( |
| 56 | + {"_xla_compute_type": ir.StringAttr.get("sparse")} |
| 57 | + ) |
| 58 | + return op |
| 59 | + |
| 60 | + |
| 61 | +def _hlo_const(x: np.ndarray) -> ir.Value: |
| 62 | + return hlo.constant( |
| 63 | + ir.DenseElementsAttr.get(x, type=mlir.dtype_to_ir_type(x.dtype)) |
| 64 | + ) |
| 65 | + |
| 66 | + |
| 67 | +def _hlo_f32(x: float, emb_dim: int): |
| 68 | + return _hlo_const( |
| 69 | + np.array(emb_dim * [x], dtype=np.float32).reshape((1, emb_dim)) |
| 70 | + ) |
| 71 | + |
| 72 | + |
| 73 | +def _tpu_sparse_dense_matmul_grad_with_adam_abstract_eval( |
| 74 | + lhs_row_pointers: np.ndarray, |
| 75 | + lhs_local_embedding_ids: np.ndarray, |
| 76 | + lhs_local_sample_ids: np.ndarray, |
| 77 | + lhs_gains: np.ndarray, |
| 78 | + embedding_table: np.ndarray, |
| 79 | + velocity: np.ndarray, |
| 80 | + momentum: np.ndarray, |
| 81 | + activations_grad: np.ndarray, |
| 82 | + alpha_t: np.float32, |
| 83 | + beta_1: np.float32, |
| 84 | + beta_2: np.float32, |
| 85 | + epsilon_hat: np.float32, |
| 86 | + *_, |
| 87 | + max_ids_per_partition: int, |
| 88 | + max_unique_ids_per_partition: int, |
| 89 | + computation_name: str = "adam_optimizer_update", |
| 90 | + sharding_strategy: int = 1, |
| 91 | +) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: |
| 92 | + """Abstract eval for sparse_dense_matmul_adam.""" |
| 93 | + |
| 94 | + utils.validate_abstract_eval_params( |
| 95 | + lhs_row_pointers, |
| 96 | + lhs_local_embedding_ids, |
| 97 | + lhs_local_sample_ids, |
| 98 | + lhs_gains, |
| 99 | + embedding_table, |
| 100 | + activations_grad, |
| 101 | + max_ids_per_partition, |
| 102 | + max_unique_ids_per_partition, |
| 103 | + computation_name, |
| 104 | + sharding_strategy, |
| 105 | + ) |
| 106 | + |
| 107 | + utils.ensure_dtype(alpha_t, np.float32, "alpha_t") |
| 108 | + utils.ensure_dtype(beta_1, np.float32, "beta_1") |
| 109 | + utils.ensure_dtype(beta_2, np.float32, "beta_2") |
| 110 | + utils.ensure_dtype(epsilon_hat, np.float32, "epsilon_hat") |
| 111 | + utils.ensure_dtype(velocity, np.float32, "momentum") |
| 112 | + utils.ensure_dtype(momentum, np.float32, "velocity") |
| 113 | + |
| 114 | + if embedding_table.shape != velocity.shape: |
| 115 | + raise ValueError( |
| 116 | + "embedding_table and velocity must have equal shapes, got" |
| 117 | + f" {embedding_table.shape} and {velocity.shape}" |
| 118 | + ) |
| 119 | + elif embedding_table.shape != momentum.shape: |
| 120 | + raise ValueError( |
| 121 | + "embedding_table and momentum must have equal shapes, got" |
| 122 | + f" {embedding_table.shape} and {momentum.shape}" |
| 123 | + ) |
| 124 | + |
| 125 | + return embedding_table, velocity, momentum |
| 126 | + |
| 127 | + |
| 128 | +tpu_sparse_dense_matmul_grad_with_adam_primitive.def_abstract_eval( |
| 129 | + _tpu_sparse_dense_matmul_grad_with_adam_abstract_eval |
| 130 | +) |
| 131 | + |
| 132 | + |
| 133 | +def _tpu_sparse_dense_matmul_grad_with_adam_lowering( |
| 134 | + ctx: mlir.LoweringRuleContext, |
| 135 | + lhs_row_pointers: np.ndarray, |
| 136 | + lhs_local_embedding_ids: np.ndarray, |
| 137 | + lhs_local_sample_ids: np.ndarray, |
| 138 | + lhs_gains: np.ndarray, |
| 139 | + embedding_table: np.ndarray, |
| 140 | + momentum: np.ndarray, |
| 141 | + velocity: np.ndarray, |
| 142 | + activations_grad: np.ndarray, |
| 143 | + alpha_t: np.ndarray, |
| 144 | + beta_1: np.ndarray, |
| 145 | + beta_2: np.ndarray, |
| 146 | + epsilon_hat: np.ndarray, |
| 147 | + *_, |
| 148 | + max_ids_per_partition: int, |
| 149 | + max_unique_ids_per_partition: int, |
| 150 | + computation_name: str = "adam_optimizer_update", |
| 151 | + sharding_strategy: int = 1, |
| 152 | +) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: |
| 153 | + """Lowering for sparse_dense_matmul_grad_with_adam.""" |
| 154 | + |
| 155 | + sdmm_adam_config = { |
| 156 | + "max_ids_per_partition": max_ids_per_partition, |
| 157 | + "max_unique_ids_per_partition": max_unique_ids_per_partition, |
| 158 | + "pad_value": constants.PADDING_VALUE, |
| 159 | + "sharding_strategy": sharding_strategy, |
| 160 | + } |
| 161 | + backend_config = json.dumps({ |
| 162 | + "sparse_dense_matmul_config": sdmm_adam_config, |
| 163 | + "device_type": "DEVICE_TYPE_SPARSECORE", |
| 164 | + }) |
| 165 | + |
| 166 | + optimizer_update_computation_name = computation_name |
| 167 | + |
| 168 | + # Define the optimizer update function mlir. |
| 169 | + # The expected signature is: |
| 170 | + # func @adam_optimizer_update( |
| 171 | + # %arg0: tensor<1xNxf32>, |
| 172 | + # %arg1: tuple<tensor<1xNxf32>>, |
| 173 | + # %arg2: tuple<tensor<1xNxf32>>, |
| 174 | + # %arg3: tuple<tensor<1xNxf32>>, |
| 175 | + # %arg4: tuple<tensor<1xNxf32>>, |
| 176 | + # %arg5: tuple<tensor<1xNxf32>>, |
| 177 | + # %arg6: tuple<tensor<1xNxf32>>, |
| 178 | + # %arg7: tuple<tensor<1xNxf32>>, |
| 179 | + # ) |
| 180 | + # -> tuple<tensor<1xNxf32>, tensor<1xNxf32>, tensor<1xNxf32>> |
| 181 | + # where N is the embedding dimension size. |
| 182 | + # The input arguments are: |
| 183 | + # %arg0: the gradient vector. |
| 184 | + # %arg1: the embedding tables before the update. |
| 185 | + # %arg2: the optimizer states (momentum). |
| 186 | + # %arg3: the optimizer states (velocity). |
| 187 | + # %arg4: the hyperparameters (alpha_t). |
| 188 | + # %arg5: the hyperparameters (beta_1). |
| 189 | + # %arg6: the hyperparameters (beta_2). |
| 190 | + # %arg7: the hyperparameters (epsilon_hat). |
| 191 | + # The output is a tuple containing the updated embedding tables and optimizer |
| 192 | + # states. |
| 193 | + |
| 194 | + embedding_table_dim_size = embedding_table.type.get_dim_size(1) |
| 195 | + hlo_f32 = functools.partial(_hlo_f32, emb_dim=embedding_table_dim_size) |
| 196 | + |
| 197 | + optimizer_update = func_dialect.FuncOp( |
| 198 | + optimizer_update_computation_name, |
| 199 | + ( |
| 200 | + [ |
| 201 | + ir.RankedTensorType.get( # grad |
| 202 | + [1, embedding_table.type.get_dim_size(1)], |
| 203 | + ir.F32Type.get(), |
| 204 | + ), |
| 205 | + ir.RankedTensorType.get( # embedding_table |
| 206 | + [1, embedding_table.type.get_dim_size(1)], |
| 207 | + ir.F32Type.get(), |
| 208 | + ), |
| 209 | + ir.RankedTensorType.get( # momentum |
| 210 | + [1, embedding_table.type.get_dim_size(1)], |
| 211 | + ir.F32Type.get(), |
| 212 | + ), |
| 213 | + ir.RankedTensorType.get( # velocity |
| 214 | + [1, embedding_table.type.get_dim_size(1)], |
| 215 | + ir.F32Type.get(), |
| 216 | + ), |
| 217 | + ir.RankedTensorType.get( # alpha_t |
| 218 | + [1, embedding_table.type.get_dim_size(1)], |
| 219 | + ir.F32Type.get(), |
| 220 | + ), |
| 221 | + ir.RankedTensorType.get( # beta_1 |
| 222 | + [1, embedding_table.type.get_dim_size(1)], |
| 223 | + ir.F32Type.get(), |
| 224 | + ), |
| 225 | + ir.RankedTensorType.get( # beta_2 |
| 226 | + [1, embedding_table.type.get_dim_size(1)], |
| 227 | + ir.F32Type.get(), |
| 228 | + ), |
| 229 | + ir.RankedTensorType.get( # epsilon_hat |
| 230 | + [1, embedding_table.type.get_dim_size(1)], |
| 231 | + ir.F32Type.get(), |
| 232 | + ), |
| 233 | + ], |
| 234 | + [ |
| 235 | + ir.TupleType.get_tuple([ |
| 236 | + ir.RankedTensorType.get( # embedding_table |
| 237 | + [1, embedding_table_dim_size], |
| 238 | + ir.F32Type.get(), |
| 239 | + ), |
| 240 | + ir.RankedTensorType.get( # momentum |
| 241 | + [1, embedding_table_dim_size], |
| 242 | + ir.F32Type.get(), |
| 243 | + ), |
| 244 | + ir.RankedTensorType.get( # velocity |
| 245 | + [1, embedding_table_dim_size], |
| 246 | + ir.F32Type.get(), |
| 247 | + ), |
| 248 | + ]), |
| 249 | + ], |
| 250 | + ), |
| 251 | + ip=ctx.module_context.ip, |
| 252 | + visibility="private", |
| 253 | + ) |
| 254 | + |
| 255 | + # This is the row-wise implementation of the optimizer. |
| 256 | + entry_block = optimizer_update.add_entry_block() |
| 257 | + with ir.InsertionPoint(entry_block): |
| 258 | + # Get parameters. |
| 259 | + grad_ = entry_block.arguments[0] |
| 260 | + embedding_table_ = entry_block.arguments[1] |
| 261 | + momentum_ = entry_block.arguments[2] |
| 262 | + velocity_ = entry_block.arguments[3] |
| 263 | + alpha_t_ = entry_block.arguments[4] |
| 264 | + beta_1_ = entry_block.arguments[5] |
| 265 | + beta_2_ = entry_block.arguments[6] |
| 266 | + epsilon_hat_ = entry_block.arguments[7] |
| 267 | + |
| 268 | + grad_square = hlo.multiply( |
| 269 | + grad_, |
| 270 | + grad_, |
| 271 | + ) |
| 272 | + |
| 273 | + # momentum: m = beta_1 * m + (1 - beta_1) * grad |
| 274 | + # = m + (1 - beta_1) * (grad - m) |
| 275 | + momentum_new = hlo.add( |
| 276 | + momentum_, |
| 277 | + hlo.multiply( |
| 278 | + hlo.subtract(hlo_f32(1.0), beta_1_), |
| 279 | + hlo.subtract(grad_, momentum_), |
| 280 | + ), |
| 281 | + ) |
| 282 | + |
| 283 | + # velocity: v = beta_2 * v + (1 - beta_2) * grad^2 |
| 284 | + # = v + (1 - beta_2) * (grad^2 - v) |
| 285 | + velocity_new = hlo.add( |
| 286 | + velocity_, |
| 287 | + hlo.multiply( |
| 288 | + hlo.subtract(hlo_f32(1.0), beta_2_), |
| 289 | + hlo.subtract(grad_square, velocity_), |
| 290 | + ), |
| 291 | + ) |
| 292 | + |
| 293 | + # theta = theta - alpha_t * m / (sqrt(v) + epsilon_hat) |
| 294 | + update = hlo.divide( |
| 295 | + hlo.multiply( |
| 296 | + alpha_t_, |
| 297 | + momentum_new, |
| 298 | + ), |
| 299 | + hlo.add( |
| 300 | + hlo.sqrt(velocity_new), |
| 301 | + epsilon_hat_, |
| 302 | + ), |
| 303 | + ) |
| 304 | + |
| 305 | + theta = hlo.subtract(embedding_table_, update) |
| 306 | + |
| 307 | + updated_tables = hlo.tuple([theta, momentum_new, velocity_new]) |
| 308 | + |
| 309 | + # return the updated embedding table, mu, nu |
| 310 | + func_dialect.ReturnOp([updated_tables]) |
| 311 | + |
| 312 | + table_tuple_op = hlo.TupleOp([embedding_table, momentum, velocity]) |
| 313 | + table_tuple_op = _annotate_sparse_compute_type(table_tuple_op) |
| 314 | + hyperparams_tuple_op = hlo.TupleOp([alpha_t, beta_1, beta_2, epsilon_hat]) |
| 315 | + hyperparams_tuple_op = _annotate_sparse_compute_type(hyperparams_tuple_op) |
| 316 | + |
| 317 | + op = mlir.custom_call( |
| 318 | + "SparseDenseMatmulGradOpWithOptimizerUpdate", |
| 319 | + result_types=[ |
| 320 | + ir.TupleType.get_tuple( |
| 321 | + [embedding_table.type, momentum.type, velocity.type] |
| 322 | + ) |
| 323 | + ], |
| 324 | + operands=[ |
| 325 | + lhs_row_pointers, |
| 326 | + lhs_local_embedding_ids, |
| 327 | + lhs_local_sample_ids, |
| 328 | + lhs_gains, |
| 329 | + activations_grad, |
| 330 | + table_tuple_op.result, |
| 331 | + hyperparams_tuple_op.result, |
| 332 | + ], |
| 333 | + backend_config=backend_config, |
| 334 | + called_computations=[optimizer_update_computation_name], |
| 335 | + ) |
| 336 | + |
| 337 | + table_tuple_op = hlo.GetTupleElementOp(op, 0) |
| 338 | + table_tuple_op = _annotate_sparse_compute_type(table_tuple_op) |
| 339 | + momentum_tuple_op = hlo.GetTupleElementOp(op, 1) |
| 340 | + momentum_tuple_op = _annotate_sparse_compute_type(momentum_tuple_op) |
| 341 | + velocity_tuple_op = hlo.GetTupleElementOp(op, 2) |
| 342 | + velocity_tuple_op = _annotate_sparse_compute_type(velocity_tuple_op) |
| 343 | + |
| 344 | + return ( |
| 345 | + table_tuple_op.results, |
| 346 | + momentum_tuple_op.results, |
| 347 | + velocity_tuple_op.results, |
| 348 | + ) |
| 349 | + |
| 350 | + |
| 351 | +mlir.register_lowering( |
| 352 | + tpu_sparse_dense_matmul_grad_with_adam_primitive, |
| 353 | + _tpu_sparse_dense_matmul_grad_with_adam_lowering, |
| 354 | +) |
0 commit comments