Skip to content

Commit

Permalink
+add AVX2 optimizations of class ResizerBf16Bilinear (part 3).
Browse files Browse the repository at this point in the history
  • Loading branch information
ermig1979 committed Jan 5, 2025
1 parent 7ed5541 commit fa8c22f
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 49 deletions.
134 changes: 86 additions & 48 deletions src/Simd/SimdAvx2ResizerBilinear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -954,13 +954,22 @@ namespace Simd
{
__m128 s0 = Sse41::BFloat16ToFloat32(Sse41::UnpackU16<0>(_mm_loadl_epi64((__m128i*)src)));
__m128 s1 = Sse41::BFloat16ToFloat32(Sse41::UnpackU16<0>(_mm_loadl_epi64((__m128i*)(src + channels))));
return _mm_add_ps(_mm_mul_ps(fx0, s0), _mm_mul_ps(fx1, s1));
return _mm_fmadd_ps(fx0, s0, _mm_mul_ps(fx1, s1));
}

SIMD_INLINE __m256 BilinearRowSumBf16(const uint16_t* src, size_t channels, __m256 fx0, __m256 fx1)
{
__m256 s0 = BFloat16ToFloat32(_mm256_cvtepu16_epi32(_mm_loadu_si128((__m128i*)src)));
__m256 s1 = BFloat16ToFloat32(_mm256_cvtepu16_epi32(_mm_loadu_si128((__m128i*)(src + channels))));
return _mm256_fmadd_ps(fx0, s0, _mm256_mul_ps(fx1, s1));
}

