blake3_avx2.c (12411B)
1 #include "blake3_impl.h" 2 3 #include <immintrin.h> 4 5 #define DEGREE 8 6 7 INLINE __m256i loadu(const uint8_t src[32]) { 8 return _mm256_loadu_si256((const __m256i *)src); 9 } 10 11 INLINE void storeu(__m256i src, uint8_t dest[16]) { 12 _mm256_storeu_si256((__m256i *)dest, src); 13 } 14 15 INLINE __m256i addv(__m256i a, __m256i b) { return _mm256_add_epi32(a, b); } 16 17 // Note that clang-format doesn't like the name "xor" for some reason. 18 INLINE __m256i xorv(__m256i a, __m256i b) { return _mm256_xor_si256(a, b); } 19 20 INLINE __m256i set1(uint32_t x) { return _mm256_set1_epi32((int32_t)x); } 21 22 INLINE __m256i rot16(__m256i x) { 23 return _mm256_shuffle_epi8( 24 x, _mm256_set_epi8(13, 12, 15, 14, 9, 8, 11, 10, 5, 4, 7, 6, 1, 0, 3, 2, 25 13, 12, 15, 14, 9, 8, 11, 10, 5, 4, 7, 6, 1, 0, 3, 2)); 26 } 27 28 INLINE __m256i rot12(__m256i x) { 29 return _mm256_or_si256(_mm256_srli_epi32(x, 12), _mm256_slli_epi32(x, 32 - 12)); 30 } 31 32 INLINE __m256i rot8(__m256i x) { 33 return _mm256_shuffle_epi8( 34 x, _mm256_set_epi8(12, 15, 14, 13, 8, 11, 10, 9, 4, 7, 6, 5, 0, 3, 2, 1, 35 12, 15, 14, 13, 8, 11, 10, 9, 4, 7, 6, 5, 0, 3, 2, 1)); 36 } 37 38 INLINE __m256i rot7(__m256i x) { 39 return _mm256_or_si256(_mm256_srli_epi32(x, 7), _mm256_slli_epi32(x, 32 - 7)); 40 } 41 42 INLINE void round_fn(__m256i v[16], __m256i m[16], size_t r) { 43 v[0] = addv(v[0], m[(size_t)MSG_SCHEDULE[r][0]]); 44 v[1] = addv(v[1], m[(size_t)MSG_SCHEDULE[r][2]]); 45 v[2] = addv(v[2], m[(size_t)MSG_SCHEDULE[r][4]]); 46 v[3] = addv(v[3], m[(size_t)MSG_SCHEDULE[r][6]]); 47 v[0] = addv(v[0], v[4]); 48 v[1] = addv(v[1], v[5]); 49 v[2] = addv(v[2], v[6]); 50 v[3] = addv(v[3], v[7]); 51 v[12] = xorv(v[12], v[0]); 52 v[13] = xorv(v[13], v[1]); 53 v[14] = xorv(v[14], v[2]); 54 v[15] = xorv(v[15], v[3]); 55 v[12] = rot16(v[12]); 56 v[13] = rot16(v[13]); 57 v[14] = rot16(v[14]); 58 v[15] = rot16(v[15]); 59 v[8] = addv(v[8], v[12]); 60 v[9] = addv(v[9], v[13]); 61 v[10] = addv(v[10], v[14]); 62 v[11] = addv(v[11], v[15]); 63 v[4] = xorv(v[4], v[8]); 64 v[5] = xorv(v[5], v[9]); 65 v[6] = xorv(v[6], v[10]); 66 v[7] = xorv(v[7], v[11]); 67 v[4] = rot12(v[4]); 68 v[5] = rot12(v[5]); 69 v[6] = rot12(v[6]); 70 v[7] = rot12(v[7]); 71 v[0] = addv(v[0], m[(size_t)MSG_SCHEDULE[r][1]]); 72 v[1] = addv(v[1], m[(size_t)MSG_SCHEDULE[r][3]]); 73 v[2] = addv(v[2], m[(size_t)MSG_SCHEDULE[r][5]]); 74 v[3] = addv(v[3], m[(size_t)MSG_SCHEDULE[r][7]]); 75 v[0] = addv(v[0], v[4]); 76 v[1] = addv(v[1], v[5]); 77 v[2] = addv(v[2], v[6]); 78 v[3] = addv(v[3], v[7]); 79 v[12] = xorv(v[12], v[0]); 80 v[13] = xorv(v[13], v[1]); 81 v[14] = xorv(v[14], v[2]); 82 v[15] = xorv(v[15], v[3]); 83 v[12] = rot8(v[12]); 84 v[13] = rot8(v[13]); 85 v[14] = rot8(v[14]); 86 v[15] = rot8(v[15]); 87 v[8] = addv(v[8], v[12]); 88 v[9] = addv(v[9], v[13]); 89 v[10] = addv(v[10], v[14]); 90 v[11] = addv(v[11], v[15]); 91 v[4] = xorv(v[4], v[8]); 92 v[5] = xorv(v[5], v[9]); 93 v[6] = xorv(v[6], v[10]); 94 v[7] = xorv(v[7], v[11]); 95 v[4] = rot7(v[4]); 96 v[5] = rot7(v[5]); 97 v[6] = rot7(v[6]); 98 v[7] = rot7(v[7]); 99 100 v[0] = addv(v[0], m[(size_t)MSG_SCHEDULE[r][8]]); 101 v[1] = addv(v[1], m[(size_t)MSG_SCHEDULE[r][10]]); 102 v[2] = addv(v[2], m[(size_t)MSG_SCHEDULE[r][12]]); 103 v[3] = addv(v[3], m[(size_t)MSG_SCHEDULE[r][14]]); 104 v[0] = addv(v[0], v[5]); 105 v[1] = addv(v[1], v[6]); 106 v[2] = addv(v[2], v[7]); 107 v[3] = addv(v[3], v[4]); 108 v[15] = xorv(v[15], v[0]); 109 v[12] = xorv(v[12], v[1]); 110 v[13] = xorv(v[13], v[2]); 111 v[14] = xorv(v[14], v[3]); 112 v[15] = rot16(v[15]); 113 v[12] = rot16(v[12]); 114 v[13] = rot16(v[13]); 115 v[14] = rot16(v[14]); 116 v[10] = addv(v[10], v[15]); 117 v[11] = addv(v[11], v[12]); 118 v[8] = addv(v[8], v[13]); 119 v[9] = addv(v[9], v[14]); 120 v[5] = xorv(v[5], v[10]); 121 v[6] = xorv(v[6], v[11]); 122 v[7] = xorv(v[7], v[8]); 123 v[4] = xorv(v[4], v[9]); 124 v[5] = rot12(v[5]); 125 v[6] = rot12(v[6]); 126 v[7] = rot12(v[7]); 127 v[4] = rot12(v[4]); 128 v[0] = addv(v[0], m[(size_t)MSG_SCHEDULE[r][9]]); 129 v[1] = addv(v[1], m[(size_t)MSG_SCHEDULE[r][11]]); 130 v[2] = addv(v[2], m[(size_t)MSG_SCHEDULE[r][13]]); 131 v[3] = addv(v[3], m[(size_t)MSG_SCHEDULE[r][15]]); 132 v[0] = addv(v[0], v[5]); 133 v[1] = addv(v[1], v[6]); 134 v[2] = addv(v[2], v[7]); 135 v[3] = addv(v[3], v[4]); 136 v[15] = xorv(v[15], v[0]); 137 v[12] = xorv(v[12], v[1]); 138 v[13] = xorv(v[13], v[2]); 139 v[14] = xorv(v[14], v[3]); 140 v[15] = rot8(v[15]); 141 v[12] = rot8(v[12]); 142 v[13] = rot8(v[13]); 143 v[14] = rot8(v[14]); 144 v[10] = addv(v[10], v[15]); 145 v[11] = addv(v[11], v[12]); 146 v[8] = addv(v[8], v[13]); 147 v[9] = addv(v[9], v[14]); 148 v[5] = xorv(v[5], v[10]); 149 v[6] = xorv(v[6], v[11]); 150 v[7] = xorv(v[7], v[8]); 151 v[4] = xorv(v[4], v[9]); 152 v[5] = rot7(v[5]); 153 v[6] = rot7(v[6]); 154 v[7] = rot7(v[7]); 155 v[4] = rot7(v[4]); 156 } 157 158 INLINE void transpose_vecs(__m256i vecs[DEGREE]) { 159 // Interleave 32-bit lanes. The low unpack is lanes 00/11/44/55, and the high 160 // is 22/33/66/77. 161 __m256i ab_0145 = _mm256_unpacklo_epi32(vecs[0], vecs[1]); 162 __m256i ab_2367 = _mm256_unpackhi_epi32(vecs[0], vecs[1]); 163 __m256i cd_0145 = _mm256_unpacklo_epi32(vecs[2], vecs[3]); 164 __m256i cd_2367 = _mm256_unpackhi_epi32(vecs[2], vecs[3]); 165 __m256i ef_0145 = _mm256_unpacklo_epi32(vecs[4], vecs[5]); 166 __m256i ef_2367 = _mm256_unpackhi_epi32(vecs[4], vecs[5]); 167 __m256i gh_0145 = _mm256_unpacklo_epi32(vecs[6], vecs[7]); 168 __m256i gh_2367 = _mm256_unpackhi_epi32(vecs[6], vecs[7]); 169 170 // Interleave 64-bit lates. The low unpack is lanes 00/22 and the high is 171 // 11/33. 172 __m256i abcd_04 = _mm256_unpacklo_epi64(ab_0145, cd_0145); 173 __m256i abcd_15 = _mm256_unpackhi_epi64(ab_0145, cd_0145); 174 __m256i abcd_26 = _mm256_unpacklo_epi64(ab_2367, cd_2367); 175 __m256i abcd_37 = _mm256_unpackhi_epi64(ab_2367, cd_2367); 176 __m256i efgh_04 = _mm256_unpacklo_epi64(ef_0145, gh_0145); 177 __m256i efgh_15 = _mm256_unpackhi_epi64(ef_0145, gh_0145); 178 __m256i efgh_26 = _mm256_unpacklo_epi64(ef_2367, gh_2367); 179 __m256i efgh_37 = _mm256_unpackhi_epi64(ef_2367, gh_2367); 180 181 // Interleave 128-bit lanes. 182 vecs[0] = _mm256_permute2x128_si256(abcd_04, efgh_04, 0x20); 183 vecs[1] = _mm256_permute2x128_si256(abcd_15, efgh_15, 0x20); 184 vecs[2] = _mm256_permute2x128_si256(abcd_26, efgh_26, 0x20); 185 vecs[3] = _mm256_permute2x128_si256(abcd_37, efgh_37, 0x20); 186 vecs[4] = _mm256_permute2x128_si256(abcd_04, efgh_04, 0x31); 187 vecs[5] = _mm256_permute2x128_si256(abcd_15, efgh_15, 0x31); 188 vecs[6] = _mm256_permute2x128_si256(abcd_26, efgh_26, 0x31); 189 vecs[7] = _mm256_permute2x128_si256(abcd_37, efgh_37, 0x31); 190 } 191 192 INLINE void transpose_msg_vecs(const uint8_t *const *inputs, 193 size_t block_offset, __m256i out[16]) { 194 out[0] = loadu(&inputs[0][block_offset + 0 * sizeof(__m256i)]); 195 out[1] = loadu(&inputs[1][block_offset + 0 * sizeof(__m256i)]); 196 out[2] = loadu(&inputs[2][block_offset + 0 * sizeof(__m256i)]); 197 out[3] = loadu(&inputs[3][block_offset + 0 * sizeof(__m256i)]); 198 out[4] = loadu(&inputs[4][block_offset + 0 * sizeof(__m256i)]); 199 out[5] = loadu(&inputs[5][block_offset + 0 * sizeof(__m256i)]); 200 out[6] = loadu(&inputs[6][block_offset + 0 * sizeof(__m256i)]); 201 out[7] = loadu(&inputs[7][block_offset + 0 * sizeof(__m256i)]); 202 out[8] = loadu(&inputs[0][block_offset + 1 * sizeof(__m256i)]); 203 out[9] = loadu(&inputs[1][block_offset + 1 * sizeof(__m256i)]); 204 out[10] = loadu(&inputs[2][block_offset + 1 * sizeof(__m256i)]); 205 out[11] = loadu(&inputs[3][block_offset + 1 * sizeof(__m256i)]); 206 out[12] = loadu(&inputs[4][block_offset + 1 * sizeof(__m256i)]); 207 out[13] = loadu(&inputs[5][block_offset + 1 * sizeof(__m256i)]); 208 out[14] = loadu(&inputs[6][block_offset + 1 * sizeof(__m256i)]); 209 out[15] = loadu(&inputs[7][block_offset + 1 * sizeof(__m256i)]); 210 for (size_t i = 0; i < 8; ++i) { 211 _mm_prefetch(&inputs[i][block_offset + 256], _MM_HINT_T0); 212 } 213 transpose_vecs(&out[0]); 214 transpose_vecs(&out[8]); 215 } 216 217 INLINE void load_counters(uint64_t counter, bool increment_counter, 218 __m256i *out_lo, __m256i *out_hi) { 219 const __m256i mask = _mm256_set1_epi32(-(int32_t)increment_counter); 220 const __m256i add0 = _mm256_set_epi32(7, 6, 5, 4, 3, 2, 1, 0); 221 const __m256i add1 = _mm256_and_si256(mask, add0); 222 __m256i l = _mm256_add_epi32(_mm256_set1_epi32(counter), add1); 223 __m256i carry = _mm256_cmpgt_epi32(_mm256_xor_si256(add1, _mm256_set1_epi32(0x80000000)), 224 _mm256_xor_si256( l, _mm256_set1_epi32(0x80000000))); 225 __m256i h = _mm256_sub_epi32(_mm256_set1_epi32(counter >> 32), carry); 226 *out_lo = l; 227 *out_hi = h; 228 } 229 230 void blake3_hash8_avx2(const uint8_t *const *inputs, size_t blocks, 231 const uint32_t key[8], uint64_t counter, 232 bool increment_counter, uint8_t flags, 233 uint8_t flags_start, uint8_t flags_end, uint8_t *out) { 234 __m256i h_vecs[8] = { 235 set1(key[0]), set1(key[1]), set1(key[2]), set1(key[3]), 236 set1(key[4]), set1(key[5]), set1(key[6]), set1(key[7]), 237 }; 238 __m256i counter_low_vec, counter_high_vec; 239 load_counters(counter, increment_counter, &counter_low_vec, 240 &counter_high_vec); 241 uint8_t block_flags = flags | flags_start; 242 243 for (size_t block = 0; block < blocks; block++) { 244 if (block + 1 == blocks) { 245 block_flags |= flags_end; 246 } 247 __m256i block_len_vec = set1(BLAKE3_BLOCK_LEN); 248 __m256i block_flags_vec = set1(block_flags); 249 __m256i msg_vecs[16]; 250 transpose_msg_vecs(inputs, block * BLAKE3_BLOCK_LEN, msg_vecs); 251 252 __m256i v[16] = { 253 h_vecs[0], h_vecs[1], h_vecs[2], h_vecs[3], 254 h_vecs[4], h_vecs[5], h_vecs[6], h_vecs[7], 255 set1(IV[0]), set1(IV[1]), set1(IV[2]), set1(IV[3]), 256 counter_low_vec, counter_high_vec, block_len_vec, block_flags_vec, 257 }; 258 round_fn(v, msg_vecs, 0); 259 round_fn(v, msg_vecs, 1); 260 round_fn(v, msg_vecs, 2); 261 round_fn(v, msg_vecs, 3); 262 round_fn(v, msg_vecs, 4); 263 round_fn(v, msg_vecs, 5); 264 round_fn(v, msg_vecs, 6); 265 h_vecs[0] = xorv(v[0], v[8]); 266 h_vecs[1] = xorv(v[1], v[9]); 267 h_vecs[2] = xorv(v[2], v[10]); 268 h_vecs[3] = xorv(v[3], v[11]); 269 h_vecs[4] = xorv(v[4], v[12]); 270 h_vecs[5] = xorv(v[5], v[13]); 271 h_vecs[6] = xorv(v[6], v[14]); 272 h_vecs[7] = xorv(v[7], v[15]); 273 274 block_flags = flags; 275 } 276 277 transpose_vecs(h_vecs); 278 storeu(h_vecs[0], &out[0 * sizeof(__m256i)]); 279 storeu(h_vecs[1], &out[1 * sizeof(__m256i)]); 280 storeu(h_vecs[2], &out[2 * sizeof(__m256i)]); 281 storeu(h_vecs[3], &out[3 * sizeof(__m256i)]); 282 storeu(h_vecs[4], &out[4 * sizeof(__m256i)]); 283 storeu(h_vecs[5], &out[5 * sizeof(__m256i)]); 284 storeu(h_vecs[6], &out[6 * sizeof(__m256i)]); 285 storeu(h_vecs[7], &out[7 * sizeof(__m256i)]); 286 } 287 288 #if !defined(BLAKE3_NO_SSE41) 289 void blake3_hash_many_sse41(const uint8_t *const *inputs, size_t num_inputs, 290 size_t blocks, const uint32_t key[8], 291 uint64_t counter, bool increment_counter, 292 uint8_t flags, uint8_t flags_start, 293 uint8_t flags_end, uint8_t *out); 294 #else 295 void blake3_hash_many_portable(const uint8_t *const *inputs, size_t num_inputs, 296 size_t blocks, const uint32_t key[8], 297 uint64_t counter, bool increment_counter, 298 uint8_t flags, uint8_t flags_start, 299 uint8_t flags_end, uint8_t *out); 300 #endif 301 302 void blake3_hash_many_avx2(const uint8_t *const *inputs, size_t num_inputs, 303 size_t blocks, const uint32_t key[8], 304 uint64_t counter, bool increment_counter, 305 uint8_t flags, uint8_t flags_start, 306 uint8_t flags_end, uint8_t *out) { 307 while (num_inputs >= DEGREE) { 308 blake3_hash8_avx2(inputs, blocks, key, counter, increment_counter, flags, 309 flags_start, flags_end, out); 310 if (increment_counter) { 311 counter += DEGREE; 312 } 313 inputs += DEGREE; 314 num_inputs -= DEGREE; 315 out = &out[DEGREE * BLAKE3_OUT_LEN]; 316 } 317 #if !defined(BLAKE3_NO_SSE41) 318 blake3_hash_many_sse41(inputs, num_inputs, blocks, key, counter, 319 increment_counter, flags, flags_start, flags_end, out); 320 #else 321 blake3_hash_many_portable(inputs, num_inputs, blocks, key, counter, 322 increment_counter, flags, flags_start, flags_end, 323 out); 324 #endif 325 }