Skip to content

Commit 3623645

Browse files
authored
chat mode improvements (#244)
* chat mode improvements * disable int4 on macos/x86 because of old nightlies * typo * typo * typo * convert runtime error to arning * wording of option texts
1 parent 5450e3e commit 3623645

File tree

4 files changed

+45
-11
lines changed

4 files changed

+45
-11
lines changed

.github/workflows/eager-dtype.yml

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,9 +77,12 @@ jobs:
7777
echo "******************************************"
7878
echo "******** INT4 group-wise quantized *******"
7979
echo "******************************************"
80-
81-
python generate.py --dtype ${DTYPE} --quant '{"linear:int4" : {"groupsize": 32}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager
82-
cat ./output_eager
80+
81+
echo "INT4 should work on MacOS on x86, but cannot be tested"
82+
echo "because nightlies are too old!"
83+
84+
# python generate.py --dtype ${DTYPE} --quant '{"linear:int4" : {"groupsize": 32}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager
85+
# cat ./output_eager
8386

8487
echo "tests complete for ${DTYPE}"
8588
done

build/builder.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@ class BuilderArgs:
3535
precision: torch.dtype = torch.float32
3636
setup_caches: bool = False
3737
use_tp: bool = False
38-
38+
is_chat_model: bool = False
39+
3940
def __post_init__(self):
4041
if not (
4142
(self.checkpoint_path and self.checkpoint_path.is_file())
@@ -66,6 +67,24 @@ def __post_init__(self):
6667

6768
@classmethod
6869
def from_args(cls, args): # -> BuilderArgs:
70+
is_chat_model = False
71+
if args.is_chat_model:
72+
is_chat_model = True
73+
else:
74+
for path in [
75+
args.checkpoint_path,
76+
args.checkpoint_dir,
77+
args.dso_path,
78+
args.pte_path,
79+
args.gguf_path
80+
]:
81+
path = str(path)
82+
if path.endswith('/'):
83+
path = path[:-1]
84+
path_basename = os.path.basename(path)
85+
if "chat" in path_basename:
86+
is_chat_model = True
87+
6988
return cls(
7089
checkpoint_path=args.checkpoint_path,
7190
checkpoint_dir=args.checkpoint_dir,
@@ -78,6 +97,7 @@ def from_args(cls, args): # -> BuilderArgs:
7897
precision=name_to_dtype(args.dtype),
7998
setup_caches=(args.output_dso_path or args.output_pte_path),
8099
use_tp=False,
100+
is_chat_model=is_chat_model,
81101
)
82102

83103
@classmethod

cli.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,12 @@ def _add_arguments_common(parser):
7676
parser.add_argument(
7777
"--chat",
7878
action="store_true",
79-
help="Use torchchat to for an interactive chat session.",
79+
help="Use torchchat for an interactive chat session.",
80+
)
81+
parser.add_argument(
82+
"--is-chat-model",
83+
action="store_true",
84+
help="Indicate that the model was trained to support chat functionality.",
8085
)
8186
parser.add_argument(
8287
"--gui",

generate.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from cli import add_arguments_for_generate, arg_init, check_args
2828
from quantize import set_precision
2929

30+
B_INST, E_INST = "[INST]", "[/INST]"
3031

3132
@dataclass
3233
class GeneratorArgs:
@@ -343,11 +344,16 @@ def _main(
343344
set_precision(builder_args.precision)
344345
is_speculative = speculative_builder_args.checkpoint_path is not None
345346

346-
is_chat = "chat" in str(os.path.basename(builder_args.checkpoint_path))
347-
if is_chat:
348-
raise RuntimeError(
349-
"need to stop filename based kludgery, at a minimum need to look at all pathnames. in particular, this now fails because chat is part of the pathname, yuck!"
350-
)
347+
if generator_args.chat_mode and not builder_args.is_chat_model:
348+
print("""
349+
*******************************************************
350+
This model is not known to support the chat function.
351+
We will enable chat mode based on your instructions.
352+
If the model is not trained to support chat, it will
353+
produce nonsensical or false output.
354+
*******************************************************
355+
""")
356+
# raise RuntimeError("You need to use --is-chat-model to indicate model has chat support.")
351357

352358
tokenizer = _initialize_tokenizer(tokenizer_args)
353359

@@ -410,7 +416,7 @@ def _main(
410416
device_sync(device=builder_args.device)
411417
if i >= 0 and generator_args.chat_mode:
412418
prompt = input("What is your prompt? ")
413-
if is_chat:
419+
if builder_args.is_chat_model:
414420
prompt = f"{B_INST} {prompt.strip()} {E_INST}"
415421
encoded = encode_tokens(
416422
tokenizer, prompt, bos=True, device=builder_args.device

0 commit comments

Comments
 (0)