Line data Source code
1 : /*
2 : * Copyright(c) 2019 Intel Corporation
3 : * SPDX - License - Identifier: BSD - 2 - Clause - Patent
4 : */
5 :
6 : #include "EbDefinitions.h"
7 : #include <immintrin.h>
8 : #include <math.h>
9 :
10 : #define REDUCED_PRI_STRENGTHS 8
11 : #define REDUCED_TOTAL_STRENGTHS (REDUCED_PRI_STRENGTHS * CDEF_SEC_STRENGTHS)
12 : #define TOTAL_STRENGTHS (CDEF_PRI_STRENGTHS * CDEF_SEC_STRENGTHS)
13 :
14 : #ifndef _mm256_set_m128i
15 : #define _mm256_set_m128i(/* __m128i */ hi, /* __m128i */ lo) \
16 : _mm256_insertf128_si256(_mm256_castsi128_si256(lo), (hi), 0x1)
17 : #endif
18 :
19 : #ifndef _mm256_setr_m128i
20 : #define _mm256_setr_m128i(/* __m128i */ lo, /* __m128i */ hi) \
21 : _mm256_set_m128i((hi), (lo))
22 : #endif
23 :
24 : /* Search for the best luma+chroma strength to add as an option, knowing we
25 : already selected nb_strengths options. */
26 9000 : uint64_t search_one_dual_avx2(int *lev0, int *lev1, int nb_strengths,
27 : uint64_t(**mse)[TOTAL_STRENGTHS], int sb_count,
28 : int fast, int start_gi, int end_gi) {
29 : DECLARE_ALIGNED(32, uint64_t, tot_mse[TOTAL_STRENGTHS][TOTAL_STRENGTHS]);
30 : int i, j;
31 9000 : uint64_t best_tot_mse = (uint64_t)1 << 62;
32 9000 : int best_id0 = 0;
33 9000 : int best_id1 = 0;
34 : (void)fast;
35 9000 : const int total_strengths = end_gi;
36 : __m256i best_mse_;
37 : __m256i curr;
38 : __m256i v_tot;
39 : __m256i v_mse;
40 : __m256i mask;
41 : __m256i tmp;
42 :
43 9000 : memset(tot_mse, 0, sizeof(tot_mse));
44 :
45 116534 : for (i = 0; i < sb_count; i++) {
46 : int gi;
47 107534 : uint64_t best_mse = (uint64_t)1 << 62;
48 : /* Find best mse among already selected options. */
49 559248 : for (gi = 0; gi < nb_strengths; gi++) {
50 451714 : uint64_t curr = mse[0][i][lev0[gi]];
51 451714 : curr += mse[1][i][lev1[gi]];
52 451714 : if (curr < best_mse)
53 192270 : best_mse = curr;
54 : }
55 107534 : best_mse_ = _mm256_set1_epi64x(best_mse);
56 : /* Find best mse when adding each possible new option. */
57 : //assert(~total_strengths % 4);
58 2628380 : for (int j = start_gi; j < total_strengths; ++j) { // process by 4x4
59 2520850 : tmp = _mm256_set1_epi64x(mse[0][i][j]);
60 26990800 : for (int k = 0; k < total_strengths; k += 4) {
61 24469900 : v_mse = _mm256_loadu_si256((const __m256i*)&mse[1][i][k]);
62 48939900 : v_tot = _mm256_loadu_si256((const __m256i*)&tot_mse[j][k]);
63 24469900 : curr = _mm256_add_epi64(tmp, v_mse);
64 24469900 : mask = _mm256_cmpgt_epi64(best_mse_, curr);
65 73409800 : v_tot = _mm256_add_epi64(v_tot, _mm256_or_si256(
66 : _mm256_andnot_si256(mask, best_mse_),
67 : _mm256_and_si256(mask, curr)));
68 24469900 : _mm256_storeu_si256((__m256i*)&tot_mse[j][k], v_tot);
69 : }
70 : }
71 : }
72 162858 : for (j = start_gi; j < total_strengths; j++) {
73 : int k;
74 3788880 : for (k = start_gi; k < total_strengths; k++) {
75 3635020 : if (tot_mse[j][k] < best_tot_mse) {
76 55266 : best_tot_mse = tot_mse[j][k];
77 55266 : best_id0 = j;
78 55266 : best_id1 = k;
79 : }
80 : }
81 : }
82 9000 : lev0[nb_strengths] = best_id0;
83 9000 : lev1[nb_strengths] = best_id1;
84 :
85 9000 : return best_tot_mse;
86 : }
87 :
88 0 : static INLINE void mse_4x4_16bit_avx2(const uint16_t **src, const uint16_t *dst, const int32_t dstride, __m256i *sum) {
89 0 : const __m256i s = _mm256_loadu_si256((const __m256i*)*src);
90 0 : const __m256i d = _mm256_setr_epi64x(
91 0 : *(uint64_t*)(dst + 0 * dstride),
92 0 : *(uint64_t*)(dst + 1 * dstride),
93 0 : *(uint64_t*)(dst + 2 * dstride),
94 0 : *(uint64_t*)(dst + 3 * dstride));
95 0 : const __m256i diff = _mm256_sub_epi16(d, s);
96 0 : const __m256i mse = _mm256_madd_epi16(diff, diff);
97 0 : *sum = _mm256_add_epi32(*sum, mse);
98 0 : *src += 16;
99 0 : }
100 :
101 2370060 : static INLINE void mse_4x4_8bit_avx2(const uint8_t **src, const uint8_t *dst, const int32_t dstride, __m256i *sum) {
102 2370060 : const __m128i s = _mm_loadu_si128((const __m128i*)*src);
103 2370060 : const __m128i d = _mm_setr_epi32(
104 2370060 : *(uint32_t*)(dst + 0 * dstride),
105 2370060 : *(uint32_t*)(dst + 1 * dstride),
106 2370060 : *(uint32_t*)(dst + 2 * dstride),
107 2370060 : *(uint32_t*)(dst + 3 * dstride));
108 :
109 2370060 : const __m256i s_16 = _mm256_cvtepu8_epi16(s);
110 2370060 : const __m256i d_16 = _mm256_cvtepu8_epi16(d);
111 :
112 2370060 : const __m256i diff = _mm256_sub_epi16(d_16, s_16);
113 2370060 : const __m256i mse = _mm256_madd_epi16(diff, diff);
114 2370060 : *sum = _mm256_add_epi32(*sum, mse);
115 2370060 : *src += 16;
116 2370060 : }
117 :
118 0 : static INLINE void mse_8x2_16bit_avx2(const uint16_t **src, const uint16_t *dst, const int32_t dstride, __m256i *sum) {
119 0 : const __m256i s = _mm256_loadu_si256((const __m256i*)*src);
120 0 : const __m128i d0 = _mm_loadu_si128((const __m128i*)(dst + 0 * dstride));
121 0 : const __m128i d1 = _mm_loadu_si128((const __m128i*)(dst + 1 * dstride));
122 0 : const __m256i d = _mm256_setr_m128i(d0, d1);
123 0 : const __m256i diff = _mm256_sub_epi16(d, s);
124 0 : const __m256i mse = _mm256_madd_epi16(diff, diff);
125 0 : *sum = _mm256_add_epi32(*sum, mse);
126 0 : *src += 16;
127 0 : }
128 :
129 0 : static INLINE void mse_8x2_8bit_avx2(const uint8_t **src, const uint8_t *dst, const int32_t dstride, __m256i *sum) {
130 0 : const __m128i s = _mm_loadu_si128((const __m128i*)*src);
131 0 : const __m128i d = _mm_set_epi64x(*(uint64_t*)(dst + 1 * dstride),
132 0 : *(uint64_t*)(dst + 0 * dstride));
133 :
134 0 : const __m256i s_16 = _mm256_cvtepu8_epi16(s);
135 0 : const __m256i d_16 = _mm256_cvtepu8_epi16(d);
136 :
137 0 : const __m256i diff = _mm256_sub_epi16(d_16, s_16);
138 0 : const __m256i mse = _mm256_madd_epi16(diff, diff);
139 0 : *sum = _mm256_add_epi32(*sum, mse);
140 0 : *src += 16;
141 0 : }
142 :
143 0 : static INLINE void mse_8x4_16bit_avx2(const uint16_t **src, const uint16_t *dst, const int32_t dstride, __m256i *sum) {
144 0 : mse_8x2_16bit_avx2(src, dst + 0 * dstride, dstride, sum);
145 0 : mse_8x2_16bit_avx2(src, dst + 2 * dstride, dstride, sum);
146 0 : }
147 :
148 0 : static INLINE void mse_8x4_8bit_avx2(const uint8_t **src, const uint8_t *dst, const int32_t dstride, __m256i *sum) {
149 0 : mse_8x2_8bit_avx2(src, dst + 0 * dstride, dstride, sum);
150 0 : mse_8x2_8bit_avx2(src, dst + 2 * dstride, dstride, sum);
151 0 : }
152 :
153 3558670 : static INLINE uint32_t sum32(const __m256i src) {
154 3558670 : const __m128i src_L = _mm256_extracti128_si256(src, 0);
155 3558670 : const __m128i src_H = _mm256_extracti128_si256(src, 1);
156 3558670 : const __m128i s = _mm_add_epi32(src_L, src_H);
157 : __m128i dst;
158 :
159 3558670 : dst = _mm_hadd_epi32(s, s);
160 3558670 : dst = _mm_hadd_epi32(dst, dst);
161 :
162 3558670 : return (uint32_t)_mm_cvtsi128_si32(dst);
163 : }
164 :
165 0 : static INLINE uint64_t dist_8x8_16bit_avx2(const uint16_t **src, const uint16_t *dst, const int32_t dstride, const int32_t coeff_shift) {
166 0 : __m256i ss = _mm256_setzero_si256();
167 0 : __m256i dd = _mm256_setzero_si256();
168 0 : __m256i s2 = _mm256_setzero_si256();
169 0 : __m256i sd = _mm256_setzero_si256();
170 0 : __m256i d2 = _mm256_setzero_si256();
171 : __m256i ssdd;
172 : __m128i sum;
173 :
174 0 : for (int32_t r = 0; r < 4; r++) {
175 0 : const __m256i s = _mm256_loadu_si256((const __m256i*)*src);
176 0 : const __m128i d0 = _mm_loadu_si128((const __m128i*)(dst + 2 * r * dstride + 0 * dstride));
177 0 : const __m128i d1 = _mm_loadu_si128((const __m128i*)(dst + 2 * r * dstride + 1 * dstride));
178 0 : const __m256i d = _mm256_setr_m128i(d0, d1);
179 0 : ss = _mm256_add_epi16(ss, s);
180 0 : dd = _mm256_add_epi16(dd, d);
181 0 : s2 = _mm256_add_epi32(s2, _mm256_madd_epi16(s, s));
182 0 : sd = _mm256_add_epi32(sd, _mm256_madd_epi16(s, d));
183 0 : d2 = _mm256_add_epi32(d2, _mm256_madd_epi16(d, d));
184 0 : *src += 16;
185 : }
186 :
187 0 : ssdd = _mm256_hadd_epi16(ss, dd);
188 0 : ssdd = _mm256_hadd_epi16(ssdd, ssdd);
189 0 : ssdd = _mm256_unpacklo_epi16(ssdd, _mm256_setzero_si256());
190 0 : const __m128i ssdd_L = _mm256_extracti128_si256(ssdd, 0);
191 0 : const __m128i ssdd_H = _mm256_extracti128_si256(ssdd, 1);
192 0 : sum = _mm_add_epi32(ssdd_L, ssdd_H);
193 0 : sum = _mm_hadd_epi32(sum, sum);
194 :
195 : /* Compute the variance -- the calculation cannot go negative. */
196 0 : uint64_t sum_s = _mm_cvtsi128_si32(sum);
197 0 : uint64_t sum_d = _mm_extract_epi32(sum, 1);
198 0 : uint64_t sum_s2 = sum32(s2);
199 0 : uint64_t sum_d2 = sum32(d2);
200 0 : uint64_t sum_sd = sum32(sd);
201 :
202 : /* Compute the variance -- the calculation cannot go negative. */
203 0 : uint64_t svar = sum_s2 - ((sum_s * sum_s + 32) >> 6);
204 0 : uint64_t dvar = sum_d2 - ((sum_d * sum_d + 32) >> 6);
205 0 : return (uint64_t)floor(
206 0 : .5 + (sum_d2 + sum_s2 - 2 * sum_sd) * .5 *
207 0 : (svar + dvar + (400 << 2 * coeff_shift)) /
208 0 : (sqrt((20000 << 4 * coeff_shift) + svar * (double)dvar)));
209 : }
210 :
211 1184380 : static INLINE uint64_t dist_8x8_8bit_avx2(const uint8_t **src, const uint8_t *dst, const int32_t dstride, const int32_t coeff_shift) {
212 1184380 : __m256i ss = _mm256_setzero_si256();
213 1184380 : __m256i dd = _mm256_setzero_si256();
214 1184380 : __m256i s2 = _mm256_setzero_si256();
215 1184380 : __m256i sd = _mm256_setzero_si256();
216 1184380 : __m256i d2 = _mm256_setzero_si256();
217 : __m256i ssdd;
218 : __m128i sum;
219 :
220 5869610 : for (int32_t r = 0; r < 4; r++) {
221 4685240 : const __m128i s = _mm_loadu_si128((const __m128i*)*src);
222 4685240 : const __m128i d = _mm_set_epi64x(*(uint64_t*)(dst + 2 * r * dstride + 1 * dstride),
223 4685240 : *(uint64_t*)(dst + 2 * r * dstride + 0 * dstride));
224 :
225 4685240 : const __m256i s_16 = _mm256_cvtepu8_epi16(s);
226 4685240 : const __m256i d_16 = _mm256_cvtepu8_epi16(d);
227 :
228 4685240 : ss = _mm256_add_epi16(ss, s_16);
229 4685240 : dd = _mm256_add_epi16(dd, d_16);
230 9370470 : s2 = _mm256_add_epi32(s2, _mm256_madd_epi16(s_16, s_16));
231 9370470 : sd = _mm256_add_epi32(sd, _mm256_madd_epi16(s_16, d_16));
232 4685240 : d2 = _mm256_add_epi32(d2, _mm256_madd_epi16(d_16, d_16));
233 4685240 : *src += 16;
234 : }
235 :
236 1184380 : ssdd = _mm256_hadd_epi16(ss, dd);
237 1184380 : ssdd = _mm256_hadd_epi16(ssdd, ssdd);
238 1184380 : ssdd = _mm256_unpacklo_epi16(ssdd, _mm256_setzero_si256());
239 1184380 : const __m128i ssdd_L = _mm256_extracti128_si256(ssdd, 0);
240 1184380 : const __m128i ssdd_H = _mm256_extracti128_si256(ssdd, 1);
241 1184380 : sum = _mm_add_epi32(ssdd_L, ssdd_H);
242 1184380 : sum = _mm_hadd_epi32(sum, sum);
243 :
244 : /* Compute the variance -- the calculation cannot go negative. */
245 1184380 : uint64_t sum_s = _mm_cvtsi128_si32(sum);
246 1184380 : uint64_t sum_d = _mm_extract_epi32(sum, 1);
247 1184380 : uint64_t sum_s2 = sum32(s2);
248 1197190 : uint64_t sum_d2 = sum32(d2);
249 1195240 : uint64_t sum_sd = sum32(sd);
250 :
251 : /* Compute the variance -- the calculation cannot go negative. */
252 1191870 : uint64_t svar = sum_s2 - ((sum_s * sum_s + 32) >> 6);
253 1191870 : uint64_t dvar = sum_d2 - ((sum_d * sum_d + 32) >> 6);
254 2383730 : return (uint64_t)floor(
255 1191870 : .5 + (sum_d2 + sum_s2 - 2 * sum_sd) * .5 *
256 1191870 : (svar + dvar + (400 << 2 * coeff_shift)) /
257 1191870 : (sqrt((20000 << 4 * coeff_shift) + svar * (double)dvar)));
258 : }
259 :
260 2380200 : static INLINE void sum_32_to_64(const __m256i src, __m256i *dst) {
261 4760400 : const __m256i src_L = _mm256_unpacklo_epi32(src, _mm256_setzero_si256());
262 2380200 : const __m256i src_H = _mm256_unpackhi_epi32(src, _mm256_setzero_si256());
263 2380200 : *dst = _mm256_add_epi64(*dst, src_L);
264 2380200 : *dst = _mm256_add_epi64(*dst, src_H);
265 2380200 : }
266 :
267 67248 : static INLINE uint64_t sum64(const __m256i src) {
268 67248 : const __m128i src_L = _mm256_extracti128_si256(src, 0);
269 67248 : const __m128i src_H = _mm256_extracti128_si256(src, 1);
270 67248 : const __m128i s = _mm_add_epi64(src_L, src_H);
271 134496 : const __m128i dst = _mm_add_epi64(s, _mm_srli_si128(s, 8));
272 :
273 67248 : return (uint64_t)_mm_cvtsi128_si64(dst);
274 : }
275 :
276 : /* Compute MSE only on the blocks we filtered. */
277 0 : uint64_t compute_cdef_dist_avx2(const uint16_t *dst, int32_t dstride, const uint16_t *src, const cdef_list *dlist, int32_t cdef_count, BlockSize bsize, int32_t coeff_shift, int32_t pli) {
278 : uint64_t sum;
279 : int32_t bi, bx, by;
280 :
281 0 : if ((bsize == BLOCK_8X8) && (pli == 0)) {
282 0 : sum = 0;
283 0 : for (bi = 0; bi < cdef_count; bi++) {
284 0 : by = dlist[bi].by;
285 0 : bx = dlist[bi].bx;
286 0 : sum += dist_8x8_16bit_avx2(&src, dst + 8 * by * dstride + 8 * bx, dstride, coeff_shift);
287 : }
288 : }
289 : else {
290 0 : __m256i mse64 = _mm256_setzero_si256();
291 :
292 0 : if (bsize == BLOCK_8X8) {
293 0 : for (bi = 0; bi < cdef_count; bi++) {
294 0 : __m256i mse32 = _mm256_setzero_si256();
295 0 : by = dlist[bi].by;
296 0 : bx = dlist[bi].bx;
297 0 : mse_8x4_16bit_avx2(&src, dst + (8 * by + 0) * dstride + 8 * bx, dstride, &mse32);
298 0 : mse_8x4_16bit_avx2(&src, dst + (8 * by + 4) * dstride + 8 * bx, dstride, &mse32);
299 0 : sum_32_to_64(mse32, &mse64);
300 : }
301 : }
302 0 : else if (bsize == BLOCK_4X8) {
303 0 : for (bi = 0; bi < cdef_count; bi++) {
304 0 : __m256i mse32 = _mm256_setzero_si256();
305 0 : by = dlist[bi].by;
306 0 : bx = dlist[bi].bx;
307 0 : mse_4x4_16bit_avx2(&src, dst + (8 * by + 0) * dstride + 4 * bx, dstride, &mse32);
308 0 : mse_4x4_16bit_avx2(&src, dst + (8 * by + 4) * dstride + 4 * bx, dstride, &mse32);
309 0 : sum_32_to_64(mse32, &mse64);
310 : }
311 : }
312 0 : else if (bsize == BLOCK_8X4) {
313 0 : for (bi = 0; bi < cdef_count; bi++) {
314 0 : __m256i mse32 = _mm256_setzero_si256();
315 0 : by = dlist[bi].by;
316 0 : bx = dlist[bi].bx;
317 0 : mse_8x4_16bit_avx2(&src, dst + 4 * by * dstride + 8 * bx, dstride, &mse32);
318 0 : sum_32_to_64(mse32, &mse64);
319 : }
320 : }
321 : else {
322 : assert(bsize == BLOCK_4X4);
323 0 : for (bi = 0; bi < cdef_count; bi++) {
324 0 : __m256i mse32 = _mm256_setzero_si256();
325 0 : by = dlist[bi].by;
326 0 : bx = dlist[bi].bx;
327 0 : mse_4x4_16bit_avx2(&src, dst + 4 * by * dstride + 4 * bx, dstride, &mse32);
328 0 : sum_32_to_64(mse32, &mse64);
329 : }
330 : }
331 :
332 0 : sum = sum64(mse64);
333 : }
334 :
335 0 : return sum >> 2 * coeff_shift;
336 : }
337 :
338 100727 : uint64_t compute_cdef_dist_8bit_avx2(const uint8_t *dst8, int32_t dstride, const uint8_t *src8, const cdef_list *dlist, int32_t cdef_count, BlockSize bsize, int32_t coeff_shift, int32_t pli) {
339 : uint64_t sum;
340 : int32_t bi, bx, by;
341 :
342 100727 : if ((bsize == BLOCK_8X8) && (pli == 0)) {
343 33594 : sum = 0;
344 1225640 : for (bi = 0; bi < cdef_count; bi++) {
345 1191710 : by = dlist[bi].by;
346 1191710 : bx = dlist[bi].bx;
347 1191710 : sum += dist_8x8_8bit_avx2(&src8, dst8 + 8 * by * dstride + 8 * bx, dstride, coeff_shift);
348 : }
349 : }
350 : else {
351 67133 : __m256i mse64 = _mm256_setzero_si256();
352 :
353 67133 : if (bsize == BLOCK_8X8) {
354 0 : for (bi = 0; bi < cdef_count; bi++) {
355 0 : __m256i mse32 = _mm256_setzero_si256();
356 0 : by = dlist[bi].by;
357 0 : bx = dlist[bi].bx;
358 0 : mse_8x4_8bit_avx2(&src8, dst8 + (8 * by + 0) * dstride + 8 * bx, dstride, &mse32);
359 0 : mse_8x4_8bit_avx2(&src8, dst8 + (8 * by + 4) * dstride + 8 * bx, dstride, &mse32);
360 0 : sum_32_to_64(mse32, &mse64);
361 : }
362 : }
363 67133 : else if (bsize == BLOCK_4X8) {
364 0 : for (bi = 0; bi < cdef_count; bi++) {
365 0 : __m256i mse32 = _mm256_setzero_si256();
366 0 : by = dlist[bi].by;
367 0 : bx = dlist[bi].bx;
368 0 : mse_4x4_8bit_avx2(&src8, dst8 + (8 * by + 0) * dstride + 4 * bx, dstride, &mse32);
369 0 : mse_4x4_8bit_avx2(&src8, dst8 + (8 * by + 4) * dstride + 4 * bx, dstride, &mse32);
370 0 : sum_32_to_64(mse32, &mse64);
371 : }
372 : }
373 67133 : else if (bsize == BLOCK_8X4) {
374 0 : for (bi = 0; bi < cdef_count; bi++) {
375 0 : __m256i mse32 = _mm256_setzero_si256();
376 0 : by = dlist[bi].by;
377 0 : bx = dlist[bi].bx;
378 0 : mse_8x4_8bit_avx2(&src8, dst8 + 4 * by * dstride + 8 * bx, dstride, &mse32);
379 0 : sum_32_to_64(mse32, &mse64);
380 : }
381 : }
382 : else {
383 : assert(bsize == BLOCK_4X4);
384 2438530 : for (bi = 0; bi < cdef_count; bi++) {
385 2371600 : __m256i mse32 = _mm256_setzero_si256();
386 2371600 : by = dlist[bi].by;
387 2371600 : bx = dlist[bi].bx;
388 2371600 : mse_4x4_8bit_avx2(&src8, dst8 + 4 * by * dstride + 4 * bx, dstride, &mse32);
389 2380990 : sum_32_to_64(mse32, &mse64);
390 : }
391 : }
392 :
393 66928 : sum = sum64(mse64);
394 : }
395 101185 : return sum >> 2 * coeff_shift;
396 : }
|