@@ -25,7 +25,7 @@ function generate_compiled_model_output() {
25
25
local MODEL_DIR=" ${CHECKPOINT_PATH%/* } "
26
26
local MODEL_NAME=$( basename " $CHECKPOINT_PATH " | sed ' s/\.[^.]*$//' )
27
27
28
- for DTYPE in float32 bfloat16; do
28
+ for DTYPE in float32 bfloat16 float16 ; do
29
29
echo " " # ############## Run inference with torch.compile for dtype $DTYPE "###############"
30
30
echo " "
31
31
echo " ******************************************"
@@ -85,10 +85,14 @@ function generate_compiled_model_output() {
85
85
echo " ******************************************"
86
86
echo " ******** INT4 group-wise quantized *******"
87
87
echo " ******************************************"
88
- python -W ignore generate.py --dtype ${DTYPE} --quant ' {"linear:int4" : {"groupsize": 32}}' --checkpoint-path " $CHECKPOINT_PATH " --temperature 0 --device " $TARGET_DEVICE " > " $MODEL_DIR /output_eager" || exit 1
89
- cat " $MODEL_DIR /output_eager"
90
- python -W ignore generate.py --dtype ${DTYPE} --compile --quant ' {"linear:int4" : {"groupsize": 32}}' --checkpoint-path " $CHECKPOINT_PATH " --temperature 0 --device " $TARGET_DEVICE " > " $MODEL_DIR /output_compiled" || exit 1
91
- cat " $MODEL_DIR /output_compiled"
88
+ if [ " $DTYPE " = float16 ]; then
89
+ echo " Skipping INT4 groupwise quantization for float16 because compile fails"
90
+ else
91
+ python -W ignore generate.py --dtype ${DTYPE} --quant ' {"linear:int4" : {"groupsize": 32}}' --checkpoint-path " $CHECKPOINT_PATH " --temperature 0 --device " $TARGET_DEVICE " > " $MODEL_DIR /output_eager" || exit 1
92
+ cat " $MODEL_DIR /output_eager"
93
+ python -W ignore generate.py --dtype ${DTYPE} --compile --quant ' {"linear:int4" : {"groupsize": 32}}' --checkpoint-path " $CHECKPOINT_PATH " --temperature 0 --device " $TARGET_DEVICE " > " $MODEL_DIR /output_compiled" || exit 1
94
+ cat " $MODEL_DIR /output_compiled"
95
+ fi
92
96
done
93
97
}
94
98
@@ -98,7 +102,7 @@ function generate_aoti_model_output() {
98
102
local MODEL_DIR=" ${CHECKPOINT_PATH%/* } "
99
103
local MODEL_NAME=$( basename " $CHECKPOINT_PATH " | sed ' s/\.[^.]*$//' )
100
104
101
- for DTYPE in float32 bfloat16; do
105
+ for DTYPE in float32 bfloat16 float16 ; do
102
106
echo " " # ############## Run inference with AOT Inductor for dtype $DTYPE "###############"
103
107
echo " "
104
108
echo " ******************************************"
@@ -150,12 +154,20 @@ function generate_aoti_model_output() {
150
154
python -W ignore generate.py --dtype ${DTYPE} --checkpoint-path " $CHECKPOINT_PATH " --temperature 0 --dso-path ${MODEL_DIR} /${MODEL_NAME} .so --device " $TARGET_DEVICE " > " $MODEL_DIR /output_aoti" || exit 1
151
155
cat " $MODEL_DIR /output_aoti"
152
156
153
- # echo "******************************************"
154
- # echo "******** INT4 group-wise quantized *******"
155
- # echo "******************************************"
156
- # python -W ignore export.py --dtype ${DTYPE} --quant '{"linear:int4" : {"groupsize": 32}}' --checkpoint-path "$CHECKPOINT_PATH" --output-dso-path ${MODEL_DIR}/${MODEL_NAME}.so --device "$TARGET_DEVICE" || exit 1
157
- # python -W ignore generate.py --dtype ${DTYPE} --checkpoint-path "$CHECKPOINT_PATH" --temperature 0 --dso-path ${MODEL_DIR}/${MODEL_NAME}.so --device "$TARGET_DEVICE" > "$MODEL_DIR/output_aoti" || exit 1
158
- # cat "$MODEL_DIR/output_aoti"
157
+ echo " ******************************************"
158
+ echo " ******** INT4 group-wise quantized *******"
159
+ echo " ******************************************"
160
+ if [ " $DTYPE " = float16 ]; then
161
+ echo " Skipping INT4 groupwise quantization for float16 because AOTI fails"
162
+ else
163
+ if [ $( uname -s) == " Linux" ]; then
164
+ echo " Skipping INT4 groupwise quantization because AOTI fails"
165
+ else
166
+ python -W ignore export.py --dtype ${DTYPE} --quant ' {"linear:int4" : {"groupsize": 32}}' --checkpoint-path " $CHECKPOINT_PATH " --output-dso-path ${MODEL_DIR} /${MODEL_NAME} .so --device " $TARGET_DEVICE " || exit 1
167
+ python -W ignore generate.py --dtype ${DTYPE} --checkpoint-path " $CHECKPOINT_PATH " --temperature 0 --dso-path ${MODEL_DIR} /${MODEL_NAME} .so --device " $TARGET_DEVICE " > " $MODEL_DIR /output_aoti" || exit 1
168
+ cat " $MODEL_DIR /output_aoti"
169
+ fi
170
+ fi
159
171
done
160
172
}
161
173
0 commit comments