27
27
from cli import add_arguments_for_generate , arg_init , check_args
28
28
from quantize import set_precision
29
29
30
+ import logging
31
+ logger = logging .getLogger (__name__ )
32
+
30
33
B_INST , E_INST = "[INST]" , "[/INST]"
31
34
32
35
@dataclass
@@ -66,7 +69,7 @@ def device_sync(device):
66
69
elif ("cpu" in device ) or ("mps" in device ):
67
70
pass
68
71
else :
69
- print (f"device={ device } is not yet suppported" )
72
+ logging . error (f"device={ device } is not yet suppported" )
70
73
71
74
72
75
torch ._inductor .config .coordinate_descent_tuning = True
@@ -106,15 +109,15 @@ def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None):
106
109
def prefill (
107
110
model : Transformer , x : torch .Tensor , input_pos : torch .Tensor , ** sampling_kwargs
108
111
) -> torch .Tensor :
109
- print (f"x: { x } , input_pos: { input_pos } " )
112
+ logging . debug (f"x: { x } , input_pos: { input_pos } " )
110
113
width = x .size (1 )
111
114
assert input_pos .size (0 ) == width
112
115
sequential_prefill = True
113
116
114
117
if sequential_prefill :
115
118
for i in range (width ):
116
119
x_sliced , ip_sliced = x [:, i ].view (- 1 , 1 ), input_pos [i ].view (- 1 )
117
- print (f"<sliced> x: { x_sliced } , input_pos: { ip_sliced } " )
120
+ logging . debug (f"<sliced> x: { x_sliced } , input_pos: { ip_sliced } " )
118
121
logits = model (x_sliced , ip_sliced ) # (x[:, i], input_pos[i])
119
122
else :
120
123
# input_pos: [B, S]
@@ -340,12 +343,12 @@ def _main(
340
343
# # only print on rank 0
341
344
# print = lambda *args, **kwargs: None
342
345
343
- print (f"Using device={ builder_args .device } " )
346
+ logging . info (f"Using device={ builder_args .device } " )
344
347
set_precision (builder_args .precision )
345
348
is_speculative = speculative_builder_args .checkpoint_path is not None
346
349
347
350
if generator_args .chat_mode and not builder_args .is_chat_model :
348
- print ("""
351
+ logging . warning ("""
349
352
*******************************************************
350
353
This model is not known to support the chat function.
351
354
We will enable chat mode based on your instructions.
@@ -374,7 +377,7 @@ def _main(
374
377
encoded = encode_tokens (
375
378
tokenizer , generator_args .prompt , bos = True , device = builder_args .device
376
379
)
377
- print (encoded )
380
+ logging . debug (encoded )
378
381
prompt_length = encoded .size (0 )
379
382
380
383
model_size = sum (
@@ -469,7 +472,7 @@ def callback(x):
469
472
)
470
473
aggregate_metrics ["accept_counts" ].append (metrics ["accept_counts" ])
471
474
if i == - 1 :
472
- print (f"Compilation time: { time .perf_counter () - t0 :.2f} seconds" )
475
+ logging . info (f"Compilation time: { time .perf_counter () - t0 :.2f} seconds" )
473
476
continue
474
477
if hasattr (prof , "export_chrome_trace" ):
475
478
if use_tp :
@@ -486,23 +489,23 @@ def callback(x):
486
489
tokens_generated = y .size (0 ) - prompt_length
487
490
tokens_sec = tokens_generated / t
488
491
aggregate_metrics ["tokens_per_sec" ].append (tokens_sec )
489
- print (
492
+ logging . info (
490
493
f"Time for inference { i + 1 } : { t :.02f} sec total, { tokens_sec :.02f} tokens/sec"
491
494
)
492
495
print (f"Bandwidth achieved: { model_size * tokens_sec / 1e9 :.02f} GB/s" )
493
496
print ("==========" )
494
497
if is_speculative :
495
498
counts_aggregated = [sum (i ) for i in zip (* aggregate_metrics ["accept_counts" ])]
496
499
acceptance_probs = [i / sum (counts_aggregated ) for i in counts_aggregated ]
497
- print (f"Acceptance probs: { acceptance_probs } " )
498
- print (
500
+ logging . info (f"Acceptance probs: { acceptance_probs } " )
501
+ logging . info (
499
502
f"Mean Accepted: { sum ([idx * i for idx , i in enumerate (counts_aggregated )])/ sum (counts_aggregated )} "
500
503
)
501
504
502
- print (
505
+ logging . info (
503
506
f"Average tokens/sec: { torch .mean (torch .tensor (aggregate_metrics ['tokens_per_sec' ])).item ():.2f} "
504
507
)
505
- print (f"Memory used: { torch .cuda .max_memory_reserved () / 1e9 :.02f} GB" )
508
+ logging . info (f"Memory used: { torch .cuda .max_memory_reserved () / 1e9 :.02f} GB" )
506
509
507
510
508
511
def main (args ):
0 commit comments