diff options
| author | Jack O'Connor <[email protected]> | 2023-05-28 14:44:56 +0200 |
|---|---|---|
| committer | Jack O'Connor <[email protected]> | 2024-08-15 16:02:10 -0700 |
| commit | 80b83effbd50425939483e5503b186db4dac4d9d (patch) | |
| tree | 36285a604224a643cc07f7fed156547d20de80f0 | |
| parent | a3ca51ff9eb659de8663315fc973556a5d96a586 (diff) | |
add an intrinsics implementation of blake3_xof_many_avx512
| -rw-r--r-- | c/blake3.c | 3 | ||||
| -rw-r--r-- | c/blake3_avx512.c | 178 | ||||
| -rw-r--r-- | c/blake3_impl.h | 9 |
3 files changed, 183 insertions, 7 deletions
@@ -100,8 +100,9 @@ INLINE void output_root_bytes(const output_t *self, uint64_t seek, uint8_t *out, out_len -= bytes; output_block_counter += 1; } - if(out_len / 64) + if(out_len / 64) { blake3_xof_many(self->input_cv, self->block, self->block_len, output_block_counter, self->flags | ROOT, out, out_len / 64); + } output_block_counter += out_len / 64; out += out_len & -64; out_len -= out_len & -64; diff --git a/c/blake3_avx512.c b/c/blake3_avx512.c index d6b1ae9..f88a32d 100644 --- a/c/blake3_avx512.c +++ b/c/blake3_avx512.c @@ -7,23 +7,27 @@ _mm_shuffle_ps(_mm_castsi128_ps(a), _mm_castsi128_ps(b), (c)))) INLINE __m128i loadu_128(const uint8_t src[16]) { - return _mm_loadu_si128((const __m128i *)src); + return _mm_loadu_si128((void*)src); } INLINE __m256i loadu_256(const uint8_t src[32]) { - return _mm256_loadu_si256((const __m256i *)src); + return _mm256_loadu_si256((void*)src); } INLINE __m512i loadu_512(const uint8_t src[64]) { - return _mm512_loadu_si512((const __m512i *)src); + return _mm512_loadu_si512((void*)src); } INLINE void storeu_128(__m128i src, uint8_t dest[16]) { - _mm_storeu_si128((__m128i *)dest, src); + _mm_storeu_si128((void*)dest, src); } INLINE void storeu_256(__m256i src, uint8_t dest[16]) { - _mm256_storeu_si256((__m256i *)dest, src); + _mm256_storeu_si256((void*)dest, src); +} + +INLINE void storeu_512(__m512i src, uint8_t dest[16]) { + _mm512_storeu_si512((void*)dest, src); } INLINE __m128i add_128(__m128i a, __m128i b) { return _mm_add_epi32(a, b); } @@ -550,6 +554,54 @@ void blake3_hash4_avx512(const uint8_t *const *inputs, size_t blocks, storeu_128(h_vecs[7], &out[7 * sizeof(__m128i)]); } +static +void blake3_xof4_avx512(const uint32_t cv[8], + const uint8_t block[BLAKE3_BLOCK_LEN], + uint8_t block_len, uint64_t counter, uint8_t flags, + uint8_t out[4 * 64]) { + __m128i h_vecs[8] = { + set1_128(cv[0]), set1_128(cv[1]), set1_128(cv[2]), set1_128(cv[3]), + set1_128(cv[4]), set1_128(cv[5]), set1_128(cv[6]), set1_128(cv[7]), + }; + uint32_t block_words[16]; + load_block_words(block, block_words); + __m128i msg_vecs[16]; + for (size_t i = 0; i < 16; i++) { + msg_vecs[i] = set1_128(block_words[i]); + } + __m128i counter_low_vec, counter_high_vec; + load_counters4(counter, true, &counter_low_vec, &counter_high_vec); + __m128i block_len_vec = set1_128(block_len); + __m128i block_flags_vec = set1_128(flags); + __m128i v[16] = { + h_vecs[0], h_vecs[1], h_vecs[2], h_vecs[3], + h_vecs[4], h_vecs[5], h_vecs[6], h_vecs[7], + set1_128(IV[0]), set1_128(IV[1]), set1_128(IV[2]), set1_128(IV[3]), + counter_low_vec, counter_high_vec, block_len_vec, block_flags_vec, + }; + round_fn4(v, msg_vecs, 0); + round_fn4(v, msg_vecs, 1); + round_fn4(v, msg_vecs, 2); + round_fn4(v, msg_vecs, 3); + round_fn4(v, msg_vecs, 4); + round_fn4(v, msg_vecs, 5); + round_fn4(v, msg_vecs, 6); + for (size_t i = 0; i < 8; i++) { + v[i] = xor_128(v[i], v[i+8]); + v[i+8] = xor_128(v[i+8], h_vecs[i]); + } + transpose_vecs_128(&v[0]); + transpose_vecs_128(&v[4]); + transpose_vecs_128(&v[8]); + transpose_vecs_128(&v[12]); + for (size_t i = 0; i < 4; i++) { + storeu_128(v[i+ 0], &out[(4*i+0) * sizeof(__m128i)]); + storeu_128(v[i+ 4], &out[(4*i+1) * sizeof(__m128i)]); + storeu_128(v[i+ 8], &out[(4*i+2) * sizeof(__m128i)]); + storeu_128(v[i+12], &out[(4*i+3) * sizeof(__m128i)]); + } +} + /* * ---------------------------------------------------------------------------- * hash8_avx512 @@ -802,6 +854,50 @@ void blake3_hash8_avx512(const uint8_t *const *inputs, size_t blocks, storeu_256(h_vecs[7], &out[7 * sizeof(__m256i)]); } +static +void blake3_xof8_avx512(const uint32_t cv[8], + const uint8_t block[BLAKE3_BLOCK_LEN], + uint8_t block_len, uint64_t counter, uint8_t flags, + uint8_t out[8 * 64]) { + __m256i h_vecs[8] = { + set1_256(cv[0]), set1_256(cv[1]), set1_256(cv[2]), set1_256(cv[3]), + set1_256(cv[4]), set1_256(cv[5]), set1_256(cv[6]), set1_256(cv[7]), + }; + uint32_t block_words[16]; + load_block_words(block, block_words); + __m256i msg_vecs[16]; + for (size_t i = 0; i < 16; i++) { + msg_vecs[i] = set1_256(block_words[i]); + } + __m256i counter_low_vec, counter_high_vec; + load_counters8(counter, true, &counter_low_vec, &counter_high_vec); + __m256i block_len_vec = set1_256(block_len); + __m256i block_flags_vec = set1_256(flags); + __m256i v[16] = { + h_vecs[0], h_vecs[1], h_vecs[2], h_vecs[3], + h_vecs[4], h_vecs[5], h_vecs[6], h_vecs[7], + set1_256(IV[0]), set1_256(IV[1]), set1_256(IV[2]), set1_256(IV[3]), + counter_low_vec, counter_high_vec, block_len_vec, block_flags_vec, + }; + round_fn8(v, msg_vecs, 0); + round_fn8(v, msg_vecs, 1); + round_fn8(v, msg_vecs, 2); + round_fn8(v, msg_vecs, 3); + round_fn8(v, msg_vecs, 4); + round_fn8(v, msg_vecs, 5); + round_fn8(v, msg_vecs, 6); + for (size_t i = 0; i < 8; i++) { + v[i] = xor_256(v[i], v[i+8]); + v[i+8] = xor_256(v[i+8], h_vecs[i]); + } + transpose_vecs_256(&v[0]); + transpose_vecs_256(&v[8]); + for (size_t i = 0; i < 8; i++) { + storeu_256(v[i+0], &out[(2*i+0) * sizeof(__m256i)]); + storeu_256(v[i+8], &out[(2*i+1) * sizeof(__m256i)]); + } +} + /* * ---------------------------------------------------------------------------- * hash16_avx512 @@ -1146,6 +1242,48 @@ void blake3_hash16_avx512(const uint8_t *const *inputs, size_t blocks, _mm256_mask_storeu_epi32(&out[15 * sizeof(__m256i)], (__mmask8)-1, _mm512_castsi512_si256(padded[15])); } +static +void blake3_xof16_avx512(const uint32_t cv[8], + const uint8_t block[BLAKE3_BLOCK_LEN], + uint8_t block_len, uint64_t counter, uint8_t flags, + uint8_t out[16 * 64]) { + __m512i h_vecs[8] = { + set1_512(cv[0]), set1_512(cv[1]), set1_512(cv[2]), set1_512(cv[3]), + set1_512(cv[4]), set1_512(cv[5]), set1_512(cv[6]), set1_512(cv[7]), + }; + uint32_t block_words[16]; + load_block_words(block, block_words); + __m512i msg_vecs[16]; + for (size_t i = 0; i < 16; i++) { + msg_vecs[i] = set1_512(block_words[i]); + } + __m512i counter_low_vec, counter_high_vec; + load_counters16(counter, true, &counter_low_vec, &counter_high_vec); + __m512i block_len_vec = set1_512(block_len); + __m512i block_flags_vec = set1_512(flags); + __m512i v[16] = { + h_vecs[0], h_vecs[1], h_vecs[2], h_vecs[3], + h_vecs[4], h_vecs[5], h_vecs[6], h_vecs[7], + set1_512(IV[0]), set1_512(IV[1]), set1_512(IV[2]), set1_512(IV[3]), + counter_low_vec, counter_high_vec, block_len_vec, block_flags_vec, + }; + round_fn16(v, msg_vecs, 0); + round_fn16(v, msg_vecs, 1); + round_fn16(v, msg_vecs, 2); + round_fn16(v, msg_vecs, 3); + round_fn16(v, msg_vecs, 4); + round_fn16(v, msg_vecs, 5); + round_fn16(v, msg_vecs, 6); + for (size_t i = 0; i < 8; i++) { + v[i] = xor_512(v[i], v[i+8]); + v[i+8] = xor_512(v[i+8], h_vecs[i]); + } + transpose_vecs_512(&v[0]); + for (size_t i = 0; i < 16; i++) { + storeu_512(v[i], &out[i * sizeof(__m512i)]); + } +} + /* * ---------------------------------------------------------------------------- * hash_many_avx512 @@ -1218,3 +1356,33 @@ void blake3_hash_many_avx512(const uint8_t *const *inputs, size_t num_inputs, out = &out[BLAKE3_OUT_LEN]; } } + +void blake3_xof_many_avx512(const uint32_t cv[8], + const uint8_t block[BLAKE3_BLOCK_LEN], + uint8_t block_len, uint64_t counter, uint8_t flags, + uint8_t* out, size_t outblocks) { + while (outblocks >= 16) { + blake3_xof16_avx512(cv, block, block_len, counter, flags, out); + counter += 16; + outblocks -= 16; + out += 16 * BLAKE3_BLOCK_LEN; + } + while (outblocks >= 8) { + blake3_xof8_avx512(cv, block, block_len, counter, flags, out); + counter += 8; + outblocks -= 8; + out += 8 * BLAKE3_BLOCK_LEN; + } + while (outblocks >= 4) { + blake3_xof4_avx512(cv, block, block_len, counter, flags, out); + counter += 4; + outblocks -= 4; + out += 4 * BLAKE3_BLOCK_LEN; + } + while (outblocks > 0) { + blake3_compress_xof_avx512(cv, block, block_len, counter, flags, out); + counter += 1; + outblocks -= 1; + out += BLAKE3_BLOCK_LEN; + } +} diff --git a/c/blake3_impl.h b/c/blake3_impl.h index e652f86..b3abce2 100644 --- a/c/blake3_impl.h +++ b/c/blake3_impl.h @@ -162,6 +162,13 @@ INLINE void load_key_words(const uint8_t key[BLAKE3_KEY_LEN], key_words[7] = load32(&key[7 * 4]); } +INLINE void load_block_words(const uint8_t block[BLAKE3_BLOCK_LEN], + uint32_t block_words[16]) { + for (size_t i = 0; i < 16; i++) { + block_words[i] = load32(&block[i * 4]); + } +} + INLINE void store32(void *dst, uint32_t w) { uint8_t *p = (uint8_t *)dst; p[0] = (uint8_t)(w >> 0); @@ -279,7 +286,7 @@ void blake3_compress_xof_avx512(const uint32_t cv[8], void blake3_xof_many_avx512(const uint32_t cv[8], const uint8_t block[BLAKE3_BLOCK_LEN], uint8_t block_len, uint64_t counter, uint8_t flags, - uint8_t out[64], size_t outblocks); + uint8_t* out, size_t outblocks); void blake3_hash_many_avx512(const uint8_t *const *inputs, size_t num_inputs, size_t blocks, const uint32_t key[8], |
