diff --git a/docs/2024.html b/docs/2024.html
index 73829bcfa9..75f8b7d96c 100644
--- a/docs/2024.html
+++ b/docs/2024.html
@@ -41,6 +41,7 @@
New features
- Base implementation, SSE4.1, AVX2, AVX-512BW optimizations of class SynetConvolution16bNhwcDepthwise.
- AVX-512BW kernel Convolution32fNhwcDepthwise_k7p3d1s1w4 for class SynetConvolution32fNhwcDepthwise.
+ - AMX-BF16 kernel DepthwiseConvolution_k7p3d1s1w4 for class SynetMergedConvolution16b.
Im
Improving
diff --git a/src/Simd/SimdAmxBf16SynetMergedConvolution16bDepthwise.cpp b/src/Simd/SimdAmxBf16SynetMergedConvolution16bDepthwise.cpp
index 4256a3e3b6..66df58e3a6 100644
--- a/src/Simd/SimdAmxBf16SynetMergedConvolution16bDepthwise.cpp
+++ b/src/Simd/SimdAmxBf16SynetMergedConvolution16bDepthwise.cpp
@@ -1067,9 +1067,127 @@ namespace Simd
//-------------------------------------------------------------------------------------------------
+ template static void DepthwiseConvolution_k7p3d1s1w4(const uint8_t* src8,
+ const ConvParam& p, const AlgParam& a, size_t maC, size_t yBeg, size_t yEnd, const float* weight, const float* bias, const float* params, uint8_t* dst)
+ {
+ assert(p.IsKernel(7) && p.IsPad(3) && p.IsStride(1) && p.IsDilation(1) && Aligned(p.srcW, 4));
+ const T* src = (T*)src8;
+ size_t srcH = p.srcH, srcW = p.srcW;
+ size_t sM = (a.bufH[1] - 1), sD = a.bufH[1] ? a.bufH[1] * p.srcW * F : F, sX = a.bufH[1] ? F : p.srcC, sY = sX * p.srcW, dstC = maC;
+ size_t dX = (a.bufH[2] ? a.maC * 2 : p.dstC * a.elem[1]), dY = p.dstW * dX, dy0 = a.bufH[2] ? yBeg : 0, dD = a.bufH[2] ? F * 2 : F * a.elem[1];
+ size_t wD = 49 * F, dstCF = AlignLo(dstC, F), dstW = p.dstW, endW = dstW - 4;
+ size_t dstCe = a.bufH[2] ? AlignHi(dstC, DF) : dstC;
+
+ __m512 s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, w0, w1, w2, w3, w4, w5, w6, d0, d1, d2, d3;
+
+ __m512 _params[2], _bias[1];
+ _params[0] = _mm512_set1_ps(params[0]);
+ if (type == SimdConvolutionActivationRestrictRange ||
+ type == SimdConvolutionActivationHswish ||
+ type == SimdConvolutionActivationHardSigmoid)
+ _params[1] = _mm512_set1_ps(params[1]);
+ for (size_t dc = 0; dc < dstCe; dc += F)
+ {
+ _bias[0] = _mm512_loadu_ps(bias + dc);
+ if (type == ::SimdConvolutionActivationPrelu)
+ _params[0] = _mm512_loadu_ps(params + dc);
+ __mmask16 tailS = TailMask16(dstC - dc);
+ __mmask32 tailC = (dc == dstCF && a.bufH[2]) ? TailMask32(dstCe - dstCF) : tailS;
+ for (size_t dy = yBeg; dy < yEnd; ++dy)
+ {
+ for (size_t dx = 0; dx < dstW; dx += 4)
+ {
+ d0 = _mm512_setzero_ps();
+ d1 = _mm512_setzero_ps();
+ d2 = _mm512_setzero_ps();
+ d3 = _mm512_setzero_ps();
+ for (size_t ky = 0; ky < 7; ++ky)
+ {
+ size_t sy = dy + ky - 3;
+ const T* ps = src + (sy & sM) * sY + (dx - 3) * sX;
+ const float* pw = weight + ky * 7 * F;
+ if (sy < srcH)
+ {
+ w0 = _mm512_maskz_loadu_ps(tailS, pw + 0 * F);
+ w1 = _mm512_maskz_loadu_ps(tailS, pw + 1 * F);
+ w2 = _mm512_maskz_loadu_ps(tailS, pw + 2 * F);
+ if (dx)
+ {
+ s0 = LoadSrc(ps + 0 * sX, tailS);
+ d0 = _mm512_fmadd_ps(s0, w0, d0);
+
+ s1 = LoadSrc(ps + 1 * sX, tailS);
+ d0 = _mm512_fmadd_ps(s1, w1, d0);
+ d1 = _mm512_fmadd_ps(s1, w0, d1);
+
+ s2 = LoadSrc(ps + 2 * sX, tailS);
+ d0 = _mm512_fmadd_ps(s2, w2, d0);
+ d1 = _mm512_fmadd_ps(s2, w1, d1);
+ d2 = _mm512_fmadd_ps(s2, w0, d2);
+ }
+ s3 = LoadSrc(ps + 3 * sX, tailS);
+ w3 = _mm512_maskz_loadu_ps(tailS, pw + 3 * F);
+ d0 = _mm512_fmadd_ps(s3, w3, d0);
+ d1 = _mm512_fmadd_ps(s3, w2, d1);
+ d2 = _mm512_fmadd_ps(s3, w1, d2);
+ d3 = _mm512_fmadd_ps(s3, w0, d3);
+
+ s4 = LoadSrc(ps + 4 * sX, tailS);
+ w4 = _mm512_maskz_loadu_ps(tailS, pw + 4 * F);
+ d0 = _mm512_fmadd_ps(s4, w4, d0);
+ d1 = _mm512_fmadd_ps(s4, w3, d1);
+ d2 = _mm512_fmadd_ps(s4, w2, d2);
+ d3 = _mm512_fmadd_ps(s4, w1, d3);
+
+ s5 = LoadSrc(ps + 5 * sX, tailS);
+ w5 = _mm512_maskz_loadu_ps(tailS, pw + 5 * F);
+ d0 = _mm512_fmadd_ps(s5, w5, d0);
+ d1 = _mm512_fmadd_ps(s5, w4, d1);
+ d2 = _mm512_fmadd_ps(s5, w3, d2);
+ d3 = _mm512_fmadd_ps(s5, w2, d3);
+
+ s6 = LoadSrc(ps + 6 * sX, tailS);
+ w6 = _mm512_maskz_loadu_ps(tailS, pw + 6 * F);
+ d0 = _mm512_fmadd_ps(s6, w6, d0);
+ d1 = _mm512_fmadd_ps(s6, w5, d1);
+ d2 = _mm512_fmadd_ps(s6, w4, d2);
+ d3 = _mm512_fmadd_ps(s6, w3, d3);
+ if (dx < endW)
+ {
+ s7 = LoadSrc(ps + 7 * sX, tailS);
+ d1 = _mm512_fmadd_ps(s7, w6, d1);
+ d2 = _mm512_fmadd_ps(s7, w5, d2);
+ d3 = _mm512_fmadd_ps(s7, w4, d3);
+
+ s8 = LoadSrc(ps + 8 * sX, tailS);
+ d2 = _mm512_fmadd_ps(s8, w6, d2);
+ d3 = _mm512_fmadd_ps(s8, w5, d3);
+
+ s9 = LoadSrc(ps + 9 * sX, tailS);
+ d3 = _mm512_fmadd_ps(s9, w6, d3);
+ }
+ }
+ }
+ uint8_t* pd = dst + (dy - dy0) * dY + dx * dX;
+ Save1(pd + 0 * dX, dD, d0, _bias, _params, tailC);
+ Save1(pd + 1 * dX, dD, d1, _bias, _params, tailC);
+ Save1(pd + 2 * dX, dD, d2, _bias, _params, tailC);
+ Save1(pd + 3 * dX, dD, d3, _bias, _params, tailC);
+ }
+ }
+ src += sD;
+ dst += dD;
+ weight += wD;
+ }
+ }
+
+ //-------------------------------------------------------------------------------------------------
+
template static void SetDepthwise(const ConvParam& p, DepthwisePtr& depthwise)
{
- if (IsKernel(p, 3) && IsDilation(p, 1) && Aligned(p.dstC, F))
+ if (p.IsKernel(7) && p.IsPad(3) && p.IsStride(1) && p.IsDilation(1) && Aligned(p.srcW, 4))
+ depthwise = DepthwiseConvolution_k7p3d1s1w4;
+ else if (IsKernel(p, 3) && IsDilation(p, 1) && Aligned(p.dstC, F))
depthwise = DepthwiseConvolution3x3;
else if(p.padX + p.padW > 2 && p.srcC >= 128)
depthwise = DepthwiseConvolutionLargePad;
diff --git a/src/Test/TestSynetMergedConvolution16b.cpp b/src/Test/TestSynetMergedConvolution16b.cpp
index 1ca4499613..334ff0d093 100644
--- a/src/Test/TestSynetMergedConvolution16b.cpp
+++ b/src/Test/TestSynetMergedConvolution16b.cpp
@@ -285,6 +285,8 @@ namespace Test
result = result && SynetMergedConvolution16bForwardAutoTest(eps, Param(Shp(1, 116, 15, 5), Cnv(a1, 3, 2), Cnv(a2, 1, 1, 116), f32, f32, c), f1, f2);
#endif
#if 1
+ result = result && SynetMergedConvolution16bForwardAutoTest(eps, Param(Shp(1, 256, 16, 16), Cnv(a1, 7, 1), Cnv(a2, 1, 1, 256), f32, f32, c), f1, f2);
+ result = result && SynetMergedConvolution16bForwardAutoTest(eps, Param(Shp(1, 256, 16, 16), Cnv(a1, 7, 1), Cnv(a2, 1, 1, 256), b16, b16, c), f1, f2);
result = result && SynetMergedConvolution16bForwardAutoTest(eps, Param(Shp(1, 304, 17, 15), Cnv(a1, 7, 1), Cnv(a2, 1, 1, 1216), b16, b16, c), f1, f2);
result = result && SynetMergedConvolution16bForwardAutoTest(eps, Param(Shp(1, 76, 64, 64), Cnv(a1, 7, 1), Cnv(a2, 1, 1, 304), f32, b16, c), f1, f2);
result = result && SynetMergedConvolution16bForwardAutoTest(eps, Param(Shp(1, 152, 32, 32), Cnv(a1, 7, 1), Cnv(a2, 1, 1, 608), f32, b16, c), f1, f2);