@@ -730,6 +730,15 @@ struct UnaryAbs
730
730
{
731
731
y = ck::type_convert<f8_t >(ck::math::abs (ck::type_convert<float >(x)));
732
732
};
733
+
734
+ template <typename Y, typename X>
735
+ __host__ __device__ constexpr void operator ()(Y& y, const X& x) const ;
736
+
737
+ template <>
738
+ __host__ __device__ void operator ()<bhalf_t, float>(bhalf_t & y, const float & x) const
739
+ {
740
+ y = ck::type_convert<bhalf_t >(ck::math::abs (x));
741
+ };
733
742
};
734
743
735
744
struct UnarySqrt
@@ -744,6 +753,79 @@ struct UnarySqrt
744
753
};
745
754
};
746
755
756
+ struct Clamp
757
+ {
758
+ Clamp (float floor = 0 .f, float ceil = NumericLimits<float >::Max())
759
+ : floor_(floor), ceil_(ceil){};
760
+
761
+ template <typename Y, typename X>
762
+ __host__ __device__ constexpr void operator ()(Y& y, const X& x) const ;
763
+
764
+ template <>
765
+ __host__ __device__ constexpr void operator ()<float, float>(float & y, const float & x) const
766
+ {
767
+ const float & a = x;
768
+ y = a > floor_ ? (a < ceil_ ? a : ceil_) : floor_;
769
+ };
770
+
771
+ template <>
772
+ __host__ __device__ constexpr void operator ()<double, double>(double & y, const double & x) const
773
+ {
774
+ const double & a = x;
775
+ y = a > floor_ ? (a < ceil_ ? a : ceil_) : floor_;
776
+ };
777
+
778
+ template <>
779
+ __host__ __device__ constexpr void operator ()<half_t, half_t>(half_t & y, const half_t & x) const
780
+ {
781
+ const float a = type_convert<half_t >(x);
782
+ const float b = a > floor_ ? (a < ceil_ ? a : ceil_) : floor_;
783
+ y = type_convert<half_t >(b);
784
+ };
785
+
786
+ template <>
787
+ __host__ __device__ constexpr void operator ()<half_t, float>(half_t & y, const float & x) const
788
+ {
789
+ const float & a = x;
790
+ const float b = a > floor_ ? (a < ceil_ ? a : ceil_) : floor_;
791
+ y = type_convert<half_t >(b);
792
+ };
793
+
794
+ template <>
795
+ __host__ __device__ constexpr void operator ()<bhalf_t, float>(bhalf_t & y, const float & x) const
796
+ {
797
+ const float & a = x;
798
+ const float b = a > floor_ ? (a < ceil_ ? a : ceil_) : floor_;
799
+ y = type_convert<bhalf_t >(b);
800
+ };
801
+
802
+ template <>
803
+ __host__ __device__ constexpr void operator ()<bhalf_t, bhalf_t>(bhalf_t & y,
804
+ const bhalf_t & x) const
805
+ {
806
+ const float a = type_convert<float >(x);
807
+ const float b = a > floor_ ? (a < ceil_ ? a : ceil_) : floor_;
808
+ y = type_convert<bhalf_t >(b);
809
+ };
810
+
811
+ template <>
812
+ __host__ __device__ constexpr void operator ()<int, int>(int & y, const int & x) const
813
+ {
814
+ const int8_t & a = x;
815
+ y = a > floor_ ? (a < ceil_ ? a : ceil_) : floor_;
816
+ };
817
+
818
+ template <>
819
+ __host__ __device__ constexpr void operator ()<int8_t, int8_t>(int8_t & y, const int8_t & x) const
820
+ {
821
+ const int8_t & a = x;
822
+ y = a > floor_ ? (a < ceil_ ? a : ceil_) : floor_;
823
+ };
824
+
825
+ const float floor_;
826
+ const float ceil_;
827
+ };
828
+
747
829
struct Relu
748
830
{
749
831
template <typename T>
@@ -756,13 +838,23 @@ struct Relu
756
838
y = x > 0 ? x : 0 ;
757
839
}
758
840
841
+ template <typename Y, typename X>
842
+ __host__ __device__ constexpr void operator ()(Y& y, const X& x) const ;
843
+
759
844
template <>
760
845
__host__ __device__ void operator ()(bhalf_t & y, const bhalf_t & x) const
761
846
{
762
847
float x_f32 = type_convert<float >(x);
763
848
float y_f32 = x_f32 > 0 ? x_f32 : 0 ;
764
849
y = type_convert<bhalf_t >(y_f32);
765
850
}
851
+
852
+ template <>
853
+ __host__ __device__ void operator ()<bhalf_t, float>(bhalf_t & y, const float & x) const
854
+ {
855
+ float y_f32 = x > 0 ? x : 0 ;
856
+ y = type_convert<bhalf_t >(y_f32);
857
+ };
766
858
};
767
859
768
860
// Fast GeLU
@@ -915,6 +1007,16 @@ struct Sigmoid
915
1007
constexpr T one = type_convert<T>(1 );
916
1008
y = one / (one + math::exp (-x));
917
1009
};
1010
+
1011
+ template <typename Y, typename X>
1012
+ __host__ __device__ constexpr void operator ()(Y& y, const X& x) const ;
1013
+
1014
+ template <>
1015
+ __host__ __device__ void operator ()<bhalf_t, float>(bhalf_t & y, const float & x) const
1016
+ {
1017
+ constexpr float one = 1 .f ;
1018
+ y = type_convert<bhalf_t >(one / (one + math::exp (-x)));
1019
+ };
918
1020
};
919
1021
920
1022
struct Silu
@@ -942,6 +1044,15 @@ struct TanH
942
1044
943
1045
y = math::tanh (x);
944
1046
};
1047
+
1048
+ template <typename Y, typename X>
1049
+ __host__ __device__ constexpr void operator ()(Y& y, const X& x) const ;
1050
+
1051
+ template <>
1052
+ __host__ __device__ void operator ()<bhalf_t, float>(bhalf_t & y, const float & x) const
1053
+ {
1054
+ y = type_convert<bhalf_t >(math::tanh (x));
1055
+ };
945
1056
};
946
1057
947
1058
struct ACos
@@ -1201,6 +1312,13 @@ struct Swish
1201
1312
y = type_convert<Y>(x / (1 .f + math::exp (bx)));
1202
1313
};
1203
1314
1315
+ template <>
1316
+ __host__ __device__ void operator ()<bhalf_t, float>(bhalf_t & y, const float & x) const
1317
+ {
1318
+ float bx = -beta_ * x;
1319
+ y = type_convert<bhalf_t >(x / (1 .f + math::exp (bx)));
1320
+ };
1321
+
1204
1322
const float beta_;
1205
1323
};
1206
1324
@@ -1219,6 +1337,16 @@ struct SoftRelu
1219
1337
constexpr T one = type_convert<T>(1 );
1220
1338
y = math::log (one + math::exp (x * casted_alpha)) / casted_alpha;
1221
1339
}
1340
+
1341
+ template <typename Y, typename X>
1342
+ __host__ __device__ constexpr void operator ()(Y& y, const X& x) const ;
1343
+
1344
+ template <>
1345
+ __host__ __device__ void operator ()<bhalf_t, float>(bhalf_t & y, const float & x) const
1346
+ {
1347
+ constexpr float one = 1 .f ;
1348
+ y = type_convert<bhalf_t >(math::log (one + math::exp (x * alpha_)) / alpha_);
1349
+ };
1222
1350
const float alpha_;
1223
1351
};
1224
1352
@@ -1240,6 +1368,17 @@ struct Power
1240
1368
T shifted_scaled_x = casted_alpha + casted_beta * x;
1241
1369
y = math::pow (shifted_scaled_x, casted_gamma);
1242
1370
}
1371
+
1372
+ template <typename Y, typename X>
1373
+ __host__ __device__ constexpr void operator ()(Y& y, const X& x) const ;
1374
+
1375
+ template <>
1376
+ __host__ __device__ void operator ()<bhalf_t, float>(bhalf_t & y, const float & x) const
1377
+ {
1378
+ const float shifted_scaled_x = alpha_ + beta_ * x;
1379
+ y = type_convert<bhalf_t >(math::pow (shifted_scaled_x, gamma_));
1380
+ };
1381
+
1243
1382
const float alpha_;
1244
1383
const float beta_;
1245
1384
const float gamma_;
@@ -1260,6 +1399,16 @@ struct ClippedRelu
1260
1399
T casted_beta = type_convert<T>(beta_);
1261
1400
y = math::min (casted_beta, math::max (casted_alpha, x));
1262
1401
}
1402
+
1403
+ template <typename Y, typename X>
1404
+ __host__ __device__ constexpr void operator ()(Y& y, const X& x) const ;
1405
+
1406
+ template <>
1407
+ __host__ __device__ void operator ()<bhalf_t, float>(bhalf_t & y, const float & x) const
1408
+ {
1409
+ y = type_convert<bhalf_t >(math::min (beta_, math::max (alpha_, x)));
1410
+ };
1411
+
1263
1412
const float alpha_;
1264
1413
const float beta_;
1265
1414
};
@@ -1278,6 +1427,16 @@ struct LeakyRelu
1278
1427
T casted_alpha = type_convert<T>(alpha_);
1279
1428
y = x >= 0 ? x : x * casted_alpha;
1280
1429
}
1430
+
1431
+ template <typename Y, typename X>
1432
+ __host__ __device__ constexpr void operator ()(Y& y, const X& x) const ;
1433
+
1434
+ template <>
1435
+ __host__ __device__ void operator ()<bhalf_t, float>(bhalf_t & y, const float & x) const
1436
+ {
1437
+ y = type_convert<bhalf_t >(x >= 0 ? x : x * alpha_);
1438
+ };
1439
+
1281
1440
const float alpha_;
1282
1441
};
1283
1442
@@ -1295,6 +1454,16 @@ struct Elu
1295
1454
T casted_alpha = type_convert<T>(alpha_);
1296
1455
y = x > 0 ? x : casted_alpha * math::expm1 (x);
1297
1456
}
1457
+
1458
+ template <typename Y, typename X>
1459
+ __host__ __device__ constexpr void operator ()(Y& y, const X& x) const ;
1460
+
1461
+ template <>
1462
+ __host__ __device__ void operator ()<bhalf_t, float>(bhalf_t & y, const float & x) const
1463
+ {
1464
+ y = type_convert<bhalf_t >(x > 0 ? x : alpha_ * math::expm1 (x));
1465
+ };
1466
+
1298
1467
const float alpha_;
1299
1468
};
1300
1469
@@ -1313,6 +1482,16 @@ struct Logistic
1313
1482
constexpr T one = type_convert<T>(1 );
1314
1483
y = casted_alpha / (one + ck::math::exp (-x) * casted_alpha);
1315
1484
}
1485
+
1486
+ template <typename Y, typename X>
1487
+ __host__ __device__ constexpr void operator ()(Y& y, const X& x) const ;
1488
+
1489
+ template <>
1490
+ __host__ __device__ void operator ()<bhalf_t, float>(bhalf_t & y, const float & x) const
1491
+ {
1492
+ constexpr float one = 1 .f ;
1493
+ y = type_convert<bhalf_t >(alpha_ / (one + ck::math::exp (-x) * alpha_));
1494
+ };
1316
1495
const float alpha_;
1317
1496
};
1318
1497
0 commit comments