Skip to content

Commit f6c2ff9

Browse files
authored
Grouped convolution forward with clamp (#2334)
* Grouped convolution forward with clamp * Optimize clamp * unary fixes * test gk bias * Revert "test gk bias" This reverts commit 8e42e29. * Revert "Revert "test gk bias"" This reverts commit e73c055. * workaround comment
1 parent d996bc7 commit f6c2ff9

File tree

41 files changed

+2103
-106
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+2103
-106
lines changed

include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -311,8 +311,9 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
311311

312312
static_assert(NumGroupsToMerge >= 1);
313313

314-
static constexpr bool isMultiA = is_detected<is_tuple, ADataType>::value;
315-
static constexpr bool isMultiB = is_detected<is_tuple, BDataType>::value;
314+
static constexpr bool isMultiA = is_detected<is_tuple, ADataType>::value;
315+
static constexpr bool isMultiB = is_detected<is_tuple, BDataType>::value;
316+
static constexpr bool isMultiAB = isMultiA || isMultiB;
316317

317318
// NGCHW is not supported for multiAB
318319
static_assert(!(is_NGCHW_NGKHW<ALayout, BLayout, ELayout>() ||
@@ -323,6 +324,10 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
323324
static constexpr index_t NumBTensor = GetNumABTensors<isMultiB, BDataType>();
324325
static constexpr index_t NumDTensor = DsDataType::Size();
325326

327+
static constexpr bool DoElementwiseBeforeCShuffle =
328+
NumDTensor == 0 && !isMultiAB && is_same_v<EDataType, bhalf_t> &&
329+
!is_same_v<CDEElementwiseOperation, tensor_operation::element_wise::PassThrough>;
330+
326331
static constexpr auto I0 = Number<0>{};
327332
static constexpr auto I1 = Number<1>{};
328333
static constexpr auto I2 = Number<2>{};
@@ -465,7 +470,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
465470
BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \
466471
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \
467472
CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, \
468-
BComputeDataType
473+
BComputeDataType, DoElementwiseBeforeCShuffle
469474
// Use appropriate gridwise gemm
470475
using GridwiseGemm = std::conditional_t<
471476
isMultiA || isMultiB,

include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,10 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
279279
static constexpr bool isMultiD = DsDataType::Size() > 0;
280280
static constexpr bool isMultiABD = isMultiA || isMultiB || isMultiD;
281281

282+
static constexpr bool DoElementwiseBeforeCShuffle =
283+
!isMultiABD && is_same_v<EDataType, bhalf_t> &&
284+
!is_same_v<CDEElementwiseOperation, tensor_operation::element_wise::PassThrough>;
285+
282286
static constexpr index_t NumATensor = GetNumABTensors<isMultiA, ADataType>();
283287
static constexpr index_t NumBTensor = GetNumABTensors<isMultiB, BDataType>();
284288
static constexpr index_t NumDTensor = DsDataType::Size();
@@ -412,7 +416,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
412416
BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \
413417
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \
414418
CDEBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, \
415-
AComputeDataType, BComputeDataType
419+
AComputeDataType, BComputeDataType, false, false, DoElementwiseBeforeCShuffle
416420

417421
// Use appropriate gridwise gemm
418422
using GridwiseGemm = GridwiseGemm_xdl_cshuffle_v3<GridwiseGemmV3TemplateParams>;
@@ -780,8 +784,20 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
780784
sizeof(EDataType);
781785
}
782786

783-
typename GridwiseGemm::Argument gemm_arg{
784-
p_a_grid, p_b_grid, p_e_grid, GemmM, GemmN, GemmK, I0, I0, I0, I1};
787+
typename GridwiseGemm::Argument gemm_arg{p_a_grid,
788+
p_b_grid,
789+
p_e_grid,
790+
GemmM,
791+
GemmN,
792+
GemmK,
793+
I0,
794+
I0,
795+
I0,
796+
I1,
797+
false,
798+
arg.a_element_op_,
799+
arg.b_element_op_,
800+
arg.cde_element_op_};
785801

786802
const auto Run = [&](const auto& kernel) {
787803
if(stream_config.flush_cache)

include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,9 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
192192

193193
static constexpr index_t NumDTensor = DsDataType::Size();
194194
static constexpr index_t MaxGemmsNum = 32;
195+
static constexpr bool DoElementwiseBeforeCShuffle =
196+
NumDTensor == 0 && is_same_v<EDataType, bhalf_t> &&
197+
!is_same_v<CDEElementwiseOperation, tensor_operation::element_wise::PassThrough>;
195198

196199
static constexpr auto I0 = Number<0>{};
197200
static constexpr auto I1 = Number<1>{};
@@ -361,7 +364,7 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
361364
BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \
362365
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \
363366
CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, \
364-
AComputeDataType
367+
AComputeDataType, DoElementwiseBeforeCShuffle
365368
// Use appropriate gridwise gemm
366369
using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle<GridwiseGemmTemplateParameters>;
367370

include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp

Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -730,6 +730,15 @@ struct UnaryAbs
730730
{
731731
y = ck::type_convert<f8_t>(ck::math::abs(ck::type_convert<float>(x)));
732732
};
733+
734+
template <typename Y, typename X>
735+
__host__ __device__ constexpr void operator()(Y& y, const X& x) const;
736+
737+
template <>
738+
__host__ __device__ void operator()<bhalf_t, float>(bhalf_t& y, const float& x) const
739+
{
740+
y = ck::type_convert<bhalf_t>(ck::math::abs(x));
741+
};
733742
};
734743

735744
struct UnarySqrt
@@ -744,6 +753,79 @@ struct UnarySqrt
744753
};
745754
};
746755

