Skip to content

Commit aeb0cb2

Browse files
committed
unary fixes
1 parent 2d4c512 commit aeb0cb2

File tree

1 file changed

+106
-0
lines changed

1 file changed

+106
-0
lines changed

include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -730,6 +730,15 @@ struct UnaryAbs
730730
{
731731
y = ck::type_convert<f8_t>(ck::math::abs(ck::type_convert<float>(x)));
732732
};
733+
734+
template <typename Y, typename X>
735+
__host__ __device__ constexpr void operator()(Y& y, const X& x) const;
736+
737+
template <>
738+
__host__ __device__ void operator()<bhalf_t, float>(bhalf_t& y, const float& x) const
739+
{
740+
y = ck::type_convert<bhalf_t>(ck::math::abs(x));
741+
};
733742
};
734743

735744
struct UnarySqrt
@@ -829,13 +838,23 @@ struct Relu
829838
y = x > 0 ? x : 0;
830839
}
831840

841+
template <typename Y, typename X>
842+
__host__ __device__ constexpr void operator()(Y& y, const X& x) const;
843+
832844
template <>
833845
__host__ __device__ void operator()(bhalf_t& y, const bhalf_t& x) const
834846
{
835847
float x_f32 = type_convert<float>(x);
836848
float y_f32 = x_f32 > 0 ? x_f32 : 0;
837849
y = type_convert<bhalf_t>(y_f32);
838850
}
851+
852+
template <>
853+
__host__ __device__ void operator()<bhalf_t, float>(bhalf_t& y, const float& x) const
854+
{
855+
float y_f32 = x > 0 ? x : 0;
856+
y = type_convert<bhalf_t>(y_f32);
857+
};
839858
};
840859

841860
// Fast GeLU
@@ -988,6 +1007,16 @@ struct Sigmoid
9881007
constexpr T one = type_convert<T>(1);
9891008
y = one / (one + math::exp(-x));
9901009
};
1010+
1011+
template <typename Y, typename X>
1012+
__host__ __device__ constexpr void operator()(Y& y, const X& x) const;
1013+
1014+
template <>
1015+
__host__ __device__ void operator()<bhalf_t, float>(bhalf_t& y, const float& x) const
1016+
{
1017+
constexpr float one = 1.f;
1018+
y = type_convert<bhalf_t>(one / (one + math::exp(-x)));
1019+
};
9911020
};
9921021

9931022
struct Silu
@@ -1015,6 +1044,15 @@ struct TanH
10151044

10161045
y = math::tanh(x);
10171046
};
1047+
1048+
template <typename Y, typename X>
1049+
__host__ __device__ constexpr void operator()(Y& y, const X& x) const;
1050+
1051+
template <>
1052+
__host__ __device__ void operator()<bhalf_t, float>(bhalf_t& y, const float& x) const
1053+
{
1054+
y = type_convert<bhalf_t>(math::tanh(x));
1055+
};
10181056
};
10191057

10201058
struct ACos
@@ -1274,6 +1312,13 @@ struct Swish
12741312
y = type_convert<Y>(x / (1.f + math::exp(bx)));
12751313
};
12761314

1315+
template <>
1316+
__host__ __device__ void operator()<bhalf_t, float>(bhalf_t& y, const float& x) const
1317+
{
1318+
float bx = -beta_ * x;
1319+
y = type_convert<bhalf_t>(x / (1.f + math::exp(bx)));
1320+
};
1321+
12771322
const float beta_;
12781323
};
12791324

@@ -1292,6 +1337,16 @@ struct SoftRelu
12921337
constexpr T one = type_convert<T>(1);
12931338
y = math::log(one + math::exp(x * casted_alpha)) / casted_alpha;
12941339
}
1340+
1341+
template <typename Y, typename X>
1342+
__host__ __device__ constexpr void operator()(Y& y, const X& x) const;
1343+
1344+
template <>
1345+
__host__ __device__ void operator()<bhalf_t, float>(bhalf_t& y, const float& x) const
1346+
{
1347+
constexpr float one = 1.f;
1348+
y = type_convert<bhalf_t>(math::log(one + math::exp(x * alpha_)) / alpha_);
1349+
};
12951350
const float alpha_;
12961351
};
12971352

