diff --git a/CMakeLists.txt b/CMakeLists.txt index 870e67f0a..22c0f9c1f 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -322,8 +322,13 @@ add_library(transformer-shared SHARED $ $ $ + $ + $ + $ + $ $ $ + $ $ $ $ @@ -356,6 +361,7 @@ add_library(transformer-shared SHARED $ $ $ + $ $ $ $ @@ -384,6 +390,7 @@ add_library(transformer-shared SHARED $ $ $ + $ $ $ $ diff --git a/src/fastertransformer/models/deberta/Deberta.h b/src/fastertransformer/models/deberta/Deberta.h index b14ffd7bd..ac948899a 100644 --- a/src/fastertransformer/models/deberta/Deberta.h +++ b/src/fastertransformer/models/deberta/Deberta.h @@ -128,6 +128,11 @@ class Deberta: public BaseLayer { const std::vector* input_tensors, const DebertaWeight* deberta_weights); void forward(TensorMap* output_tensors, TensorMap* input_tensors, const DebertaWeight* deberta_weights); + + inline size_t getHiddenUnits() + { + return hidden_units_; + } }; } // namespace fastertransformer diff --git a/src/fastertransformer/models/deberta/DebertaWeight.cc b/src/fastertransformer/models/deberta/DebertaWeight.cc index 7811253c8..de059fb46 100644 --- a/src/fastertransformer/models/deberta/DebertaWeight.cc +++ b/src/fastertransformer/models/deberta/DebertaWeight.cc @@ -174,7 +174,7 @@ void DebertaWeight::loadModel(std::string dir_path) for (uint l = 0; l < num_layer_; l++) { if (isValidLayerParallelId(l)) { - deberta_layer_weights[l].loadModel(dir_path + "model.encoder.layer." + std::to_string(l) + ".", + deberta_layer_weights[l].loadModel(dir_path + "/model.encoder.layer." + std::to_string(l) + ".", model_file_type); } } diff --git a/src/fastertransformer/triton_backend/CMakeLists.txt b/src/fastertransformer/triton_backend/CMakeLists.txt index 0079e087a..56cda1bde 100644 --- a/src/fastertransformer/triton_backend/CMakeLists.txt +++ b/src/fastertransformer/triton_backend/CMakeLists.txt @@ -26,3 +26,4 @@ if (ENABLE_FP8) add_subdirectory(multi_gpu_gpt_fp8) endif() add_subdirectory(bert) +add_subdirectory(deberta) diff --git a/src/fastertransformer/triton_backend/deberta/CMakeLists.txt b/src/fastertransformer/triton_backend/deberta/CMakeLists.txt new file mode 100644 index 000000000..5245ff7c3 --- /dev/null +++ b/src/fastertransformer/triton_backend/deberta/CMakeLists.txt @@ -0,0 +1,25 @@ +# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +cmake_minimum_required(VERSION 3.8) + +set(deberta_triton_backend_files + DebertaTritonModel.cc + DebertaTritonModelInstance.cc +) + +add_library(DebertaTritonBackend STATIC ${deberta_triton_backend_files}) +set_property(TARGET DebertaTritonBackend PROPERTY POSITION_INDEPENDENT_CODE ON) +target_link_libraries(DebertaTritonBackend PRIVATE Deberta TransformerTritonBackend -lcublasLt) +target_compile_features(DebertaTritonBackend PRIVATE cxx_std_14) diff --git a/src/fastertransformer/triton_backend/deberta/DebertaTritonModel.cc b/src/fastertransformer/triton_backend/deberta/DebertaTritonModel.cc new file mode 100644 index 000000000..7cc58303f --- /dev/null +++ b/src/fastertransformer/triton_backend/deberta/DebertaTritonModel.cc @@ -0,0 +1,225 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "3rdparty/INIReader.h" + +#include "src/fastertransformer/triton_backend/deberta/DebertaTritonModel.h" +#include "src/fastertransformer/triton_backend/deberta/DebertaTritonModelInstance.h" + +namespace ft = fastertransformer; + +template +DebertaTritonModel::DebertaTritonModel(size_t tensor_para_size, + size_t pipeline_para_size, + bool enable_custom_all_reduce, + std::string model_dir, + bool is_sparse, + bool is_remove_padding): + tensor_para_size_(tensor_para_size), + pipeline_para_size_(pipeline_para_size), + shared_weights_(std::vector>>(ft::getDeviceCount())), + enable_custom_all_reduce_(enable_custom_all_reduce), + model_dir_(model_dir), + is_sparse_(is_sparse), + is_remove_padding_(is_remove_padding) +{ + FT_CHECK_WITH_INFO(is_sparse == false, "still not support sparse in deberta backend"); + + INIReader reader = INIReader(model_dir + "/config.ini"); + if (reader.ParseError() < 0) { + std::cout << "[ERROR] Can't load '" << model_dir << "/config.ini" + << "'\n"; + ft::FT_CHECK(false); + } + + /* Deberta base Configuration File Example + [deberta] + model_name = deberta + hidden_size = 1024 + num_layer = 24 + head_num = 16 + size_per_head = 64 + activation_type = gelu + inter_size = 4096 + vocab_size = 128100 + max_relative_positions = 512 + relative_position_buckets = 256 + weight_data_type = fp32 + */ + + model_name_ = reader.Get("deberta", "model_name"); + head_num_ = reader.GetInteger("deberta", "head_num"); + size_per_head_ = reader.GetInteger("deberta", "size_per_head"); + inter_size_ = reader.GetInteger("deberta", "inter_size"); + vocab_size_ = reader.GetInteger("deberta", "vocab_size"); + num_layer_ = reader.GetInteger("deberta", "num_layer"); + max_relative_positions_ = reader.GetInteger("deberta", "max_relative_positions"); + relative_position_buckets_ = reader.GetInteger("deberta", "relative_position_buckets"); + layernorm_type_ = ft::getLayerNormType("post_layernorm"); + activation_type_ = ft::getActivationType(reader.Get("deberta", "activation_type", "Gelu")); + q_scaling_ = reader.GetFloat("deberta", "q_scaling", sqrtf(3.0f)); +} + +template +std::unique_ptr +DebertaTritonModel::createModelInstance(int device_id, + int rank, + cudaStream_t stream, + std::pair, std::vector> nccl_params, + std::shared_ptr custom_all_reduce_comm) +{ + ft::check_cuda_error(cudaSetDevice(device_id)); + const int comms_rank = device_id % (tensor_para_size_ * pipeline_para_size_); + const int tensor_para_rank = rank % tensor_para_size_; + const int pipeline_para_rank = rank / tensor_para_size_; + + std::unique_ptr> allocator( + new ft::Allocator(device_id)); + + allocator->setStream(stream); + + cublasHandle_t cublas_handle; + cublasLtHandle_t cublaslt_handle; + + cublasCreate(&cublas_handle); + cublasLtCreate(&cublaslt_handle); + cublasSetStream(cublas_handle, stream); + + std::unique_ptr cublas_algo_map(new ft::cublasAlgoMap("gemm_config.in")); + std::unique_ptr cublas_wrapper_mutex(new std::mutex()); + std::unique_ptr cublas_wrapper(new ft::cublasMMWrapper( + cublas_handle, cublaslt_handle, stream, cublas_algo_map.get(), cublas_wrapper_mutex.get(), allocator.get())); + + std::unique_ptr cuda_device_prop_ptr(new cudaDeviceProp); + ft::check_cuda_error(cudaGetDeviceProperties(cuda_device_prop_ptr.get(), device_id)); + + if (std::is_same::value) { + cublas_wrapper->setGemmConfig(CUDA_R_16F, CUDA_R_16F, CUDA_R_16F, CUDA_R_32F); + } +#ifdef ENABLE_BF16 + else if (std::is_same::value) { + cublas_wrapper->setBF16GemmConfig(); + } +#endif + else if (std::is_same::value) { + cublas_wrapper->setFP32GemmConfig(); + } + + ft::NcclParam tensor_para = nccl_params.first[comms_rank]; + ft::NcclParam pipeline_para = nccl_params.second[comms_rank]; + + auto deberta = + std::make_unique>(ft::Deberta(0, // max_batch_size, FT will adjust the buffer automatically. + 0, // max_seq_len, FT will adjust the buffer automatically. + head_num_, + size_per_head_, + max_relative_positions_, + relative_position_buckets_, + inter_size_, + num_layer_, + q_scaling_, + stream, + cublas_wrapper.get(), + allocator.get(), + false, + is_sparse_, + activation_type_, + layernorm_type_, + tensor_para, + pipeline_para, + custom_all_reduce_comm, + enable_custom_all_reduce_)); + +#ifdef SPARSITY_ENABLED + if (is_sparse_) { + for (int i = 0; i < num_layer_; ++i) { + shared_weights_[device_id]->deberta_layer_weights[i].compress_weights(*(cublas_wrapper.get()), + head_num_ * size_per_head_); + } + } +#endif + + return std::unique_ptr>(new DebertaTritonModelInstance(std::move(deberta), + shared_weights_[device_id], + std::move(allocator), + std::move(cublas_algo_map), + std::move(cublas_wrapper_mutex), + std::move(cublas_wrapper), + std::move(cuda_device_prop_ptr))); +} + +template +void DebertaTritonModel::createSharedWeights(int device_id, int rank) +{ + ft::check_cuda_error(cudaSetDevice(device_id)); + const int tensor_para_rank = rank % tensor_para_size_; + const int pipeline_para_rank = rank / tensor_para_size_; + shared_weights_[device_id] = std::make_shared>(head_num_ * size_per_head_, + inter_size_, + max_relative_positions_, + relative_position_buckets_, + vocab_size_, + num_layer_, + tensor_para_size_, + tensor_para_rank, + pipeline_para_size_, + pipeline_para_rank); + + shared_weights_[device_id]->loadModel(model_dir_); + return; +} + +template +std::string DebertaTritonModel::toString() +{ + std::stringstream ss; + ss << "Model: " << model_name_ << "\nmodel_dir: " << model_dir_ << "\nhead_num: " << head_num_ + << "\nsize_per_head: " << size_per_head_ << "\ninter_size: " << inter_size_ << "\nnum_layer: " << num_layer_ + << "\ntensor_para_size: " << tensor_para_size_ << "\npipeline_para_size: " << pipeline_para_size_ + << "\nmax_relative_positions: " << max_relative_positions_ << "\nrelative_position_buckets: " << relative_position_buckets_ + << "\nq_scaling: " << q_scaling_ << "\nis_remove_padding: " << is_remove_padding_ + << "\nis_sparse: " << is_sparse_ << "\nactivation_type: " << static_cast(activation_type_) + << "\nlayernorm_type: " << static_cast(layernorm_type_) << "\nvocab_size: " << vocab_size_ + << "\nenable_custom_all_reduce:" << enable_custom_all_reduce_ << std::endl; + + return ss.str(); +} + +template +void DebertaTritonModel::createCustomComms( + std::vector>* custom_all_reduce_comms, int world_size) +{ + using commDataType = typename ft::CustomARCommTypeConverter::Type; + ft::initCustomAllReduceComm(custom_all_reduce_comms, enable_custom_all_reduce_, world_size); +} + +template +int DebertaTritonModel::getTensorParaSize() +{ + return tensor_para_size_; +} + +template +int DebertaTritonModel::getPipelineParaSize() +{ + return pipeline_para_size_; +} + +template struct DebertaTritonModel; +template struct DebertaTritonModel; +#ifdef ENABLE_BF16 +template struct DebertaTritonModel<__nv_bfloat16>; +#endif diff --git a/src/fastertransformer/triton_backend/deberta/DebertaTritonModel.h b/src/fastertransformer/triton_backend/deberta/DebertaTritonModel.h new file mode 100644 index 000000000..6ac9f71f7 --- /dev/null +++ b/src/fastertransformer/triton_backend/deberta/DebertaTritonModel.h @@ -0,0 +1,70 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "src/fastertransformer/models/deberta/Deberta.h" +#include "src/fastertransformer/triton_backend/transformer_triton_backend.hpp" + +namespace ft = fastertransformer; + +template +struct DebertaTritonModel: public AbstractTransformerModel { + DebertaTritonModel(size_t tensor_para_size, + size_t pipeline_para_size, + bool enable_custom_all_reduce, + std::string model_dir, + bool is_sparse, + bool is_remove_padding); + + virtual std::unique_ptr + createModelInstance(int deviceId, + int rank, + cudaStream_t stream, + std::pair, std::vector> nccl_params, + std::shared_ptr custom_all_reduce_comm = nullptr) override; + + virtual void createSharedWeights(int deviceId, int rank) override; + + virtual void createCustomComms(std::vector>* custom_all_reduce_comms, + int world_size) override; + + virtual std::string toString() override; + virtual int getTensorParaSize() override; + virtual int getPipelineParaSize() override; + +private: + size_t head_num_; + size_t size_per_head_; + size_t inter_size_; + size_t num_layer_; + size_t vocab_size_; + size_t tensor_para_size_; + size_t pipeline_para_size_; + size_t max_relative_positions_; + size_t relative_position_buckets_; + + float q_scaling_; + bool is_remove_padding_; + bool is_sparse_; + ft::ActivationType activation_type_; + ft::LayerNormType layernorm_type_; + + std::string model_name_; + std::string model_dir_; + bool enable_custom_all_reduce_ = 0; + std::vector>> shared_weights_; +}; diff --git a/src/fastertransformer/triton_backend/deberta/DebertaTritonModelInstance.cc b/src/fastertransformer/triton_backend/deberta/DebertaTritonModelInstance.cc new file mode 100644 index 000000000..603fea7a8 --- /dev/null +++ b/src/fastertransformer/triton_backend/deberta/DebertaTritonModelInstance.cc @@ -0,0 +1,148 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/fastertransformer/triton_backend/deberta/DebertaTritonModelInstance.h" +#include "src/fastertransformer/triton_backend/triton_utils.hpp" + +namespace ft = fastertransformer; + +template +void triton_stream_callback(std::unordered_map* output_tensors, void* ctx) +{ + DebertaTritonModelInstance* model = reinterpret_cast*>(ctx); + auto result = DebertaTritonModelInstance::convert_outputs(*output_tensors); + + model->stream_cb_(result, model->stream_ctx_); +} + +template +DebertaTritonModelInstance::DebertaTritonModelInstance(std::unique_ptr> deberta, + std::shared_ptr> deberta_weight, + std::unique_ptr> allocator, + std::unique_ptr cublas_algo_map, + std::unique_ptr cublas_wrapper_mutex, + std::unique_ptr cublas_wrapper, + std::unique_ptr cuda_device_prop_ptr): + deberta_(std::move(deberta)), + deberta_weight_(deberta_weight), + allocator_(std::move(allocator)), + cublas_algo_map_(std::move(cublas_algo_map)), + cublas_wrapper_mutex_(std::move(cublas_wrapper_mutex)), + cublas_wrapper_(std::move(cublas_wrapper)), + cuda_device_prop_ptr_(std::move(cuda_device_prop_ptr)) +{ +} + +template +std::shared_ptr> +DebertaTritonModelInstance::forward(std::shared_ptr> input_tensors) +{ + ft::FT_CHECK(false); + return nullptr; +} + +template +ft::TensorMap DebertaTritonModelInstance::convert_inputs( + std::shared_ptr> input_tensors) +{ + move_tensor_H2D(input_tensors->at("input_ids"), d_input_ids_, &allocator_); + move_tensor_H2D(input_tensors->at("sequence_lengths"), d_input_lengths_, &allocator_); + + ft::TensorMap ft_input_tensors( + {{"input_ids", as_GPU_tensor(input_tensors->at("input_ids"), d_input_ids_)}, + {"sequence_lengths", as_GPU_tensor(input_tensors->at("sequence_lengths"), d_input_lengths_)}}); + + return ft_input_tensors; +} + +template +std::shared_ptr> +DebertaTritonModelInstance::convert_outputs(ft::TensorMap& output_tensors) +{ + std::unordered_map* outputs_mapping = + new std::unordered_map(); + + for (auto it = output_tensors.begin(); it != output_tensors.end(); it++) { + outputs_mapping->insert({it->first, triton::Tensor::convertFtTensorToTriton(it->second)}); + } + + return std::shared_ptr>(outputs_mapping); +} + +template +std::shared_ptr> +DebertaTritonModelInstance::forward(std::shared_ptr> input_tensors) +{ + const size_t batch_size = input_tensors->at("input_ids").shape[0]; + const size_t max_seq_len = input_tensors->at("input_ids").shape[1]; + const size_t hidden_units = deberta_->getHiddenUnits(); + + allocateBuffer(batch_size, max_seq_len, hidden_units); + + ft::TensorMap ft_input_tensors = convert_inputs(input_tensors); + + ft::TensorMap output_tensors = ft::TensorMap({{"output_hidden_state", + ft::Tensor{ft::MEMORY_GPU, + ft::getTensorType(), + std::vector{batch_size, max_seq_len, hidden_units}, + d_output_hidden_state_}}}); + + try { + deberta_->forward(&output_tensors, &ft_input_tensors, deberta_weight_.get()); + cudaStreamSynchronize(deberta_->getStream()); + } + catch (...) { + h_exception_ = std::current_exception(); + output_tensors.insert({"error_message", ft::Tensor{ft::MEMORY_CPU, ft::TYPE_BYTES, {1}, &h_exception_}}); + } + + return convert_outputs(output_tensors); +} + +template +DebertaTritonModelInstance::~DebertaTritonModelInstance() +{ + freeBuffer(); +} + +template +void DebertaTritonModelInstance::allocateBuffer(const size_t batch_size, + const size_t max_seq_len, + const size_t hidden_units) +{ + d_output_hidden_state_ = + (T*)(allocator_->reMalloc(d_output_hidden_state_, sizeof(T) * batch_size * max_seq_len * hidden_units, false)); +} + +template +void DebertaTritonModelInstance::freeBuffer() +{ + if (d_output_hidden_state_ != nullptr) { + allocator_->free((void**)(&d_output_hidden_state_)); + } + if (d_input_ids_ != nullptr) { + allocator_->free((void**)(&d_input_ids_)); + } + if (d_input_lengths_ != nullptr) { + allocator_->free((void**)(&d_input_lengths_)); + } +} + +template struct DebertaTritonModelInstance; +template struct DebertaTritonModelInstance; +#ifdef ENABLE_BF16 +template struct DebertaTritonModelInstance<__nv_bfloat16>; +#endif diff --git a/src/fastertransformer/triton_backend/deberta/DebertaTritonModelInstance.h b/src/fastertransformer/triton_backend/deberta/DebertaTritonModelInstance.h new file mode 100644 index 000000000..8c35e42a3 --- /dev/null +++ b/src/fastertransformer/triton_backend/deberta/DebertaTritonModelInstance.h @@ -0,0 +1,64 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "src/fastertransformer/models/deberta/Deberta.h" +#include "src/fastertransformer/triton_backend/transformer_triton_backend.hpp" + +namespace ft = fastertransformer; + +template +struct DebertaTritonModelInstance: AbstractTransformerModelInstance { + + DebertaTritonModelInstance(std::unique_ptr> deberta, + std::shared_ptr> deberta_weight, + std::unique_ptr> allocator, + std::unique_ptr cublas_algo_map, + std::unique_ptr cublas_wrapper_mutex, + std::unique_ptr cublas_wrapper, + std::unique_ptr cuda_device_prop_ptr); + ~DebertaTritonModelInstance(); + + std::shared_ptr> + forward(std::shared_ptr> input_tensors) override; + + std::shared_ptr> + forward(std::shared_ptr> input_tensors) override; + + static std::shared_ptr> + convert_outputs(ft::TensorMap& output_tensors); + +private: + const std::unique_ptr> deberta_; + const std::shared_ptr> deberta_weight_; + const std::unique_ptr> allocator_; + const std::unique_ptr cublas_algo_map_; + const std::unique_ptr cublas_wrapper_mutex_; + const std::unique_ptr cublas_wrapper_; + const std::unique_ptr cuda_device_prop_ptr_; + + ft::TensorMap convert_inputs(std::shared_ptr> input_tensors); + + void allocateBuffer(const size_t batch_size, const size_t max_seq_len, const size_t hidden_units); + void freeBuffer(); + + int* d_input_ids_ = nullptr; + int* d_input_lengths_ = nullptr; + T* d_output_hidden_state_ = nullptr; + + std::exception_ptr h_exception_ = nullptr; +};