@@ -116,6 +116,7 @@ void BartDecoding<T>::allocateBuffer(
116
116
117
117
start_ids_buf_ = (int *)(allocator_->reMalloc (start_ids_buf_, sizeof (int ) * batch_size, false ));
118
118
end_ids_buf_ = (int *)(allocator_->reMalloc (end_ids_buf_, sizeof (int ) * batch_size, false ));
119
+ forced_bos_ids_buf_ = (int *)(allocator_->reMalloc (forced_bos_ids_buf_, sizeof (int ) * batch_size, false ));
119
120
120
121
output_ids_buf_ =
121
122
(int *)(allocator_->reMalloc (output_ids_buf_, sizeof (int ) * batchxbeam * (max_seq_len + 1 ), false ));
@@ -182,6 +183,7 @@ void BartDecoding<T>::freeBuffer()
182
183
allocator_->free ((void **)(&tiled_encoder_sequence_length_));
183
184
184
185
allocator_->free ((void **)(&start_ids_buf_));
186
+ allocator_->free ((void **)(&forced_bos_ids_buf_));
185
187
allocator_->free ((void **)(&end_ids_buf_));
186
188
187
189
allocator_->free ((void **)(&output_ids_buf_));
@@ -343,6 +345,7 @@ void BartDecoding<T>::forward(TensorMap* output_tensors,
343
345
// stop_words_list [batch_size, 2, stop_words_length], optional
344
346
// bad_words_list [batch_size, 2, stop_words_length], optional
345
347
// start_id [batch_size] on cpu, optional
348
+ // forced_bos_id [batch_size] on cpu, optional
346
349
// end_id [batch_size] on cpu, optional
347
350
// runtime_top_k [1] or [batch_size] on cpu, optional, uint.
348
351
// runtime_top_p [1] or [batch_size] on cpu, optional, float.
@@ -382,6 +385,7 @@ void BartDecoding<T>::forward(TensorMap* output_tensors,
382
385
dynamic_decode_layer_->setup (batch_size, beam_width, &input_map);
383
386
handleOptArg (&input_map, " start_id" , start_ids_buf_, start_id_, batch_size);
384
387
handleOptArg (&input_map, " end_id" , end_ids_buf_, end_id_, batch_size);
388
+ handleOptArg (&input_map, " forced_bos_id" , forced_bos_ids_buf_, -1 , batch_size);
385
389
}
386
390
387
391
FT_CHECK_WITH_INFO (input_tensors->at (" encoder_output" ).shape [2 ] == d_model_,
@@ -792,6 +796,32 @@ void BartDecoding<T>::forward(TensorMap* output_tensors,
792
796
dynamic_decode_output_tensors.insert (*t);
793
797
}
794
798
dynamic_decode_layer_->forward (&dynamic_decode_output_tensors, &dynamic_decode_input_tensors);
799
+ if (step == 1 && input_tensors->isExist (" forced_bos_id" )) {
800
+ invokeForceId (output_ids_buf_,
801
+ forced_bos_ids_buf_,
802
+ batch_size,
803
+ beam_width,
804
+ step,
805
+ stream_);
806
+ sync_check_cuda_error ();
807
+ }
808
+ // {
809
+ // for (auto t = dynamic_decode_output_tensors.begin(); t != dynamic_decode_output_tensors.end(); ++t) {
810
+ // printf("step: %d, t->first: %s\n", step, t->first.c_str());
811
+ // // printf("%s\n", t->second.toString().c_str());
812
+ // {
813
+ // int* buf;
814
+ // int st = t->second.size();
815
+ // buf = new int[st];
816
+ // cudaMemcpy(buf, t->second.data, sizeof(int) * t->second.size(), cudaMemcpyDeviceToHost);
817
+ // for (int i=0; i<st; i++) {
818
+ // printf("%d ", buf[i]);
819
+ // }
820
+ // printf("\n");
821
+ // }
822
+ // }
823
+ // printf("\n\n");
824
+ // }
795
825
}
796
826
}
797
827
0 commit comments