This is the implementation for our paper Steering Your Generalists: Improving Robotic Foundation Models via Value Guidance (CoRL 2024).
- paper link: https://arxiv.org/abs/2410.13816
- project page: https://nakamotoo.github.io/V-GPS/
- video: https://youtu.be/d5Yd_gJoZo0
This repository includes the code for training the language-conditioned Cal-QL value function, as well as the code for combining it with the Octo model for test-time action sampling and evaluation on the SIMPLER simulated environment. We also provide our value function checkpoint, pre-trained on the WidowX (Bridge) and Google Robot (Fractal) datasets, so you can directly run the evaluation without training your own model.
If you find this repository useful for your research, please cite:
@article{nakamoto2024steering,
author = {Mitsuhiko Nakamoto and Oier Mees and Aviral Kumar and Sergey Levine},
title = {Steering Your Generalists: Improving Robotic Foundation Models via Value Guidance},
journal = {Conference on Robot Learning (CoRL)},
year = {2024},
}
- Create a conda environment:
conda create -n vgps python=3.10
conda activate vgps
- Clone this repo with all submodules
git clone https://github.com/nakamotoo/V-GPS --recurse-submodules
cd V-GPS
- Install all packages and dependencies
pip install -e .
pip install -e octo
pip install -e SimplerEnv
pip install -e SimplerEnv/ManiSkill2_real2sim
pip install -r requirements.txt
For GPU:
pip install --upgrade "jax[cuda11_pip]==0.4.20" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
For TPU:
pip install --upgrade "jax[tpu]==0.4.20" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
We use the pretraining dataset from the Open X-Embodiment Dataset and you need to download & pre-process it into rlds format. Please refer to this instruction for more details.
Once you have prepared the dataset, you can run experiments using the following command. Be sure to set the data_dir
to the correct path of your dataset.
bash experiments/scripts/launch_calql.sh
To run the evaluation on Simpler environments
bash experiments/scripts/eval_vgps.sh
The evaluate the base policy without V-GPS:
bash experiments/scripts/eval_baseline.sh
To enable proper rendering you might need to install Vulkan as
apt-get install -yqq --no-install-recommends libvulkan-dev vulkan-tools
If you run into issues on setting up Simpler environment, please refer to SimplerEnv.
We provide a pre-trained checkpoint here. This checkpoint is trained with batch size of 256 for 500k steps on bridge and fractal datasets.
The offline RL training code is built upon bridge_data_v2 and Dibya Ghosh's jaxrl_m repositories. We also thank Paul Zhou for his initial implementation of Cal-QL in this repository. The dataloader is built upon octo, and the evaluation code is built upon SimplerEnv.
In case of any questions, bugs, suggestions or improvements, please feel free to contact me at nakamoto[at]berkeley[dot]edu