Skip to content

Expanded readme setup steps. #6

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 23, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 68 additions & 29 deletions deepseek_r1_jax/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -363,13 +363,49 @@ gcloud alpha compute tpus tpu-vm create "$TPU_NAME" --zone="$TPU_ZONE" \
--project="$PROJECT" --accelerator-type="$ACCELERATOR" --version="$IMAGE"
```

### Setting up code & data
### Utilities and package installation

```bash
# Util function to start remote processes on TPU machines.
TPU_NAME="..."
TPU_ZONE="..."
TPU_PROJECT="..."

tpu_exec() {
local workers=$(seq $1 $2 | tr '\n' ',')
gcloud alpha compute tpus tpu-vm ssh --zone="$TPU_ZONE" --project="$TPU_PROJECT" \
"$TPU_NAME" --worker="$workers" --command="$3"
}
```

```bash
# Install required packages and virtualenv.
INSTALL_COMMAND=$(cat << EOM
sudo apt update
sudo apt install -y nfs-common nfs-kernel-server nfs-server net-tools tmux python3-ipyparallel
curl -LsSf https://astral.sh/uv/install.sh | sh && source ~/.local/bin/env
uv python install 3.10 && uv venv --python 3.10
cd ~/
uv pip install -U "jax[tpu]" ipyparallel
if [ ! -d jax-llm-examples/deepseek_r1_jax ]; then
git clone https://github.com/jax-ml/jax-llm-examples.git
fi
cd jax-llm-examples/deepseek_r1_jax
uv pip install -e .
EOM
)

tpu_exec 0 15 "$INSTALL_COMMAND"
```

### Setting up code & data on the TPU-VM.

#### 1. [gcsfuse](https://cloud.google.com/storage/docs/cloud-storage-fuse/install#install-source-code)

For datasets and checkpoints.

```
```bash
mkdir {local_folder}
gcsfuse --implicit-dirs {bucket_name_no_gs://} {local_folder}
```
#### 2. NFS
Expand All @@ -378,8 +414,7 @@ For code consistency between hosts in the TPU Pod / Cluster.

```bash
# on worker 0
WORKER0_IP="..."
sudo apt install -y nfs-server nfs-common net-tools tmux
WORKER0_IP="..." # Internal IP address
mkdir -p ~/nfs; sudo umount ~/nfs
echo "$HOME/nfs $WORKER0_IP/24(rw,sync,no_subtree_check)" | sudo tee /etc/exports
sudo exportfs -a
Expand All @@ -388,10 +423,14 @@ sudo chown $USER:$USER -R ~/nfs
```

```bash
# on all other workers (!= 0)
SERVER_IP="..."
mkdir -p ~/nfs
sudo umount ~/nfs; sudo mount -t nfs $SERVER_IP:/home/$USER/nfs ~/nfs
MOUNT_COMMAND=$(cat << EOM
mkdir -p ~/nfs
sudo umount ~/nfs
# VM_USER should be username in your TPU VM and should be the same across all VM workers.
sudo mount -t nfs WORKER0_IP:/home/VM_USER/nfs ~/nfs
EOM
)
tpu_exec 1 15 "$MOUNT_COMMAND"
```

#### (Optionally) 3. [sshfs](https://github.com/libfuse/sshfs)
Expand All @@ -402,21 +441,6 @@ For a quick preview from a local machine.
sshfs ~/local_folder TPU_WORKER_0_IP:~/remote_folder
```

### Utilities

```bash
TPU_NAME="..."
TPU_ZONE="..."
TPU_PROJECT="..."

tpu_exec() {
local workers=$(seq $1 $2 | tr '\n' ',')
gcloud alpha compute tpus tpu-vm ssh --zone="$TPU_ZONE" --project="$TPU_PROJECT" \
"$TPU_NAME" --worker="$workers" --command="$2"
}
tpu_exec all 'pip install -U "jax[tpu]"'
```

### Starting the `ipyparallel` cluster

Start $N - 1$ workers (ipyparallel calls them `engines`) because we want worker 0 to execute interactively.
Expand All @@ -426,21 +450,21 @@ SERVER_IP="..."
CONTROLLER_SETUP=$(cat << EOM
tmux kill-session -t controller; pkill -9 python
tmux new -d -s controller '\
. ~/venv/bin/activate && ipcontroller --profile-dir=~/nfs --ip=$SERVER_IP'
. ~/.venv/bin/activate && ipcontroller --profile-dir=~/nfs --ip=$SERVER_IP'
EOM
)

ENGINE_SETUP=$(cat << EOM
tmux kill-session -t engine; pkill -9 ipengine
tmux new -d -s engine '. ~/venv/bin/activate && ipengine --profile-dir=~/nfs'
tmux new -d -s engine '. ~/.venv/bin/activate && ipengine --profile-dir=~/nfs'
EOM
)

tpu_exec 0 0 "$CONTROLLER_CMD" # only worker 0
tpu_exec 1 15 "$ENGINE_CMD" # all workers except worker 0
tpu_exec 0 0 "$CONTROLLER_SETUP" # only worker 0
tpu_exec 1 15 "$ENGINE_SETUP" # all workers except worker 0
```

#### Jupyter Notebook
#### Confirm ipyparallel setup works by ssh'ing into worker 0.
> Cell 0:
```python
import ipyparallel as ip
Expand All @@ -462,7 +486,22 @@ print(f"Hello from {socket.gethostname()}")

> Note: "--local" argument means "also run on this process", it's necessary to
> get easy access to the output of computations on worker 0

> After you've confirmed this setup works, you can utilize main.ipynb to run inference.

### Troubleshooting
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's an awesome section!

- Runing into import errors/package incompatibility errors when running on the engines.\
Solution: check if you have activated your virtualenv before running jupyter and on all of the engines.\
E.g.\
%%px --local\
print(sys.executable)
- Your notebook is hanging and eventually times out when running jax.distributed.initialize().\
Solution: you are likely running a pre-existing jupyter session and have not cleared the Engines.\
Run client[:].abort() to clear.\
Worst case you may need to restart the engines via:\
tpu_exec 0 0 "$CONTROLLER_CMD".
- You've encountered OOM on a run that shouldn't have run out of memory.\
Solution: you have again likely not cleared pre-existing sessions and still have weights loaded in memory.\
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

libtpu is an exclusive process, you won't be able to run two jax processes with tpu memory usage

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it'll actually fail to jax.distributed.initialize I think

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting, do you have any other thoughts on why I encountered OOMs but clearing the engines fixed the issue?

Please run the above suggestions and try again.

## Next steps

Expand Down