2727from cli import add_arguments_for_generate , arg_init , check_args
2828from quantize import set_precision
2929
30+ import logging
31+ logger = logging .getLogger (__name__ )
32+
3033B_INST , E_INST = "[INST]" , "[/INST]"
3134
3235@dataclass
@@ -66,7 +69,7 @@ def device_sync(device):
6669 elif ("cpu" in device ) or ("mps" in device ):
6770 pass
6871 else :
69- print (f"device={ device } is not yet suppported" )
72+ logging . error (f"device={ device } is not yet suppported" )
7073
7174
7275torch ._inductor .config .coordinate_descent_tuning = True
@@ -106,15 +109,15 @@ def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None):
106109def prefill (
107110 model : Transformer , x : torch .Tensor , input_pos : torch .Tensor , ** sampling_kwargs
108111) -> torch .Tensor :
109- print (f"x: { x } , input_pos: { input_pos } " )
112+ logging . debug (f"x: { x } , input_pos: { input_pos } " )
110113 width = x .size (1 )
111114 assert input_pos .size (0 ) == width
112115 sequential_prefill = True
113116
114117 if sequential_prefill :
115118 for i in range (width ):
116119 x_sliced , ip_sliced = x [:, i ].view (- 1 , 1 ), input_pos [i ].view (- 1 )
117- print (f"<sliced> x: { x_sliced } , input_pos: { ip_sliced } " )
120+ logging . debug (f"<sliced> x: { x_sliced } , input_pos: { ip_sliced } " )
118121 logits = model (x_sliced , ip_sliced ) # (x[:, i], input_pos[i])
119122 else :
120123 # input_pos: [B, S]
@@ -340,12 +343,12 @@ def _main(
340343 # # only print on rank 0
341344 # print = lambda *args, **kwargs: None
342345
343- print (f"Using device={ builder_args .device } " )
346+ logging . info (f"Using device={ builder_args .device } " )
344347 set_precision (builder_args .precision )
345348 is_speculative = speculative_builder_args .checkpoint_path is not None
346349
347350 if generator_args .chat_mode and not builder_args .is_chat_model :
348- print ("""
351+ logging . warning ("""
349352*******************************************************
350353 This model is not known to support the chat function.
351354 We will enable chat mode based on your instructions.
@@ -374,7 +377,7 @@ def _main(
374377 encoded = encode_tokens (
375378 tokenizer , generator_args .prompt , bos = True , device = builder_args .device
376379 )
377- print (encoded )
380+ logging . debug (encoded )
378381 prompt_length = encoded .size (0 )
379382
380383 model_size = sum (
@@ -469,7 +472,7 @@ def callback(x):
469472 )
470473 aggregate_metrics ["accept_counts" ].append (metrics ["accept_counts" ])
471474 if i == - 1 :
472- print (f"Compilation time: { time .perf_counter () - t0 :.2f} seconds" )
475+ logging . info (f"Compilation time: { time .perf_counter () - t0 :.2f} seconds" )
473476 continue
474477 if hasattr (prof , "export_chrome_trace" ):
475478 if use_tp :
@@ -486,23 +489,23 @@ def callback(x):
486489 tokens_generated = y .size (0 ) - prompt_length
487490 tokens_sec = tokens_generated / t
488491 aggregate_metrics ["tokens_per_sec" ].append (tokens_sec )
489- print (
492+ logging . info (
490493 f"Time for inference { i + 1 } : { t :.02f} sec total, { tokens_sec :.02f} tokens/sec"
491494 )
492495 print (f"Bandwidth achieved: { model_size * tokens_sec / 1e9 :.02f} GB/s" )
493496 print ("==========" )
494497 if is_speculative :
495498 counts_aggregated = [sum (i ) for i in zip (* aggregate_metrics ["accept_counts" ])]
496499 acceptance_probs = [i / sum (counts_aggregated ) for i in counts_aggregated ]
497- print (f"Acceptance probs: { acceptance_probs } " )
498- print (
500+ logging . info (f"Acceptance probs: { acceptance_probs } " )
501+ logging . info (
499502 f"Mean Accepted: { sum ([idx * i for idx , i in enumerate (counts_aggregated )])/ sum (counts_aggregated )} "
500503 )
501504
502- print (
505+ logging . info (
503506 f"Average tokens/sec: { torch .mean (torch .tensor (aggregate_metrics ['tokens_per_sec' ])).item ():.2f} "
504507 )
505- print (f"Memory used: { torch .cuda .max_memory_reserved () / 1e9 :.02f} GB" )
508+ logging . info (f"Memory used: { torch .cuda .max_memory_reserved () / 1e9 :.02f} GB" )
506509
507510
508511def main (args ):
0 commit comments