756+
struct Clamp
757+
{
758+
Clamp(float floor = 0.f, float ceil = NumericLimits<float>::Max())
759+
: floor_(floor), ceil_(ceil){};
760+
761+
template <typename Y, typename X>
762+
__host__ __device__ constexpr void operator()(Y& y, const X& x) const;
763+
764+
template <>
765+
__host__ __device__ constexpr void operator()<float, float>(float& y, const float& x) const
766+
{
767+
const float& a = x;
768+
y = a > floor_ ? (a < ceil_ ? a : ceil_) : floor_;
769+
};
770+
771+
template <>
772+
__host__ __device__ constexpr void operator()<double, double>(double& y, const double& x) const
773+
{
774+
const double& a = x;
775+
y = a > floor_ ? (a < ceil_ ? a : ceil_) : floor_;
776+
};
777+
778+
template <>
779+
__host__ __device__ constexpr void operator()<half_t, half_t>(half_t& y, const half_t& x) const
780+
{
781+
const float a = type_convert<half_t>(x);
782+
const float b = a > floor_ ? (a < ceil_ ? a : ceil_) : floor_;
783+
y = type_convert<half_t>(b);
784+
};
785+
786+
template <>
787+
__host__ __device__ constexpr void operator()<half_t, float>(half_t& y, const float& x) const
788+
{
789+
const float& a = x;
790+
const float b = a > floor_ ? (a < ceil_ ? a : ceil_) : floor_;
791+
y = type_convert<half_t>(b);
792+
};
793+
794+
template <>
795+
__host__ __device__ constexpr void operator()<bhalf_t, float>(bhalf_t& y, const float& x) const
796+
{
797+
const float& a = x;
798+
const float b = a > floor_ ? (a < ceil_ ? a : ceil_) : floor_;
799+
y = type_convert<bhalf_t>(b);
800+
};
801+
802+
template <>
803+
__host__ __device__ constexpr void operator()<bhalf_t, bhalf_t>(bhalf_t& y,
804+
const bhalf_t& x) const
805+
{
806+
const float a = type_convert<float>(x);
807+
const float b = a > floor_ ? (a < ceil_ ? a : ceil_) : floor_;
808+
y = type_convert<bhalf_t>(b);
809+
};
810+
811+
template <>
812+
__host__ __device__ constexpr void operator()<int, int>(int& y, const int& x) const
813+
{
814+
const int8_t& a = x;
815+
y = a > floor_ ? (a < ceil_ ? a : ceil_) : floor_;
816+
};
817+
818+
template <>
819+
__host__ __device__ constexpr void operator()<int8_t, int8_t>(int8_t& y, const int8_t& x) const
820+
{
821+
const int8_t& a = x;
822+
y = a > floor_ ? (a < ceil_ ? a : ceil_) : floor_;
823+
};
824+
825+
const float floor_;
826+
const float ceil_;
827+
};
828+
747829
struct Relu
748830
{
749831
template <typename T>
@@ -756,13 +838,23 @@ struct Relu
756838
y = x > 0 ? x : 0;
757839
}
758840

