Skip to content

Commit 5034db6

Browse files
committed
Add doc for script
1 parent 7f88c2d commit 5034db6

File tree

4 files changed

+33
-246
lines changed

4 files changed

+33
-246
lines changed

.ci/scripts/validate.sh

+24-12
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ function generate_compiled_model_output() {
2525
local MODEL_DIR="${CHECKPOINT_PATH%/*}"
2626
local MODEL_NAME=$(basename "$CHECKPOINT_PATH" | sed 's/\.[^.]*$//')
2727

28-
for DTYPE in float32 bfloat16; do
28+
for DTYPE in float32 bfloat16 float16; do
2929
echo ""############### Run inference with torch.compile for dtype $DTYPE "###############"
3030
echo ""
3131
echo "******************************************"
@@ -85,10 +85,14 @@ function generate_compiled_model_output() {
8585
echo "******************************************"
8686
echo "******** INT4 group-wise quantized *******"
8787
echo "******************************************"
88-
python -W ignore generate.py --dtype ${DTYPE} --quant '{"linear:int4" : {"groupsize": 32}}' --checkpoint-path "$CHECKPOINT_PATH" --temperature 0 --device "$TARGET_DEVICE" > "$MODEL_DIR/output_eager" || exit 1
89-
cat "$MODEL_DIR/output_eager"
90-
python -W ignore generate.py --dtype ${DTYPE} --compile --quant '{"linear:int4" : {"groupsize": 32}}' --checkpoint-path "$CHECKPOINT_PATH" --temperature 0 --device "$TARGET_DEVICE" > "$MODEL_DIR/output_compiled" || exit 1
91-
cat "$MODEL_DIR/output_compiled"
88+
if [ "$DTYPE" = float16 ]; then
89+
echo "Skipping INT4 groupwise quantization for float16 because compile fails"
90+
else
91+
python -W ignore generate.py --dtype ${DTYPE} --quant '{"linear:int4" : {"groupsize": 32}}' --checkpoint-path "$CHECKPOINT_PATH" --temperature 0 --device "$TARGET_DEVICE" > "$MODEL_DIR/output_eager" || exit 1
92+
cat "$MODEL_DIR/output_eager"
93+
python -W ignore generate.py --dtype ${DTYPE} --compile --quant '{"linear:int4" : {"groupsize": 32}}' --checkpoint-path "$CHECKPOINT_PATH" --temperature 0 --device "$TARGET_DEVICE" > "$MODEL_DIR/output_compiled" || exit 1
94+
cat "$MODEL_DIR/output_compiled"
95+
fi
9296
done
9397
}
9498

@@ -98,7 +102,7 @@ function generate_aoti_model_output() {
98102
local MODEL_DIR="${CHECKPOINT_PATH%/*}"
99103
local MODEL_NAME=$(basename "$CHECKPOINT_PATH" | sed 's/\.[^.]*$//')
100104

101-
for DTYPE in float32 bfloat16; do
105+
for DTYPE in float32 bfloat16 float16; do
102106
echo ""############### Run inference with AOT Inductor for dtype $DTYPE "###############"
103107
echo ""
104108
echo "******************************************"
@@ -150,12 +154,20 @@ function generate_aoti_model_output() {
150154
python -W ignore generate.py --dtype ${DTYPE} --checkpoint-path "$CHECKPOINT_PATH" --temperature 0 --dso-path ${MODEL_DIR}/${MODEL_NAME}.so --device "$TARGET_DEVICE" > "$MODEL_DIR/output_aoti" || exit 1
151155
cat "$MODEL_DIR/output_aoti"
152156

153-
# echo "******************************************"
154-
# echo "******** INT4 group-wise quantized *******"
155-
# echo "******************************************"
156-
# python -W ignore export.py --dtype ${DTYPE} --quant '{"linear:int4" : {"groupsize": 32}}' --checkpoint-path "$CHECKPOINT_PATH" --output-dso-path ${MODEL_DIR}/${MODEL_NAME}.so --device "$TARGET_DEVICE" || exit 1
157-
# python -W ignore generate.py --dtype ${DTYPE} --checkpoint-path "$CHECKPOINT_PATH" --temperature 0 --dso-path ${MODEL_DIR}/${MODEL_NAME}.so --device "$TARGET_DEVICE" > "$MODEL_DIR/output_aoti" || exit 1
158-
# cat "$MODEL_DIR/output_aoti"
157+
echo "******************************************"
158+
echo "******** INT4 group-wise quantized *******"
159+
echo "******************************************"
160+
if [ "$DTYPE" = float16 ]; then
161+
echo "Skipping INT4 groupwise quantization for float16 because AOTI fails"
162+
else
163+
if [ $(uname -s) == "Linux" ]; then
164+
echo "Skipping INT4 groupwise quantization because AOTI fails"
165+
else
166+
python -W ignore export.py --dtype ${DTYPE} --quant '{"linear:int4" : {"groupsize": 32}}' --checkpoint-path "$CHECKPOINT_PATH" --output-dso-path ${MODEL_DIR}/${MODEL_NAME}.so --device "$TARGET_DEVICE" || exit 1
167+
python -W ignore generate.py --dtype ${DTYPE} --checkpoint-path "$CHECKPOINT_PATH" --temperature 0 --dso-path ${MODEL_DIR}/${MODEL_NAME}.so --device "$TARGET_DEVICE" > "$MODEL_DIR/output_aoti" || exit 1
168+
cat "$MODEL_DIR/output_aoti"
169+
fi
170+
fi
159171
done
160172
}
161173

.github/workflows/compile-dtype.yml

-118
This file was deleted.

.github/workflows/compile_t4-dtype.yml

-115
This file was deleted.

scripts/workflow.sh

+9-1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,14 @@
55
# This source code is licensed under the BSD-style license found in the
66
# LICENSE file in the root directory of this source tree.
77

8+
################################################################################
9+
# Usage:
10+
# bash script.sh [cpu|cuda] [model_repo] [optional_command]
11+
# Arguments:
12+
# cpu|cuda: Specify the device to run validation on (cpu or cuda).
13+
# model_repo: Model repository name to validate (e.g., tinyllamas/stories15M).
14+
# optional_command: (optional) Specify additional command "compile", "aoti" or "executorch" to run the selected validation.
15+
################################################################################
816

917
set -eu
1018

@@ -75,7 +83,7 @@ MODEL_REPOS=(
7583
"mistralai/Mistral-7B-v0.1"
7684
"mistralai/Mistral-7B-Instruct-v0.1"
7785
"mistralai/Mistral-7B-Instruct-v0.2"
78-
# "openlm-research/open_llama_7b"
86+
"openlm-research/open_llama_7b"
7987
"codellama/CodeLlama-7b-Python-hf"
8088
"codellama/CodeLlama-34b-Python-hf"
8189
# "meta-llama/Llama-2-7b-chat-hf"

0 commit comments

Comments
 (0)