Skip to content

Commit 45b68c3

Browse files
author
George Hong
committed
Clean up logging for generate subcommand and setup log levels
1 parent 3623645 commit 45b68c3

File tree

2 files changed

+18
-12
lines changed

2 files changed

+18
-12
lines changed

generate.py

+15-12
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@
2727
from cli import add_arguments_for_generate, arg_init, check_args
2828
from quantize import set_precision
2929

30+
import logging
31+
logger = logging.getLogger(__name__)
32+
3033
B_INST, E_INST = "[INST]", "[/INST]"
3134

3235
@dataclass
@@ -66,7 +69,7 @@ def device_sync(device):
6669
elif ("cpu" in device) or ("mps" in device):
6770
pass
6871
else:
69-
print(f"device={ device } is not yet suppported")
72+
logging.error(f"device={ device } is not yet suppported")
7073

7174

7275
torch._inductor.config.coordinate_descent_tuning = True
@@ -106,15 +109,15 @@ def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None):
106109
def prefill(
107110
model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs
108111
) -> torch.Tensor:
109-
print(f"x: {x}, input_pos: {input_pos}")
112+
logging.debug(f"x: {x}, input_pos: {input_pos}")
110113
width = x.size(1)
111114
assert input_pos.size(0) == width
112115
sequential_prefill = True
113116

114117
if sequential_prefill:
115118
for i in range(width):
116119
x_sliced, ip_sliced = x[:, i].view(-1, 1), input_pos[i].view(-1)
117-
print(f"<sliced> x: {x_sliced}, input_pos: {ip_sliced}")
120+
logging.debug(f"<sliced> x: {x_sliced}, input_pos: {ip_sliced}")
118121
logits = model(x_sliced, ip_sliced) # (x[:, i], input_pos[i])
119122
else:
120123
# input_pos: [B, S]
@@ -340,12 +343,12 @@ def _main(
340343
# # only print on rank 0
341344
# print = lambda *args, **kwargs: None
342345

343-
print(f"Using device={builder_args.device}")
346+
logging.info(f"Using device={builder_args.device}")
344347
set_precision(builder_args.precision)
345348
is_speculative = speculative_builder_args.checkpoint_path is not None
346349

347350
if generator_args.chat_mode and not builder_args.is_chat_model:
348-
print("""
351+
logging.warning("""
349352
*******************************************************
350353
This model is not known to support the chat function.
351354
We will enable chat mode based on your instructions.
@@ -374,7 +377,7 @@ def _main(
374377
encoded = encode_tokens(
375378
tokenizer, generator_args.prompt, bos=True, device=builder_args.device
376379
)
377-
print(encoded)
380+
logging.debug(encoded)
378381
prompt_length = encoded.size(0)
379382

380383
model_size = sum(
@@ -469,7 +472,7 @@ def callback(x):
469472
)
470473
aggregate_metrics["accept_counts"].append(metrics["accept_counts"])
471474
if i == -1:
472-
print(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
475+
logging.info(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
473476
continue
474477
if hasattr(prof, "export_chrome_trace"):
475478
if use_tp:
@@ -486,23 +489,23 @@ def callback(x):
486489
tokens_generated = y.size(0) - prompt_length
487490
tokens_sec = tokens_generated / t
488491
aggregate_metrics["tokens_per_sec"].append(tokens_sec)
489-
print(
492+
logging.info(
490493
f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_sec:.02f} tokens/sec"
491494
)
492495
print(f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s")
493496
print("==========")
494497
if is_speculative:
495498
counts_aggregated = [sum(i) for i in zip(*aggregate_metrics["accept_counts"])]
496499
acceptance_probs = [i / sum(counts_aggregated) for i in counts_aggregated]
497-
print(f"Acceptance probs: {acceptance_probs}")
498-
print(
500+
logging.info(f"Acceptance probs: {acceptance_probs}")
501+
logging.info(
499502
f"Mean Accepted: {sum([idx * i for idx, i in enumerate(counts_aggregated)])/sum(counts_aggregated)}"
500503
)
501504

502-
print(
505+
logging.info(
503506
f"Average tokens/sec: {torch.mean(torch.tensor(aggregate_metrics['tokens_per_sec'])).item():.2f}"
504507
)
505-
print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB")
508+
logging.info(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB")
506509

507510

508511
def main(args):

torchchat.py

+3
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
check_args,
1515
)
1616

17+
import logging
18+
1719
default_device = "cpu" # 'cuda' if torch.cuda.is_available() else 'cpu'
1820

1921

@@ -35,6 +37,7 @@
3537

3638
args = parser.parse_args()
3739
args = arg_init(args)
40+
logging.basicConfig(format='%(message)s', level=logging.DEBUG if args.verbose else logging.INFO)
3841

3942
if args.subcommand == "generate":
4043
check_args(args, "generate")

0 commit comments

Comments
 (0)