@@ -85,10 +85,14 @@ function generate_compiled_model_output() {
8585 echo " ******************************************"
8686 echo " ******** INT4 group-wise quantized *******"
8787 echo " ******************************************"
88- python -W ignore generate.py --dtype ${DTYPE} --quant ' {"linear:int4" : {"groupsize": 32}}' --checkpoint-path " $CHECKPOINT_PATH " --temperature 0 --device " $TARGET_DEVICE " > " $MODEL_DIR /output_eager" || exit 1
89- cat " $MODEL_DIR /output_eager"
90- python -W ignore generate.py --dtype ${DTYPE} --compile --quant ' {"linear:int4" : {"groupsize": 32}}' --checkpoint-path " $CHECKPOINT_PATH " --temperature 0 --device " $TARGET_DEVICE " > " $MODEL_DIR /output_compiled" || exit 1
91- cat " $MODEL_DIR /output_compiled"
88+ if [ " $DTYPE " = float16 ]; then
89+ echo " Skipping INT4 groupwise quantization for float16 because torch.compile fails"
90+ else
91+ python -W ignore generate.py --dtype ${DTYPE} --quant ' {"linear:int4" : {"groupsize": 32}}' --checkpoint-path " $CHECKPOINT_PATH " --temperature 0 --device " $TARGET_DEVICE " > " $MODEL_DIR /output_eager" || exit 1
92+ cat " $MODEL_DIR /output_eager"
93+ python -W ignore generate.py --dtype ${DTYPE} --compile --quant ' {"linear:int4" : {"groupsize": 32}}' --checkpoint-path " $CHECKPOINT_PATH " --temperature 0 --device " $TARGET_DEVICE " > " $MODEL_DIR /output_compiled" || exit 1
94+ cat " $MODEL_DIR /output_compiled"
95+ fi
9296 done
9397}
9498
@@ -153,12 +157,16 @@ function generate_aoti_model_output() {
153157 echo " ******************************************"
154158 echo " ******** INT4 group-wise quantized *******"
155159 echo " ******************************************"
156- if [ $( uname -s) == " Linux" ]; then
157- echo " Skipping INT4 groupwise quantization because AOTI fails"
158- else
159- python -W ignore export.py --dtype ${DTYPE} --quant ' {"linear:int4" : {"groupsize": 32}}' --checkpoint-path " $CHECKPOINT_PATH " --output-dso-path ${MODEL_DIR} /${MODEL_NAME} .so --device " $TARGET_DEVICE " || exit 1
160- python -W ignore generate.py --dtype ${DTYPE} --checkpoint-path " $CHECKPOINT_PATH " --temperature 0 --dso-path ${MODEL_DIR} /${MODEL_NAME} .so --device " $TARGET_DEVICE " > " $MODEL_DIR /output_aoti" || exit 1
161- cat " $MODEL_DIR /output_aoti"
160+ if [ " $DTYPE " = float16 ]; then
161+ echo " Skipping INT4 groupwise quantization for float16 because AOTI fails"
162+ else
163+ if [ $( uname -s) == " Linux" ]; then
164+ echo " Skipping INT4 groupwise quantization because AOTI fails"
165+ else
166+ python -W ignore export.py --dtype ${DTYPE} --quant ' {"linear:int4" : {"groupsize": 32}}' --checkpoint-path " $CHECKPOINT_PATH " --output-dso-path ${MODEL_DIR} /${MODEL_NAME} .so --device " $TARGET_DEVICE " || exit 1
167+ python -W ignore generate.py --dtype ${DTYPE} --checkpoint-path " $CHECKPOINT_PATH " --temperature 0 --dso-path ${MODEL_DIR} /${MODEL_NAME} .so --device " $TARGET_DEVICE " > " $MODEL_DIR /output_aoti" || exit 1
168+ cat " $MODEL_DIR /output_aoti"
169+ fi
162170 fi
163171 done
164172}
0 commit comments