841+
template <typename Y, typename X>
842+
__host__ __device__ constexpr void operator()(Y& y, const X& x) const;
843+
759844
template <>
760845
__host__ __device__ void operator()(bhalf_t& y, const bhalf_t& x) const
761846
{
762847
float x_f32 = type_convert<float>(x);
763848
float y_f32 = x_f32 > 0 ? x_f32 : 0;
764849
y = type_convert<bhalf_t>(y_f32);
765850
}
851+
852+
template <>
853+
__host__ __device__ void operator()<bhalf_t, float>(bhalf_t& y, const float& x) const
854+
{
855+
float y_f32 = x > 0 ? x : 0;
856+
y = type_convert<bhalf_t>(y_f32);
857+
};
766858
};
767859

768860
// Fast GeLU
@@ -915,6 +1007,16 @@ struct Sigmoid
9151007
constexpr T one = type_convert<T>(1);
9161008
y = one / (one + math::exp(-x));
9171009
};
1010+
1011+
template <typename Y, typename X>
1012+
__host__ __device__ constexpr void operator()(Y& y, const X& x) const;
1013+
1014+
template <>
1015+
__host__ __device__ void operator()<bhalf_t, float>(bhalf_t& y, const float& x) const
1016+
{
1017+
constexpr float one = 1.f;
1018+
y = type_convert<bhalf_t>(one / (one + math::exp(-x)));
1019+
};
9181020
};
9191021

9201022
struct Silu
@@ -942,6 +1044,15 @@ struct TanH
9421044

9431045
y = math::tanh(x);
9441046
};
1047+
1048+
template <typename Y, typename X>
1049+
__host__ __device__ constexpr void operator()(Y& y, const X& x) const;
1050+
1051+
template <>
1052+
__host__ __device__ void operator()<bhalf_t, float>(bhalf_t& y, const float& x) const
1053+
{
1054+
y = type_convert<bhalf_t>(math::tanh(x));
1055+
};
9451056
};
9461057

9471058
struct ACos
@@ -1201,6 +1312,13 @@ struct Swish
12011312
y = type_convert<Y>(x / (1.f + math::exp(bx)));
12021313
};
12031314

1315+
template <>
1316+
__host__ __device__ void operator()<bhalf_t, float>(bhalf_t& y, const float& x) const
1317+
{
1318+
float bx = -beta_ * x;
1319+
y = type_convert<bhalf_t>(x / (1.f + math::exp(bx)));
1320+
};
1321+
12041322
const float beta_;
12051323
};
12061324

@@ -1219,6 +1337,16 @@ struct SoftRelu
12191337
constexpr T one = type_convert<T>(1);
12201338
y = math::log(one + math::exp(x * casted_alpha)) / casted_alpha;
12211339
}
1340+
1341+
template <typename Y, typename X>
1342+
__host__ __device__ constexpr void operator()(Y& y, const X& x) const;
1343+
1344+
template <>
1345+
__host__ __device__ void operator()<bhalf_t, float>(bhalf_t& y, const float& x) const
1346+
{
1347+
constexpr float one = 1.f;
1348+
y = type_convert<bhalf_t>(math::log(one + math::exp(x * alpha_)) / alpha_);
1349+
};
12221350
const float alpha_;
12231351
};
12241352

