Skip to content

Commit

Permalink
Add basic documentation pages
Browse files Browse the repository at this point in the history
  • Loading branch information
melissawm committed Oct 22, 2024
1 parent 3637fc5 commit 7206b6c
Show file tree
Hide file tree
Showing 18 changed files with 132 additions and 16 deletions.
33 changes: 17 additions & 16 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@
limitations under the License.
-->

# MaxText

[![Unit Tests](https://github.com/google/maxtext/actions/workflows/UnitTests.yml/badge.svg)](https://github.com/google/maxtext/actions/workflows/UnitTests.yml)

# Overview
## Overview

MaxText is a **high performance**, **highly scalable**, **open-source** LLM written in pure Python/Jax and targeting Google Cloud TPUs and GPUs for **training** and **inference**. MaxText achieves [high MFUs](#runtime-performance-results) and scales from single host to very large clusters while staying simple and "optimization-free" thanks to the power of Jax and the XLA compiler.

Expand All @@ -30,15 +31,15 @@ Key supported features:
* Training and Inference (in preview)
* Models: Llama2, Mistral and Gemma

# Table of Contents
## Table of Contents

* [Getting Started](getting_started/First_run.md)
* [Runtime Performance Results](#runtime-performance-results)
* [Comparison To Alternatives](#comparison-to-alternatives)
* [Development](#development)
* [Features and Diagnostics](#features-and-diagnostics)

# Getting Started
## Getting Started

For your first time running MaxText, we provide specific [instructions](getting_started/First_run.md).

Expand All @@ -51,11 +52,11 @@ Some extra helpful guides:

In addition to the getting started guides, there are always other MaxText capabilities that are being constantly being added! The full suite of end-to-end tests is in [end_to_end](end_to_end). We run them with a nightly cadence. They can be a good source for understanding MaxText Alternatively you can see the continuous [unit tests](.github/workflows/UnitTests.yml) which are run almost continuously.

# Runtime Performance Results
## Runtime Performance Results

More details on reproducing these results can be found in [MaxText/configs/README.md](MaxText/configs/README.md).

## TPU v5p
### TPU v5p

| No. of params | Accelerator Type | TFLOP/chip/sec | Model flops utilization (MFU) |
|---|---|---|---|
Expand All @@ -70,7 +71,7 @@ More details on reproducing these results can be found in [MaxText/configs/READM
| 1160B | v5p-7680 | 2.95e+02 | 64.27% |
| 1160B | v5p-12288 | 3.04e+02 | 66.23% |

## TPU v5e
### TPU v5e

For 16B, 32B, 64B, and 128B models. See full run configs in [MaxText/configs/v5e/](MaxText/configs/v5e/) as `16b.sh`, `32b.sh`, `64b.sh`, `128b.sh`.

Expand All @@ -83,16 +84,16 @@ For 16B, 32B, 64B, and 128B models. See full run configs in [MaxText/configs/v5e
| 16x v5e-256 | 111 | 56.56% | 123 | 62.26% | 105 | 53.29% | 100 | 50.86% |
| 32x v5e-256 | 108 | 54.65% | 119 | 60.40% | 99 | 50.18% | 91 | 46.25% |

# Comparison to Alternatives
## Comparison to Alternatives

MaxText is heavily inspired by [MinGPT](https://github.com/karpathy/minGPT)/[NanoGPT](https://github.com/karpathy/nanoGPT), elegant standalone GPT implementations written in PyTorch and targeting Nvidia GPUs. MaxText is more complex, supporting more industry standard models and scaling to tens of thousands of chips. Ultimately MaxText has an MFU more than three times the [17%](https://twitter.com/karpathy/status/1613250489097027584?cxt=HHwWgIDUhbixteMsAAAA) reported most recently with that codebase, is massively scalable and implements a key-value cache for efficient auto-regressive decoding.

MaxText is more similar to [Nvidia/Megatron-LM](https://github.com/NVIDIA/Megatron-LM), a very well tuned LLM implementation targeting Nvidia GPUs. The two implementations achieve comparable MFUs. The difference in the codebases highlights the different programming strategies. MaxText is pure Python, relying heavily on the XLA compiler to achieve high performance. By contrast, Megatron-LM is a mix of Python and CUDA, relying on well-optimized CUDA kernels to achieve high performance.

MaxText is also comparable to [Pax](https://github.com/google/paxml). Like Pax, MaxText provides high-performance and scalable implementations of LLMs in Jax. Pax focuses on enabling powerful configuration parameters, enabling developers to change the model by editing config parameters. By contrast, MaxText is a simple, concrete implementation of various LLMs that encourages users to extend by forking and directly editing the source code.

# Features and Diagnostics
## Collect Stack Traces
## Features and Diagnostics
### Collect Stack Traces
When running a Single Program, Multiple Data (SPMD) job on accelerators, the overall process can hang if there is any error or any VM hangs/crashes for some reason. In this scenario, capturing stack traces will help to identify and troubleshoot the issues for the jobs running on TPU VMs.

The following configurations will help to debug a fault or when a program is stuck or hung somewhere by collecting stack traces. Change the parameter values accordingly in `MaxText/configs/base.yml`:
Expand All @@ -106,10 +107,10 @@ jsonPayload.verb="stacktraceanalyzer"

Here is the related PyPI package: https://pypi.org/project/cloud-tpu-diagnostics.

## Ahead of Time Compilation (AOT)
### Ahead of Time Compilation (AOT)
To compile your training run ahead of time, we provide a tool `train_compile.py`. This tool allows you to compile the main `train_step` in `train.py` for target hardware (e.g. a large number of v5e devices) without using the full cluster.

### TPU Support
#### TPU Support

You may use only a CPU or a single VM from a different family to pre-compile for a TPU cluster. This compilation helps with two main goals:

Expand All @@ -119,7 +120,7 @@ You may use only a CPU or a single VM from a different family to pre-compile for

The tool `train_compile.py` is tightly linked to `train.py` and uses the same configuration file `configs/base.yml`. Although you don't need to run on a TPU, you do need to install `jax[tpu]` in addition to other dependencies, so we recommend running `setup.sh` to install these if you have not already done so.

#### Example AOT 1: Compile ahead of time basics
##### Example AOT 1: Compile ahead of time basics
After installing the dependencies listed above, you are ready to compile ahead of time:
```
# Run the below on a single machine, e.g. a CPU
Expand All @@ -129,7 +130,7 @@ global_parameter_scale=16 per_device_batch_size=4

This will compile a 16B parameter MaxText model on 2 v5e pods.

#### Example AOT 2: Save compiled function, then load and run it
##### Example AOT 2: Save compiled function, then load and run it
Here is an example that saves then loads the compiled `train_step`, starting with the save:

**Step 1: Run AOT and save compiled function**
Expand All @@ -156,14 +157,14 @@ base_output_directory=gs://my-output-bucket dataset_path=gs://my-dataset-bucket

In the save step of example 2 above we included exporting the compiler flag `LIBTPU_INIT_ARGS` and `learning_rate` because those affect the compiled object `my_compiled_train.pickle.` The sizes of the model (e.g. `global_parameter_scale`, `max_sequence_length` and `per_device_batch`) are fixed when you initially compile via `compile_train.py`, you will see a size error if you try to run the saved compiled object with different sizes than you compiled with. However a subtle note is that the **learning rate schedule** is also fixed when you run `compile_train` - which is determined by both `steps` and `learning_rate`. The optimizer parameters such as `adam_b1` are passed only as shaped objects to the compiler - thus their real values are determined when you run `train.py`, not during the compilation. If you do pass in different shapes (e.g. `per_device_batch`), you will get a clear error message reporting that the compiled signature has different expected shapes than what was input. If you attempt to run on different hardware than the compilation targets requested via `compile_topology`, you will get an error saying there is a failure to map the devices from the compiled to your real devices. Using different XLA flags or a LIBTPU than what was compiled will probably run silently with the environment you compiled in without error. However there is no guaranteed behavior in this case; you should run in the same environment you compiled in.

### GPU Support
#### GPU Support
Ahead-of-time compilation is also supported for GPUs with some differences from TPUs:

1. GPU does not support compilation across hardware: A GPU host is still required to run AoT compilation, but a single GPU host can compile a program for a larger cluster of the same hardware.

1. For [A3 Cloud GPUs](https://cloud.google.com/compute/docs/gpus#h100-gpus), the maximum "slice" size is a single host, and the `compile_topology_num_slices` parameter represents the number of A3 machines to precompile for.

#### Example
##### Example
This example illustrates the flags to use for a multihost GPU compilation targeting a cluster of 4 A3 hosts:

**Step 1: Run AOT and save compiled function**
Expand Down Expand Up @@ -191,5 +192,5 @@ base_output_directory=gs://my-output-bucket dataset_path=gs://my-dataset-bucket
As in the TPU case, note that the compilation environment must match the execution environment, in this case by setting the same `XLA_FLAGS`.


## Automatically Upload Logs to Vertex Tensorboard
### Automatically Upload Logs to Vertex Tensorboard
MaxText supports automatic upload of logs collected in a directory to a Tensorboard instance in Vertex AI. Follow [user guide](getting_started/Use_Vertex_AI_Tensorboard.md) to know more.
31 changes: 31 additions & 0 deletions docs/about.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# What is Maxtext?

MaxText is a Google initiated open source project for high performance, highly scalable, open-source LLM written in pure Python/[JAX](https://jax.readthedocs.io/en/latest/index.html) and targeting Google Cloud TPUs and GPUs for training and inference.

MaxText achieves very high MFUs (Model Flop Utilization) and scales from single host to very large clusters while staying simple and "optimization-free".

MaxText additionally provides an highly optimized reference implementations for popular Open Source models like:

- Llama 2, 3 and 3.1
- Mistral and Mixtral
- Gemma and Gemma2
- GPT

These reference implementations support pre-training and full fine tuning. Maxtext also allows you to create various sized models for benchmarking purposes.

The key value proposition of using MaxText for pre-training or full fine tuning is:

- Very high performance of average of 50% MFU
- Open code base - Code base can be found at the following github location.
- Easy to understand: MaxText is purely written in JAX and Python, which makes it accessible to ML developers interested in inspecting the implementation or stepping through it. It is written at the block-by-block level, with code for Embeddings, Attention, Normalization etc. Different Attention mechanisms like MQA and GQA are all present. For quantization, it uses the JAX AQT library. The implementation is suitable for both GPUs and TPUs.

MaxText aims to be a launching off point for ambitious LLM projects both in research and production. We encourage users to start by experimenting with MaxText out of the box and then fork and modify MaxText to meet their needs.

!!! note

Maxtext today only supports Pre-training and Full Fine Tuning of the models. It does not support PEFT/LoRA, Supervised Fine Tuning or RLHF

## Who is the target user of Maxtext?

- Any individual or a company that is interested in forking maxtext and seeing it as a reference implementation of a high performance Large Language Models and wants to build their own LLMs on TPU and GPU.
- Any individual or a company that is interested in performing a pre-training or Full Fine Tuning of the supported open source models, can use Maxtext as a blackbox to perform full fine tuning. Maxtext attains an extremely high MFU, resulting in large savings in training costs.
10 changes: 10 additions & 0 deletions docs/data_loading.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Data Loading

Maxtext supports input data pipelines in the following ways:
Tf.data*
Grain
Hugging Face Datasets

*Tf.data is the most performant way of loading large scale datasets.

You can read more about the pipelines in [](getting_started/Data_Input_Pipeline.md).
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,15 @@

### HuggingFace pipeline
The following data are collected using c4 data in Parquet format.

| Pipeline | seq_len | VM type | per_host_batch | # of host | # of batch | first step (s) | total time (s) |
| ----------- | ------- | ---------- | ----------------- | --------- | ---------- | ------------- | -------------- |
| HuggingFace | 2048 | TPU v4-8 | 32 (per_device=8) | 1 | 1000 | 6 | 72 |
| HuggingFace | 2048 | TPU v4-128 | 32 (per_device=8) | 16 | 1000 | 6 | 72 |

### Grain pipeline
The following data are collected using c4 data in ArrayRecord format.

| Pipeline | seq_len | VM type | per_host_batch | # of host | # of batch | worker | first step (s) | total time (s) |
| ----------- | ------- | ---------- | ----------------- | --------- | ---------- | ----- | -------------- | --------------- |
| Grain | 2048 | TPU v4-8 | 32 (per_device=8) | 1 | 1000 | 1 | 7 | 1200 |
Expand All @@ -29,6 +31,7 @@ The following data are collected using c4 data in ArrayRecord format.

### TFDS pipeline
The following data are collected using c4 data in TFRecord format.

| Pipeline | seq_len | VM type | per_host_batch | # of host | # of batch | first step (s) | total time (s) |
| ----------- | ------- | ---------- | ----------------- | --------- | ---------- | ------------- | -------------- |
| TFDS | 2048 | TPU v4-8 | 32 (per_device=8) | 1 | 1000 | 2 | 17 |
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Binary file added docs/getting_started/build_model.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
16 changes: 16 additions & 0 deletions docs/getting_started/steps_model.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Steps to build a Model

![](build_model.png)
_Fig1: Stages of LLM Model Development from pre-training to fine tuning and finally serving a model._

Model building starts with Pre-training a base model architecture. Pre-training is the process where you take a model architecture, which starts with random weights and train with a very large corpus in the scale of trillions of tokens. E.g. Google’s Gemma models were pre-trained on 6 Trillion tokens; LLama 3 was trained with 15 Trillion tokens

Post the pre-training most model producers will publish a checkpoint of the weights of the model. The corpus used for pre-training these models are usually a large public corpus like Common Crawl, public code bases, books etc.

Though these may be a great way to answer very general questions or prompts, they usually fail on very domain specific questions and answers like Medical and Life Sciences, Engineering, etc.

Customers and enterprises usually like to continue training a pre-trained model or performing a full fine tuning of the models using their own datasets. These datasets are usually in billions of tokens. This allows better prompt understanding when questions are asked on keywords and terms specific to their model or domain specific question.

Post a Full Fine Tuning, most models go through a process of Instruction Fine Tuning(PEFT/LoRA), Supervised Fine Tuning and RLHF to improve the model quality and follow prompt answers better.

PEFT/Lora, Supervised Finetuning are less expensive operations compared to full fine tuning.
1 change: 1 addition & 0 deletions docs/index.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{% include "../README.md" %}
15 changes: 15 additions & 0 deletions docs/reference/code_organization.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# MaxText Code Organization

Maxtext is purely written in JAX and python. Below are some folders and files
that show a high-level organization of the code and some key files.

File/Folder | Description
---------|---------------------------------
`configs` | Folder contains all the config file, including model configs (llama2, mistral etc) , and pre-optimized configs for different model size on different TPUs
`input_pipelines` | Input training data related code
`layers` | Model layer implementation
`end_to_end` | Example scripts to run Maxtext
`Maxtext/train.py` | The main training script you will run directly
`Maxtext/config/base.yaml` | The base configuration file containing all the related info: checkpointing, model arch, sharding schema, data input, learning rate, profile, compilation, decode
`Maxtext/decode.py` | This is a script to run offline inference with a sample prompt
`setup.sh`| Bash script used to install all needed library dependencies.
1 change: 1 addition & 0 deletions docs/reference/config_options.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Configuration options
36 changes: 36 additions & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
site_name: MaxText Documentation

theme:
name: material
features:
- navigation.tabs
- navigation.expand

docs_dir: docs

plugins:
- search
- include-markdown

markdown_extensions:
- tables

nav:
- Home: index.md
- about.md
- Getting started:
- getting_started/First_run.md
- getting_started/steps_model.md
- End-to-end example: https://www.kaggle.com/code/melissawm/maxtext-examples
- Advanced usage:
- getting_started/Run_MaxText_via_multihost_job.md
- getting_started/Run_MaxText_via_multihost_runner.md
- getting_started/Run_MaxText_via_xpk.md
- getting_started/Use_Vertex_AI_Tensorboard.md
- getting_started/Run_Llama2.md
- data_loading.md
- Reference:
- reference/code_organization.md
- reference/config_options.md
- getting_started/Data_Input_Pipeline.md
- getting_started/Data_Input_Perf.md
2 changes: 2 additions & 0 deletions requirements_docs.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
mkdocs-material
mkdocs-include-markdown-plugin

0 comments on commit 7206b6c

Please sign in to comment.