Commit 7339f3f
committed
Squashed commit of the following:
commit 336c75d
Author: Mark Lee <[email protected]>
Date: Mon Mar 3 09:04:07 2025 -0800
Supports arbitrary uniform partitioning in host-global array conversions. (apple#1029)
* Allows specifying PartitionSpec to host_to_global_device_array.
* Generalizes to arbitrary uniform partitioning.
* Addresses comments and adds mixed shape test.
commit 0881412
Author: Dongseong Hwang <[email protected]>
Date: Sat Mar 1 15:41:38 2025 -0800
Refactor Mask in Attention (apple#1028)
Currently, the attention code is **hardcoded** to handle either `causal_mask`
or an arbitrary `mask_fn`.
To support **sliding window masks**, we previously used a **hack** by injecting
the `_sliding_window_size` attribute into functions.
This refactor **makes the masking logic more flexible** by allowing arbitrary
`MaskFnAttentionBias`.
- If downstream requires a **new mask pattern**, they can simply:
1. Implement a **subclass of `MaskFnAttentionBias`**.
2. Set `attention.mask` accordingly.
commit f67d3f9
Author: Dongseong Hwang <[email protected]>
Date: Fri Feb 28 08:53:00 2025 -0800
Flash Attention now explicitly checks whether it is in decoding mode. (apple#1026)
Currently, Flash Attention infers decoding implicitly based on circumstantial
evidence. This PR makes the check explicit.
commit f8d2c66
Author: qdavid1 <[email protected]>
Date: Thu Feb 27 15:26:18 2025 -0800
External KV input for _update_layer_kwargs (apple#1025)
commit a3bf5e2
Author: Hanzhi Zhou <[email protected]>
Date: Wed Feb 26 17:23:40 2025 -0800
Minor changes to Checkpointer (apple#1024)
commit 55e1841
Author: Wentao Wu <[email protected]>
Date: Wed Feb 26 15:45:51 2025 -0800
Add an option to break ties for top_k_logits when k = 1 (apple#1022)
* Add an option to support stable top_k = 1.
* address comments
* address comments
* address comments
* Update axlearn/common/logit_modifiers.py
Co-authored-by: Mark Lee <[email protected]>
* Update axlearn/common/logit_modifiers.py
Co-authored-by: Mark Lee <[email protected]>
* Update axlearn/common/logit_modifiers.py
Co-authored-by: Mark Lee <[email protected]>
* Update axlearn/common/logit_modifiers.py
Co-authored-by: Mark Lee <[email protected]>
* address comments
---------
Co-authored-by: Mark Lee <[email protected]>
commit fbca3fc
Author: Meng (Ethan) Li <[email protected]>
Date: Wed Feb 26 14:05:25 2025 -0800
Add priority_class as a launch flag (apple#1020)
commit b26bd74
Author: Meng (Ethan) Li <[email protected]>
Date: Wed Feb 26 14:04:47 2025 -0800
Fix TypeError in calcualte_goodput.py (apple#1023)
commit f8191e1
Author: Dongseong Hwang <[email protected]>
Date: Wed Feb 26 11:03:44 2025 -0800
Emulate flash attentnion unittests on CPU. (apple#1021)
utils.py codebase is not well covered by CI because it branches different
backend.
This PR introduces new CPU test, utils_test.py.
This test is expected to run on CPU and is designed to validate GPU/TPU code
from a CPU environment by fake mesh.
It allows quick verification in CI and local environments to ensure that code
changes do not break GPU/TPU Flash Attention.
commit daec8c5
Author: Chang Liu <[email protected]>
Date: Tue Feb 25 12:38:43 2025 -0800
Add additional_network and additional_subnetwork config to support multi-nic for v6e (apple#1019)
Co-authored-by: Chang Liu <[email protected]>
commit ac642ea
Author: Dongseong Hwang <[email protected]>
Date: Tue Feb 25 12:02:57 2025 -0800
Fix crash in log-mel frontend when waveform samples are integers. (apple#1017)
After updating JAX, this existing hidden bug started causing CI failures.
When the sample dtype is int32 (which is valid), `jnp.finfo` returns None,
even though `jnp.iinfo` is available.
The previous JAX version seemed to handle this case more forgivingly.
```
../axlearn/axlearn/audio/frontend_utils.py:297: in linear_to_log_spectrogram
return jnp.log(jnp.maximum(x, jnp.finfo(x.dtype).tiny))
```
commit 7c64b55
Author: Meng (Ethan) Li <[email protected]>
Date: Tue Feb 25 10:52:50 2025 -0800
Add LoadBalancer to GKE replicatedJob (apple#1015)
Co-authored-by: Liang (SPG) He <[email protected]>
commit 8e8a41b
Author: Chang Lan <[email protected]>
Date: Tue Feb 25 10:26:59 2025 -0800
Expose jax.lax.scan's unroll option to Repeat layer (apple#1016)
* Expose jax.lax.scan's unroll option to Repeat layer.
* Defaults to None to avoid golden config changes
commit 682bce6
Author: Dongseong Hwang <[email protected]>
Date: Tue Feb 25 10:09:41 2025 -0800
Handle None bias in BiasAndResidual (apple#1018)
commit f053318
Author: Ruoming Pang <[email protected]>
Date: Mon Feb 24 11:12:23 2025 -0500
Allows a required value in a config_for_{function,class} to be specified via **kwargs in instantiate(). (apple#1013)
* Allows a required value in a ClassConfigBase to be specified via **kwargs in instantiate().
* Allows a required value in a FunctionConfigBase to be specified via **kwargs in instantiate().
commit a93cd1b
Author: Luzy <[email protected]>
Date: Sat Feb 22 19:55:01 2025 -0500
fix dtype in frontend pre emphasis (apple#1014)
commit c1fe2e9
Author: Maggie Zhang <[email protected]>
Date: Fri Feb 21 19:07:39 2025 -0800
GoodPut minor fix: only process 0 should start goodput uploader (apple#984)
* only process 0 will start goodput uploader
* Add unit test
commit 4b1fbf0
Author: Chang Lan <[email protected]>
Date: Fri Feb 21 13:36:45 2025 -0800
Async context invocation for checkpointing (apple#1012)
* Async context invocation supoprt for checkpointer
* Add comment
* Add comments
commit d4cd158
Author: Ruoming Pang <[email protected]>
Date: Fri Feb 21 14:22:24 2025 -0500
Allows the kwargs given in `cfg.instantiate(**kwargs)` override field values in `cfg` for FunctionConfigBase and ClassConfigBase. (apple#1011)
* Allows the kwargs given in `cfg.instantiate(**kwargs)` override field values in `cfg` for FunctionConfigBase and ClassConfigBase.
This makes it easier for `config_for_function` and `config_for_class` to be used for functions and classes that take args of types not allowed by Config fields, e.g., Tensor.
* Fixes pytype.
* Addresses review.
commit 2ae6e66
Author: Meng (Ethan) Li <[email protected]>
Date: Fri Feb 21 09:22:35 2025 -0800
Enable megascale abort on hang or error (apple#1010)
* Enable megascale_error_reporter_abort on hang and error by default
* Increase threshold to 10m
commit ce4b2fb
Author: Chunyang Wen <[email protected]>
Date: Fri Feb 21 23:12:38 2025 +0800
Add GPU monitor (apple#1006)
commit baf8ad7
Author: Dongseong Hwang <[email protected]>
Date: Thu Feb 20 19:35:07 2025 -0800
Clarify setting sliding_window_size = 8 results in a window size of 9, including itself. (apple#1009)
commit cf41112
Author: Hanzhi Zhou <[email protected]>
Date: Thu Feb 20 16:29:13 2025 -0800
Partially reverts "gRPC Checkpointer (apple#1005)" (apple#1008)
* Revert "gRPC Checkpointer (apple#1005)"
This reverts commit d27c562.
* Keep some changes
commit 454bdba
Author: Matthew Hopkins <[email protected]>
Date: Thu Feb 20 15:10:51 2025 -0800
upgrade jax 0.4.38 (apple#1007)
commit d27c562
Author: Hanzhi Zhou <[email protected]>
Date: Tue Feb 18 18:54:53 2025 -0800
gRPC Checkpointer (apple#1005)
commit fb90620
Author: Ruoming Pang <[email protected]>
Date: Tue Feb 18 21:08:38 2025 -0500
Makes file_system.glob support multiple patterns. (apple#1003)
* Makes file_system.glob support multiple patterns.
* Makes file_system.glob support multiple patterns.
* Makes file_system.glob support multiple patterns.
* Makes file_system.glob support multiple patterns.
commit 334f421
Author: Mark Lee <[email protected]>
Date: Tue Feb 18 17:03:39 2025 -0800
Reverts sliding window attention changes. (apple#1004)
* Revert "Fix flash decoding in GPU. (apple#999)"
This reverts commit fdadfd8.
* Revert "Supports TPU context parallel training (apple#981)"
This reverts commit e151d69.
* Revert "Implemented sliding window attention to maintain KV cache only for the window size to enable infinite decoding. (apple#995)"
This reverts commit 67645d0.
* Retain model/decoder asr changes.
commit 3dacc6b
Author: Chang Lan <[email protected]>
Date: Mon Feb 17 19:45:00 2025 -0800
Refactor aot_compilation for reuse (apple#1000)
commit c44fe18
Author: Ruoming Pang <[email protected]>
Date: Mon Feb 17 21:38:17 2025 -0500
Makes checkpointer_test.py use file_system. (apple#1001)
commit fdadfd8
Author: Dongseong Hwang <[email protected]>
Date: Mon Feb 17 18:23:21 2025 -0800
Fix flash decoding in GPU. (apple#999)
target_positions used to be time_step, but after PR apple#995, it now represents the
actual target positions with shape [batch, step_len].
apple#995
Updating the GPU decoding code to align with this change.
CI did not cover GPU unit tests.
TEST=test_extend_step10 of axlearn/common/flash_attention/layer_test.py in GPU
commit 9e64388
Author: Ruoming Pang <[email protected]>
Date: Mon Feb 17 16:40:03 2025 -0500
Makes axlearn/cloud/ use file_system. (apple#998)
* Makes bastion.py use file_system. This is a first step towards removing the tf.io.gfile dependency.
* Adds testing for file_system.readfile.
* Fixes pytype.
* Makes axlearn/cloud use file_system instead of gfile.
commit 5fba4ce
Author: Chang Lan <[email protected]>
Date: Mon Feb 17 09:44:10 2025 -0800
AOT compilation support for inference (apple#997)
* Add optional `devices` init argument to InferenceRunner for passing
fake devices during AOT compilation.
* Add more v5e slice types.
commit e151d69
Author: Hanzhi Zhou <[email protected]>
Date: Sun Feb 16 13:01:15 2025 -0800
Supports TPU context parallel training (apple#981)
Fix
Fix tests
commit 67645d0
Author: Dongseong Hwang <[email protected]>
Date: Sat Feb 15 13:26:51 2025 -0800
Implemented sliding window attention to maintain KV cache only for the window size to enable infinite decoding. (apple#995)
* Revert "Transpose kv cache for better decode performance (apple#979)"
This reverts commit b130416.
* Update golden configs
* Implemented sliding window attention to maintain KV cache only for the window size to enable infinite decoding.
Currently, when using `MultiheadAttention` or `GroupedQueryAttention` for
sliding window attention, the KV cache is kept for the full sequence length
(`seq_len`) instead of the window length (`window_len`).
For example, a model with `window_len=1k` and `seq_len=2M` keeps a KV cache
for the full 2M tokens. It then biases 1999k invalid KV tokens before
calculating attention, resulting in a computational complexity of **O(2M²)**
instead of the desired **O(1k²)**.
This issue persists even when using flash attention. Flash attention uses the
KV cache allocated in HBM as its input. While unnecessary blocks are discarded
during computation, the KV cache still occupies HBM inefficiently for the full
2M tokens.
To address this, when `MultiheadAttention` detects a sliding window mask, it
stores the key-value (KV) cache in a ring buffer inside the input linear layer.
As a result, downstream projects using `MultiheadAttention` automatically
benefit from efficient KV cache handling in `init_states` and `extend_step`.
Additionally, for use cases like local-global attention in LLMs, it is
recommended to use sliding window masks for even the global attention as well.
For example, if you want to train an LLM with a context length of 8k, you can
set the sliding window size to 8k during training. This enables functionally
infinite decoding during inference. Accuracy wouldn't be good tho.
Note:
* query_positions in QKVLinear.forward() was introduced by
apple#914. Now it returns to the caller.
This PR actually moves from downstream speech/streaming/sliding_window_attention.py
* transpose
commit 272a4d2
Author: Chang Lan <[email protected]>
Date: Fri Feb 14 10:48:21 2025 -0800
Add v5e-8 (apple#994)
commit debb46a
Author: Mark Lee <[email protected]>
Date: Thu Feb 13 22:27:28 2025 -0800
Decouples jobsets from replicated jobs. (apple#991)
* Decouples jobsets from replicated jobs.
* Address comments.
commit 8f2b99d
Author: Maggie Zhang <[email protected]>
Date: Thu Feb 13 20:49:48 2025 -0800
Add Goodput documentation (apple#989)
* Temporarily change checkpointing to every 5 steps
* revert local changes
* Add example command for goodput usage
commit 31e8da0
Author: Alexander Pivovarov <[email protected]>
Date: Thu Feb 13 15:44:51 2025 -0800
Fix Missing return statement in base_layer_test.py::ExplicitFanLayer::_compute_fan_axes (apple#987)
commit 7f2dd9e
Author: Apoorv Gupta <[email protected]>
Date: Thu Feb 13 14:36:11 2025 -0800
Flash Attention for Neuron (apple#939)
commit 6ca4f56
Author: Philipp Dufter <[email protected]>
Date: Thu Feb 13 23:09:28 2025 +0100
pass on log_warning in input_tf_data.skip_on_error (apple#990)
* make log_warnings customizable in tfds skip error
* address comments
commit 1a8a0eb
Author: Hanzhi Zhou <[email protected]>
Date: Thu Feb 13 13:32:34 2025 -0800
Integrate Orbax's emergency checkpoint. (apple#820)
* Integrate Orbax emergency checkpoint
* Address comments
* comment
* Address comments
* Upgrade orbax
* Improve comments
* Improve comments
* Update for new orbax versions
* Better timer
* Address comments
* Add step test
* Fix
* Add comment
commit 42fd715
Author: Apoorv Gupta <[email protected]>
Date: Thu Feb 13 09:13:30 2025 -0800
TRN2 Meshes and Configurations (apple#916)
* TRN2 Meshes and Configurations
* Add get_recursive and set_recursive to ConfigBase.
* Use loops inside get/set_recursively
+ address comments
* Update partition spec
* Use get_recursively inside set
* Move trn2 configs to a helper function.
+ Fix modifier tests
* TRN2 partitionspec supports DP over FSDP and TP
* Use for loop in get_recursively
* Update Golden Configs
commit d47d5ce
Author: Haoshuo Huang <[email protected]>
Date: Tue Feb 11 18:13:13 2025 -0800
Add support to slice dataset based on proportions. (apple#982)
commit ed8f382
Author: Mark Lee <[email protected]>
Date: Tue Feb 11 13:22:44 2025 -0800
Allow metrics layers to have state. (apple#978)
* Allow metrics layers to have state.
* Move BaseLossMetrics to a new file.
commit b130416
Author: Chang Lan <[email protected]>
Date: Tue Feb 11 00:01:28 2025 -0800
Transpose kv cache for better decode performance (apple#979)
commit 48bf488
Author: Haoshuo Huang <[email protected]>
Date: Mon Feb 10 22:25:18 2025 -0800
Add support for grain.IterDataset in sampling (apple#980)
commit d4b563c
Author: Alexander Pivovarov <[email protected]>
Date: Mon Feb 10 15:36:22 2025 -0800
Replace jnp.ndarray with Tensor from axlearn.common.utils (apple#973)
commit 0666d80
Author: Alexander Pivovarov <[email protected]>
Date: Mon Feb 10 15:35:23 2025 -0800
Fix membership checks in tool_use_execution.py (apple#974)
commit 2f4763c
Author: Alexander Pivovarov <[email protected]>
Date: Mon Feb 10 15:31:59 2025 -0800
Remove redundant import logging (apple#975)
commit 58dcf33
Author: Hanzhi Zhou <[email protected]>
Date: Mon Feb 10 13:41:33 2025 -0800
Enable cudnn dropout (apple#913)
commit ae855ed
Author: Mark Lee <[email protected]>
Date: Mon Feb 10 12:43:50 2025 -0800
Ensures that cache_dtype is respected. (apple#977)
commit cfef38b
Author: Daniel Swann <[email protected]>
Date: Mon Feb 10 10:56:10 2025 -0800
:sparkles: Add cache for CloudBuild API location queries (apple#967)
commit 8fd9137
Author: Wei Liu <[email protected]>
Date: Sun Feb 9 15:33:53 2025 -0800
Add segment_ids option in DiTAttentionLayer (apple#976)
commit e55a404
Author: Chang Lan <[email protected]>
Date: Sun Feb 9 04:38:49 2025 -0800
Use broadcasting trick for KV update (apple#972)
* Use vmap and dynamic_update_slice for KV update
* Broadcasting trick
* Simplify the impl per @markblee's suggestion
* comments
commit b955187
Author: Dongseong Hwang <[email protected]>
Date: Fri Feb 7 14:12:48 2025 -0800
Don't keep initial key/value inputs in the KV cache. (apple#968)
The current code is weird. It stores the input key/value in the KV cache, but
this doesn’t make sense in either init_states or prefill:
* init_states: This is not prefill, so key/value should not be stored in the KV cache.
* prefill: The extend_step() function overrides this part anyway.
Thus, this PR removes this unnecessary and confusing logic.
The logic was introduced in apple#860
commit c3d656d
Author: zhengdong-zhang <[email protected]>
Date: Fri Feb 7 10:18:42 2025 -0800
Refactorization. (apple#963)
commit 1c883d8
Author: Zhao Xu <[email protected]>
Date: Fri Feb 7 10:02:56 2025 -0800
Support system role when calling the Gemini API. (apple#971)
commit ceab4f4
Author: Haoshuo Huang <[email protected]>
Date: Thu Feb 6 20:41:07 2025 -0800
Making shared_memory configurable (apple#969)
* Making shared_memory configurable
* fix eol space
commit 323faa3
Author: Meng (Ethan) Li <[email protected]>
Date: Thu Feb 6 12:11:28 2025 -0800
Use env id for gcp settings (apple#957)
* Use env_id to replace zone as gcp_settings key to support multiple env under the same zone
* fall back to zone
* address comments
* Suppport project in the label filter; always get zone from gcp_setting value instead of return it directly
commit 2ec3a02
Author: Chang Lan <[email protected]>
Date: Wed Feb 5 22:25:58 2025 -0800
Fix incorrect number of formatting arguments (apple#966)
commit d131d3b
Author: Nan Du <[email protected]>
Date: Mon Feb 3 11:44:00 2025 -0800
Reduce the verbosity of variable norm summaries (apple#965)
commit c1c6e29
Author: Kelvin Zou <[email protected]>
Date: Fri Jan 31 22:24:39 2025 -0800
Sliding window support for GPU flash attention (apple#962)
* snapshot
* snapshot
* snapshot
* remove unexpected change
* adding shape commenbt
* fix pylint
* snapshot
commit 0936a17
Author: Mark Lee <[email protected]>
Date: Fri Jan 31 13:59:12 2025 -0800
Supports loss_weights and live_targets in metrics. (apple#960)
* Supports loss_weights, live_targets, and module sharing in metrics.
* Addresses comments.
* Explicitly test flatten_metrics=True.
commit 7a40f91
Author: Dipannita Shaw <[email protected]>
Date: Fri Jan 31 11:45:33 2025 -0800
Add Goodput & Badput recording and monitoring support. (apple#783)
* Code clean up
* Add more testing
* Fix docstrings
* Remove recorder calls from trainer for now
* Code cleanup gcp/measurement.py
Co-authored-by: Ruoming Pang <[email protected]>
* Code cleanup common/measurement.py
Co-authored-by: Ruoming Pang <[email protected]>
* Fix pre commit errors
* Adding more tests
* Further clean up
* Fix a test error
---------
Co-authored-by: Ruoming Pang <[email protected]>
commit 031a7f3
Author: Mark Lee <[email protected]>
Date: Thu Jan 30 20:19:12 2025 -0800
Skipping empty grain batches during unbatch. (apple#961)
* Skipping empty grain batches during unbatch.
* Use a loop instead of recursion.
commit 795da33
Author: Hanzhi Zhou <[email protected]>
Date: Thu Jan 30 07:17:16 2025 -0800
Optimizer offloading through weight-only offload (apple#867)
* Optimizer offloading
* Style fix
* Type fix
commit b1a1a5a
Author: Haoshuo Huang <[email protected]>
Date: Wed Jan 29 21:44:15 2025 -0800
Improve gcsfuse io (apple#959)
commit d76ef6f
Author: Hanzhi Zhou <[email protected]>
Date: Wed Jan 29 15:10:13 2025 -0800
SplashAttention performance tuning for v6e (apple#958)
* SplashAttention tuning for v6e
* Add import to fix pytype errors
commit 2d002e3
Author: Hanzhi Zhou <[email protected]>
Date: Wed Jan 29 12:07:56 2025 -0800
Use InputDispatcher for fuji models (apple#956)
* Use dispatcher
* Update golden configs
* Remove logical feed indices
commit fad264b
Author: Mark Lee <[email protected]>
Date: Tue Jan 28 10:41:54 2025 -0800
Explicitly pass module outputs to metrics. (apple#953)
* Explicitly pass module outputs to metrics.
* Support and add checks for module/state updates.
* Only flatten summaries.
commit 59508e3
Author: Hanzhi Zhou <[email protected]>
Date: Tue Jan 28 10:34:52 2025 -0800
Add v6e PCIe overload workaround flag (apple#955)
commit 028ecfd
Author: Haoshuo Huang <[email protected]>
Date: Mon Jan 27 20:54:28 2025 -0800
Fix GCSFUSE flags by setting resource limit. (apple#954)
commit 3e2c6dd
Author: Matthew Hopkins <[email protected]>
Date: Mon Jan 27 14:56:42 2025 -0800
update jax to 0.4.37 (apple#948)
update BlockSpec usage in tpu_attention
use TYPE_CHECKING for BuildDatasetFn in input_fake
add todo for BuildDatasetFn
commit b125f00
Author: Hanzhi Zhou <[email protected]>
Date: Mon Jan 27 11:29:23 2025 -0800
Add v6e special meshes (apple#952)
* Add v6e special mesh
* Add v6e special mesh
* Fix
* Fix
commit a854738
Author: Firenze11 <[email protected]>
Date: Mon Jan 27 09:17:46 2025 -0800
Allow external positions to be inputed in RoPE embedding layer (apple#926)
* Allow external positions to be inputed in RoPE embedding layer
Use case: In RoPE embedding, position embeddings are applied to Q, K, V values after `i_proj`. Unlike the implementation of current `RoFormerQKVLinear`, in MaskedDiT we need to customize positions to indicate masked versus non-masked positions in the position embedding. When we convert this masked roformer attention module to flash attention, we need its signature to be supported by `MultiheadAttention`.
* Update attention_test.py
* Update dit.py
* Update attention.py
* Update attention_test.py
* Update attention.py
* Update dit.py
* Update axlearn/common/attention.py
Co-authored-by: Mark Lee <[email protected]>
* respond to comments.
Co-authored-by: Ruoming Pang <[email protected]>
* Update attention.py
* Update attention.py
* Update attention.py
---------
Co-authored-by: Mark Lee <[email protected]>
Co-authored-by: Ruoming Pang <[email protected]>
commit 999401a
Author: qdavid1 <[email protected]>
Date: Mon Jan 27 09:11:17 2025 -0800
Update LoraFusedQKVLinear (apple#949)
commit 1c22688
Author: Mark Lee <[email protected]>
Date: Sun Jan 26 04:51:02 2025 -0800
Workaround module outputs being dropped. (apple#951)
commit 94c81cb
Author: Meng (Ethan) Li <[email protected]>
Date: Fri Jan 24 11:01:45 2025 -0800
Add link to github issue regarding kubernetes-32.0.0 (apple#947)
commit a6e0f4a
Author: Meng (Ethan) Li <[email protected]>
Date: Fri Jan 24 08:40:25 2025 -0800
Pin kubernetes pip version to 31.0.0 to fix client authentication error (apple#946)
commit 076521a
Author: Mark Lee <[email protected]>
Date: Thu Jan 23 15:11:00 2025 -0800
Forward input keys to decoder. (apple#944)
commit 30284c8
Author: Hanzhi Zhou <[email protected]>
Date: Thu Jan 23 10:33:54 2025 -0800
Legacy flash remat fix (apple#943)
* Fix the same problem for legacy tpu attn
* Fix
commit 6a9f980
Author: Mark Lee <[email protected]>
Date: Thu Jan 23 09:20:46 2025 -0800
Adds mesh rule for a3-megagpu-8g. (apple#936)
commit ac7a3ed
Author: Dongseong Hwang <[email protected]>
Date: Thu Jan 23 08:15:27 2025 -0800
Enabled running Pallas Flash Attention on CPU. (apple#922)
Pallas supports CPU simulation (`interpret=True`), so we can use the same
TPU Pallas kernel on CPU — making code debugging easier.
This change lets the following unittests run on CPU as if they were on TPU,
enabling easier testing and debugging:
- `axlearn/common/flash_attention/tpu_attention_test.py`
Similarly, `gpu_attention_test.py` can also be run on CPU as if they were on GPU.
- `axlearn/common/flash_attention/gpu_attention_test.py`
Now CI covers those tests on CPU as well.
In M3 Max MacBook Pro, test coverages and processing time are as follows,
* axlearn/common/flash_attention/gpu_attention_test.py: 3024 passed, 1345 skipped in 200.38s (0:03:20)
* axlearn/common/flash_attention/tpu_attention_test.py: 18 passed, 435 skipped in 34.82s
commit 8ea85bd
Author: Hanzhi Zhou <[email protected]>
Date: Wed Jan 22 09:51:15 2025 -0800
Some fixes for flash remat (apple#942)
commit 185b1b5
Author: Chang Lan <[email protected]>
Date: Tue Jan 21 11:21:08 2025 -0800
Repeat KV heads in Flash Attention (apple#938)
* Roll back '_repeat_kv_heads' change in Flash Attention
Recent PR removed _repeat_kv_heads from Flash Attention for GQA optimization,
in the hope to reduce HBM usage. However the actual HBM saving would be limited
in the model-parallel setting, as the heads are already sharded across devices.
It also introduces some limitation which breaks some of the existing sharding
configurations.
For example, let's say num_heads = 8 and num_kv_heads = 4. When we repeat KV heads,
we can set the model axis as 8 so that each device will have only one Q, K, V head;
Without repeat_kv_heads, the max value of model axis is 4, and each device will have
2 Q heads as a result, increasing the actual HBM usage.
* Repeat kv as necessary for sharding
* Unit tests
* Address comments.
commit 4678740
Author: Chang Lan <[email protected]>
Date: Mon Jan 20 20:36:44 2025 -0800
AOT compilation for v6e (apple#937)
commit 357bef6
Author: Mark Lee <[email protected]>
Date: Mon Jan 20 20:23:39 2025 -0800
Makes causal lm metrics configurable. (apple#934)
* Makes causal lm metrics configurable.
* Address review comments.
* Make metrics required.
* Update golden configs.
* Removes PredictModel.
commit 16ca0c2
Author: Mark Lee <[email protected]>
Date: Sun Jan 19 14:19:20 2025 -0800
Supports flexible input partition specs. (apple#933)
* Supports flexible input partition specs in causal lm.
* Moves the input partitioning to Input.
* Adds missing pytest marker.
* Address review comments.
* Rebase and update golden configs.
* Fixes batch axis names and adds a test.
commit 9b75ef1
Author: Mark Lee <[email protected]>
Date: Sun Jan 19 07:43:19 2025 -0800
Avoid a top-level import of tokenizers. (apple#935)
commit 9996f34
Author: sychen52 <[email protected]>
Date: Sat Jan 18 09:44:04 2025 -0800
Add llama 3 tokenizer (apple#850)
* Add llama 3 tokenizer
add a new version called V3_TIKTOKEN.
other edits based on suggestions.
* Handle special tokens like other vocabularies.
* use encode instead of encode_batch
commit ad14de3
Author: Haoshuo Huang <[email protected]>
Date: Fri Jan 17 14:19:24 2025 -0800
Add ReadOptions args to _make_autoregressive_inputs (apple#931)
* Add ReadOptions args to _make_autoregressive_inputs
* use read_options as args instead
commit 4858070
Author: Sam Stoelinga <[email protected]>
Date: Fri Jan 17 13:54:05 2025 -0800
improve GCS perf: Change resource limit to request (apple#851)
commit b0ee05e
Author: Bailin <[email protected]>
Date: Fri Jan 17 22:53:00 2025 +0800
Add Mamab2 and its Jamba variant (apple#839)
* add mamab2
* merge
* unify init and prefill
* adapt final changes
---------
Co-authored-by: bailin_wang <[email protected]>
commit 1e25e4a
Author: Hanzhi Zhou <[email protected]>
Date: Thu Jan 16 11:25:24 2025 -0800
Cache AoT compilation result (apple#927)
* Cache AoT compilation result
* Fix comments
* Fix
* Fix
* Fix
* Fix1 parent 5e19503 commit 7339f3f
File tree
283 files changed
+22542
-5643
lines changed- .axlearn
- axlearn
- audio
- cloud
- common
- gcp
- jobs
- monitoring
- common
- flash_attention
- monitoring
- ssm_kernels
- experiments
- testdata
- axlearn.experiments.text.gpt.c4_trainer
- axlearn.experiments.text.gpt.deterministic_trainer
- axlearn.experiments.text.gpt.pajama_sigmoid_trainer
- axlearn.experiments.text.gpt.pajama_trainer
- axlearn.experiments.text.gpt.param_converter_test
- text/gpt
- open_api
- metrics
- vision
- docs
Some content is hidden
Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.
283 files changed
+22542
-5643
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
4 | 4 | | |
5 | 5 | | |
6 | 6 | | |
| 7 | + | |
7 | 8 | | |
8 | 9 | | |
9 | 10 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1 | 1 | | |
2 | 2 | | |
| 3 | + | |
| 4 | + | |
| 5 | + | |
| 6 | + | |
| 7 | + | |
| 8 | + | |
| 9 | + | |
| 10 | + | |
| 11 | + | |
| 12 | + | |
3 | 13 | | |
4 | 14 | | |
5 | 15 | | |
6 | | - | |
| 16 | + | |
| 17 | + | |
| 18 | + | |
| 19 | + | |
7 | 20 | | |
8 | 21 | | |
9 | 22 | | |
10 | 23 | | |
11 | | - | |
| 24 | + | |
12 | 25 | | |
13 | 26 | | |
14 | 27 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
10 | 10 | | |
11 | 11 | | |
12 | 12 | | |
13 | | - | |
| 13 | + | |
14 | 14 | | |
15 | 15 | | |
16 | 16 | | |
| |||
1619 | 1619 | | |
1620 | 1620 | | |
1621 | 1621 | | |
1622 | | - | |
| 1622 | + | |
1623 | 1623 | | |
1624 | 1624 | | |
1625 | 1625 | | |
| |||
1698 | 1698 | | |
1699 | 1699 | | |
1700 | 1700 | | |
| 1701 | + | |
| 1702 | + | |
| 1703 | + | |
| 1704 | + | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
115 | 115 | | |
116 | 116 | | |
117 | 117 | | |
118 | | - | |
119 | | - | |
| 118 | + | |
| 119 | + | |
| 120 | + | |
| 121 | + | |
| 122 | + | |
| 123 | + | |
| 124 | + | |
| 125 | + | |
| 126 | + | |
| 127 | + | |
| 128 | + | |
| 129 | + | |
| 130 | + | |
| 131 | + | |
120 | 132 | | |
121 | 133 | | |
122 | 134 | | |
| |||
210 | 222 | | |
211 | 223 | | |
212 | 224 | | |
213 | | - | |
| 225 | + | |
214 | 226 | | |
215 | 227 | | |
216 | 228 | | |
| |||
225 | 237 | | |
226 | 238 | | |
227 | 239 | | |
228 | | - | |
| 240 | + | |
229 | 241 | | |
230 | 242 | | |
231 | 243 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
17 | 17 | | |
18 | 18 | | |
19 | 19 | | |
20 | | - | |
| 20 | + | |
21 | 21 | | |
22 | 22 | | |
23 | 23 | | |
| |||
180 | 180 | | |
181 | 181 | | |
182 | 182 | | |
183 | | - | |
| 183 | + | |
| 184 | + | |
184 | 185 | | |
185 | 186 | | |
186 | 187 | | |
| |||
189 | 190 | | |
190 | 191 | | |
191 | 192 | | |
192 | | - | |
| 193 | + | |
193 | 194 | | |
194 | 195 | | |
195 | 196 | | |
| |||
334 | 335 | | |
335 | 336 | | |
336 | 337 | | |
| 338 | + | |
| 339 | + | |
| 340 | + | |
| 341 | + | |
| 342 | + | |
| 343 | + | |
| 344 | + | |
| 345 | + | |
| 346 | + | |
| 347 | + | |
| 348 | + | |
| 349 | + | |
| 350 | + | |
| 351 | + | |
| 352 | + | |
| 353 | + | |
| 354 | + | |
| 355 | + | |
| 356 | + | |
| 357 | + | |
| 358 | + | |
| 359 | + | |
| 360 | + | |
| 361 | + | |
| 362 | + | |
337 | 363 | | |
338 | 364 | | |
339 | 365 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
247 | 247 | | |
248 | 248 | | |
249 | 249 | | |
250 | | - | |
| 250 | + | |
251 | 251 | | |
252 | 252 | | |
253 | 253 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
2 | 2 | | |
3 | 3 | | |
4 | 4 | | |
5 | | - | |
6 | | - | |
7 | 5 | | |
8 | 6 | | |
9 | 7 | | |
| |||
130 | 128 | | |
131 | 129 | | |
132 | 130 | | |
133 | | - | |
134 | | - | |
135 | | - | |
136 | | - | |
137 | | - | |
138 | | - | |
139 | | - | |
140 | | - | |
141 | | - | |
142 | | - | |
| 131 | + | |
| 132 | + | |
| 133 | + | |
| 134 | + | |
| 135 | + | |
| 136 | + | |
| 137 | + | |
| 138 | + | |
| 139 | + | |
| 140 | + | |
143 | 141 | | |
144 | | - | |
145 | | - | |
146 | | - | |
| 142 | + | |
147 | 143 | | |
148 | 144 | | |
149 | 145 | | |
| |||
171 | 167 | | |
172 | 168 | | |
173 | 169 | | |
174 | | - | |
| 170 | + | |
175 | 171 | | |
176 | 172 | | |
177 | 173 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
23 | 23 | | |
24 | 24 | | |
25 | 25 | | |
26 | | - | |
27 | | - | |
| 26 | + | |
| 27 | + | |
28 | 28 | | |
29 | 29 | | |
30 | 30 | | |
0 commit comments