@@ -1240,6 +1368,17 @@ struct Power
12401368
T shifted_scaled_x = casted_alpha + casted_beta * x;
12411369
y = math::pow(shifted_scaled_x, casted_gamma);
12421370
}
1371+
1372+
template <typename Y, typename X>
1373+
__host__ __device__ constexpr void operator()(Y& y, const X& x) const;
1374+
1375+
template <>
1376+
__host__ __device__ void operator()<bhalf_t, float>(bhalf_t& y, const float& x) const
1377+
{
1378+
const float shifted_scaled_x = alpha_ + beta_ * x;
1379+
y = type_convert<bhalf_t>(math::pow(shifted_scaled_x, gamma_));
1380+
};
1381+
12431382
const float alpha_;
12441383
const float beta_;
12451384
const float gamma_;
@@ -1260,6 +1399,16 @@ struct ClippedRelu
12601399
T casted_beta = type_convert<T>(beta_);
12611400
y = math::min(casted_beta, math::max(casted_alpha, x));
12621401
}
1402+
1403+
template <typename Y, typename X>
1404+
__host__ __device__ constexpr void operator()(Y& y, const X& x) const;
1405+
1406+
template <>
1407+
__host__ __device__ void operator()<bhalf_t, float>(bhalf_t& y, const float& x) const
1408+
{
1409+
y = type_convert<bhalf_t>(math::min(beta_, math::max(alpha_, x)));
1410+
};
1411+
12631412
const float alpha_;
12641413
const float beta_;
12651414
};
@@ -1278,6 +1427,16 @@ struct LeakyRelu
12781427
T casted_alpha = type_convert<T>(alpha_);
12791428
y = x >= 0 ? x : x * casted_alpha;
12801429
}
1430+
1431+
template <typename Y, typename X>
1432+
__host__ __device__ constexpr void operator()(Y& y, const X& x) const;
1433+
1434+
template <>
1435+
__host__ __device__ void operator()<bhalf_t, float>(bhalf_t& y, const float& x) const
1436+
{
1437+
y = type_convert<bhalf_t>(x >= 0 ? x : x * alpha_);
1438+
};
1439+
12811440
const float alpha_;
12821441
};
12831442

@@ -1295,6 +1454,16 @@ struct Elu
12951454
T casted_alpha = type_convert<T>(alpha_);
12961455
y = x > 0 ? x : casted_alpha * math::expm1(x);
12971456
}
1457+
1458+
template <typename Y, typename X>
1459+
__host__ __device__ constexpr void operator()(Y& y, const X& x) const;
1460+
1461+
template <>
1462+
__host__ __device__ void operator()<bhalf_t, float>(bhalf_t& y, const float& x) const
1463+
{
1464+
y = type_convert<bhalf_t>(x > 0 ? x : alpha_ * math::expm1(x));
1465+
};
1466+
12981467
const float alpha_;
12991468
};
13001469

@@ -1313,6 +1482,16 @@ struct Logistic
13131482
constexpr T one = type_convert<T>(1);
13141483
y = casted_alpha / (one + ck::math::exp(-x) * casted_alpha);
13151484
}
1485+
1486+
template <typename Y, typename X>
1487+
__host__ __device__ constexpr void operator()(Y& y, const X& x) const;
1488+
1489+
template <>
1490+
__host__ __device__ void operator()<bhalf_t, float>(bhalf_t& y, const float& x) const
1491+
{
1492+
constexpr float one = 1.f;
1493+
y = type_convert<bhalf_t>(alpha_ / (one + ck::math::exp(-x) * alpha_));
1494+
};
13161495
const float alpha_;
13171496
};
13181497

0 commit comments

Comments
 (0)