chibipub

experimental activitypub node in C
git clone git://jb55.com/chibipub
Log | Files | Refs | README | LICENSE

blake3_neon.c (12720B)


      1 #include "blake3_impl.h"
      2 
      3 #include <arm_neon.h>
      4 
      5 // TODO: This is probably incorrect for big-endian ARM. How should that work?
      6 INLINE uint32x4_t loadu_128(const uint8_t src[16]) {
      7   // vld1q_u32 has alignment requirements. Don't use it.
      8   uint32x4_t x;
      9   memcpy(&x, src, 16);
     10   return x;
     11 }
     12 
     13 INLINE void storeu_128(uint32x4_t src, uint8_t dest[16]) {
     14   // vst1q_u32 has alignment requirements. Don't use it.
     15   memcpy(dest, &src, 16);
     16 }
     17 
     18 INLINE uint32x4_t add_128(uint32x4_t a, uint32x4_t b) {
     19   return vaddq_u32(a, b);
     20 }
     21 
     22 INLINE uint32x4_t xor_128(uint32x4_t a, uint32x4_t b) {
     23   return veorq_u32(a, b);
     24 }
     25 
     26 INLINE uint32x4_t set1_128(uint32_t x) { return vld1q_dup_u32(&x); }
     27 
     28 INLINE uint32x4_t set4(uint32_t a, uint32_t b, uint32_t c, uint32_t d) {
     29   uint32_t array[4] = {a, b, c, d};
     30   return vld1q_u32(array);
     31 }
     32 
     33 INLINE uint32x4_t rot16_128(uint32x4_t x) {
     34   return vorrq_u32(vshrq_n_u32(x, 16), vshlq_n_u32(x, 32 - 16));
     35 }
     36 
     37 INLINE uint32x4_t rot12_128(uint32x4_t x) {
     38   return vorrq_u32(vshrq_n_u32(x, 12), vshlq_n_u32(x, 32 - 12));
     39 }
     40 
     41 INLINE uint32x4_t rot8_128(uint32x4_t x) {
     42   return vorrq_u32(vshrq_n_u32(x, 8), vshlq_n_u32(x, 32 - 8));
     43 }
     44 
     45 INLINE uint32x4_t rot7_128(uint32x4_t x) {
     46   return vorrq_u32(vshrq_n_u32(x, 7), vshlq_n_u32(x, 32 - 7));
     47 }
     48 
     49 // TODO: compress_neon
     50 
     51 // TODO: hash2_neon
     52 
     53 /*
     54  * ----------------------------------------------------------------------------
     55  * hash4_neon
     56  * ----------------------------------------------------------------------------
     57  */
     58 
     59 INLINE void round_fn4(uint32x4_t v[16], uint32x4_t m[16], size_t r) {
     60   v[0] = add_128(v[0], m[(size_t)MSG_SCHEDULE[r][0]]);
     61   v[1] = add_128(v[1], m[(size_t)MSG_SCHEDULE[r][2]]);
     62   v[2] = add_128(v[2], m[(size_t)MSG_SCHEDULE[r][4]]);
     63   v[3] = add_128(v[3], m[(size_t)MSG_SCHEDULE[r][6]]);
     64   v[0] = add_128(v[0], v[4]);
     65   v[1] = add_128(v[1], v[5]);
     66   v[2] = add_128(v[2], v[6]);
     67   v[3] = add_128(v[3], v[7]);
     68   v[12] = xor_128(v[12], v[0]);
     69   v[13] = xor_128(v[13], v[1]);
     70   v[14] = xor_128(v[14], v[2]);
     71   v[15] = xor_128(v[15], v[3]);
     72   v[12] = rot16_128(v[12]);
     73   v[13] = rot16_128(v[13]);
     74   v[14] = rot16_128(v[14]);
     75   v[15] = rot16_128(v[15]);
     76   v[8] = add_128(v[8], v[12]);
     77   v[9] = add_128(v[9], v[13]);
     78   v[10] = add_128(v[10], v[14]);
     79   v[11] = add_128(v[11], v[15]);
     80   v[4] = xor_128(v[4], v[8]);
     81   v[5] = xor_128(v[5], v[9]);
     82   v[6] = xor_128(v[6], v[10]);
     83   v[7] = xor_128(v[7], v[11]);
     84   v[4] = rot12_128(v[4]);
     85   v[5] = rot12_128(v[5]);
     86   v[6] = rot12_128(v[6]);
     87   v[7] = rot12_128(v[7]);
     88   v[0] = add_128(v[0], m[(size_t)MSG_SCHEDULE[r][1]]);
     89   v[1] = add_128(v[1], m[(size_t)MSG_SCHEDULE[r][3]]);
     90   v[2] = add_128(v[2], m[(size_t)MSG_SCHEDULE[r][5]]);
     91   v[3] = add_128(v[3], m[(size_t)MSG_SCHEDULE[r][7]]);
     92   v[0] = add_128(v[0], v[4]);
     93   v[1] = add_128(v[1], v[5]);
     94   v[2] = add_128(v[2], v[6]);
     95   v[3] = add_128(v[3], v[7]);
     96   v[12] = xor_128(v[12], v[0]);
     97   v[13] = xor_128(v[13], v[1]);
     98   v[14] = xor_128(v[14], v[2]);
     99   v[15] = xor_128(v[15], v[3]);
    100   v[12] = rot8_128(v[12]);
    101   v[13] = rot8_128(v[13]);
    102   v[14] = rot8_128(v[14]);
    103   v[15] = rot8_128(v[15]);
    104   v[8] = add_128(v[8], v[12]);
    105   v[9] = add_128(v[9], v[13]);
    106   v[10] = add_128(v[10], v[14]);
    107   v[11] = add_128(v[11], v[15]);
    108   v[4] = xor_128(v[4], v[8]);
    109   v[5] = xor_128(v[5], v[9]);
    110   v[6] = xor_128(v[6], v[10]);
    111   v[7] = xor_128(v[7], v[11]);
    112   v[4] = rot7_128(v[4]);
    113   v[5] = rot7_128(v[5]);
    114   v[6] = rot7_128(v[6]);
    115   v[7] = rot7_128(v[7]);
    116 
    117   v[0] = add_128(v[0], m[(size_t)MSG_SCHEDULE[r][8]]);
    118   v[1] = add_128(v[1], m[(size_t)MSG_SCHEDULE[r][10]]);
    119   v[2] = add_128(v[2], m[(size_t)MSG_SCHEDULE[r][12]]);
    120   v[3] = add_128(v[3], m[(size_t)MSG_SCHEDULE[r][14]]);
    121   v[0] = add_128(v[0], v[5]);
    122   v[1] = add_128(v[1], v[6]);
    123   v[2] = add_128(v[2], v[7]);
    124   v[3] = add_128(v[3], v[4]);
    125   v[15] = xor_128(v[15], v[0]);
    126   v[12] = xor_128(v[12], v[1]);
    127   v[13] = xor_128(v[13], v[2]);
    128   v[14] = xor_128(v[14], v[3]);
    129   v[15] = rot16_128(v[15]);
    130   v[12] = rot16_128(v[12]);
    131   v[13] = rot16_128(v[13]);
    132   v[14] = rot16_128(v[14]);
    133   v[10] = add_128(v[10], v[15]);
    134   v[11] = add_128(v[11], v[12]);
    135   v[8] = add_128(v[8], v[13]);
    136   v[9] = add_128(v[9], v[14]);
    137   v[5] = xor_128(v[5], v[10]);
    138   v[6] = xor_128(v[6], v[11]);
    139   v[7] = xor_128(v[7], v[8]);
    140   v[4] = xor_128(v[4], v[9]);
    141   v[5] = rot12_128(v[5]);
    142   v[6] = rot12_128(v[6]);
    143   v[7] = rot12_128(v[7]);
    144   v[4] = rot12_128(v[4]);
    145   v[0] = add_128(v[0], m[(size_t)MSG_SCHEDULE[r][9]]);
    146   v[1] = add_128(v[1], m[(size_t)MSG_SCHEDULE[r][11]]);
    147   v[2] = add_128(v[2], m[(size_t)MSG_SCHEDULE[r][13]]);
    148   v[3] = add_128(v[3], m[(size_t)MSG_SCHEDULE[r][15]]);
    149   v[0] = add_128(v[0], v[5]);
    150   v[1] = add_128(v[1], v[6]);
    151   v[2] = add_128(v[2], v[7]);
    152   v[3] = add_128(v[3], v[4]);
    153   v[15] = xor_128(v[15], v[0]);
    154   v[12] = xor_128(v[12], v[1]);
    155   v[13] = xor_128(v[13], v[2]);
    156   v[14] = xor_128(v[14], v[3]);
    157   v[15] = rot8_128(v[15]);
    158   v[12] = rot8_128(v[12]);
    159   v[13] = rot8_128(v[13]);
    160   v[14] = rot8_128(v[14]);
    161   v[10] = add_128(v[10], v[15]);
    162   v[11] = add_128(v[11], v[12]);
    163   v[8] = add_128(v[8], v[13]);
    164   v[9] = add_128(v[9], v[14]);
    165   v[5] = xor_128(v[5], v[10]);
    166   v[6] = xor_128(v[6], v[11]);
    167   v[7] = xor_128(v[7], v[8]);
    168   v[4] = xor_128(v[4], v[9]);
    169   v[5] = rot7_128(v[5]);
    170   v[6] = rot7_128(v[6]);
    171   v[7] = rot7_128(v[7]);
    172   v[4] = rot7_128(v[4]);
    173 }
    174 
    175 INLINE void transpose_vecs_128(uint32x4_t vecs[4]) {
    176   // Individually transpose the four 2x2 sub-matrices in each corner.
    177   uint32x4x2_t rows01 = vtrnq_u32(vecs[0], vecs[1]);
    178   uint32x4x2_t rows23 = vtrnq_u32(vecs[2], vecs[3]);
    179 
    180   // Swap the top-right and bottom-left 2x2s (which just got transposed).
    181   vecs[0] =
    182       vcombine_u32(vget_low_u32(rows01.val[0]), vget_low_u32(rows23.val[0]));
    183   vecs[1] =
    184       vcombine_u32(vget_low_u32(rows01.val[1]), vget_low_u32(rows23.val[1]));
    185   vecs[2] =
    186       vcombine_u32(vget_high_u32(rows01.val[0]), vget_high_u32(rows23.val[0]));
    187   vecs[3] =
    188       vcombine_u32(vget_high_u32(rows01.val[1]), vget_high_u32(rows23.val[1]));
    189 }
    190 
    191 INLINE void transpose_msg_vecs4(const uint8_t *const *inputs,
    192                                 size_t block_offset, uint32x4_t out[16]) {
    193   out[0] = loadu_128(&inputs[0][block_offset + 0 * sizeof(uint32x4_t)]);
    194   out[1] = loadu_128(&inputs[1][block_offset + 0 * sizeof(uint32x4_t)]);
    195   out[2] = loadu_128(&inputs[2][block_offset + 0 * sizeof(uint32x4_t)]);
    196   out[3] = loadu_128(&inputs[3][block_offset + 0 * sizeof(uint32x4_t)]);
    197   out[4] = loadu_128(&inputs[0][block_offset + 1 * sizeof(uint32x4_t)]);
    198   out[5] = loadu_128(&inputs[1][block_offset + 1 * sizeof(uint32x4_t)]);
    199   out[6] = loadu_128(&inputs[2][block_offset + 1 * sizeof(uint32x4_t)]);
    200   out[7] = loadu_128(&inputs[3][block_offset + 1 * sizeof(uint32x4_t)]);
    201   out[8] = loadu_128(&inputs[0][block_offset + 2 * sizeof(uint32x4_t)]);
    202   out[9] = loadu_128(&inputs[1][block_offset + 2 * sizeof(uint32x4_t)]);
    203   out[10] = loadu_128(&inputs[2][block_offset + 2 * sizeof(uint32x4_t)]);
    204   out[11] = loadu_128(&inputs[3][block_offset + 2 * sizeof(uint32x4_t)]);
    205   out[12] = loadu_128(&inputs[0][block_offset + 3 * sizeof(uint32x4_t)]);
    206   out[13] = loadu_128(&inputs[1][block_offset + 3 * sizeof(uint32x4_t)]);
    207   out[14] = loadu_128(&inputs[2][block_offset + 3 * sizeof(uint32x4_t)]);
    208   out[15] = loadu_128(&inputs[3][block_offset + 3 * sizeof(uint32x4_t)]);
    209   transpose_vecs_128(&out[0]);
    210   transpose_vecs_128(&out[4]);
    211   transpose_vecs_128(&out[8]);
    212   transpose_vecs_128(&out[12]);
    213 }
    214 
    215 INLINE void load_counters4(uint64_t counter, bool increment_counter,
    216                            uint32x4_t *out_low, uint32x4_t *out_high) {
    217   uint64_t mask = (increment_counter ? ~0 : 0);
    218   *out_low = set4(
    219       counter_low(counter + (mask & 0)), counter_low(counter + (mask & 1)),
    220       counter_low(counter + (mask & 2)), counter_low(counter + (mask & 3)));
    221   *out_high = set4(
    222       counter_high(counter + (mask & 0)), counter_high(counter + (mask & 1)),
    223       counter_high(counter + (mask & 2)), counter_high(counter + (mask & 3)));
    224 }
    225 
    226 void blake3_hash4_neon(const uint8_t *const *inputs, size_t blocks,
    227                        const uint32_t key[8], uint64_t counter,
    228                        bool increment_counter, uint8_t flags,
    229                        uint8_t flags_start, uint8_t flags_end, uint8_t *out) {
    230   uint32x4_t h_vecs[8] = {
    231       set1_128(key[0]), set1_128(key[1]), set1_128(key[2]), set1_128(key[3]),
    232       set1_128(key[4]), set1_128(key[5]), set1_128(key[6]), set1_128(key[7]),
    233   };
    234   uint32x4_t counter_low_vec, counter_high_vec;
    235   load_counters4(counter, increment_counter, &counter_low_vec,
    236                  &counter_high_vec);
    237   uint8_t block_flags = flags | flags_start;
    238 
    239   for (size_t block = 0; block < blocks; block++) {
    240     if (block + 1 == blocks) {
    241       block_flags |= flags_end;
    242     }
    243     uint32x4_t block_len_vec = set1_128(BLAKE3_BLOCK_LEN);
    244     uint32x4_t block_flags_vec = set1_128(block_flags);
    245     uint32x4_t msg_vecs[16];
    246     transpose_msg_vecs4(inputs, block * BLAKE3_BLOCK_LEN, msg_vecs);
    247 
    248     uint32x4_t v[16] = {
    249         h_vecs[0],       h_vecs[1],        h_vecs[2],       h_vecs[3],
    250         h_vecs[4],       h_vecs[5],        h_vecs[6],       h_vecs[7],
    251         set1_128(IV[0]), set1_128(IV[1]),  set1_128(IV[2]), set1_128(IV[3]),
    252         counter_low_vec, counter_high_vec, block_len_vec,   block_flags_vec,
    253     };
    254     round_fn4(v, msg_vecs, 0);
    255     round_fn4(v, msg_vecs, 1);
    256     round_fn4(v, msg_vecs, 2);
    257     round_fn4(v, msg_vecs, 3);
    258     round_fn4(v, msg_vecs, 4);
    259     round_fn4(v, msg_vecs, 5);
    260     round_fn4(v, msg_vecs, 6);
    261     h_vecs[0] = xor_128(v[0], v[8]);
    262     h_vecs[1] = xor_128(v[1], v[9]);
    263     h_vecs[2] = xor_128(v[2], v[10]);
    264     h_vecs[3] = xor_128(v[3], v[11]);
    265     h_vecs[4] = xor_128(v[4], v[12]);
    266     h_vecs[5] = xor_128(v[5], v[13]);
    267     h_vecs[6] = xor_128(v[6], v[14]);
    268     h_vecs[7] = xor_128(v[7], v[15]);
    269 
    270     block_flags = flags;
    271   }
    272 
    273   transpose_vecs_128(&h_vecs[0]);
    274   transpose_vecs_128(&h_vecs[4]);
    275   // The first four vecs now contain the first half of each output, and the
    276   // second four vecs contain the second half of each output.
    277   storeu_128(h_vecs[0], &out[0 * sizeof(uint32x4_t)]);
    278   storeu_128(h_vecs[4], &out[1 * sizeof(uint32x4_t)]);
    279   storeu_128(h_vecs[1], &out[2 * sizeof(uint32x4_t)]);
    280   storeu_128(h_vecs[5], &out[3 * sizeof(uint32x4_t)]);
    281   storeu_128(h_vecs[2], &out[4 * sizeof(uint32x4_t)]);
    282   storeu_128(h_vecs[6], &out[5 * sizeof(uint32x4_t)]);
    283   storeu_128(h_vecs[3], &out[6 * sizeof(uint32x4_t)]);
    284   storeu_128(h_vecs[7], &out[7 * sizeof(uint32x4_t)]);
    285 }
    286 
    287 /*
    288  * ----------------------------------------------------------------------------
    289  * hash_many_neon
    290  * ----------------------------------------------------------------------------
    291  */
    292 
    293 void blake3_compress_in_place_portable(uint32_t cv[8],
    294                                        const uint8_t block[BLAKE3_BLOCK_LEN],
    295                                        uint8_t block_len, uint64_t counter,
    296                                        uint8_t flags);
    297 
    298 INLINE void hash_one_neon(const uint8_t *input, size_t blocks,
    299                           const uint32_t key[8], uint64_t counter,
    300                           uint8_t flags, uint8_t flags_start, uint8_t flags_end,
    301                           uint8_t out[BLAKE3_OUT_LEN]) {
    302   uint32_t cv[8];
    303   memcpy(cv, key, BLAKE3_KEY_LEN);
    304   uint8_t block_flags = flags | flags_start;
    305   while (blocks > 0) {
    306     if (blocks == 1) {
    307       block_flags |= flags_end;
    308     }
    309     // TODO: Implement compress_neon. However note that according to
    310     // https://github.com/BLAKE2/BLAKE2/commit/7965d3e6e1b4193438b8d3a656787587d2579227,
    311     // compress_neon might not be any faster than compress_portable.
    312     blake3_compress_in_place_portable(cv, input, BLAKE3_BLOCK_LEN, counter,
    313                                       block_flags);
    314     input = &input[BLAKE3_BLOCK_LEN];
    315     blocks -= 1;
    316     block_flags = flags;
    317   }
    318   memcpy(out, cv, BLAKE3_OUT_LEN);
    319 }
    320 
    321 void blake3_hash_many_neon(const uint8_t *const *inputs, size_t num_inputs,
    322                            size_t blocks, const uint32_t key[8],
    323                            uint64_t counter, bool increment_counter,
    324                            uint8_t flags, uint8_t flags_start,
    325                            uint8_t flags_end, uint8_t *out) {
    326   while (num_inputs >= 4) {
    327     blake3_hash4_neon(inputs, blocks, key, counter, increment_counter, flags,
    328                       flags_start, flags_end, out);
    329     if (increment_counter) {
    330       counter += 4;
    331     }
    332     inputs += 4;
    333     num_inputs -= 4;
    334     out = &out[4 * BLAKE3_OUT_LEN];
    335   }
    336   while (num_inputs > 0) {
    337     hash_one_neon(inputs[0], blocks, key, counter, flags, flags_start,
    338                   flags_end, out);
    339     if (increment_counter) {
    340       counter += 1;
    341     }
    342     inputs += 1;
    343     num_inputs -= 1;
    344     out = &out[BLAKE3_OUT_LEN];
    345   }
    346 }