void ResizerBf16Bilinear::Run(const uint16_t* src, size_t srcStride, uint16_t* dst, size_t dstStride)
{
size_t cn = _param.channels, cnF = AlignLo(cn, Sse41::F), cnT = cn - cnF, cnL = cnT - Sse41::F;
__m128 _1 = _mm_set1_ps(1.0f);
size_t cn = _param.channels,
cnH = AlignLo(cn, Sse41::F), cnTH = cn - cnH, cnLH = cnTH - Sse41::F,
cnF = AlignLo(cn, F), cnTF = cn - cnF, cnLF = cnTF - F;
__m256 _1 = _mm256_set1_ps(1.0f);
if (_rowBuf)
{
size_t rs = _param.dstW * cn, rsH = AlignLo(rs, Sse41::F), rsF = AlignLo(rs, F);
Expand Down Expand Up @@ -988,6 +997,19 @@ namespace Simd
float* pb = pbx[k];
const uint16_t* ps = src + (sy + k) * srcStride;
size_t dx = 0;
if (cn >= 4)
{
for (; dx < rs;)
{
const uint16_t* ps0 = ps + _ix[dx];
__m128 fx1 = _mm_set1_ps(_ax[dx]);
__m128 fx0 = _mm_sub_ps(_mm256_castps256_ps128(_1), fx1);
for (size_t end = dx + cnH; dx < end; dx += Sse41::F, ps0 += Sse41::F)
_mm_storeu_ps(pb + dx, BilinearRowSumBf16(ps0, cn, fx0, fx1));
if (cnTH)
_mm_storeu_ps(pb + dx + cnLH, BilinearRowSumBf16(ps0 + cnLH, cn, fx0, fx1)), dx += cnTH;
}
}
if (cn == 1)
{
for (; dx < rsH; dx += Sse41::F)
Expand All @@ -1001,10 +1023,8 @@ namespace Simd
__m128 s0 = Sse41::BFloat16ToFloat32Even(_src);
__m128 s1 = Sse41::BFloat16ToFloat32Odd(_src);
__m128 fx1 = _mm_loadu_ps(_ax.data + dx);
__m128 fx0 = _mm_sub_ps(_1, fx1);
__m128 m0 = _mm_mul_ps(fx0, s0);
__m128 m1 = _mm_mul_ps(fx1, s1);
_mm_storeu_ps(pb + dx, _mm_add_ps(m0, m1));
__m128 fx0 = _mm_sub_ps(_mm256_castps256_ps128(_1), fx1);
_mm_storeu_ps(pb + dx, _mm_fmadd_ps(fx0, s0, _mm_mul_ps(fx1, s1)));
}
}
if (cn == 2)
Expand All @@ -1015,10 +1035,8 @@ namespace Simd
__m128 s0 = _mm_castsi128_ps(_mm_shuffle_epi8(_src, K8_IDX_20));
__m128 s1 = _mm_castsi128_ps(_mm_shuffle_epi8(_src, K8_IDX_21));
__m128 fx1 = _mm_loadu_ps(_ax.data + dx);
__m128 fx0 = _mm_sub_ps(_1, fx1);
__m128 m0 = _mm_mul_ps(fx0, s0);
__m128 m1 = _mm_mul_ps(fx1, s1);
_mm_storeu_ps(pb + dx, _mm_add_ps(m0, m1));
__m128 fx0 = _mm_sub_ps(_mm256_castps256_ps128(_1), fx1);
_mm_storeu_ps(pb + dx, _mm_fmadd_ps(fx0, s0, _mm_mul_ps(fx1, s1)));
}
}
if (cn == 3 && rs > 3)
Expand All @@ -1027,23 +1045,10 @@ namespace Simd
for (; dx < rs3; dx += 3)
{
__m128 fx1 = _mm_set1_ps(_ax.data[dx]);
__m128 fx0 = _mm_sub_ps(_1, fx1);
__m128 fx0 = _mm_sub_ps(_mm256_castps256_ps128(_1), fx1);
_mm_storeu_ps(pb + dx, BilinearRowSumBf16(ps + _ix[dx], cn, fx0, fx1));
}
}
if (cn >= 4)
{
for (; dx < rs;)
{
const uint16_t* ps0 = ps + _ix[dx];
__m128 fx1 = _mm_set1_ps(_ax[dx]);
__m128 fx0 = _mm_sub_ps(_1, fx1);
for (size_t end = dx + cnF; dx < end; dx += Sse41::F, ps0 += Sse41::F)
_mm_storeu_ps(pb + dx, BilinearRowSumBf16(ps0, cn, fx0, fx1));
if (cnT)
_mm_storeu_ps(pb + dx + cnL, BilinearRowSumBf16(ps0 + cnL, cn, fx0, fx1)), dx += cnT;
}
}
for (; dx < rs; dx++)
{
int32_t sx = _ix[dx];
Expand All @@ -1062,9 +1067,8 @@ namespace Simd
}
for (; dx < rsH; dx += Sse41::F)
{
__m128 m0 = _mm_mul_ps(_mm_loadu_ps(pbx[0] + dx), _mm256_castps256_ps128(_fy0));
__m128 m1 = _mm_mul_ps(_mm_loadu_ps(pbx[1] + dx), _mm256_castps256_ps128(_fy1));
__m128i d0 = Sse41::Float32ToBFloat16(_mm_add_ps(m0, m1));
__m128i d0 = Sse41::Float32ToBFloat16(_mm_fmadd_ps(_mm_loadu_ps(pbx[0] + dx), _mm256_castps256_ps128(_fy0),
_mm_mul_ps(_mm_loadu_ps(pbx[1] + dx), _mm256_castps256_ps128(_fy1))));
_mm_storel_epi64((__m128i*)(dst + dx), _mm_packus_epi32(d0, Sse41::K_ZERO));
}
for (; dx < rs; dx++)
Expand All @@ -1073,31 +1077,65 @@ namespace Simd
}
else
{
for (size_t dy = 0; dy < _param.dstH; dy++, dst += dstStride)
if (cnF)
{
__m128 fy1 = _mm_set1_ps(_ay[dy]);
__m128 fy0 = _mm_sub_ps(_1, fy1);
const uint16_t* src0 = src + _iy[dy] * srcStride, * src1 = src0 + srcStride;
for (size_t dx = 0; dx < _param.dstW; dx++)
for (size_t dy = 0; dy < _param.dstH; dy++, dst += dstStride)
{
size_t os = _ix[dx], end = os + cnF, od = dx * cn;
__m128 fx1 = _mm_set1_ps(_ax[dx]);
__m128 fx0 = _mm_sub_ps(_1, fx1);
for (; os < end; os += Sse41::F, od += Sse41::F)
__m256 fy1 = _mm256_set1_ps(_ay[dy]);
__m256 fy0 = _mm256_sub_ps(_1, fy1);
const uint16_t* src0 = src + _iy[dy] * srcStride, * src1 = src0 + srcStride;
for (size_t dx = 0; dx < _param.dstW; dx++)
{
__m128 r0 = BilinearRowSumBf16(src0 + os, cn, fx0, fx1);
__m128 r1 = BilinearRowSumBf16(src1 + os, cn, fx0, fx1);
__m128i d0 = Sse41::Float32ToBFloat16(_mm_add_ps(_mm_mul_ps(r0, fy0), _mm_mul_ps(r1, fy1)));
_mm_storel_epi64((__m128i*)(dst + od), _mm_packus_epi32(d0, Sse41::K_ZERO));
size_t os = _ix[dx], end = os + cnF, od = dx * cn;
__m256 fx1 = _mm256_set1_ps(_ax[dx]);
__m256 fx0 = _mm256_sub_ps(_1, fx1);
for (; os < end; os += F, od += F)
{
__m256 r0 = BilinearRowSumBf16(src0 + os, cn, fx0, fx1);
__m256 r1 = BilinearRowSumBf16(src1 + os, cn, fx0, fx1);
__m256i d0 = Float32ToBFloat16(_mm256_fmadd_ps(r0, fy0, _mm256_mul_ps(r1, fy1)));
_mm_storeu_si128((__m128i*)(dst + od), _mm256_castsi256_si128(_mm256_permute4x64_epi64(_mm256_packus_epi32(d0, K_ZERO), 0xD8)));
}
if (cnTH)
{
os += cnLH;
od += cnLH;
__m256 r0 = BilinearRowSumBf16(src0 + os, cn, fx0, fx1);
__m256 r1 = BilinearRowSumBf16(src1 + os, cn, fx0, fx1);
__m256i d0 = Float32ToBFloat16(_mm256_fmadd_ps(r0, fy0, _mm256_mul_ps(r1, fy1)));
_mm_storeu_si128((__m128i*)(dst + od), _mm256_castsi256_si128(_mm256_permute4x64_epi64(_mm256_packus_epi32(d0, K_ZERO), 0xD8)));
}
}
if (cnT)
}
}
else
{
for (size_t dy = 0; dy < _param.dstH; dy++, dst += dstStride)
{
__m128 fy1 = _mm_set1_ps(_ay[dy]);
__m128 fy0 = _mm_sub_ps(_mm256_castps256_ps128(_1), fy1);
const uint16_t* src0 = src + _iy[dy] * srcStride, * src1 = src0 + srcStride;
for (size_t dx = 0; dx < _param.dstW; dx++)
{
os += cnL;
od += cnL;
__m128 r0 = BilinearRowSumBf16(src0 + os, cn, fx0, fx1);
__m128 r1 = BilinearRowSumBf16(src1 + os, cn, fx0, fx1);
__m128i d0 = Sse41::Float32ToBFloat16(_mm_add_ps(_mm_mul_ps(r0, fy0), _mm_mul_ps(r1, fy1)));
_mm_storel_epi64((__m128i*)(dst + od), _mm_packus_epi32(d0, Sse41::K_ZERO));
size_t os = _ix[dx], end = os + cnH, od = dx * cn;
__m128 fx1 = _mm_set1_ps(_ax[dx]);
__m128 fx0 = _mm_sub_ps(_mm256_castps256_ps128(_1), fx1);
for (; os < end; os += Sse41::F, od += Sse41::F)
{
__m128 r0 = BilinearRowSumBf16(src0 + os, cn, fx0, fx1);
__m128 r1 = BilinearRowSumBf16(src1 + os, cn, fx0, fx1);
__m128i d0 = Sse41::Float32ToBFloat16(_mm_fmadd_ps(r0, fy0, _mm_mul_ps(r1, fy1)));
_mm_storel_epi64((__m128i*)(dst + od), _mm_packus_epi32(d0, Sse41::K_ZERO));
}
if (cnTH)
{
os += cnLH;
od += cnLH;
__m128 r0 = BilinearRowSumBf16(src0 + os, cn, fx0, fx1);
__m128 r1 = BilinearRowSumBf16(src1 + os, cn, fx0, fx1);
__m128i d0 = Sse41::Float32ToBFloat16(_mm_fmadd_ps(r0, fy0, _mm_mul_ps(r1, fy1)));
_mm_storel_epi64((__m128i*)(dst + od), _mm_packus_epi32(d0, Sse41::K_ZERO));
}
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/Simd/SimdBaseResizerBilinear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,7 @@ namespace Simd
ResizerBf16Bilinear::ResizerBf16Bilinear(const ResParam& param)
: Resizer(param)
{
_rowBuf = !(_param.align >= 16 && (_param.channels >= _param.align / 4 || _param.channels == 64)) || _param.dstH >= _param.srcH;
_rowBuf = _param.align < 16 || _param.channels < 4 || _param.dstH >= _param.srcH;
_ay.Resize(_param.dstH, false, _param.align);
_iy.Resize(_param.dstH, false, _param.align);
EstimateIndexAlpha(_param, _param.srcH, _param.dstH, 1, 1, _iy.data, _ay.data);
Expand Down

0 comments on commit fa8c22f

Please sign in to comment.