@@ -104,17 +104,22 @@ def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None):
104
104
105
105
106
106
def prefill (
107
- model : Transformer , x : torch .Tensor , input_pos : torch .Tensor , ** sampling_kwargs
107
+ model : Transformer ,
108
+ x : torch .Tensor ,
109
+ input_pos : torch .Tensor ,
110
+ * ,
111
+ sequential_prefill = True ,
112
+ ** sampling_kwargs
108
113
) -> torch .Tensor :
109
- print (f"x: { x } , input_pos: { input_pos } " )
114
+ # print(f"x: {x}, input_pos: {input_pos}")
110
115
width = x .size (1 )
111
116
assert input_pos .size (0 ) == width
112
117
sequential_prefill = True
113
118
114
119
if sequential_prefill :
115
120
for i in range (width ):
116
121
x_sliced , ip_sliced = x [:, i ].view (- 1 , 1 ), input_pos [i ].view (- 1 )
117
- print (f"<sliced> x: { x_sliced } , input_pos: { ip_sliced } " )
122
+ # print(f"<sliced> x: {x_sliced}, input_pos: {ip_sliced}")
118
123
logits = model (x_sliced , ip_sliced ) # (x[:, i], input_pos[i])
119
124
else :
120
125
# input_pos: [B, S]
@@ -157,13 +162,6 @@ def decode_n_tokens(
157
162
return new_tokens , new_probs
158
163
159
164
160
- # try:
161
- # from .thin_wrapper import model_forward
162
- #
163
- # except:
164
- # print("compiled model load not successful, running eager model")
165
-
166
-
167
165
def model_forward (model , x , input_pos ):
168
166
return model (x , input_pos )
169
167
@@ -374,7 +372,7 @@ def _main(
374
372
encoded = encode_tokens (
375
373
tokenizer , generator_args .prompt , bos = True , device = builder_args .device
376
374
)
377
- print (encoded )
375
+ # print(encoded)
378
376
prompt_length = encoded .size (0 )
379
377
380
378
model_size = sum (
0 commit comments