@@ -1313,6 +1368,17 @@ struct Power
13131368
T shifted_scaled_x = casted_alpha + casted_beta * x;
13141369
y = math::pow(shifted_scaled_x, casted_gamma);
13151370
}
1371+
1372+
template <typename Y, typename X>
1373+
__host__ __device__ constexpr void operator()(Y& y, const X& x) const;
1374+
1375+
template <>
1376+
__host__ __device__ void operator()<bhalf_t, float>(bhalf_t& y, const float& x) const
1377+
{
1378+
const float shifted_scaled_x = alpha_ + beta_ * x;
1379+
y = type_convert<bhalf_t>(math::pow(shifted_scaled_x, gamma_));
1380+
};
1381+
13161382
const float alpha_;
13171383
const float beta_;
13181384
const float gamma_;
@@ -1333,6 +1399,16 @@ struct ClippedRelu
13331399
T casted_beta = type_convert<T>(beta_);
13341400
y = math::min(casted_beta, math::max(casted_alpha, x));
13351401
}
1402+
1403+
template <typename Y, typename X>
1404+
__host__ __device__ constexpr void operator()(Y& y, const X& x) const;
1405+
1406+
template <>
1407+
__host__ __device__ void operator()<bhalf_t, float>(bhalf_t& y, const float& x) const
1408+
{
1409+
y = type_convert<bhalf_t>(math::min(beta_, math::max(alpha_, x)));
1410+
};
1411+
13361412
const float alpha_;
13371413
const float beta_;
13381414
};
@@ -1351,6 +1427,16 @@ struct LeakyRelu
13511427
T casted_alpha = type_convert<T>(alpha_);
13521428
y = x >= 0 ? x : x * casted_alpha;
13531429
}
1430+
1431+
template <typename Y, typename X>
1432+
__host__ __device__ constexpr void operator()(Y& y, const X& x) const;
1433+
1434+
template <>
1435+
__host__ __device__ void operator()<bhalf_t, float>(bhalf_t& y, const float& x) const
1436+
{
1437+
y = type_convert<bhalf_t>(x >= 0 ? x : x * alpha_);
1438+
};
1439+
13541440
const float alpha_;
13551441
};
13561442

@@ -1368,6 +1454,16 @@ struct Elu
13681454
T casted_alpha = type_convert<T>(alpha_);
13691455
y = x > 0 ? x : casted_alpha * math::expm1(x);
13701456
}
1457+
1458+
template <typename Y, typename X>
1459+
__host__ __device__ constexpr void operator()(Y& y, const X& x) const;
1460+
1461+
template <>
1462+
__host__ __device__ void operator()<bhalf_t, float>(bhalf_t& y, const float& x) const
1463+
{
1464+
y = type_convert<bhalf_t>(x > 0 ? x : alpha_ * math::expm1(x));
1465+
};
1466+
13711467
const float alpha_;
13721468
};
13731469

@@ -1386,6 +1482,16 @@ struct Logistic
13861482
constexpr T one = type_convert<T>(1);
13871483
y = casted_alpha / (one + ck::math::exp(-x) * casted_alpha);
13881484
}
1485+
1486+
template <typename Y, typename X>
1487+
__host__ __device__ constexpr void operator()(Y& y, const X& x) const;
1488+
1489+
template <>
1490+
__host__ __device__ void operator()<bhalf_t, float>(bhalf_t& y, const float& x) const
1491+
{
1492+
constexpr float one = 1.f;
1493+
y = type_convert<bhalf_t>(alpha_ / (one + ck::math::exp(-x) * alpha_));
1494+
};
13891495
const float alpha_;
13901496
};
13911497

0 commit comments

Comments
 (0)