-
Notifications
You must be signed in to change notification settings - Fork 432
Open
Description
We want to better handle a world where rollouts are really slow (long inference models, agents w/ api calls) in open-instruct.
One option is to keep the vllm engines running at high throughput and just update then as you go, and gather finished generations as they come in ala pipelinerl (without a new KV cache)
Command to reproduce long generations:
export integration_mix="hamishivi/saurabh5_rlvr_acecoder_all_filtered_qwen2_5_openthoughts2 10000 hamishivi/hamishivi_rlvr_orz_math_57k_collected_all_filtered_hamishivi_qwen2_5_openthoughts2 8500 hamishivi/tulu_3_rewritten_400k_string_f1_only_v2_all_filtered_qwen2_5_openthoughts2 10000 hamishivi/allenai_IF_multi_constraints_upto5_all_filtered_qwen2_5_openthoughts2 10000"
export orz_only_mix="hamishivi/hamishivi_rlvr_orz_math_57k_collected_all_filtered_hamishivi_qwen2_5_openthoughts2 1.0"
# testing:
# logic data only
# multi subject only.
# olmo2_lc qwen3
for model in qwen2_5; do
for split_var in orz_only_mix; do
split_value="${!split_var}"
exp_name=2606rl_update_fil_dapo_pol_${model}_${split_var}_${RANDOM}
if [ "$model" == "qwen3" ]; then
model_name_or_path=hamishivi/qwen3_openthoughts2
chat_template_name=tulu_thinker
add_bos=False
elif [ "$model" == "olmo2_lc" ]; then
model_name_or_path=hamishivi/olmo2_lc_ot2
chat_template_name=tulu_thinker
add_bos=True
elif [ "$model" == "qwen2_5" ]; then
model_name_or_path=hamishivi/qwen2_5_openthoughts2
chat_template_name=tulu_thinker
add_bos=False
fi
python mason.py \
--cluster ai2/jupiter-cirrascale-2 --image hamishivi/open_instruct_hamish_update_2606 \
--pure_docker_mode \
--workspace ai2/tulu-thinker \
--priority high \
--preemptible \
--num_nodes 6 \
--max_retries 0 \
--env VLLM_ALLOW_LONG_MAX_MODEL_LEN=1 \
--budget ai2/oe-adapt \
--gpus 8 -- source configs/beaker_configs/ray_node_setup.sh \&\& source configs/beaker_configs/code_api_setup.sh \&\& python open_instruct/grpo_fast.py \
--exp_name ${exp_name} \
--beta 0.0 \
--num_samples_per_prompt_rollout 16 \
--num_unique_prompts_rollout 256 \
--num_mini_batches 1 \
--num_epochs 1 \
--learning_rate 5e-7 \
--per_device_train_batch_size 1 \
--output_dir /output \
--kl_estimator kl3 \
--dataset_mixer_list ${split_value} \
--dataset_mixer_list_splits train \
--dataset_mixer_eval_list hamishivi/tulu_3_rewritten_100k 32 \
--dataset_mixer_eval_list_splits train \
--max_token_length 10240 \
--add_bos ${add_bos} \
--max_prompt_token_length 2048 \
--response_length 16384 \
--pack_length 20480 \
--model_name_or_path ${model_name_or_path} \
--chat_template_name ${chat_template_name} \
--stop_strings "</answer>" \
--non_stop_penalty False \
--temperature 1.0 \
--ground_truths_key ground_truth \
--sft_messages_key messages \
--total_episodes 10000000 \
--deepspeed_stage 2 \
--num_learners_per_node 8 8 \
--vllm_num_engines 32 \
--vllm_tensor_parallel_size 1 \
--lr_scheduler_type constant \
--apply_verifiable_reward true \
--seed 1 \
--num_evals 5 \
--save_freq 50 \
--try_launch_beaker_eval_jobs_on_weka True \
--gradient_checkpointing \
--with_tracking \
--vllm_enable_prefix_caching \
--clip_higher 0.28 \
--mask_truncated_completions True \
--oe_eval_max_length 32768 \
--oe_eval_tasks "minerva_math::hamish_zs_reasoning,gsm8k::zs_cot_latex,gsm8k::hamish_zs_reasoning,minerva_math_500::hamish_zs_reasoning,zebralogic::hamish_zs_reasoning,aime::hamish_zs_reasoning,agi_eval_english:0shot_cot::hamish_zs_reasoning,gpqa:0shot_cot::hamish_zs_reasoning,ifeval::hamish_zs_reasoning,popqa::hamish_zs_reasoning,mmlu:cot::hamish_zs_reasoning,alpaca_eval_v3::hamish_zs_reasoning,bbh:cot::hamish_zs_reasoning,mbppplus:0-shot-chat::tulu-thinker,codex_humanevalplus:0-shot-chat-v1::tulu-thinker"
done
done
Metadata
Metadata
Assignees
Labels
No labels