diff options
Diffstat (limited to 'c/blake3_avx512.c')
| -rw-r--r-- | c/blake3_avx512.c | 178 |
1 files changed, 173 insertions, 5 deletions
diff --git a/c/blake3_avx512.c b/c/blake3_avx512.c index d6b1ae9..f88a32d 100644 --- a/c/blake3_avx512.c +++ b/c/blake3_avx512.c @@ -7,23 +7,27 @@ _mm_shuffle_ps(_mm_castsi128_ps(a), _mm_castsi128_ps(b), (c)))) INLINE __m128i loadu_128(const uint8_t src[16]) { - return _mm_loadu_si128((const __m128i *)src); + return _mm_loadu_si128((void*)src); } INLINE __m256i loadu_256(const uint8_t src[32]) { - return _mm256_loadu_si256((const __m256i *)src); + return _mm256_loadu_si256((void*)src); } INLINE __m512i loadu_512(const uint8_t src[64]) { - return _mm512_loadu_si512((const __m512i *)src); + return _mm512_loadu_si512((void*)src); } INLINE void storeu_128(__m128i src, uint8_t dest[16]) { - _mm_storeu_si128((__m128i *)dest, src); + _mm_storeu_si128((void*)dest, src); } INLINE void storeu_256(__m256i src, uint8_t dest[16]) { - _mm256_storeu_si256((__m256i *)dest, src); + _mm256_storeu_si256((void*)dest, src); +} + +INLINE void storeu_512(__m512i src, uint8_t dest[16]) { + _mm512_storeu_si512((void*)dest, src); } INLINE __m128i add_128(__m128i a, __m128i b) { return _mm_add_epi32(a, b); } @@ -550,6 +554,54 @@ void blake3_hash4_avx512(const uint8_t *const *inputs, size_t blocks, storeu_128(h_vecs[7], &out[7 * sizeof(__m128i)]); } +static +void blake3_xof4_avx512(const uint32_t cv[8], + const uint8_t block[BLAKE3_BLOCK_LEN], + uint8_t block_len, uint64_t counter, uint8_t flags, + uint8_t out[4 * 64]) { + __m128i h_vecs[8] = { + set1_128(cv[0]), set1_128(cv[1]), set1_128(cv[2]), set1_128(cv[3]), + set1_128(cv[4]), set1_128(cv[5]), set1_128(cv[6]), set1_128(cv[7]), + }; + uint32_t block_words[16]; + load_block_words(block, block_words); + __m128i msg_vecs[16]; + for (size_t i = 0; i < 16; i++) { + msg_vecs[i] = set1_128(block_words[i]); + } + __m128i counter_low_vec, counter_high_vec; + load_counters4(counter, true, &counter_low_vec, &counter_high_vec); + __m128i block_len_vec = set1_128(block_len); + __m128i block_flags_vec = set1_128(flags); + __m128i v[16] = { + h_vecs[0], h_vecs[1], h_vecs[2], h_vecs[3], + h_vecs[4], h_vecs[5], h_vecs[6], h_vecs[7], + set1_128(IV[0]), set1_128(IV[1]), set1_128(IV[2]), set1_128(IV[3]), + counter_low_vec, counter_high_vec, block_len_vec, block_flags_vec, + }; + round_fn4(v, msg_vecs, 0); + round_fn4(v, msg_vecs, 1); + round_fn4(v, msg_vecs, 2); + round_fn4(v, msg_vecs, 3); + round_fn4(v, msg_vecs, 4); + round_fn4(v, msg_vecs, 5); + round_fn4(v, msg_vecs, 6); + for (size_t i = 0; i < 8; i++) { + v[i] = xor_128(v[i], v[i+8]); + v[i+8] = xor_128(v[i+8], h_vecs[i]); + } + transpose_vecs_128(&v[0]); + transpose_vecs_128(&v[4]); + transpose_vecs_128(&v[8]); + transpose_vecs_128(&v[12]); + for (size_t i = 0; i < 4; i++) { + storeu_128(v[i+ 0], &out[(4*i+0) * sizeof(__m128i)]); + storeu_128(v[i+ 4], &out[(4*i+1) * sizeof(__m128i)]); + storeu_128(v[i+ 8], &out[(4*i+2) * sizeof(__m128i)]); + storeu_128(v[i+12], &out[(4*i+3) * sizeof(__m128i)]); + } +} + /* * ---------------------------------------------------------------------------- * hash8_avx512 @@ -802,6 +854,50 @@ void blake3_hash8_avx512(const uint8_t *const *inputs, size_t blocks, storeu_256(h_vecs[7], &out[7 * sizeof(__m256i)]); } +static +void blake3_xof8_avx512(const uint32_t cv[8], + const uint8_t block[BLAKE3_BLOCK_LEN], + uint8_t block_len, uint64_t counter, uint8_t flags, + uint8_t out[8 * 64]) { + __m256i h_vecs[8] = { + set1_256(cv[0]), set1_256(cv[1]), set1_256(cv[2]), set1_256(cv[3]), + set1_256(cv[4]), set1_256(cv[5]), set1_256(cv[6]), set1_256(cv[7]), + }; + uint32_t block_words[16]; + load_block_words(block, block_words); + __m256i msg_vecs[16]; + for (size_t i = 0; i < 16; i++) { + msg_vecs[i] = set1_256(block_words[i]); + } + __m256i counter_low_vec, counter_high_vec; + load_counters8(counter, true, &counter_low_vec, &counter_high_vec); + __m256i block_len_vec = set1_256(block_len); + __m256i block_flags_vec = set1_256(flags); + __m256i v[16] = { + h_vecs[0], h_vecs[1], h_vecs[2], h_vecs[3], + h_vecs[4], h_vecs[5], h_vecs[6], h_vecs[7], + set1_256(IV[0]), set1_256(IV[1]), set1_256(IV[2]), set1_256(IV[3]), + counter_low_vec, counter_high_vec, block_len_vec, block_flags_vec, + }; + round_fn8(v, msg_vecs, 0); + round_fn8(v, msg_vecs, 1); + round_fn8(v, msg_vecs, 2); + round_fn8(v, msg_vecs, 3); + round_fn8(v, msg_vecs, 4); + round_fn8(v, msg_vecs, 5); + round_fn8(v, msg_vecs, 6); + for (size_t i = 0; i < 8; i++) { + v[i] = xor_256(v[i], v[i+8]); + v[i+8] = xor_256(v[i+8], h_vecs[i]); + } + transpose_vecs_256(&v[0]); + transpose_vecs_256(&v[8]); + for (size_t i = 0; i < 8; i++) { + storeu_256(v[i+0], &out[(2*i+0) * sizeof(__m256i)]); + storeu_256(v[i+8], &out[(2*i+1) * sizeof(__m256i)]); + } +} + /* * ---------------------------------------------------------------------------- * hash16_avx512 @@ -1146,6 +1242,48 @@ void blake3_hash16_avx512(const uint8_t *const *inputs, size_t blocks, _mm256_mask_storeu_epi32(&out[15 * sizeof(__m256i)], (__mmask8)-1, _mm512_castsi512_si256(padded[15])); } +static +void blake3_xof16_avx512(const uint32_t cv[8], + const uint8_t block[BLAKE3_BLOCK_LEN], + uint8_t block_len, uint64_t counter, uint8_t flags, + uint8_t out[16 * 64]) { + __m512i h_vecs[8] = { + set1_512(cv[0]), set1_512(cv[1]), set1_512(cv[2]), set1_512(cv[3]), + set1_512(cv[4]), set1_512(cv[5]), set1_512(cv[6]), set1_512(cv[7]), + }; + uint32_t block_words[16]; + load_block_words(block, block_words); + __m512i msg_vecs[16]; + for (size_t i = 0; i < 16; i++) { + msg_vecs[i] = set1_512(block_words[i]); + } + __m512i counter_low_vec, counter_high_vec; + load_counters16(counter, true, &counter_low_vec, &counter_high_vec); + __m512i block_len_vec = set1_512(block_len); + __m512i block_flags_vec = set1_512(flags); + __m512i v[16] = { + h_vecs[0], h_vecs[1], h_vecs[2], h_vecs[3], + h_vecs[4], h_vecs[5], h_vecs[6], h_vecs[7], + set1_512(IV[0]), set1_512(IV[1]), set1_512(IV[2]), set1_512(IV[3]), + counter_low_vec, counter_high_vec, block_len_vec, block_flags_vec, + }; + round_fn16(v, msg_vecs, 0); + round_fn16(v, msg_vecs, 1); + round_fn16(v, msg_vecs, 2); + round_fn16(v, msg_vecs, 3); + round_fn16(v, msg_vecs, 4); + round_fn16(v, msg_vecs, 5); + round_fn16(v, msg_vecs, 6); + for (size_t i = 0; i < 8; i++) { + v[i] = xor_512(v[i], v[i+8]); + v[i+8] = xor_512(v[i+8], h_vecs[i]); + } + transpose_vecs_512(&v[0]); + for (size_t i = 0; i < 16; i++) { + storeu_512(v[i], &out[i * sizeof(__m512i)]); + } +} + /* * ---------------------------------------------------------------------------- * hash_many_avx512 @@ -1218,3 +1356,33 @@ void blake3_hash_many_avx512(const uint8_t *const *inputs, size_t num_inputs, out = &out[BLAKE3_OUT_LEN]; } } + +void blake3_xof_many_avx512(const uint32_t cv[8], + const uint8_t block[BLAKE3_BLOCK_LEN], + uint8_t block_len, uint64_t counter, uint8_t flags, + uint8_t* out, size_t outblocks) { + while (outblocks >= 16) { + blake3_xof16_avx512(cv, block, block_len, counter, flags, out); + counter += 16; + outblocks -= 16; + out += 16 * BLAKE3_BLOCK_LEN; + } + while (outblocks >= 8) { + blake3_xof8_avx512(cv, block, block_len, counter, flags, out); + counter += 8; + outblocks -= 8; + out += 8 * BLAKE3_BLOCK_LEN; + } + while (outblocks >= 4) { + blake3_xof4_avx512(cv, block, block_len, counter, flags, out); + counter += 4; + outblocks -= 4; + out += 4 * BLAKE3_BLOCK_LEN; + } + while (outblocks > 0) { + blake3_compress_xof_avx512(cv, block, block_len, counter, flags, out); + counter += 1; + outblocks -= 1; + out += BLAKE3_BLOCK_LEN; + } +} |
