-
Notifications
You must be signed in to change notification settings - Fork 15
Expanded readme setup steps. #6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -363,13 +363,49 @@ gcloud alpha compute tpus tpu-vm create "$TPU_NAME" --zone="$TPU_ZONE" \ | |
--project="$PROJECT" --accelerator-type="$ACCELERATOR" --version="$IMAGE" | ||
``` | ||
|
||
### Setting up code & data | ||
### Utilities and package installation | ||
|
||
```bash | ||
# Util function to start remote processes on TPU machines. | ||
TPU_NAME="..." | ||
TPU_ZONE="..." | ||
TPU_PROJECT="..." | ||
|
||
tpu_exec() { | ||
local workers=$(seq $1 $2 | tr '\n' ',') | ||
gcloud alpha compute tpus tpu-vm ssh --zone="$TPU_ZONE" --project="$TPU_PROJECT" \ | ||
"$TPU_NAME" --worker="$workers" --command="$3" | ||
} | ||
``` | ||
|
||
```bash | ||
# Install required packages and virtualenv. | ||
INSTALL_COMMAND=$(cat << EOM | ||
sudo apt update | ||
sudo apt install -y nfs-common nfs-kernel-server nfs-server net-tools tmux python3-ipyparallel | ||
curl -LsSf https://astral.sh/uv/install.sh | sh && source ~/.local/bin/env | ||
uv python install 3.10 && uv venv --python 3.10 | ||
cd ~/ | ||
uv pip install -U "jax[tpu]" ipyparallel | ||
if [ ! -d jax-llm-examples/deepseek_r1_jax ]; then | ||
git clone https://github.com/jax-ml/jax-llm-examples.git | ||
fi | ||
cd jax-llm-examples/deepseek_r1_jax | ||
uv pip install -e . | ||
EOM | ||
) | ||
|
||
tpu_exec 0 15 "$INSTALL_COMMAND" | ||
``` | ||
|
||
### Setting up code & data on the TPU-VM. | ||
|
||
#### 1. [gcsfuse](https://cloud.google.com/storage/docs/cloud-storage-fuse/install#install-source-code) | ||
|
||
For datasets and checkpoints. | ||
|
||
``` | ||
```bash | ||
mkdir {local_folder} | ||
gcsfuse --implicit-dirs {bucket_name_no_gs://} {local_folder} | ||
``` | ||
#### 2. NFS | ||
|
@@ -378,8 +414,7 @@ For code consistency between hosts in the TPU Pod / Cluster. | |
|
||
```bash | ||
# on worker 0 | ||
WORKER0_IP="..." | ||
sudo apt install -y nfs-server nfs-common net-tools tmux | ||
WORKER0_IP="..." # Internal IP address | ||
mkdir -p ~/nfs; sudo umount ~/nfs | ||
echo "$HOME/nfs $WORKER0_IP/24(rw,sync,no_subtree_check)" | sudo tee /etc/exports | ||
sudo exportfs -a | ||
|
@@ -388,10 +423,14 @@ sudo chown $USER:$USER -R ~/nfs | |
``` | ||
|
||
```bash | ||
# on all other workers (!= 0) | ||
SERVER_IP="..." | ||
mkdir -p ~/nfs | ||
sudo umount ~/nfs; sudo mount -t nfs $SERVER_IP:/home/$USER/nfs ~/nfs | ||
MOUNT_COMMAND=$(cat << EOM | ||
mkdir -p ~/nfs | ||
sudo umount ~/nfs | ||
# VM_USER should be username in your TPU VM and should be the same across all VM workers. | ||
sudo mount -t nfs WORKER0_IP:/home/VM_USER/nfs ~/nfs | ||
EOM | ||
) | ||
tpu_exec 1 15 "$MOUNT_COMMAND" | ||
``` | ||
|
||
#### (Optionally) 3. [sshfs](https://github.com/libfuse/sshfs) | ||
|
@@ -402,21 +441,6 @@ For a quick preview from a local machine. | |
sshfs ~/local_folder TPU_WORKER_0_IP:~/remote_folder | ||
``` | ||
|
||
### Utilities | ||
|
||
```bash | ||
TPU_NAME="..." | ||
TPU_ZONE="..." | ||
TPU_PROJECT="..." | ||
|
||
tpu_exec() { | ||
local workers=$(seq $1 $2 | tr '\n' ',') | ||
gcloud alpha compute tpus tpu-vm ssh --zone="$TPU_ZONE" --project="$TPU_PROJECT" \ | ||
"$TPU_NAME" --worker="$workers" --command="$2" | ||
} | ||
tpu_exec all 'pip install -U "jax[tpu]"' | ||
``` | ||
|
||
### Starting the `ipyparallel` cluster | ||
|
||
Start $N - 1$ workers (ipyparallel calls them `engines`) because we want worker 0 to execute interactively. | ||
|
@@ -426,21 +450,21 @@ SERVER_IP="..." | |
CONTROLLER_SETUP=$(cat << EOM | ||
tmux kill-session -t controller; pkill -9 python | ||
tmux new -d -s controller '\ | ||
. ~/venv/bin/activate && ipcontroller --profile-dir=~/nfs --ip=$SERVER_IP' | ||
. ~/.venv/bin/activate && ipcontroller --profile-dir=~/nfs --ip=$SERVER_IP' | ||
EOM | ||
) | ||
|
||
ENGINE_SETUP=$(cat << EOM | ||
tmux kill-session -t engine; pkill -9 ipengine | ||
tmux new -d -s engine '. ~/venv/bin/activate && ipengine --profile-dir=~/nfs' | ||
tmux new -d -s engine '. ~/.venv/bin/activate && ipengine --profile-dir=~/nfs' | ||
EOM | ||
) | ||
|
||
tpu_exec 0 0 "$CONTROLLER_CMD" # only worker 0 | ||
tpu_exec 1 15 "$ENGINE_CMD" # all workers except worker 0 | ||
tpu_exec 0 0 "$CONTROLLER_SETUP" # only worker 0 | ||
tpu_exec 1 15 "$ENGINE_SETUP" # all workers except worker 0 | ||
``` | ||
|
||
#### Jupyter Notebook | ||
#### Confirm ipyparallel setup works by ssh'ing into worker 0. | ||
> Cell 0: | ||
```python | ||
import ipyparallel as ip | ||
|
@@ -462,7 +486,22 @@ print(f"Hello from {socket.gethostname()}") | |
|
||
> Note: "--local" argument means "also run on this process", it's necessary to | ||
> get easy access to the output of computations on worker 0 | ||
|
||
> After you've confirmed this setup works, you can utilize main.ipynb to run inference. | ||
|
||
### Troubleshooting | ||
- Runing into import errors/package incompatibility errors when running on the engines.\ | ||
Solution: check if you have activated your virtualenv before running jupyter and on all of the engines.\ | ||
E.g.\ | ||
%%px --local\ | ||
print(sys.executable) | ||
- Your notebook is hanging and eventually times out when running jax.distributed.initialize().\ | ||
Solution: you are likely running a pre-existing jupyter session and have not cleared the Engines.\ | ||
Run client[:].abort() to clear.\ | ||
Worst case you may need to restart the engines via:\ | ||
tpu_exec 0 0 "$CONTROLLER_CMD". | ||
- You've encountered OOM on a run that shouldn't have run out of memory.\ | ||
Solution: you have again likely not cleared pre-existing sessions and still have weights loaded in memory.\ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it'll actually fail to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Interesting, do you have any other thoughts on why I encountered OOMs but clearing the engines fixed the issue? |
||
Please run the above suggestions and try again. | ||
|
||
## Next steps | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's an awesome section!