|
2 | 2 |
|
3 | 3 | """Defines trainer config modifiers, which will be used in model definitions."""
|
4 | 4 |
|
5 |
| -from typing import Dict, Sequence, Union |
| 5 | +from typing import Callable, Dict, NamedTuple, Sequence, Union |
6 | 6 |
|
7 | 7 | from axlearn.common import config
|
8 | 8 | from axlearn.common.base_layer import RematSpec
|
9 | 9 | from axlearn.common.config import (
|
10 | 10 | REQUIRED,
|
| 11 | + ConfigBase, |
11 | 12 | ConfigModifier,
|
12 | 13 | ConfigOr,
|
13 | 14 | Required,
|
|
17 | 18 | from axlearn.common.gradient_accumulation import with_minibatch_steps
|
18 | 19 | from axlearn.common.metrics import MetricAccumulator
|
19 | 20 | from axlearn.common.trainer import SpmdTrainer
|
20 |
| -from axlearn.common.utils import HybridMeshShape, MeshShape |
| 21 | +from axlearn.common.utils import HybridMeshShape, MeshShape, PartitionSpec |
| 22 | + |
| 23 | + |
| 24 | +class _FoundModule(NamedTuple): |
| 25 | + """Module found in recursive search of matching module.""" |
| 26 | + |
| 27 | + # The module found |
| 28 | + module: ConfigModifier.Config |
| 29 | + # The parent of the module found |
| 30 | + parent_module: ConfigModifier.Config |
| 31 | + # Key of the found module in parent |
| 32 | + key_in_parent: ConfigModifier.Config |
| 33 | + |
| 34 | + |
| 35 | +def _find_target_module(module_name: str, cfg: SpmdTrainer.Config) -> _FoundModule: |
| 36 | + """Recursively search for the target module matching module_name in provided cfg. |
| 37 | +
|
| 38 | + Args: |
| 39 | + module_name: Name of the target module |
| 40 | + cfg: The trainer config to be searched for module_name |
| 41 | +
|
| 42 | + Raises: |
| 43 | + ValueError: The module_name is not found. |
| 44 | +
|
| 45 | + Returns: |
| 46 | + A Tuple(curr_module, key_in_parent, parent_module) |
| 47 | + curr_module: Module found |
| 48 | + parent_module: The parent module |
| 49 | + key_in_parent: Key in parent for the found module |
| 50 | + """ |
| 51 | + |
| 52 | + # Here we assume x.y.z format. |
| 53 | + # One example would be model.decoder.transformer.layer. |
| 54 | + target_modules = module_name.split(".") |
| 55 | + curr_module = cfg |
| 56 | + key_in_parent = None |
| 57 | + parent_module = None |
| 58 | + |
| 59 | + for target_module in target_modules: |
| 60 | + if not hasattr(curr_module, target_module): |
| 61 | + raise ValueError(f"{target_module} is not found in {curr_module}.") |
| 62 | + parent_module = curr_module |
| 63 | + key_in_parent = target_module |
| 64 | + curr_module = getattr(curr_module, target_module) |
| 65 | + return _FoundModule( |
| 66 | + module=curr_module, parent_module=parent_module, key_in_parent=key_in_parent |
| 67 | + ) |
21 | 68 |
|
22 | 69 |
|
23 | 70 | class GradientAccumulationModifier(ConfigModifier):
|
@@ -100,18 +147,11 @@ def __call__(self, cfg: SpmdTrainer.Config) -> SpmdTrainer.Config:
|
100 | 147 | """
|
101 | 148 |
|
102 | 149 | for module_name, remat_spec in self._remat_policies.items():
|
103 |
| - # Here we assume x.y.z format. |
104 |
| - # One example would be model.decoder.transformer.layer. |
105 |
| - target_modules = module_name.split(".") |
106 |
| - curr_module = cfg |
107 |
| - for target_module in target_modules: |
108 |
| - if not hasattr(curr_module, target_module): |
109 |
| - raise ValueError(f"{target_module} is not found in {curr_module}.") |
110 |
| - curr_module = getattr(curr_module, target_module) |
| 150 | + found_module = _find_target_module(module_name, cfg) |
111 | 151 | # Here we assume all modules have remat_spec attribute.
|
112 |
| - if not hasattr(curr_module, "remat_spec"): |
113 |
| - raise ValueError(f"{curr_module} does not have remat_spec attribute") |
114 |
| - curr_module.remat_spec = remat_spec |
| 152 | + if not hasattr(found_module.module, "remat_spec"): |
| 153 | + raise ValueError(f"{found_module.module} does not have remat_spec attribute") |
| 154 | + found_module.module.remat_spec = remat_spec |
115 | 155 | return cfg
|
116 | 156 |
|
117 | 157 |
|
@@ -146,6 +186,113 @@ def __call__(self, cfg: SpmdTrainer.Config) -> SpmdTrainer.Config:
|
146 | 186 | return cfg
|
147 | 187 |
|
148 | 188 |
|
| 189 | +class ModelConfigModifier(ConfigModifier): |
| 190 | + """Update the model config for the trainer config.""" |
| 191 | + |
| 192 | + @config_class |
| 193 | + class Config(ConfigModifier.Config): |
| 194 | + """Configure ModelConfigModifier. |
| 195 | +
|
| 196 | + Attributes: |
| 197 | + model_cfg_modifications: A mapping from module path |
| 198 | + (e.g. `model.decoder.transformer.layer`) to a Config. |
| 199 | + """ |
| 200 | + |
| 201 | + model_cfg_modifications: Dict[str, Callable[[ConfigBase], ConfigBase]] = {} |
| 202 | + |
| 203 | + def __init__(self, cfg: Config): |
| 204 | + super().__init__(cfg) |
| 205 | + self._model_cfg_modifications = self.config.model_cfg_modifications |
| 206 | + |
| 207 | + def _merge_configs(self, target_cfg: ConfigBase, found_module: ConfigBase) -> ConfigBase: |
| 208 | + """Merge configurations from the config being replaced on a best effort basis. |
| 209 | +
|
| 210 | + Args: |
| 211 | + target_cfg: configuration that will replace found_module. |
| 212 | + found_module: existing configuration whose class will be replaced |
| 213 | + but it's confguration will be merged with target_cfg. |
| 214 | +
|
| 215 | + Returns: |
| 216 | + The modified config. |
| 217 | +
|
| 218 | + """ |
| 219 | + for key in target_cfg.keys(): |
| 220 | + if key == "klass": |
| 221 | + continue |
| 222 | + elif hasattr(found_module.module, key) and hasattr(target_cfg, key): |
| 223 | + setattr(target_cfg, key, getattr(found_module.module, key)) |
| 224 | + return target_cfg |
| 225 | + |
| 226 | + def __call__(self, cfg: SpmdTrainer.Config) -> SpmdTrainer.Config: |
| 227 | + """Overwrite the model config of the specified modules. |
| 228 | +
|
| 229 | + Args: |
| 230 | + cfg: The trainer config to be modified. |
| 231 | +
|
| 232 | + Raises: |
| 233 | + ValueError: The target module is not found. |
| 234 | +
|
| 235 | + Returns: |
| 236 | + The modified trainer config. |
| 237 | + """ |
| 238 | + |
| 239 | + # Iterate over modules in the mapping, modules are sorted based on module name length. |
| 240 | + # This ensures parent is modified before children to avoid missing modifications. |
| 241 | + for module_name, model_cfg in sorted( |
| 242 | + self._model_cfg_modifications.items(), key=lambda item: len(item[0]) |
| 243 | + ): |
| 244 | + found_module = _find_target_module(module_name, cfg) |
| 245 | + |
| 246 | + model_cfg = self._merge_configs(model_cfg, found_module) |
| 247 | + # Replace in the parent config |
| 248 | + setattr(found_module.parent_module, found_module.key_in_parent, model_cfg) |
| 249 | + return cfg |
| 250 | + |
| 251 | + |
| 252 | +class PartitionSpecModifier(ConfigModifier): |
| 253 | + """Update the partition spec attribute for the specified modules.""" |
| 254 | + |
| 255 | + @config_class |
| 256 | + class Config(ConfigModifier.Config): |
| 257 | + """Configure PartitionSpecModifier. |
| 258 | +
|
| 259 | + Attributes: |
| 260 | + partition_specs: A nested mapping from module path |
| 261 | + (e.g. `model.decoder.transformer.layer`) to another |
| 262 | + mapping of model attribute to PartitionSpec. |
| 263 | + """ |
| 264 | + |
| 265 | + partition_specs: Required[Dict[str, PartitionSpec]] = REQUIRED |
| 266 | + |
| 267 | + def __init__(self, cfg: Config): |
| 268 | + super().__init__(cfg) |
| 269 | + self._attribute_dicts = self.config.partition_specs |
| 270 | + |
| 271 | + def __call__(self, cfg: SpmdTrainer.Config) -> SpmdTrainer.Config: |
| 272 | + """Update the partition_spec attributes for the specified modules. |
| 273 | +
|
| 274 | + Args: |
| 275 | + cfg: The trainer config to be modified. |
| 276 | +
|
| 277 | + Raises: |
| 278 | + ValueError: The target module is not found. |
| 279 | + ValueError: The partition_spec attribute is not found. |
| 280 | +
|
| 281 | + Returns: |
| 282 | + The modified trainer config. |
| 283 | + """ |
| 284 | + for module_name, partition_spec_dict in self._attribute_dicts.items(): |
| 285 | + found_module = _find_target_module(module_name, cfg) |
| 286 | + for partition_spec_name, partition_spec in partition_spec_dict.items(): |
| 287 | + if not hasattr(found_module.module, partition_spec_name): |
| 288 | + raise ValueError( |
| 289 | + f"{found_module.module} does not have {partition_spec_name} attribute" |
| 290 | + ) |
| 291 | + setattr(found_module.module, partition_spec_name, partition_spec) |
| 292 | + |
| 293 | + return cfg |
| 294 | + |
| 295 | + |
149 | 296 | class ChainConfigModifier(ConfigModifier):
|
150 | 297 | """Chain multiple config modifiers together."""
|
151 | 298 |
|
|
0 commit comments