Skip to content

Commit e0b124a

Browse files
sfc-gh-zhwangsfc-gh-ashankarbyshiuezhang-ge-haoYing1123
authored
Add ability for force bos id for mbart (#22)
* Merge with main (#1) * Update beam_search_topk_kernels.cu fix: fix bug of beam search * fix: change int of some kernels to int64_t to prevent overflow * fix: gpt tensor shapes inconsistency (NVIDIA#505) Signed-off-by: AkiyamaYummy <[email protected]> * Update gpt_guide.md (NVIDIA#529) * fix: fix bug of gpt buffer and gpt gemm overflow * Update T5DecodingWeight.cc fix: fix loading bug of t5 * [Enhancement]add pytorch backend support for gptneox (NVIDIA#550) * add pytorch backend support for gptneox Signed-off-by: AkiyamaYummy <[email protected]> * fix early stopping invalid * 1) Some unused parameters and logic have been removed. 2) Revisions that would affect pipeline parallelism have been reverted. 3) The code has been made capable of direct validation on TabbyML/NeoX-1.3B. Signed-off-by: AkiyamaYummy <[email protected]> * Change the names of classes, removing 'parallel' from their names Signed-off-by: AkiyamaYummy <[email protected]> * Format the code. Signed-off-by: AkiyamaYummy <[email protected]> * Only print results when rank is 0. Signed-off-by: AkiyamaYummy <[email protected]> * Add dist.init_process_group(). Signed-off-by: AkiyamaYummy <[email protected]> * update docs Signed-off-by: AkiyamaYummy <[email protected]> --------- Signed-off-by: AkiyamaYummy <[email protected]> * Update cublasMMWrapper.cc Fix the CUBLAS_VERSION checking of cublasMMWrapper * Update cublasMMWrapper.cc * fix overflow in softmax_kernel when process long seqlen and big batch_size (NVIDIA#524) * Update unfused_attention_kernels.cu fix bug of softmax kernel * [Enhancement]create huggingface_gptneox_convert.py (NVIDIA#569) * create huggingface_gptneox_convert.py Signed-off-by: AkiyamaYummy <[email protected]> * adjust HF's multi bin files Signed-off-by: AkiyamaYummy <[email protected]> * update gptneox_guide.md Signed-off-by: AkiyamaYummy <[email protected]> --------- Signed-off-by: AkiyamaYummy <[email protected]> * perf(bloom): improve performance of huggingface_bloom_convert.py, decrease the time cost and the mem using (NVIDIA#568) Co-authored-by: r.yang <[email protected]> * Fix/gpt early stop (NVIDIA#584) * fix: fix bug of early stopping of gpt * [bugfix] Fix 2-shot All Reduce correctness issue (indexing bug). (NVIDIA#672) FasterTransformer 2-shot all reduce is implemented as a reduce-scatter + all-gather. There is an indexing bug in the all-gather step. Prior to this change, 2-shot all reduce was only producing correct results on device 0. Now, all devices have the correct results. * fix: swap tensor bug (NVIDIA#683) * Support size_per_head=112 (NVIDIA#660) * fix multi-gpu build * add support for size_per_head=112 for gpt decoder * remove mpi_cxx from multi-gpu build for now (NVIDIA#705) --------- Signed-off-by: AkiyamaYummy <[email protected]> Co-authored-by: byshiue <[email protected]> Co-authored-by: _yummy_ <[email protected]> Co-authored-by: Ying Sheng <[email protected]> Co-authored-by: zhangxin81 <[email protected]> Co-authored-by: 杨睿 <[email protected]> Co-authored-by: r.yang <[email protected]> Co-authored-by: Rahul Kindi <[email protected]> Co-authored-by: Perkz Zheng <[email protected]> Co-authored-by: Daya Khudia <[email protected]> Co-authored-by: Dean Wyatte <[email protected]> * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit --------- Signed-off-by: AkiyamaYummy <[email protected]> Co-authored-by: Asim Shankar <[email protected]> Co-authored-by: byshiue <[email protected]> Co-authored-by: _yummy_ <[email protected]> Co-authored-by: Ying Sheng <[email protected]> Co-authored-by: zhangxin81 <[email protected]> Co-authored-by: 杨睿 <[email protected]> Co-authored-by: r.yang <[email protected]> Co-authored-by: Rahul Kindi <[email protected]> Co-authored-by: Perkz Zheng <[email protected]> Co-authored-by: Daya Khudia <[email protected]> Co-authored-by: Dean Wyatte <[email protected]>
1 parent 3336e68 commit e0b124a

File tree

4 files changed

+68
-2
lines changed

4 files changed

+68
-2
lines changed

src/fastertransformer/kernels/decoding_kernels.cu

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,34 @@ void invokeDecodingInitialize(bool* finished,
6464
finished, sequence_length, word_ids, cum_log_probs, sentence_ids, batch_size, beam_width, max_input_length);
6565
}
6666

67+
__global__ void forceId(int* word_ids,
68+
const int* force_bos_ids,
69+
const int batch_size,
70+
const int beam_width,
71+
const int step)
72+
{
73+
for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < batch_size * beam_width;
74+
index += blockDim.x * gridDim.x) {
75+
if (word_ids != nullptr) {
76+
word_ids[index+step*batch_size*beam_width] = force_bos_ids[index / beam_width];
77+
}
78+
}
79+
}
80+
81+
void invokeForceId(int* word_ids,
82+
const int* force_bos_ids,
83+
const int batch_size,
84+
const int beam_width,
85+
const int step,
86+
cudaStream_t stream)
87+
{
88+
dim3 grid((int)ceil(batch_size * beam_width * 1.0 / 256));
89+
dim3 block(256);
90+
91+
forceId<<<grid, block, 0, stream>>>(
92+
word_ids, force_bos_ids, batch_size, beam_width, step);
93+
}
94+
6795
template void invokeDecodingInitialize(bool* finished,
6896
int* sequence_length,
6997
int* word_ids,

src/fastertransformer/kernels/decoding_kernels.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,13 @@ void invokeDecodingInitialize(bool* finished,
3333
const int max_input_length,
3434
cudaStream_t stream);
3535

36+
void invokeForceId(int* word_ids,
37+
const int* force_bos_ids,
38+
const int batch_size,
39+
const int beam_width,
40+
const int step,
41+
cudaStream_t stream);
42+
3643
// get token from all_ids at step, then lookup from the embedding table
3744
// by the token
3845
template<typename T>

src/fastertransformer/models/bart/BartDecoding.cc

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ void BartDecoding<T>::allocateBuffer(
116116

117117
start_ids_buf_ = (int*)(allocator_->reMalloc(start_ids_buf_, sizeof(int) * batch_size, false));
118118
end_ids_buf_ = (int*)(allocator_->reMalloc(end_ids_buf_, sizeof(int) * batch_size, false));
119+
forced_bos_ids_buf_ = (int*)(allocator_->reMalloc(forced_bos_ids_buf_, sizeof(int) * batch_size, false));
119120

120121
output_ids_buf_ =
121122
(int*)(allocator_->reMalloc(output_ids_buf_, sizeof(int) * batchxbeam * (max_seq_len + 1), false));
@@ -182,6 +183,7 @@ void BartDecoding<T>::freeBuffer()
182183
allocator_->free((void**)(&tiled_encoder_sequence_length_));
183184

184185
allocator_->free((void**)(&start_ids_buf_));
186+
allocator_->free((void**)(&forced_bos_ids_buf_));
185187
allocator_->free((void**)(&end_ids_buf_));
186188

187189
allocator_->free((void**)(&output_ids_buf_));
@@ -343,6 +345,7 @@ void BartDecoding<T>::forward(TensorMap* output_tensors,
343345
// stop_words_list [batch_size, 2, stop_words_length], optional
344346
// bad_words_list [batch_size, 2, stop_words_length], optional
345347
// start_id [batch_size] on cpu, optional
348+
// forced_bos_id [batch_size] on cpu, optional
346349
// end_id [batch_size] on cpu, optional
347350
// runtime_top_k [1] or [batch_size] on cpu, optional, uint.
348351
// runtime_top_p [1] or [batch_size] on cpu, optional, float.
@@ -382,6 +385,7 @@ void BartDecoding<T>::forward(TensorMap* output_tensors,
382385
dynamic_decode_layer_->setup(batch_size, beam_width, &input_map);
383386
handleOptArg(&input_map, "start_id", start_ids_buf_, start_id_, batch_size);
384387
handleOptArg(&input_map, "end_id", end_ids_buf_, end_id_, batch_size);
388+
handleOptArg(&input_map, "forced_bos_id", forced_bos_ids_buf_, -1, batch_size);
385389
}
386390

387391
FT_CHECK_WITH_INFO(input_tensors->at("encoder_output").shape[2] == d_model_,
@@ -792,6 +796,32 @@ void BartDecoding<T>::forward(TensorMap* output_tensors,
792796
dynamic_decode_output_tensors.insert(*t);
793797
}
794798
dynamic_decode_layer_->forward(&dynamic_decode_output_tensors, &dynamic_decode_input_tensors);
799+
if (step == 1 && input_tensors->isExist("forced_bos_id")) {
800+
invokeForceId(output_ids_buf_,
801+
forced_bos_ids_buf_,
802+
batch_size,
803+
beam_width,
804+
step,
805+
stream_);
806+
sync_check_cuda_error();
807+
}
808+
// {
809+
// for (auto t = dynamic_decode_output_tensors.begin(); t != dynamic_decode_output_tensors.end(); ++t) {
810+
// printf("step: %d, t->first: %s\n", step, t->first.c_str());
811+
// // printf("%s\n", t->second.toString().c_str());
812+
// {
813+
// int* buf;
814+
// int st = t->second.size();
815+
// buf = new int[st];
816+
// cudaMemcpy(buf, t->second.data, sizeof(int) * t->second.size(), cudaMemcpyDeviceToHost);
817+
// for (int i=0; i<st; i++) {
818+
// printf("%d ", buf[i]);
819+
// }
820+
// printf("\n");
821+
// }
822+
// }
823+
// printf("\n\n");
824+
// }
795825
}
796826
}
797827

src/fastertransformer/models/bart/BartDecoding.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,8 +94,9 @@ class BartDecoding: public BaseLayer {
9494
bool* finished_buf_ = nullptr;
9595
bool* h_finished_buf_ = nullptr;
9696

97-
int* start_ids_buf_ = nullptr;
98-
int* end_ids_buf_ = nullptr;
97+
int* start_ids_buf_ = nullptr;
98+
int* forced_bos_ids_buf_ = nullptr;
99+
int* end_ids_buf_ = nullptr;
99100

100101
T* key_cache_ = nullptr;
101102
T* value_cache_ = nullptr;

0 commit comments

Comments
 (0)