chibipub

experimental activitypub node in C
git clone git://jb55.com/chibipub
Log | Files | Refs | README | LICENSE

blake3_avx512.c (47960B)


      1 #include "blake3_impl.h"
      2 
      3 #include <immintrin.h>
      4 
      5 #define _mm_shuffle_ps2(a, b, c)                                               \
      6   (_mm_castps_si128(                                                           \
      7       _mm_shuffle_ps(_mm_castsi128_ps(a), _mm_castsi128_ps(b), (c))))
      8 
      9 INLINE __m128i loadu_128(const uint8_t src[16]) {
     10   return _mm_loadu_si128((const __m128i *)src);
     11 }
     12 
     13 INLINE __m256i loadu_256(const uint8_t src[32]) {
     14   return _mm256_loadu_si256((const __m256i *)src);
     15 }
     16 
     17 INLINE __m512i loadu_512(const uint8_t src[64]) {
     18   return _mm512_loadu_si512((const __m512i *)src);
     19 }
     20 
     21 INLINE void storeu_128(__m128i src, uint8_t dest[16]) {
     22   _mm_storeu_si128((__m128i *)dest, src);
     23 }
     24 
     25 INLINE void storeu_256(__m256i src, uint8_t dest[16]) {
     26   _mm256_storeu_si256((__m256i *)dest, src);
     27 }
     28 
     29 INLINE __m128i add_128(__m128i a, __m128i b) { return _mm_add_epi32(a, b); }
     30 
     31 INLINE __m256i add_256(__m256i a, __m256i b) { return _mm256_add_epi32(a, b); }
     32 
     33 INLINE __m512i add_512(__m512i a, __m512i b) { return _mm512_add_epi32(a, b); }
     34 
     35 INLINE __m128i xor_128(__m128i a, __m128i b) { return _mm_xor_si128(a, b); }
     36 
     37 INLINE __m256i xor_256(__m256i a, __m256i b) { return _mm256_xor_si256(a, b); }
     38 
     39 INLINE __m512i xor_512(__m512i a, __m512i b) { return _mm512_xor_si512(a, b); }
     40 
     41 INLINE __m128i set1_128(uint32_t x) { return _mm_set1_epi32((int32_t)x); }
     42 
     43 INLINE __m256i set1_256(uint32_t x) { return _mm256_set1_epi32((int32_t)x); }
     44 
     45 INLINE __m512i set1_512(uint32_t x) { return _mm512_set1_epi32((int32_t)x); }
     46 
     47 INLINE __m128i set4(uint32_t a, uint32_t b, uint32_t c, uint32_t d) {
     48   return _mm_setr_epi32((int32_t)a, (int32_t)b, (int32_t)c, (int32_t)d);
     49 }
     50 
     51 INLINE __m128i rot16_128(__m128i x) { return _mm_ror_epi32(x, 16); }
     52 
     53 INLINE __m256i rot16_256(__m256i x) { return _mm256_ror_epi32(x, 16); }
     54 
     55 INLINE __m512i rot16_512(__m512i x) { return _mm512_ror_epi32(x, 16); }
     56 
     57 INLINE __m128i rot12_128(__m128i x) { return _mm_ror_epi32(x, 12); }
     58 
     59 INLINE __m256i rot12_256(__m256i x) { return _mm256_ror_epi32(x, 12); }
     60 
     61 INLINE __m512i rot12_512(__m512i x) { return _mm512_ror_epi32(x, 12); }
     62 
     63 INLINE __m128i rot8_128(__m128i x) { return _mm_ror_epi32(x, 8); }
     64 
     65 INLINE __m256i rot8_256(__m256i x) { return _mm256_ror_epi32(x, 8); }
     66 
     67 INLINE __m512i rot8_512(__m512i x) { return _mm512_ror_epi32(x, 8); }
     68 
     69 INLINE __m128i rot7_128(__m128i x) { return _mm_ror_epi32(x, 7); }
     70 
     71 INLINE __m256i rot7_256(__m256i x) { return _mm256_ror_epi32(x, 7); }
     72 
     73 INLINE __m512i rot7_512(__m512i x) { return _mm512_ror_epi32(x, 7); }
     74 
     75 /*
     76  * ----------------------------------------------------------------------------
     77  * compress_avx512
     78  * ----------------------------------------------------------------------------
     79  */
     80 
     81 INLINE void g1(__m128i *row0, __m128i *row1, __m128i *row2, __m128i *row3,
     82                __m128i m) {
     83   *row0 = add_128(add_128(*row0, m), *row1);
     84   *row3 = xor_128(*row3, *row0);
     85   *row3 = rot16_128(*row3);
     86   *row2 = add_128(*row2, *row3);
     87   *row1 = xor_128(*row1, *row2);
     88   *row1 = rot12_128(*row1);
     89 }
     90 
     91 INLINE void g2(__m128i *row0, __m128i *row1, __m128i *row2, __m128i *row3,
     92                __m128i m) {
     93   *row0 = add_128(add_128(*row0, m), *row1);
     94   *row3 = xor_128(*row3, *row0);
     95   *row3 = rot8_128(*row3);
     96   *row2 = add_128(*row2, *row3);
     97   *row1 = xor_128(*row1, *row2);
     98   *row1 = rot7_128(*row1);
     99 }
    100 
    101 // Note the optimization here of leaving row1 as the unrotated row, rather than
    102 // row0. All the message loads below are adjusted to compensate for this. See
    103 // discussion at https://github.com/sneves/blake2-avx2/pull/4
    104 INLINE void diagonalize(__m128i *row0, __m128i *row2, __m128i *row3) {
    105   *row0 = _mm_shuffle_epi32(*row0, _MM_SHUFFLE(2, 1, 0, 3));
    106   *row3 = _mm_shuffle_epi32(*row3, _MM_SHUFFLE(1, 0, 3, 2));
    107   *row2 = _mm_shuffle_epi32(*row2, _MM_SHUFFLE(0, 3, 2, 1));
    108 }
    109 
    110 INLINE void undiagonalize(__m128i *row0, __m128i *row2, __m128i *row3) {
    111   *row0 = _mm_shuffle_epi32(*row0, _MM_SHUFFLE(0, 3, 2, 1));
    112   *row3 = _mm_shuffle_epi32(*row3, _MM_SHUFFLE(1, 0, 3, 2));
    113   *row2 = _mm_shuffle_epi32(*row2, _MM_SHUFFLE(2, 1, 0, 3));
    114 }
    115 
    116 INLINE void compress_pre(__m128i rows[4], const uint32_t cv[8],
    117                          const uint8_t block[BLAKE3_BLOCK_LEN],
    118                          uint8_t block_len, uint64_t counter, uint8_t flags) {
    119   rows[0] = loadu_128((uint8_t *)&cv[0]);
    120   rows[1] = loadu_128((uint8_t *)&cv[4]);
    121   rows[2] = set4(IV[0], IV[1], IV[2], IV[3]);
    122   rows[3] = set4(counter_low(counter), counter_high(counter),
    123                  (uint32_t)block_len, (uint32_t)flags);
    124 
    125   __m128i m0 = loadu_128(&block[sizeof(__m128i) * 0]);
    126   __m128i m1 = loadu_128(&block[sizeof(__m128i) * 1]);
    127   __m128i m2 = loadu_128(&block[sizeof(__m128i) * 2]);
    128   __m128i m3 = loadu_128(&block[sizeof(__m128i) * 3]);
    129 
    130   __m128i t0, t1, t2, t3, tt;
    131 
    132   // Round 1. The first round permutes the message words from the original
    133   // input order, into the groups that get mixed in parallel.
    134   t0 = _mm_shuffle_ps2(m0, m1, _MM_SHUFFLE(2, 0, 2, 0)); //  6  4  2  0
    135   g1(&rows[0], &rows[1], &rows[2], &rows[3], t0);
    136   t1 = _mm_shuffle_ps2(m0, m1, _MM_SHUFFLE(3, 1, 3, 1)); //  7  5  3  1
    137   g2(&rows[0], &rows[1], &rows[2], &rows[3], t1);
    138   diagonalize(&rows[0], &rows[2], &rows[3]);
    139   t2 = _mm_shuffle_ps2(m2, m3, _MM_SHUFFLE(2, 0, 2, 0)); // 14 12 10  8
    140   t2 = _mm_shuffle_epi32(t2, _MM_SHUFFLE(2, 1, 0, 3));   // 12 10  8 14
    141   g1(&rows[0], &rows[1], &rows[2], &rows[3], t2);
    142   t3 = _mm_shuffle_ps2(m2, m3, _MM_SHUFFLE(3, 1, 3, 1)); // 15 13 11  9
    143   t3 = _mm_shuffle_epi32(t3, _MM_SHUFFLE(2, 1, 0, 3));   // 13 11  9 15
    144   g2(&rows[0], &rows[1], &rows[2], &rows[3], t3);
    145   undiagonalize(&rows[0], &rows[2], &rows[3]);
    146   m0 = t0;
    147   m1 = t1;
    148   m2 = t2;
    149   m3 = t3;
    150 
    151   // Round 2. This round and all following rounds apply a fixed permutation
    152   // to the message words from the round before.
    153   t0 = _mm_shuffle_ps2(m0, m1, _MM_SHUFFLE(3, 1, 1, 2));
    154   t0 = _mm_shuffle_epi32(t0, _MM_SHUFFLE(0, 3, 2, 1));
    155   g1(&rows[0], &rows[1], &rows[2], &rows[3], t0);
    156   t1 = _mm_shuffle_ps2(m2, m3, _MM_SHUFFLE(3, 3, 2, 2));
    157   tt = _mm_shuffle_epi32(m0, _MM_SHUFFLE(0, 0, 3, 3));
    158   t1 = _mm_blend_epi16(tt, t1, 0xCC);
    159   g2(&rows[0], &rows[1], &rows[2], &rows[3], t1);
    160   diagonalize(&rows[0], &rows[2], &rows[3]);
    161   t2 = _mm_unpacklo_epi64(m3, m1);
    162   tt = _mm_blend_epi16(t2, m2, 0xC0);
    163   t2 = _mm_shuffle_epi32(tt, _MM_SHUFFLE(1, 3, 2, 0));
    164   g1(&rows[0], &rows[1], &rows[2], &rows[3], t2);
    165   t3 = _mm_unpackhi_epi32(m1, m3);
    166   tt = _mm_unpacklo_epi32(m2, t3);
    167   t3 = _mm_shuffle_epi32(tt, _MM_SHUFFLE(0, 1, 3, 2));
    168   g2(&rows[0], &rows[1], &rows[2], &rows[3], t3);
    169   undiagonalize(&rows[0], &rows[2], &rows[3]);
    170   m0 = t0;
    171   m1 = t1;
    172   m2 = t2;
    173   m3 = t3;
    174 
    175   // Round 3
    176   t0 = _mm_shuffle_ps2(m0, m1, _MM_SHUFFLE(3, 1, 1, 2));
    177   t0 = _mm_shuffle_epi32(t0, _MM_SHUFFLE(0, 3, 2, 1));
    178   g1(&rows[0], &rows[1], &rows[2], &rows[3], t0);
    179   t1 = _mm_shuffle_ps2(m2, m3, _MM_SHUFFLE(3, 3, 2, 2));
    180   tt = _mm_shuffle_epi32(m0, _MM_SHUFFLE(0, 0, 3, 3));
    181   t1 = _mm_blend_epi16(tt, t1, 0xCC);
    182   g2(&rows[0], &rows[1], &rows[2], &rows[3], t1);
    183   diagonalize(&rows[0], &rows[2], &rows[3]);
    184   t2 = _mm_unpacklo_epi64(m3, m1);
    185   tt = _mm_blend_epi16(t2, m2, 0xC0);
    186   t2 = _mm_shuffle_epi32(tt, _MM_SHUFFLE(1, 3, 2, 0));
    187   g1(&rows[0], &rows[1], &rows[2], &rows[3], t2);
    188   t3 = _mm_unpackhi_epi32(m1, m3);
    189   tt = _mm_unpacklo_epi32(m2, t3);
    190   t3 = _mm_shuffle_epi32(tt, _MM_SHUFFLE(0, 1, 3, 2));
    191   g2(&rows[0], &rows[1], &rows[2], &rows[3], t3);
    192   undiagonalize(&rows[0], &rows[2], &rows[3]);
    193   m0 = t0;
    194   m1 = t1;
    195   m2 = t2;
    196   m3 = t3;
    197 
    198   // Round 4
    199   t0 = _mm_shuffle_ps2(m0, m1, _MM_SHUFFLE(3, 1, 1, 2));
    200   t0 = _mm_shuffle_epi32(t0, _MM_SHUFFLE(0, 3, 2, 1));
    201   g1(&rows[0], &rows[1], &rows[2], &rows[3], t0);
    202   t1 = _mm_shuffle_ps2(m2, m3, _MM_SHUFFLE(3, 3, 2, 2));
    203   tt = _mm_shuffle_epi32(m0, _MM_SHUFFLE(0, 0, 3, 3));
    204   t1 = _mm_blend_epi16(tt, t1, 0xCC);
    205   g2(&rows[0], &rows[1], &rows[2], &rows[3], t1);
    206   diagonalize(&rows[0], &rows[2], &rows[3]);
    207   t2 = _mm_unpacklo_epi64(m3, m1);
    208   tt = _mm_blend_epi16(t2, m2, 0xC0);
    209   t2 = _mm_shuffle_epi32(tt, _MM_SHUFFLE(1, 3, 2, 0));
    210   g1(&rows[0], &rows[1], &rows[2], &rows[3], t2);
    211   t3 = _mm_unpackhi_epi32(m1, m3);
    212   tt = _mm_unpacklo_epi32(m2, t3);
    213   t3 = _mm_shuffle_epi32(tt, _MM_SHUFFLE(0, 1, 3, 2));
    214   g2(&rows[0], &rows[1], &rows[2], &rows[3], t3);
    215   undiagonalize(&rows[0], &rows[2], &rows[3]);
    216   m0 = t0;
    217   m1 = t1;
    218   m2 = t2;
    219   m3 = t3;
    220 
    221   // Round 5
    222   t0 = _mm_shuffle_ps2(m0, m1, _MM_SHUFFLE(3, 1, 1, 2));
    223   t0 = _mm_shuffle_epi32(t0, _MM_SHUFFLE(0, 3, 2, 1));
    224   g1(&rows[0], &rows[1], &rows[2], &rows[3], t0);
    225   t1 = _mm_shuffle_ps2(m2, m3, _MM_SHUFFLE(3, 3, 2, 2));
    226   tt = _mm_shuffle_epi32(m0, _MM_SHUFFLE(0, 0, 3, 3));
    227   t1 = _mm_blend_epi16(tt, t1, 0xCC);
    228   g2(&rows[0], &rows[1], &rows[2], &rows[3], t1);
    229   diagonalize(&rows[0], &rows[2], &rows[3]);
    230   t2 = _mm_unpacklo_epi64(m3, m1);
    231   tt = _mm_blend_epi16(t2, m2, 0xC0);
    232   t2 = _mm_shuffle_epi32(tt, _MM_SHUFFLE(1, 3, 2, 0));
    233   g1(&rows[0], &rows[1], &rows[2], &rows[3], t2);
    234   t3 = _mm_unpackhi_epi32(m1, m3);
    235   tt = _mm_unpacklo_epi32(m2, t3);
    236   t3 = _mm_shuffle_epi32(tt, _MM_SHUFFLE(0, 1, 3, 2));
    237   g2(&rows[0], &rows[1], &rows[2], &rows[3], t3);
    238   undiagonalize(&rows[0], &rows[2], &rows[3]);
    239   m0 = t0;
    240   m1 = t1;
    241   m2 = t2;
    242   m3 = t3;
    243 
    244   // Round 6
    245   t0 = _mm_shuffle_ps2(m0, m1, _MM_SHUFFLE(3, 1, 1, 2));
    246   t0 = _mm_shuffle_epi32(t0, _MM_SHUFFLE(0, 3, 2, 1));
    247   g1(&rows[0], &rows[1], &rows[2], &rows[3], t0);
    248   t1 = _mm_shuffle_ps2(m2, m3, _MM_SHUFFLE(3, 3, 2, 2));
    249   tt = _mm_shuffle_epi32(m0, _MM_SHUFFLE(0, 0, 3, 3));
    250   t1 = _mm_blend_epi16(tt, t1, 0xCC);
    251   g2(&rows[0], &rows[1], &rows[2], &rows[3], t1);
    252   diagonalize(&rows[0], &rows[2], &rows[3]);
    253   t2 = _mm_unpacklo_epi64(m3, m1);
    254   tt = _mm_blend_epi16(t2, m2, 0xC0);
    255   t2 = _mm_shuffle_epi32(tt, _MM_SHUFFLE(1, 3, 2, 0));
    256   g1(&rows[0], &rows[1], &rows[2], &rows[3], t2);
    257   t3 = _mm_unpackhi_epi32(m1, m3);
    258   tt = _mm_unpacklo_epi32(m2, t3);
    259   t3 = _mm_shuffle_epi32(tt, _MM_SHUFFLE(0, 1, 3, 2));
    260   g2(&rows[0], &rows[1], &rows[2], &rows[3], t3);
    261   undiagonalize(&rows[0], &rows[2], &rows[3]);
    262   m0 = t0;
    263   m1 = t1;
    264   m2 = t2;
    265   m3 = t3;
    266 
    267   // Round 7
    268   t0 = _mm_shuffle_ps2(m0, m1, _MM_SHUFFLE(3, 1, 1, 2));
    269   t0 = _mm_shuffle_epi32(t0, _MM_SHUFFLE(0, 3, 2, 1));
    270   g1(&rows[0], &rows[1], &rows[2], &rows[3], t0);
    271   t1 = _mm_shuffle_ps2(m2, m3, _MM_SHUFFLE(3, 3, 2, 2));
    272   tt = _mm_shuffle_epi32(m0, _MM_SHUFFLE(0, 0, 3, 3));
    273   t1 = _mm_blend_epi16(tt, t1, 0xCC);
    274   g2(&rows[0], &rows[1], &rows[2], &rows[3], t1);
    275   diagonalize(&rows[0], &rows[2], &rows[3]);
    276   t2 = _mm_unpacklo_epi64(m3, m1);
    277   tt = _mm_blend_epi16(t2, m2, 0xC0);
    278   t2 = _mm_shuffle_epi32(tt, _MM_SHUFFLE(1, 3, 2, 0));
    279   g1(&rows[0], &rows[1], &rows[2], &rows[3], t2);
    280   t3 = _mm_unpackhi_epi32(m1, m3);
    281   tt = _mm_unpacklo_epi32(m2, t3);
    282   t3 = _mm_shuffle_epi32(tt, _MM_SHUFFLE(0, 1, 3, 2));
    283   g2(&rows[0], &rows[1], &rows[2], &rows[3], t3);
    284   undiagonalize(&rows[0], &rows[2], &rows[3]);
    285 }
    286 
    287 void blake3_compress_xof_avx512(const uint32_t cv[8],
    288                                 const uint8_t block[BLAKE3_BLOCK_LEN],
    289                                 uint8_t block_len, uint64_t counter,
    290                                 uint8_t flags, uint8_t out[64]) {
    291   __m128i rows[4];
    292   compress_pre(rows, cv, block, block_len, counter, flags);
    293   storeu_128(xor_128(rows[0], rows[2]), &out[0]);
    294   storeu_128(xor_128(rows[1], rows[3]), &out[16]);
    295   storeu_128(xor_128(rows[2], loadu_128((uint8_t *)&cv[0])), &out[32]);
    296   storeu_128(xor_128(rows[3], loadu_128((uint8_t *)&cv[4])), &out[48]);
    297 }
    298 
    299 void blake3_compress_in_place_avx512(uint32_t cv[8],
    300                                      const uint8_t block[BLAKE3_BLOCK_LEN],
    301                                      uint8_t block_len, uint64_t counter,
    302                                      uint8_t flags) {
    303   __m128i rows[4];
    304   compress_pre(rows, cv, block, block_len, counter, flags);
    305   storeu_128(xor_128(rows[0], rows[2]), (uint8_t *)&cv[0]);
    306   storeu_128(xor_128(rows[1], rows[3]), (uint8_t *)&cv[4]);
    307 }
    308 
    309 /*
    310  * ----------------------------------------------------------------------------
    311  * hash4_avx512
    312  * ----------------------------------------------------------------------------
    313  */
    314 
    315 INLINE void round_fn4(__m128i v[16], __m128i m[16], size_t r) {
    316   v[0] = add_128(v[0], m[(size_t)MSG_SCHEDULE[r][0]]);
    317   v[1] = add_128(v[1], m[(size_t)MSG_SCHEDULE[r][2]]);
    318   v[2] = add_128(v[2], m[(size_t)MSG_SCHEDULE[r][4]]);
    319   v[3] = add_128(v[3], m[(size_t)MSG_SCHEDULE[r][6]]);
    320   v[0] = add_128(v[0], v[4]);
    321   v[1] = add_128(v[1], v[5]);
    322   v[2] = add_128(v[2], v[6]);
    323   v[3] = add_128(v[3], v[7]);
    324   v[12] = xor_128(v[12], v[0]);
    325   v[13] = xor_128(v[13], v[1]);
    326   v[14] = xor_128(v[14], v[2]);
    327   v[15] = xor_128(v[15], v[3]);
    328   v[12] = rot16_128(v[12]);
    329   v[13] = rot16_128(v[13]);
    330   v[14] = rot16_128(v[14]);
    331   v[15] = rot16_128(v[15]);
    332   v[8] = add_128(v[8], v[12]);
    333   v[9] = add_128(v[9], v[13]);
    334   v[10] = add_128(v[10], v[14]);
    335   v[11] = add_128(v[11], v[15]);
    336   v[4] = xor_128(v[4], v[8]);
    337   v[5] = xor_128(v[5], v[9]);
    338   v[6] = xor_128(v[6], v[10]);
    339   v[7] = xor_128(v[7], v[11]);
    340   v[4] = rot12_128(v[4]);
    341   v[5] = rot12_128(v[5]);
    342   v[6] = rot12_128(v[6]);
    343   v[7] = rot12_128(v[7]);
    344   v[0] = add_128(v[0], m[(size_t)MSG_SCHEDULE[r][1]]);
    345   v[1] = add_128(v[1], m[(size_t)MSG_SCHEDULE[r][3]]);
    346   v[2] = add_128(v[2], m[(size_t)MSG_SCHEDULE[r][5]]);
    347   v[3] = add_128(v[3], m[(size_t)MSG_SCHEDULE[r][7]]);
    348   v[0] = add_128(v[0], v[4]);
    349   v[1] = add_128(v[1], v[5]);
    350   v[2] = add_128(v[2], v[6]);
    351   v[3] = add_128(v[3], v[7]);
    352   v[12] = xor_128(v[12], v[0]);
    353   v[13] = xor_128(v[13], v[1]);
    354   v[14] = xor_128(v[14], v[2]);
    355   v[15] = xor_128(v[15], v[3]);
    356   v[12] = rot8_128(v[12]);
    357   v[13] = rot8_128(v[13]);
    358   v[14] = rot8_128(v[14]);
    359   v[15] = rot8_128(v[15]);
    360   v[8] = add_128(v[8], v[12]);
    361   v[9] = add_128(v[9], v[13]);
    362   v[10] = add_128(v[10], v[14]);
    363   v[11] = add_128(v[11], v[15]);
    364   v[4] = xor_128(v[4], v[8]);
    365   v[5] = xor_128(v[5], v[9]);
    366   v[6] = xor_128(v[6], v[10]);
    367   v[7] = xor_128(v[7], v[11]);
    368   v[4] = rot7_128(v[4]);
    369   v[5] = rot7_128(v[5]);
    370   v[6] = rot7_128(v[6]);
    371   v[7] = rot7_128(v[7]);
    372 
    373   v[0] = add_128(v[0], m[(size_t)MSG_SCHEDULE[r][8]]);
    374   v[1] = add_128(v[1], m[(size_t)MSG_SCHEDULE[r][10]]);
    375   v[2] = add_128(v[2], m[(size_t)MSG_SCHEDULE[r][12]]);
    376   v[3] = add_128(v[3], m[(size_t)MSG_SCHEDULE[r][14]]);
    377   v[0] = add_128(v[0], v[5]);
    378   v[1] = add_128(v[1], v[6]);
    379   v[2] = add_128(v[2], v[7]);
    380   v[3] = add_128(v[3], v[4]);
    381   v[15] = xor_128(v[15], v[0]);
    382   v[12] = xor_128(v[12], v[1]);
    383   v[13] = xor_128(v[13], v[2]);
    384   v[14] = xor_128(v[14], v[3]);
    385   v[15] = rot16_128(v[15]);
    386   v[12] = rot16_128(v[12]);
    387   v[13] = rot16_128(v[13]);
    388   v[14] = rot16_128(v[14]);
    389   v[10] = add_128(v[10], v[15]);
    390   v[11] = add_128(v[11], v[12]);
    391   v[8] = add_128(v[8], v[13]);
    392   v[9] = add_128(v[9], v[14]);
    393   v[5] = xor_128(v[5], v[10]);
    394   v[6] = xor_128(v[6], v[11]);
    395   v[7] = xor_128(v[7], v[8]);
    396   v[4] = xor_128(v[4], v[9]);
    397   v[5] = rot12_128(v[5]);
    398   v[6] = rot12_128(v[6]);
    399   v[7] = rot12_128(v[7]);
    400   v[4] = rot12_128(v[4]);
    401   v[0] = add_128(v[0], m[(size_t)MSG_SCHEDULE[r][9]]);
    402   v[1] = add_128(v[1], m[(size_t)MSG_SCHEDULE[r][11]]);
    403   v[2] = add_128(v[2], m[(size_t)MSG_SCHEDULE[r][13]]);
    404   v[3] = add_128(v[3], m[(size_t)MSG_SCHEDULE[r][15]]);
    405   v[0] = add_128(v[0], v[5]);
    406   v[1] = add_128(v[1], v[6]);
    407   v[2] = add_128(v[2], v[7]);
    408   v[3] = add_128(v[3], v[4]);
    409   v[15] = xor_128(v[15], v[0]);
    410   v[12] = xor_128(v[12], v[1]);
    411   v[13] = xor_128(v[13], v[2]);
    412   v[14] = xor_128(v[14], v[3]);
    413   v[15] = rot8_128(v[15]);
    414   v[12] = rot8_128(v[12]);
    415   v[13] = rot8_128(v[13]);
    416   v[14] = rot8_128(v[14]);
    417   v[10] = add_128(v[10], v[15]);
    418   v[11] = add_128(v[11], v[12]);
    419   v[8] = add_128(v[8], v[13]);
    420   v[9] = add_128(v[9], v[14]);
    421   v[5] = xor_128(v[5], v[10]);
    422   v[6] = xor_128(v[6], v[11]);
    423   v[7] = xor_128(v[7], v[8]);
    424   v[4] = xor_128(v[4], v[9]);
    425   v[5] = rot7_128(v[5]);
    426   v[6] = rot7_128(v[6]);
    427   v[7] = rot7_128(v[7]);
    428   v[4] = rot7_128(v[4]);
    429 }
    430 
    431 INLINE void transpose_vecs_128(__m128i vecs[4]) {
    432   // Interleave 32-bit lates. The low unpack is lanes 00/11 and the high is
    433   // 22/33. Note that this doesn't split the vector into two lanes, as the
    434   // AVX2 counterparts do.
    435   __m128i ab_01 = _mm_unpacklo_epi32(vecs[0], vecs[1]);
    436   __m128i ab_23 = _mm_unpackhi_epi32(vecs[0], vecs[1]);
    437   __m128i cd_01 = _mm_unpacklo_epi32(vecs[2], vecs[3]);
    438   __m128i cd_23 = _mm_unpackhi_epi32(vecs[2], vecs[3]);
    439 
    440   // Interleave 64-bit lanes.
    441   __m128i abcd_0 = _mm_unpacklo_epi64(ab_01, cd_01);
    442   __m128i abcd_1 = _mm_unpackhi_epi64(ab_01, cd_01);
    443   __m128i abcd_2 = _mm_unpacklo_epi64(ab_23, cd_23);
    444   __m128i abcd_3 = _mm_unpackhi_epi64(ab_23, cd_23);
    445 
    446   vecs[0] = abcd_0;
    447   vecs[1] = abcd_1;
    448   vecs[2] = abcd_2;
    449   vecs[3] = abcd_3;
    450 }
    451 
    452 INLINE void transpose_msg_vecs4(const uint8_t *const *inputs,
    453                                 size_t block_offset, __m128i out[16]) {
    454   out[0] = loadu_128(&inputs[0][block_offset + 0 * sizeof(__m128i)]);
    455   out[1] = loadu_128(&inputs[1][block_offset + 0 * sizeof(__m128i)]);
    456   out[2] = loadu_128(&inputs[2][block_offset + 0 * sizeof(__m128i)]);
    457   out[3] = loadu_128(&inputs[3][block_offset + 0 * sizeof(__m128i)]);
    458   out[4] = loadu_128(&inputs[0][block_offset + 1 * sizeof(__m128i)]);
    459   out[5] = loadu_128(&inputs[1][block_offset + 1 * sizeof(__m128i)]);
    460   out[6] = loadu_128(&inputs[2][block_offset + 1 * sizeof(__m128i)]);
    461   out[7] = loadu_128(&inputs[3][block_offset + 1 * sizeof(__m128i)]);
    462   out[8] = loadu_128(&inputs[0][block_offset + 2 * sizeof(__m128i)]);
    463   out[9] = loadu_128(&inputs[1][block_offset + 2 * sizeof(__m128i)]);
    464   out[10] = loadu_128(&inputs[2][block_offset + 2 * sizeof(__m128i)]);
    465   out[11] = loadu_128(&inputs[3][block_offset + 2 * sizeof(__m128i)]);
    466   out[12] = loadu_128(&inputs[0][block_offset + 3 * sizeof(__m128i)]);
    467   out[13] = loadu_128(&inputs[1][block_offset + 3 * sizeof(__m128i)]);
    468   out[14] = loadu_128(&inputs[2][block_offset + 3 * sizeof(__m128i)]);
    469   out[15] = loadu_128(&inputs[3][block_offset + 3 * sizeof(__m128i)]);
    470   for (size_t i = 0; i < 4; ++i) {
    471     _mm_prefetch(&inputs[i][block_offset + 256], _MM_HINT_T0);
    472   }
    473   transpose_vecs_128(&out[0]);
    474   transpose_vecs_128(&out[4]);
    475   transpose_vecs_128(&out[8]);
    476   transpose_vecs_128(&out[12]);
    477 }
    478 
    479 INLINE void load_counters4(uint64_t counter, bool increment_counter,
    480                            __m128i *out_lo, __m128i *out_hi) {
    481   uint64_t mask = (increment_counter ? ~0 : 0);
    482   __m256i mask_vec = _mm256_set1_epi64x(mask);
    483   __m256i deltas = _mm256_setr_epi64x(0, 1, 2, 3);
    484   deltas = _mm256_and_si256(mask_vec, deltas);
    485   __m256i counters =
    486       _mm256_add_epi64(_mm256_set1_epi64x((int64_t)counter), deltas);
    487   *out_lo = _mm256_cvtepi64_epi32(counters);
    488   *out_hi = _mm256_cvtepi64_epi32(_mm256_srli_epi64(counters, 32));
    489 }
    490 
    491 void blake3_hash4_avx512(const uint8_t *const *inputs, size_t blocks,
    492                          const uint32_t key[8], uint64_t counter,
    493                          bool increment_counter, uint8_t flags,
    494                          uint8_t flags_start, uint8_t flags_end, uint8_t *out) {
    495   __m128i h_vecs[8] = {
    496       set1_128(key[0]), set1_128(key[1]), set1_128(key[2]), set1_128(key[3]),
    497       set1_128(key[4]), set1_128(key[5]), set1_128(key[6]), set1_128(key[7]),
    498   };
    499   __m128i counter_low_vec, counter_high_vec;
    500   load_counters4(counter, increment_counter, &counter_low_vec,
    501                  &counter_high_vec);
    502   uint8_t block_flags = flags | flags_start;
    503 
    504   for (size_t block = 0; block < blocks; block++) {
    505     if (block + 1 == blocks) {
    506       block_flags |= flags_end;
    507     }
    508     __m128i block_len_vec = set1_128(BLAKE3_BLOCK_LEN);
    509     __m128i block_flags_vec = set1_128(block_flags);
    510     __m128i msg_vecs[16];
    511     transpose_msg_vecs4(inputs, block * BLAKE3_BLOCK_LEN, msg_vecs);
    512 
    513     __m128i v[16] = {
    514         h_vecs[0],       h_vecs[1],        h_vecs[2],       h_vecs[3],
    515         h_vecs[4],       h_vecs[5],        h_vecs[6],       h_vecs[7],
    516         set1_128(IV[0]), set1_128(IV[1]),  set1_128(IV[2]), set1_128(IV[3]),
    517         counter_low_vec, counter_high_vec, block_len_vec,   block_flags_vec,
    518     };
    519     round_fn4(v, msg_vecs, 0);
    520     round_fn4(v, msg_vecs, 1);
    521     round_fn4(v, msg_vecs, 2);
    522     round_fn4(v, msg_vecs, 3);
    523     round_fn4(v, msg_vecs, 4);
    524     round_fn4(v, msg_vecs, 5);
    525     round_fn4(v, msg_vecs, 6);
    526     h_vecs[0] = xor_128(v[0], v[8]);
    527     h_vecs[1] = xor_128(v[1], v[9]);
    528     h_vecs[2] = xor_128(v[2], v[10]);
    529     h_vecs[3] = xor_128(v[3], v[11]);
    530     h_vecs[4] = xor_128(v[4], v[12]);
    531     h_vecs[5] = xor_128(v[5], v[13]);
    532     h_vecs[6] = xor_128(v[6], v[14]);
    533     h_vecs[7] = xor_128(v[7], v[15]);
    534 
    535     block_flags = flags;
    536   }
    537 
    538   transpose_vecs_128(&h_vecs[0]);
    539   transpose_vecs_128(&h_vecs[4]);
    540   // The first four vecs now contain the first half of each output, and the
    541   // second four vecs contain the second half of each output.
    542   storeu_128(h_vecs[0], &out[0 * sizeof(__m128i)]);
    543   storeu_128(h_vecs[4], &out[1 * sizeof(__m128i)]);
    544   storeu_128(h_vecs[1], &out[2 * sizeof(__m128i)]);
    545   storeu_128(h_vecs[5], &out[3 * sizeof(__m128i)]);
    546   storeu_128(h_vecs[2], &out[4 * sizeof(__m128i)]);
    547   storeu_128(h_vecs[6], &out[5 * sizeof(__m128i)]);
    548   storeu_128(h_vecs[3], &out[6 * sizeof(__m128i)]);
    549   storeu_128(h_vecs[7], &out[7 * sizeof(__m128i)]);
    550 }
    551 
    552 /*
    553  * ----------------------------------------------------------------------------
    554  * hash8_avx512
    555  * ----------------------------------------------------------------------------
    556  */
    557 
    558 INLINE void round_fn8(__m256i v[16], __m256i m[16], size_t r) {
    559   v[0] = add_256(v[0], m[(size_t)MSG_SCHEDULE[r][0]]);
    560   v[1] = add_256(v[1], m[(size_t)MSG_SCHEDULE[r][2]]);
    561   v[2] = add_256(v[2], m[(size_t)MSG_SCHEDULE[r][4]]);
    562   v[3] = add_256(v[3], m[(size_t)MSG_SCHEDULE[r][6]]);
    563   v[0] = add_256(v[0], v[4]);
    564   v[1] = add_256(v[1], v[5]);
    565   v[2] = add_256(v[2], v[6]);
    566   v[3] = add_256(v[3], v[7]);
    567   v[12] = xor_256(v[12], v[0]);
    568   v[13] = xor_256(v[13], v[1]);
    569   v[14] = xor_256(v[14], v[2]);
    570   v[15] = xor_256(v[15], v[3]);
    571   v[12] = rot16_256(v[12]);
    572   v[13] = rot16_256(v[13]);
    573   v[14] = rot16_256(v[14]);
    574   v[15] = rot16_256(v[15]);
    575   v[8] = add_256(v[8], v[12]);
    576   v[9] = add_256(v[9], v[13]);
    577   v[10] = add_256(v[10], v[14]);
    578   v[11] = add_256(v[11], v[15]);
    579   v[4] = xor_256(v[4], v[8]);
    580   v[5] = xor_256(v[5], v[9]);
    581   v[6] = xor_256(v[6], v[10]);
    582   v[7] = xor_256(v[7], v[11]);
    583   v[4] = rot12_256(v[4]);
    584   v[5] = rot12_256(v[5]);
    585   v[6] = rot12_256(v[6]);
    586   v[7] = rot12_256(v[7]);
    587   v[0] = add_256(v[0], m[(size_t)MSG_SCHEDULE[r][1]]);
    588   v[1] = add_256(v[1], m[(size_t)MSG_SCHEDULE[r][3]]);
    589   v[2] = add_256(v[2], m[(size_t)MSG_SCHEDULE[r][5]]);
    590   v[3] = add_256(v[3], m[(size_t)MSG_SCHEDULE[r][7]]);
    591   v[0] = add_256(v[0], v[4]);
    592   v[1] = add_256(v[1], v[5]);
    593   v[2] = add_256(v[2], v[6]);
    594   v[3] = add_256(v[3], v[7]);
    595   v[12] = xor_256(v[12], v[0]);
    596   v[13] = xor_256(v[13], v[1]);
    597   v[14] = xor_256(v[14], v[2]);
    598   v[15] = xor_256(v[15], v[3]);
    599   v[12] = rot8_256(v[12]);
    600   v[13] = rot8_256(v[13]);
    601   v[14] = rot8_256(v[14]);
    602   v[15] = rot8_256(v[15]);
    603   v[8] = add_256(v[8], v[12]);
    604   v[9] = add_256(v[9], v[13]);
    605   v[10] = add_256(v[10], v[14]);
    606   v[11] = add_256(v[11], v[15]);
    607   v[4] = xor_256(v[4], v[8]);
    608   v[5] = xor_256(v[5], v[9]);
    609   v[6] = xor_256(v[6], v[10]);
    610   v[7] = xor_256(v[7], v[11]);
    611   v[4] = rot7_256(v[4]);
    612   v[5] = rot7_256(v[5]);
    613   v[6] = rot7_256(v[6]);
    614   v[7] = rot7_256(v[7]);
    615 
    616   v[0] = add_256(v[0], m[(size_t)MSG_SCHEDULE[r][8]]);
    617   v[1] = add_256(v[1], m[(size_t)MSG_SCHEDULE[r][10]]);
    618   v[2] = add_256(v[2], m[(size_t)MSG_SCHEDULE[r][12]]);
    619   v[3] = add_256(v[3], m[(size_t)MSG_SCHEDULE[r][14]]);
    620   v[0] = add_256(v[0], v[5]);
    621   v[1] = add_256(v[1], v[6]);
    622   v[2] = add_256(v[2], v[7]);
    623   v[3] = add_256(v[3], v[4]);
    624   v[15] = xor_256(v[15], v[0]);
    625   v[12] = xor_256(v[12], v[1]);
    626   v[13] = xor_256(v[13], v[2]);
    627   v[14] = xor_256(v[14], v[3]);
    628   v[15] = rot16_256(v[15]);
    629   v[12] = rot16_256(v[12]);
    630   v[13] = rot16_256(v[13]);
    631   v[14] = rot16_256(v[14]);
    632   v[10] = add_256(v[10], v[15]);
    633   v[11] = add_256(v[11], v[12]);
    634   v[8] = add_256(v[8], v[13]);
    635   v[9] = add_256(v[9], v[14]);
    636   v[5] = xor_256(v[5], v[10]);
    637   v[6] = xor_256(v[6], v[11]);
    638   v[7] = xor_256(v[7], v[8]);
    639   v[4] = xor_256(v[4], v[9]);
    640   v[5] = rot12_256(v[5]);
    641   v[6] = rot12_256(v[6]);
    642   v[7] = rot12_256(v[7]);
    643   v[4] = rot12_256(v[4]);
    644   v[0] = add_256(v[0], m[(size_t)MSG_SCHEDULE[r][9]]);
    645   v[1] = add_256(v[1], m[(size_t)MSG_SCHEDULE[r][11]]);
    646   v[2] = add_256(v[2], m[(size_t)MSG_SCHEDULE[r][13]]);
    647   v[3] = add_256(v[3], m[(size_t)MSG_SCHEDULE[r][15]]);
    648   v[0] = add_256(v[0], v[5]);
    649   v[1] = add_256(v[1], v[6]);
    650   v[2] = add_256(v[2], v[7]);
    651   v[3] = add_256(v[3], v[4]);
    652   v[15] = xor_256(v[15], v[0]);
    653   v[12] = xor_256(v[12], v[1]);
    654   v[13] = xor_256(v[13], v[2]);
    655   v[14] = xor_256(v[14], v[3]);
    656   v[15] = rot8_256(v[15]);
    657   v[12] = rot8_256(v[12]);
    658   v[13] = rot8_256(v[13]);
    659   v[14] = rot8_256(v[14]);
    660   v[10] = add_256(v[10], v[15]);
    661   v[11] = add_256(v[11], v[12]);
    662   v[8] = add_256(v[8], v[13]);
    663   v[9] = add_256(v[9], v[14]);
    664   v[5] = xor_256(v[5], v[10]);
    665   v[6] = xor_256(v[6], v[11]);
    666   v[7] = xor_256(v[7], v[8]);
    667   v[4] = xor_256(v[4], v[9]);
    668   v[5] = rot7_256(v[5]);
    669   v[6] = rot7_256(v[6]);
    670   v[7] = rot7_256(v[7]);
    671   v[4] = rot7_256(v[4]);
    672 }
    673 
    674 INLINE void transpose_vecs_256(__m256i vecs[8]) {
    675   // Interleave 32-bit lanes. The low unpack is lanes 00/11/44/55, and the high
    676   // is 22/33/66/77.
    677   __m256i ab_0145 = _mm256_unpacklo_epi32(vecs[0], vecs[1]);
    678   __m256i ab_2367 = _mm256_unpackhi_epi32(vecs[0], vecs[1]);
    679   __m256i cd_0145 = _mm256_unpacklo_epi32(vecs[2], vecs[3]);
    680   __m256i cd_2367 = _mm256_unpackhi_epi32(vecs[2], vecs[3]);
    681   __m256i ef_0145 = _mm256_unpacklo_epi32(vecs[4], vecs[5]);
    682   __m256i ef_2367 = _mm256_unpackhi_epi32(vecs[4], vecs[5]);
    683   __m256i gh_0145 = _mm256_unpacklo_epi32(vecs[6], vecs[7]);
    684   __m256i gh_2367 = _mm256_unpackhi_epi32(vecs[6], vecs[7]);
    685 
    686   // Interleave 64-bit lates. The low unpack is lanes 00/22 and the high is
    687   // 11/33.
    688   __m256i abcd_04 = _mm256_unpacklo_epi64(ab_0145, cd_0145);
    689   __m256i abcd_15 = _mm256_unpackhi_epi64(ab_0145, cd_0145);
    690   __m256i abcd_26 = _mm256_unpacklo_epi64(ab_2367, cd_2367);
    691   __m256i abcd_37 = _mm256_unpackhi_epi64(ab_2367, cd_2367);
    692   __m256i efgh_04 = _mm256_unpacklo_epi64(ef_0145, gh_0145);
    693   __m256i efgh_15 = _mm256_unpackhi_epi64(ef_0145, gh_0145);
    694   __m256i efgh_26 = _mm256_unpacklo_epi64(ef_2367, gh_2367);
    695   __m256i efgh_37 = _mm256_unpackhi_epi64(ef_2367, gh_2367);
    696 
    697   // Interleave 128-bit lanes.
    698   vecs[0] = _mm256_permute2x128_si256(abcd_04, efgh_04, 0x20);
    699   vecs[1] = _mm256_permute2x128_si256(abcd_15, efgh_15, 0x20);
    700   vecs[2] = _mm256_permute2x128_si256(abcd_26, efgh_26, 0x20);
    701   vecs[3] = _mm256_permute2x128_si256(abcd_37, efgh_37, 0x20);
    702   vecs[4] = _mm256_permute2x128_si256(abcd_04, efgh_04, 0x31);
    703   vecs[5] = _mm256_permute2x128_si256(abcd_15, efgh_15, 0x31);
    704   vecs[6] = _mm256_permute2x128_si256(abcd_26, efgh_26, 0x31);
    705   vecs[7] = _mm256_permute2x128_si256(abcd_37, efgh_37, 0x31);
    706 }
    707 
    708 INLINE void transpose_msg_vecs8(const uint8_t *const *inputs,
    709                                 size_t block_offset, __m256i out[16]) {
    710   out[0] = loadu_256(&inputs[0][block_offset + 0 * sizeof(__m256i)]);
    711   out[1] = loadu_256(&inputs[1][block_offset + 0 * sizeof(__m256i)]);
    712   out[2] = loadu_256(&inputs[2][block_offset + 0 * sizeof(__m256i)]);
    713   out[3] = loadu_256(&inputs[3][block_offset + 0 * sizeof(__m256i)]);
    714   out[4] = loadu_256(&inputs[4][block_offset + 0 * sizeof(__m256i)]);
    715   out[5] = loadu_256(&inputs[5][block_offset + 0 * sizeof(__m256i)]);
    716   out[6] = loadu_256(&inputs[6][block_offset + 0 * sizeof(__m256i)]);
    717   out[7] = loadu_256(&inputs[7][block_offset + 0 * sizeof(__m256i)]);
    718   out[8] = loadu_256(&inputs[0][block_offset + 1 * sizeof(__m256i)]);
    719   out[9] = loadu_256(&inputs[1][block_offset + 1 * sizeof(__m256i)]);
    720   out[10] = loadu_256(&inputs[2][block_offset + 1 * sizeof(__m256i)]);
    721   out[11] = loadu_256(&inputs[3][block_offset + 1 * sizeof(__m256i)]);
    722   out[12] = loadu_256(&inputs[4][block_offset + 1 * sizeof(__m256i)]);
    723   out[13] = loadu_256(&inputs[5][block_offset + 1 * sizeof(__m256i)]);
    724   out[14] = loadu_256(&inputs[6][block_offset + 1 * sizeof(__m256i)]);
    725   out[15] = loadu_256(&inputs[7][block_offset + 1 * sizeof(__m256i)]);
    726   for (size_t i = 0; i < 8; ++i) {
    727     _mm_prefetch(&inputs[i][block_offset + 256], _MM_HINT_T0);
    728   }
    729   transpose_vecs_256(&out[0]);
    730   transpose_vecs_256(&out[8]);
    731 }
    732 
    733 INLINE void load_counters8(uint64_t counter, bool increment_counter,
    734                            __m256i *out_lo, __m256i *out_hi) {
    735   uint64_t mask = (increment_counter ? ~0 : 0);
    736   __m512i mask_vec = _mm512_set1_epi64(mask);
    737   __m512i deltas = _mm512_setr_epi64(0, 1, 2, 3, 4, 5, 6, 7);
    738   deltas = _mm512_and_si512(mask_vec, deltas);
    739   __m512i counters =
    740       _mm512_add_epi64(_mm512_set1_epi64((int64_t)counter), deltas);
    741   *out_lo = _mm512_cvtepi64_epi32(counters);
    742   *out_hi = _mm512_cvtepi64_epi32(_mm512_srli_epi64(counters, 32));
    743 }
    744 
    745 void blake3_hash8_avx512(const uint8_t *const *inputs, size_t blocks,
    746                          const uint32_t key[8], uint64_t counter,
    747                          bool increment_counter, uint8_t flags,
    748                          uint8_t flags_start, uint8_t flags_end, uint8_t *out) {
    749   __m256i h_vecs[8] = {
    750       set1_256(key[0]), set1_256(key[1]), set1_256(key[2]), set1_256(key[3]),
    751       set1_256(key[4]), set1_256(key[5]), set1_256(key[6]), set1_256(key[7]),
    752   };
    753   __m256i counter_low_vec, counter_high_vec;
    754   load_counters8(counter, increment_counter, &counter_low_vec,
    755                  &counter_high_vec);
    756   uint8_t block_flags = flags | flags_start;
    757 
    758   for (size_t block = 0; block < blocks; block++) {
    759     if (block + 1 == blocks) {
    760       block_flags |= flags_end;
    761     }
    762     __m256i block_len_vec = set1_256(BLAKE3_BLOCK_LEN);
    763     __m256i block_flags_vec = set1_256(block_flags);
    764     __m256i msg_vecs[16];
    765     transpose_msg_vecs8(inputs, block * BLAKE3_BLOCK_LEN, msg_vecs);
    766 
    767     __m256i v[16] = {
    768         h_vecs[0],       h_vecs[1],        h_vecs[2],       h_vecs[3],
    769         h_vecs[4],       h_vecs[5],        h_vecs[6],       h_vecs[7],
    770         set1_256(IV[0]), set1_256(IV[1]),  set1_256(IV[2]), set1_256(IV[3]),
    771         counter_low_vec, counter_high_vec, block_len_vec,   block_flags_vec,
    772     };
    773     round_fn8(v, msg_vecs, 0);
    774     round_fn8(v, msg_vecs, 1);
    775     round_fn8(v, msg_vecs, 2);
    776     round_fn8(v, msg_vecs, 3);
    777     round_fn8(v, msg_vecs, 4);
    778     round_fn8(v, msg_vecs, 5);
    779     round_fn8(v, msg_vecs, 6);
    780     h_vecs[0] = xor_256(v[0], v[8]);
    781     h_vecs[1] = xor_256(v[1], v[9]);
    782     h_vecs[2] = xor_256(v[2], v[10]);
    783     h_vecs[3] = xor_256(v[3], v[11]);
    784     h_vecs[4] = xor_256(v[4], v[12]);
    785     h_vecs[5] = xor_256(v[5], v[13]);
    786     h_vecs[6] = xor_256(v[6], v[14]);
    787     h_vecs[7] = xor_256(v[7], v[15]);
    788 
    789     block_flags = flags;
    790   }
    791 
    792   transpose_vecs_256(h_vecs);
    793   storeu_256(h_vecs[0], &out[0 * sizeof(__m256i)]);
    794   storeu_256(h_vecs[1], &out[1 * sizeof(__m256i)]);
    795   storeu_256(h_vecs[2], &out[2 * sizeof(__m256i)]);
    796   storeu_256(h_vecs[3], &out[3 * sizeof(__m256i)]);
    797   storeu_256(h_vecs[4], &out[4 * sizeof(__m256i)]);
    798   storeu_256(h_vecs[5], &out[5 * sizeof(__m256i)]);
    799   storeu_256(h_vecs[6], &out[6 * sizeof(__m256i)]);
    800   storeu_256(h_vecs[7], &out[7 * sizeof(__m256i)]);
    801 }
    802 
    803 /*
    804  * ----------------------------------------------------------------------------
    805  * hash16_avx512
    806  * ----------------------------------------------------------------------------
    807  */
    808 
    809 INLINE void round_fn16(__m512i v[16], __m512i m[16], size_t r) {
    810   v[0] = add_512(v[0], m[(size_t)MSG_SCHEDULE[r][0]]);
    811   v[1] = add_512(v[1], m[(size_t)MSG_SCHEDULE[r][2]]);
    812   v[2] = add_512(v[2], m[(size_t)MSG_SCHEDULE[r][4]]);
    813   v[3] = add_512(v[3], m[(size_t)MSG_SCHEDULE[r][6]]);
    814   v[0] = add_512(v[0], v[4]);
    815   v[1] = add_512(v[1], v[5]);
    816   v[2] = add_512(v[2], v[6]);
    817   v[3] = add_512(v[3], v[7]);
    818   v[12] = xor_512(v[12], v[0]);
    819   v[13] = xor_512(v[13], v[1]);
    820   v[14] = xor_512(v[14], v[2]);
    821   v[15] = xor_512(v[15], v[3]);
    822   v[12] = rot16_512(v[12]);
    823   v[13] = rot16_512(v[13]);
    824   v[14] = rot16_512(v[14]);
    825   v[15] = rot16_512(v[15]);
    826   v[8] = add_512(v[8], v[12]);
    827   v[9] = add_512(v[9], v[13]);
    828   v[10] = add_512(v[10], v[14]);
    829   v[11] = add_512(v[11], v[15]);
    830   v[4] = xor_512(v[4], v[8]);
    831   v[5] = xor_512(v[5], v[9]);
    832   v[6] = xor_512(v[6], v[10]);
    833   v[7] = xor_512(v[7], v[11]);
    834   v[4] = rot12_512(v[4]);
    835   v[5] = rot12_512(v[5]);
    836   v[6] = rot12_512(v[6]);
    837   v[7] = rot12_512(v[7]);
    838   v[0] = add_512(v[0], m[(size_t)MSG_SCHEDULE[r][1]]);
    839   v[1] = add_512(v[1], m[(size_t)MSG_SCHEDULE[r][3]]);
    840   v[2] = add_512(v[2], m[(size_t)MSG_SCHEDULE[r][5]]);
    841   v[3] = add_512(v[3], m[(size_t)MSG_SCHEDULE[r][7]]);
    842   v[0] = add_512(v[0], v[4]);
    843   v[1] = add_512(v[1], v[5]);
    844   v[2] = add_512(v[2], v[6]);
    845   v[3] = add_512(v[3], v[7]);
    846   v[12] = xor_512(v[12], v[0]);
    847   v[13] = xor_512(v[13], v[1]);
    848   v[14] = xor_512(v[14], v[2]);
    849   v[15] = xor_512(v[15], v[3]);
    850   v[12] = rot8_512(v[12]);
    851   v[13] = rot8_512(v[13]);
    852   v[14] = rot8_512(v[14]);
    853   v[15] = rot8_512(v[15]);
    854   v[8] = add_512(v[8], v[12]);
    855   v[9] = add_512(v[9], v[13]);
    856   v[10] = add_512(v[10], v[14]);
    857   v[11] = add_512(v[11], v[15]);
    858   v[4] = xor_512(v[4], v[8]);
    859   v[5] = xor_512(v[5], v[9]);
    860   v[6] = xor_512(v[6], v[10]);
    861   v[7] = xor_512(v[7], v[11]);
    862   v[4] = rot7_512(v[4]);
    863   v[5] = rot7_512(v[5]);
    864   v[6] = rot7_512(v[6]);
    865   v[7] = rot7_512(v[7]);
    866 
    867   v[0] = add_512(v[0], m[(size_t)MSG_SCHEDULE[r][8]]);
    868   v[1] = add_512(v[1], m[(size_t)MSG_SCHEDULE[r][10]]);
    869   v[2] = add_512(v[2], m[(size_t)MSG_SCHEDULE[r][12]]);
    870   v[3] = add_512(v[3], m[(size_t)MSG_SCHEDULE[r][14]]);
    871   v[0] = add_512(v[0], v[5]);
    872   v[1] = add_512(v[1], v[6]);
    873   v[2] = add_512(v[2], v[7]);
    874   v[3] = add_512(v[3], v[4]);
    875   v[15] = xor_512(v[15], v[0]);
    876   v[12] = xor_512(v[12], v[1]);
    877   v[13] = xor_512(v[13], v[2]);
    878   v[14] = xor_512(v[14], v[3]);
    879   v[15] = rot16_512(v[15]);
    880   v[12] = rot16_512(v[12]);
    881   v[13] = rot16_512(v[13]);
    882   v[14] = rot16_512(v[14]);
    883   v[10] = add_512(v[10], v[15]);
    884   v[11] = add_512(v[11], v[12]);
    885   v[8] = add_512(v[8], v[13]);
    886   v[9] = add_512(v[9], v[14]);
    887   v[5] = xor_512(v[5], v[10]);
    888   v[6] = xor_512(v[6], v[11]);
    889   v[7] = xor_512(v[7], v[8]);
    890   v[4] = xor_512(v[4], v[9]);
    891   v[5] = rot12_512(v[5]);
    892   v[6] = rot12_512(v[6]);
    893   v[7] = rot12_512(v[7]);
    894   v[4] = rot12_512(v[4]);
    895   v[0] = add_512(v[0], m[(size_t)MSG_SCHEDULE[r][9]]);
    896   v[1] = add_512(v[1], m[(size_t)MSG_SCHEDULE[r][11]]);
    897   v[2] = add_512(v[2], m[(size_t)MSG_SCHEDULE[r][13]]);
    898   v[3] = add_512(v[3], m[(size_t)MSG_SCHEDULE[r][15]]);
    899   v[0] = add_512(v[0], v[5]);
    900   v[1] = add_512(v[1], v[6]);
    901   v[2] = add_512(v[2], v[7]);
    902   v[3] = add_512(v[3], v[4]);
    903   v[15] = xor_512(v[15], v[0]);
    904   v[12] = xor_512(v[12], v[1]);
    905   v[13] = xor_512(v[13], v[2]);
    906   v[14] = xor_512(v[14], v[3]);
    907   v[15] = rot8_512(v[15]);
    908   v[12] = rot8_512(v[12]);
    909   v[13] = rot8_512(v[13]);
    910   v[14] = rot8_512(v[14]);
    911   v[10] = add_512(v[10], v[15]);
    912   v[11] = add_512(v[11], v[12]);
    913   v[8] = add_512(v[8], v[13]);
    914   v[9] = add_512(v[9], v[14]);
    915   v[5] = xor_512(v[5], v[10]);
    916   v[6] = xor_512(v[6], v[11]);
    917   v[7] = xor_512(v[7], v[8]);
    918   v[4] = xor_512(v[4], v[9]);
    919   v[5] = rot7_512(v[5]);
    920   v[6] = rot7_512(v[6]);
    921   v[7] = rot7_512(v[7]);
    922   v[4] = rot7_512(v[4]);
    923 }
    924 
    925 // 0b10001000, or lanes a0/a2/b0/b2 in little-endian order
    926 #define LO_IMM8 0x88
    927 
    928 INLINE __m512i unpack_lo_128(__m512i a, __m512i b) {
    929   return _mm512_shuffle_i32x4(a, b, LO_IMM8);
    930 }
    931 
    932 // 0b11011101, or lanes a1/a3/b1/b3 in little-endian order
    933 #define HI_IMM8 0xdd
    934 
    935 INLINE __m512i unpack_hi_128(__m512i a, __m512i b) {
    936   return _mm512_shuffle_i32x4(a, b, HI_IMM8);
    937 }
    938 
    939 INLINE void transpose_vecs_512(__m512i vecs[16]) {
    940   // Interleave 32-bit lanes. The _0 unpack is lanes
    941   // 0/0/1/1/4/4/5/5/8/8/9/9/12/12/13/13, and the _2 unpack is lanes
    942   // 2/2/3/3/6/6/7/7/10/10/11/11/14/14/15/15.
    943   __m512i ab_0 = _mm512_unpacklo_epi32(vecs[0], vecs[1]);
    944   __m512i ab_2 = _mm512_unpackhi_epi32(vecs[0], vecs[1]);
    945   __m512i cd_0 = _mm512_unpacklo_epi32(vecs[2], vecs[3]);
    946   __m512i cd_2 = _mm512_unpackhi_epi32(vecs[2], vecs[3]);
    947   __m512i ef_0 = _mm512_unpacklo_epi32(vecs[4], vecs[5]);
    948   __m512i ef_2 = _mm512_unpackhi_epi32(vecs[4], vecs[5]);
    949   __m512i gh_0 = _mm512_unpacklo_epi32(vecs[6], vecs[7]);
    950   __m512i gh_2 = _mm512_unpackhi_epi32(vecs[6], vecs[7]);
    951   __m512i ij_0 = _mm512_unpacklo_epi32(vecs[8], vecs[9]);
    952   __m512i ij_2 = _mm512_unpackhi_epi32(vecs[8], vecs[9]);
    953   __m512i kl_0 = _mm512_unpacklo_epi32(vecs[10], vecs[11]);
    954   __m512i kl_2 = _mm512_unpackhi_epi32(vecs[10], vecs[11]);
    955   __m512i mn_0 = _mm512_unpacklo_epi32(vecs[12], vecs[13]);
    956   __m512i mn_2 = _mm512_unpackhi_epi32(vecs[12], vecs[13]);
    957   __m512i op_0 = _mm512_unpacklo_epi32(vecs[14], vecs[15]);
    958   __m512i op_2 = _mm512_unpackhi_epi32(vecs[14], vecs[15]);
    959 
    960   // Interleave 64-bit lates. The _0 unpack is lanes
    961   // 0/0/0/0/4/4/4/4/8/8/8/8/12/12/12/12, the _1 unpack is lanes
    962   // 1/1/1/1/5/5/5/5/9/9/9/9/13/13/13/13, the _2 unpack is lanes
    963   // 2/2/2/2/6/6/6/6/10/10/10/10/14/14/14/14, and the _3 unpack is lanes
    964   // 3/3/3/3/7/7/7/7/11/11/11/11/15/15/15/15.
    965   __m512i abcd_0 = _mm512_unpacklo_epi64(ab_0, cd_0);
    966   __m512i abcd_1 = _mm512_unpackhi_epi64(ab_0, cd_0);
    967   __m512i abcd_2 = _mm512_unpacklo_epi64(ab_2, cd_2);
    968   __m512i abcd_3 = _mm512_unpackhi_epi64(ab_2, cd_2);
    969   __m512i efgh_0 = _mm512_unpacklo_epi64(ef_0, gh_0);
    970   __m512i efgh_1 = _mm512_unpackhi_epi64(ef_0, gh_0);
    971   __m512i efgh_2 = _mm512_unpacklo_epi64(ef_2, gh_2);
    972   __m512i efgh_3 = _mm512_unpackhi_epi64(ef_2, gh_2);
    973   __m512i ijkl_0 = _mm512_unpacklo_epi64(ij_0, kl_0);
    974   __m512i ijkl_1 = _mm512_unpackhi_epi64(ij_0, kl_0);
    975   __m512i ijkl_2 = _mm512_unpacklo_epi64(ij_2, kl_2);
    976   __m512i ijkl_3 = _mm512_unpackhi_epi64(ij_2, kl_2);
    977   __m512i mnop_0 = _mm512_unpacklo_epi64(mn_0, op_0);
    978   __m512i mnop_1 = _mm512_unpackhi_epi64(mn_0, op_0);
    979   __m512i mnop_2 = _mm512_unpacklo_epi64(mn_2, op_2);
    980   __m512i mnop_3 = _mm512_unpackhi_epi64(mn_2, op_2);
    981 
    982   // Interleave 128-bit lanes. The _0 unpack is
    983   // 0/0/0/0/8/8/8/8/0/0/0/0/8/8/8/8, the _1 unpack is
    984   // 1/1/1/1/9/9/9/9/1/1/1/1/9/9/9/9, and so on.
    985   __m512i abcdefgh_0 = unpack_lo_128(abcd_0, efgh_0);
    986   __m512i abcdefgh_1 = unpack_lo_128(abcd_1, efgh_1);
    987   __m512i abcdefgh_2 = unpack_lo_128(abcd_2, efgh_2);
    988   __m512i abcdefgh_3 = unpack_lo_128(abcd_3, efgh_3);
    989   __m512i abcdefgh_4 = unpack_hi_128(abcd_0, efgh_0);
    990   __m512i abcdefgh_5 = unpack_hi_128(abcd_1, efgh_1);
    991   __m512i abcdefgh_6 = unpack_hi_128(abcd_2, efgh_2);
    992   __m512i abcdefgh_7 = unpack_hi_128(abcd_3, efgh_3);
    993   __m512i ijklmnop_0 = unpack_lo_128(ijkl_0, mnop_0);
    994   __m512i ijklmnop_1 = unpack_lo_128(ijkl_1, mnop_1);
    995   __m512i ijklmnop_2 = unpack_lo_128(ijkl_2, mnop_2);
    996   __m512i ijklmnop_3 = unpack_lo_128(ijkl_3, mnop_3);
    997   __m512i ijklmnop_4 = unpack_hi_128(ijkl_0, mnop_0);
    998   __m512i ijklmnop_5 = unpack_hi_128(ijkl_1, mnop_1);
    999   __m512i ijklmnop_6 = unpack_hi_128(ijkl_2, mnop_2);
   1000   __m512i ijklmnop_7 = unpack_hi_128(ijkl_3, mnop_3);
   1001 
   1002   // Interleave 128-bit lanes again for the final outputs.
   1003   vecs[0] = unpack_lo_128(abcdefgh_0, ijklmnop_0);
   1004   vecs[1] = unpack_lo_128(abcdefgh_1, ijklmnop_1);
   1005   vecs[2] = unpack_lo_128(abcdefgh_2, ijklmnop_2);
   1006   vecs[3] = unpack_lo_128(abcdefgh_3, ijklmnop_3);
   1007   vecs[4] = unpack_lo_128(abcdefgh_4, ijklmnop_4);
   1008   vecs[5] = unpack_lo_128(abcdefgh_5, ijklmnop_5);
   1009   vecs[6] = unpack_lo_128(abcdefgh_6, ijklmnop_6);
   1010   vecs[7] = unpack_lo_128(abcdefgh_7, ijklmnop_7);
   1011   vecs[8] = unpack_hi_128(abcdefgh_0, ijklmnop_0);
   1012   vecs[9] = unpack_hi_128(abcdefgh_1, ijklmnop_1);
   1013   vecs[10] = unpack_hi_128(abcdefgh_2, ijklmnop_2);
   1014   vecs[11] = unpack_hi_128(abcdefgh_3, ijklmnop_3);
   1015   vecs[12] = unpack_hi_128(abcdefgh_4, ijklmnop_4);
   1016   vecs[13] = unpack_hi_128(abcdefgh_5, ijklmnop_5);
   1017   vecs[14] = unpack_hi_128(abcdefgh_6, ijklmnop_6);
   1018   vecs[15] = unpack_hi_128(abcdefgh_7, ijklmnop_7);
   1019 }
   1020 
   1021 INLINE void transpose_msg_vecs16(const uint8_t *const *inputs,
   1022                                  size_t block_offset, __m512i out[16]) {
   1023   out[0] = loadu_512(&inputs[0][block_offset]);
   1024   out[1] = loadu_512(&inputs[1][block_offset]);
   1025   out[2] = loadu_512(&inputs[2][block_offset]);
   1026   out[3] = loadu_512(&inputs[3][block_offset]);
   1027   out[4] = loadu_512(&inputs[4][block_offset]);
   1028   out[5] = loadu_512(&inputs[5][block_offset]);
   1029   out[6] = loadu_512(&inputs[6][block_offset]);
   1030   out[7] = loadu_512(&inputs[7][block_offset]);
   1031   out[8] = loadu_512(&inputs[8][block_offset]);
   1032   out[9] = loadu_512(&inputs[9][block_offset]);
   1033   out[10] = loadu_512(&inputs[10][block_offset]);
   1034   out[11] = loadu_512(&inputs[11][block_offset]);
   1035   out[12] = loadu_512(&inputs[12][block_offset]);
   1036   out[13] = loadu_512(&inputs[13][block_offset]);
   1037   out[14] = loadu_512(&inputs[14][block_offset]);
   1038   out[15] = loadu_512(&inputs[15][block_offset]);
   1039   for (size_t i = 0; i < 16; ++i) {
   1040     _mm_prefetch(&inputs[i][block_offset + 256], _MM_HINT_T0);
   1041   }
   1042   transpose_vecs_512(out);
   1043 }
   1044 
   1045 INLINE void load_counters16(uint64_t counter, bool increment_counter,
   1046                             __m512i *out_lo, __m512i *out_hi) {
   1047   const __m512i mask = _mm512_set1_epi32(-(int32_t)increment_counter);
   1048   const __m512i add0 = _mm512_set_epi32(15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0);
   1049   const __m512i add1 = _mm512_and_si512(mask, add0);
   1050   __m512i l = _mm512_add_epi32(_mm512_set1_epi32(counter), add1);
   1051   __mmask16 carry = _mm512_cmp_epu32_mask(l, add1, _MM_CMPINT_LT);
   1052   __m512i h = _mm512_mask_add_epi32(_mm512_set1_epi32(counter >> 32), carry, _mm512_set1_epi32(counter >> 32), _mm512_set1_epi32(1));
   1053   *out_lo = l;
   1054   *out_hi = h;
   1055 }
   1056 
   1057 void blake3_hash16_avx512(const uint8_t *const *inputs, size_t blocks,
   1058                           const uint32_t key[8], uint64_t counter,
   1059                           bool increment_counter, uint8_t flags,
   1060                           uint8_t flags_start, uint8_t flags_end,
   1061                           uint8_t *out) {
   1062   __m512i h_vecs[8] = {
   1063       set1_512(key[0]), set1_512(key[1]), set1_512(key[2]), set1_512(key[3]),
   1064       set1_512(key[4]), set1_512(key[5]), set1_512(key[6]), set1_512(key[7]),
   1065   };
   1066   __m512i counter_low_vec, counter_high_vec;
   1067   load_counters16(counter, increment_counter, &counter_low_vec,
   1068                   &counter_high_vec);
   1069   uint8_t block_flags = flags | flags_start;
   1070 
   1071   for (size_t block = 0; block < blocks; block++) {
   1072     if (block + 1 == blocks) {
   1073       block_flags |= flags_end;
   1074     }
   1075     __m512i block_len_vec = set1_512(BLAKE3_BLOCK_LEN);
   1076     __m512i block_flags_vec = set1_512(block_flags);
   1077     __m512i msg_vecs[16];
   1078     transpose_msg_vecs16(inputs, block * BLAKE3_BLOCK_LEN, msg_vecs);
   1079 
   1080     __m512i v[16] = {
   1081         h_vecs[0],       h_vecs[1],        h_vecs[2],       h_vecs[3],
   1082         h_vecs[4],       h_vecs[5],        h_vecs[6],       h_vecs[7],
   1083         set1_512(IV[0]), set1_512(IV[1]),  set1_512(IV[2]), set1_512(IV[3]),
   1084         counter_low_vec, counter_high_vec, block_len_vec,   block_flags_vec,
   1085     };
   1086     round_fn16(v, msg_vecs, 0);
   1087     round_fn16(v, msg_vecs, 1);
   1088     round_fn16(v, msg_vecs, 2);
   1089     round_fn16(v, msg_vecs, 3);
   1090     round_fn16(v, msg_vecs, 4);
   1091     round_fn16(v, msg_vecs, 5);
   1092     round_fn16(v, msg_vecs, 6);
   1093     h_vecs[0] = xor_512(v[0], v[8]);
   1094     h_vecs[1] = xor_512(v[1], v[9]);
   1095     h_vecs[2] = xor_512(v[2], v[10]);
   1096     h_vecs[3] = xor_512(v[3], v[11]);
   1097     h_vecs[4] = xor_512(v[4], v[12]);
   1098     h_vecs[5] = xor_512(v[5], v[13]);
   1099     h_vecs[6] = xor_512(v[6], v[14]);
   1100     h_vecs[7] = xor_512(v[7], v[15]);
   1101 
   1102     block_flags = flags;
   1103   }
   1104 
   1105   // transpose_vecs_512 operates on a 16x16 matrix of words, but we only have 8
   1106   // state vectors. Pad the matrix with zeros. After transposition, store the
   1107   // lower half of each vector.
   1108   __m512i padded[16] = {
   1109       h_vecs[0],   h_vecs[1],   h_vecs[2],   h_vecs[3],
   1110       h_vecs[4],   h_vecs[5],   h_vecs[6],   h_vecs[7],
   1111       set1_512(0), set1_512(0), set1_512(0), set1_512(0),
   1112       set1_512(0), set1_512(0), set1_512(0), set1_512(0),
   1113   };
   1114   transpose_vecs_512(padded);
   1115   _mm256_mask_storeu_epi32(&out[0 * sizeof(__m256i)], (__mmask8)-1, _mm512_castsi512_si256(padded[0]));
   1116   _mm256_mask_storeu_epi32(&out[1 * sizeof(__m256i)], (__mmask8)-1, _mm512_castsi512_si256(padded[1]));
   1117   _mm256_mask_storeu_epi32(&out[2 * sizeof(__m256i)], (__mmask8)-1, _mm512_castsi512_si256(padded[2]));
   1118   _mm256_mask_storeu_epi32(&out[3 * sizeof(__m256i)], (__mmask8)-1, _mm512_castsi512_si256(padded[3]));
   1119   _mm256_mask_storeu_epi32(&out[4 * sizeof(__m256i)], (__mmask8)-1, _mm512_castsi512_si256(padded[4]));
   1120   _mm256_mask_storeu_epi32(&out[5 * sizeof(__m256i)], (__mmask8)-1, _mm512_castsi512_si256(padded[5]));
   1121   _mm256_mask_storeu_epi32(&out[6 * sizeof(__m256i)], (__mmask8)-1, _mm512_castsi512_si256(padded[6]));
   1122   _mm256_mask_storeu_epi32(&out[7 * sizeof(__m256i)], (__mmask8)-1, _mm512_castsi512_si256(padded[7]));
   1123   _mm256_mask_storeu_epi32(&out[8 * sizeof(__m256i)], (__mmask8)-1, _mm512_castsi512_si256(padded[8]));
   1124   _mm256_mask_storeu_epi32(&out[9 * sizeof(__m256i)], (__mmask8)-1, _mm512_castsi512_si256(padded[9]));
   1125   _mm256_mask_storeu_epi32(&out[10 * sizeof(__m256i)], (__mmask8)-1, _mm512_castsi512_si256(padded[10]));
   1126   _mm256_mask_storeu_epi32(&out[11 * sizeof(__m256i)], (__mmask8)-1, _mm512_castsi512_si256(padded[11]));
   1127   _mm256_mask_storeu_epi32(&out[12 * sizeof(__m256i)], (__mmask8)-1, _mm512_castsi512_si256(padded[12]));
   1128   _mm256_mask_storeu_epi32(&out[13 * sizeof(__m256i)], (__mmask8)-1, _mm512_castsi512_si256(padded[13]));
   1129   _mm256_mask_storeu_epi32(&out[14 * sizeof(__m256i)], (__mmask8)-1, _mm512_castsi512_si256(padded[14]));
   1130   _mm256_mask_storeu_epi32(&out[15 * sizeof(__m256i)], (__mmask8)-1, _mm512_castsi512_si256(padded[15]));
   1131 }
   1132 
   1133 /*
   1134  * ----------------------------------------------------------------------------
   1135  * hash_many_avx512
   1136  * ----------------------------------------------------------------------------
   1137  */
   1138 
   1139 INLINE void hash_one_avx512(const uint8_t *input, size_t blocks,
   1140                             const uint32_t key[8], uint64_t counter,
   1141                             uint8_t flags, uint8_t flags_start,
   1142                             uint8_t flags_end, uint8_t out[BLAKE3_OUT_LEN]) {
   1143   uint32_t cv[8];
   1144   memcpy(cv, key, BLAKE3_KEY_LEN);
   1145   uint8_t block_flags = flags | flags_start;
   1146   while (blocks > 0) {
   1147     if (blocks == 1) {
   1148       block_flags |= flags_end;
   1149     }
   1150     blake3_compress_in_place_avx512(cv, input, BLAKE3_BLOCK_LEN, counter,
   1151                                     block_flags);
   1152     input = &input[BLAKE3_BLOCK_LEN];
   1153     blocks -= 1;
   1154     block_flags = flags;
   1155   }
   1156   memcpy(out, cv, BLAKE3_OUT_LEN);
   1157 }
   1158 
   1159 void blake3_hash_many_avx512(const uint8_t *const *inputs, size_t num_inputs,
   1160                              size_t blocks, const uint32_t key[8],
   1161                              uint64_t counter, bool increment_counter,
   1162                              uint8_t flags, uint8_t flags_start,
   1163                              uint8_t flags_end, uint8_t *out) {
   1164   while (num_inputs >= 16) {
   1165     blake3_hash16_avx512(inputs, blocks, key, counter, increment_counter, flags,
   1166                          flags_start, flags_end, out);
   1167     if (increment_counter) {
   1168       counter += 16;
   1169     }
   1170     inputs += 16;
   1171     num_inputs -= 16;
   1172     out = &out[16 * BLAKE3_OUT_LEN];
   1173   }
   1174   while (num_inputs >= 8) {
   1175     blake3_hash8_avx512(inputs, blocks, key, counter, increment_counter, flags,
   1176                         flags_start, flags_end, out);
   1177     if (increment_counter) {
   1178       counter += 8;
   1179     }
   1180     inputs += 8;
   1181     num_inputs -= 8;
   1182     out = &out[8 * BLAKE3_OUT_LEN];
   1183   }
   1184   while (num_inputs >= 4) {
   1185     blake3_hash4_avx512(inputs, blocks, key, counter, increment_counter, flags,
   1186                         flags_start, flags_end, out);
   1187     if (increment_counter) {
   1188       counter += 4;
   1189     }
   1190     inputs += 4;
   1191     num_inputs -= 4;
   1192     out = &out[4 * BLAKE3_OUT_LEN];
   1193   }
   1194   while (num_inputs > 0) {
   1195     hash_one_avx512(inputs[0], blocks, key, counter, flags, flags_start,
   1196                     flags_end, out);
   1197     if (increment_counter) {
   1198       counter += 1;
   1199     }
   1200     inputs += 1;
   1201     num_inputs -= 1;
   1202     out = &out[BLAKE3_OUT_LEN];
   1203   }
   1204 }