@@ -730,6 +730,15 @@ struct UnaryAbs
730
730
{
731
731
y = ck::type_convert<f8_t >(ck::math::abs (ck::type_convert<float >(x)));
732
732
};
733
+
734
+ template <typename Y, typename X>
735
+ __host__ __device__ constexpr void operator ()(Y& y, const X& x) const ;
736
+
737
+ template <>
738
+ __host__ __device__ void operator ()<bhalf_t, float>(bhalf_t & y, const float & x) const
739
+ {
740
+ y = ck::type_convert<bhalf_t >(ck::math::abs (x));
741
+ };
733
742
};
734
743
735
744
struct UnarySqrt
@@ -829,13 +838,23 @@ struct Relu
829
838
y = x > 0 ? x : 0 ;
830
839
}
831
840
841
+ template <typename Y, typename X>
842
+ __host__ __device__ constexpr void operator ()(Y& y, const X& x) const ;
843
+
832
844
template <>
833
845
__host__ __device__ void operator ()(bhalf_t & y, const bhalf_t & x) const
834
846
{
835
847
float x_f32 = type_convert<float >(x);
836
848
float y_f32 = x_f32 > 0 ? x_f32 : 0 ;
837
849
y = type_convert<bhalf_t >(y_f32);
838
850
}
851
+
852
+ template <>
853
+ __host__ __device__ void operator ()<bhalf_t, float>(bhalf_t & y, const float & x) const
854
+ {
855
+ float y_f32 = x > 0 ? x : 0 ;
856
+ y = type_convert<bhalf_t >(y_f32);
857
+ };
839
858
};
840
859
841
860
// Fast GeLU
@@ -988,6 +1007,16 @@ struct Sigmoid
988
1007
constexpr T one = type_convert<T>(1 );
989
1008
y = one / (one + math::exp (-x));
990
1009
};
1010
+
1011
+ template <typename Y, typename X>
1012
+ __host__ __device__ constexpr void operator ()(Y& y, const X& x) const ;
1013
+
1014
+ template <>
1015
+ __host__ __device__ void operator ()<bhalf_t, float>(bhalf_t & y, const float & x) const
1016
+ {
1017
+ constexpr float one = 1 .f ;
1018
+ y = type_convert<bhalf_t >(one / (one + math::exp (-x)));
1019
+ };
991
1020
};
992
1021
993
1022
struct Silu
@@ -1015,6 +1044,15 @@ struct TanH
1015
1044
1016
1045
y = math::tanh (x);
1017
1046
};
1047
+
1048
+ template <typename Y, typename X>
1049
+ __host__ __device__ constexpr void operator ()(Y& y, const X& x) const ;
1050
+
1051
+ template <>
1052
+ __host__ __device__ void operator ()<bhalf_t, float>(bhalf_t & y, const float & x) const
1053
+ {
1054
+ y = type_convert<bhalf_t >(math::tanh (x));
1055
+ };
1018
1056
};
1019
1057
1020
1058
struct ACos
@@ -1274,6 +1312,13 @@ struct Swish
1274
1312
y = type_convert<Y>(x / (1 .f + math::exp (bx)));
1275
1313
};
1276
1314
1315
+ template <>
1316
+ __host__ __device__ void operator ()<bhalf_t, float>(bhalf_t & y, const float & x) const
1317
+ {
1318
+ float bx = -beta_ * x;
1319
+ y = type_convert<bhalf_t >(x / (1 .f + math::exp (bx)));
1320
+ };
1321
+
1277
1322
const float beta_;
1278
1323
};
1279
1324
@@ -1292,6 +1337,16 @@ struct SoftRelu
1292
1337
constexpr T one = type_convert<T>(1 );
1293
1338
y = math::log (one + math::exp (x * casted_alpha)) / casted_alpha;
1294
1339
}
1340
+
1341
+ template <typename Y, typename X>
1342
+ __host__ __device__ constexpr void operator ()(Y& y, const X& x) const ;
1343
+
1344
+ template <>
1345
+ __host__ __device__ void operator ()<bhalf_t, float>(bhalf_t & y, const float & x) const
1346
+ {
1347
+ constexpr float one = 1 .f ;
1348
+ y = type_convert<bhalf_t >(math::log (one + math::exp (x * alpha_)) / alpha_);
1349
+ };
1295
1350
const float alpha_;
1296
1351
};
1297
1352
@@ -1313,6 +1368,17 @@ struct Power
1313
1368
T shifted_scaled_x = casted_alpha + casted_beta * x;
1314
1369
y = math::pow (shifted_scaled_x, casted_gamma);
1315
1370
}
1371
+
1372
+ template <typename Y, typename X>
1373
+ __host__ __device__ constexpr void operator ()(Y& y, const X& x) const ;
1374
+
1375
+ template <>
1376
+ __host__ __device__ void operator ()<bhalf_t, float>(bhalf_t & y, const float & x) const
1377
+ {
1378
+ const float shifted_scaled_x = alpha_ + beta_ * x;
1379
+ y = type_convert<bhalf_t >(math::pow (shifted_scaled_x, gamma_));
1380
+ };
1381
+
1316
1382
const float alpha_;
1317
1383
const float beta_;
1318
1384
const float gamma_;
@@ -1333,6 +1399,16 @@ struct ClippedRelu
1333
1399
T casted_beta = type_convert<T>(beta_);
1334
1400
y = math::min (casted_beta, math::max (casted_alpha, x));
1335
1401
}
1402
+
1403
+ template <typename Y, typename X>
1404
+ __host__ __device__ constexpr void operator ()(Y& y, const X& x) const ;
1405
+
1406
+ template <>
1407
+ __host__ __device__ void operator ()<bhalf_t, float>(bhalf_t & y, const float & x) const
1408
+ {
1409
+ y = type_convert<bhalf_t >(math::min (beta_, math::max (alpha_, x)));
1410
+ };
1411
+
1336
1412
const float alpha_;
1337
1413
const float beta_;
1338
1414
};
@@ -1351,6 +1427,16 @@ struct LeakyRelu
1351
1427
T casted_alpha = type_convert<T>(alpha_);
1352
1428
y = x >= 0 ? x : x * casted_alpha;
1353
1429
}
1430
+
1431
+ template <typename Y, typename X>
1432
+ __host__ __device__ constexpr void operator ()(Y& y, const X& x) const ;
1433
+
1434
+ template <>
1435
+ __host__ __device__ void operator ()<bhalf_t, float>(bhalf_t & y, const float & x) const
1436
+ {
1437
+ y = type_convert<bhalf_t >(x >= 0 ? x : x * alpha_);
1438
+ };
1439
+
1354
1440
const float alpha_;
1355
1441
};
1356
1442
@@ -1368,6 +1454,16 @@ struct Elu
1368
1454
T casted_alpha = type_convert<T>(alpha_);
1369
1455
y = x > 0 ? x : casted_alpha * math::expm1 (x);
1370
1456
}
1457
+
1458
+ template <typename Y, typename X>
1459
+ __host__ __device__ constexpr void operator ()(Y& y, const X& x) const ;
1460
+
1461
+ template <>
1462
+ __host__ __device__ void operator ()<bhalf_t, float>(bhalf_t & y, const float & x) const
1463
+ {
1464
+ y = type_convert<bhalf_t >(x > 0 ? x : alpha_ * math::expm1 (x));
1465
+ };
1466
+
1371
1467
const float alpha_;
1372
1468
};
1373
1469
@@ -1386,6 +1482,16 @@ struct Logistic
1386
1482
constexpr T one = type_convert<T>(1 );
1387
1483
y = casted_alpha / (one + ck::math::exp (-x) * casted_alpha);
1388
1484
}
1485
+
1486
+ template <typename Y, typename X>
1487
+ __host__ __device__ constexpr void operator ()(Y& y, const X& x) const ;
1488
+
1489
+ template <>
1490
+ __host__ __device__ void operator ()<bhalf_t, float>(bhalf_t & y, const float & x) const
1491
+ {
1492
+ constexpr float one = 1 .f ;
1493
+ y = type_convert<bhalf_t >(alpha_ / (one + ck::math::exp (-x) * alpha_));
1494
+ };
1389
1495
const float alpha_;
1390
1496
};
1391
1497
0 commit comments