diff options
| author | Jack O'Connor <[email protected]> | 2022-03-09 00:17:59 -0500 |
|---|---|---|
| committer | Jack O'Connor <[email protected]> | 2022-03-09 00:29:37 -0500 |
| commit | 4c929ddac1ac3d39a1285a1527fd916d7934d7ad (patch) | |
| tree | ea8b714a9e8f119c7e0328659c6dd334bd206241 | |
| parent | 5d4655920151a941b997bc0c59a86d493f7e3548 (diff) | |
blake3_avx512_xof_stream_16
| -rw-r--r-- | benches/bench.rs | 19 | ||||
| -rw-r--r-- | src/kernel.rs | 245 |
2 files changed, 262 insertions, 2 deletions
diff --git a/benches/bench.rs b/benches/bench.rs index 4aa62aa..d82f3ce 100644 --- a/benches/bench.rs +++ b/benches/bench.rs @@ -572,3 +572,22 @@ fn bench_two_updates(b: &mut Bencher) { hasher.finalize() }); } + +#[bench] +fn bench_xof_kernel(b: &mut Bencher) { + let mut output = [0; 16 * 64]; + b.bytes = output.len() as u64; + let message_words = [0; 16]; + let key_words = [0; 8]; + let flags = 1 | 2 | 16; // CHUNK_START | CHUNK_END | KEYED_HASH + b.iter(|| unsafe { + blake3::kernel::xof_stream16(&message_words, &key_words, 0, flags, &mut output); + }); + // Double check that this output is reasonable. + let mut expected = [0; 16 * 64]; + blake3::Hasher::new_keyed(&[0; 32]) + .update(&[0; 64]) + .finalize_xof() + .fill(&mut expected); + assert_eq!(expected, output); +} diff --git a/src/kernel.rs b/src/kernel.rs index 3f5498d..eeab5e7 100644 --- a/src/kernel.rs +++ b/src/kernel.rs @@ -970,7 +970,7 @@ global_asm!( // // zmm0-zmm31: [clobbered] // rdi: pointer to 16 contiguous chunks of 1024 bytes each, unaligned - // rsi: pointer to the 32-byte key, unaligned + // rsi: pointer to the 8-word key, 4-byte aligned // rdx: pointer to two 64-byte aligned vectors, counter-low followed by counter-high // ecx: [clobbered] // r8d: flags (other than CHUNK_START and CHUNK_END) @@ -1051,7 +1051,7 @@ global_asm!( // zmm0-zmm31: [clobbered] // rdi: pointer to the left child CVs, 8 transposed state vectors, 64-byte aligned // rsi: pointer to the right child CVs, 8 transposed state vectors, 64-byte aligned - // rdx: pointer to the 32-byte key, unaligned + // rdx: pointer to the 8-word key, 4-byte aligned // ecx: [clobbered] // r8d: flags (other than PARENT) // r9: out pointer to 8x64 bytes, 64-byte aligned @@ -1160,6 +1160,193 @@ global_asm!( "vmovdqa32 zmmword ptr [r9 + 7 * 64], zmm7", "vzeroupper", "ret", + // + // -------------------------------------------------------------------------------------------- + // blake3_avx512_xof_stream_16 + // + // zmm0-zmm31: [clobbered] + // rdi: pointer to the 16-word message block, 4-byte aligned + // rsi: pointer to the 8-word input CV, 4-byte aligned + // rdx: pointer to two 64-byte aligned vectors, counter-low followed by counter-high + // ecx: [clobbered] + // r8d: flags (other than ROOT) + // r9: out pointer to 16x64=1024 bytes, unaligned + // + // This routine performs the root compression for 16 consecutive output blocks and writes 1024 + // bytes of output to the out pointer. + // -------------------------------------------------------------------------------------------- + "blake3_avx512_xof_stream_16:", + // Broadcast the input CV into zmm0-zmm7, the first two rows of the state. + "vpbroadcastd zmm0, dword ptr [rsi + 0 * 4]", + "vpbroadcastd zmm1, dword ptr [rsi + 1 * 4]", + "vpbroadcastd zmm2, dword ptr [rsi + 2 * 4]", + "vpbroadcastd zmm3, dword ptr [rsi + 3 * 4]", + "vpbroadcastd zmm4, dword ptr [rsi + 4 * 4]", + "vpbroadcastd zmm5, dword ptr [rsi + 5 * 4]", + "vpbroadcastd zmm6, dword ptr [rsi + 6 * 4]", + "vpbroadcastd zmm7, dword ptr [rsi + 7 * 4]", + // Initialize zmm8-zmm15, the third and fourth rows of the state. + "vmovdqa32 zmm8, zmmword ptr [BLAKE3_IV0_16 + rip]", // IV constants + "vmovdqa32 zmm9, zmmword ptr [BLAKE3_IV1_16 + rip]", + "vmovdqa32 zmm10, zmmword ptr [BLAKE3_IV2_16 + rip]", + "vmovdqa32 zmm11, zmmword ptr [BLAKE3_IV3_16 + rip]", + "vmovdqa32 zmm12, zmmword ptr [rdx + 64 * 0]", // counter low + "vmovdqa32 zmm13, zmmword ptr [rdx + 64 * 1]", // counter high + "mov ecx, 64", + "vpbroadcastd zmm14, ecx", // block length (always 64) + "or r8d, 8", // set the ROOT flag + "vpbroadcastd zmm15, r8d", // flags + // Broadcast the message block into zmm16-zmm31 + "vpbroadcastd zmm16, dword ptr [rdi + 0 * 4]", + "vpbroadcastd zmm17, dword ptr [rdi + 1 * 4]", + "vpbroadcastd zmm18, dword ptr [rdi + 2 * 4]", + "vpbroadcastd zmm19, dword ptr [rdi + 3 * 4]", + "vpbroadcastd zmm20, dword ptr [rdi + 4 * 4]", + "vpbroadcastd zmm21, dword ptr [rdi + 5 * 4]", + "vpbroadcastd zmm22, dword ptr [rdi + 6 * 4]", + "vpbroadcastd zmm23, dword ptr [rdi + 7 * 4]", + "vpbroadcastd zmm24, dword ptr [rdi + 8 * 4]", + "vpbroadcastd zmm25, dword ptr [rdi + 9 * 4]", + "vpbroadcastd zmm26, dword ptr [rdi + 10 * 4]", + "vpbroadcastd zmm27, dword ptr [rdi + 11 * 4]", + "vpbroadcastd zmm28, dword ptr [rdi + 12 * 4]", + "vpbroadcastd zmm29, dword ptr [rdi + 13 * 4]", + "vpbroadcastd zmm30, dword ptr [rdi + 14 * 4]", + "vpbroadcastd zmm31, dword ptr [rdi + 15 * 4]", + // Run the kernel. + "call blake3_avx512_kernel_16", + // Re-broadcast the input CV and feed it forward into the second half of the state. + "vpbroadcastd zmm16, dword ptr [rsi + 0 * 4]", + "vpxord zmm8, zmm8, zmm16", + "vpbroadcastd zmm17, dword ptr [rsi + 1 * 4]", + "vpxord zmm9, zmm9, zmm17", + "vpbroadcastd zmm18, dword ptr [rsi + 2 * 4]", + "vpxord zmm10, zmm10, zmm18", + "vpbroadcastd zmm19, dword ptr [rsi + 3 * 4]", + "vpxord zmm11, zmm11, zmm19", + "vpbroadcastd zmm20, dword ptr [rsi + 4 * 4]", + "vpxord zmm12, zmm12, zmm20", + "vpbroadcastd zmm21, dword ptr [rsi + 5 * 4]", + "vpxord zmm13, zmm13, zmm21", + "vpbroadcastd zmm22, dword ptr [rsi + 6 * 4]", + "vpxord zmm14, zmm14, zmm22", + "vpbroadcastd zmm23, dword ptr [rsi + 7 * 4]", + "vpxord zmm15, zmm15, zmm23", + // zmm0-zmm15 now contain the final extended state vectors, transposed. We need to un-transpose + // them before we write them out. As with blake3_avx512_blocks_16, we prefer to avoid expensive + // operations across 128-bit lanes, so we do a couple of interleaving passes and then write out + // 128 bits at a time. + // + // First, interleave 32-bit words. Use zmm16-zmm31 to hold the intermediate results. This + // takes the input vectors like: + // + // a0, b0, c0, d0, e0, f0, g0, h0, i0, j0, k0, l0, m0, n0, o0, p0 + // + // And produces vectors like: + // + // a0, a1, b0, b1, e0, e1, g0, g1, i0, i1, k0, k1, m0, m1, o0, o1 + "vpunpckldq zmm16, zmm0, zmm1", + "vpunpckhdq zmm17, zmm0, zmm1", + "vpunpckldq zmm18, zmm2, zmm3", + "vpunpckhdq zmm19, zmm2, zmm3", + "vpunpckldq zmm20, zmm4, zmm5", + "vpunpckhdq zmm21, zmm4, zmm5", + "vpunpckldq zmm22, zmm6, zmm7", + "vpunpckhdq zmm23, zmm6, zmm7", + "vpunpckldq zmm24, zmm8, zmm9", + "vpunpckhdq zmm25, zmm8, zmm9", + "vpunpckldq zmm26, zmm10, zmm11", + "vpunpckhdq zmm27, zmm10, zmm11", + "vpunpckldq zmm28, zmm12, zmm13", + "vpunpckhdq zmm29, zmm12, zmm13", + "vpunpckldq zmm30, zmm14, zmm15", + "vpunpckhdq zmm31, zmm14, zmm15", + // Then interleave 64-bit words back into zmm0-zmm15, producing vectors like: + // + // a0, a1, a2, a3, e0, e1, e2, e3, i0, i1, i2, i3, m0, m1, m2, m3 + "vpunpcklqdq zmm0, zmm16, zmm18", + "vpunpckhqdq zmm1, zmm16, zmm18", + "vpunpcklqdq zmm2, zmm17, zmm19", + "vpunpckhqdq zmm3, zmm17, zmm19", + "vpunpcklqdq zmm4, zmm20, zmm22", + "vpunpckhqdq zmm5, zmm20, zmm22", + "vpunpcklqdq zmm6, zmm21, zmm23", + "vpunpckhqdq zmm7, zmm21, zmm23", + "vpunpcklqdq zmm8, zmm24, zmm26", + "vpunpckhqdq zmm9, zmm24, zmm26", + "vpunpcklqdq zmm10, zmm25, zmm27", + "vpunpckhqdq zmm11, zmm25, zmm27", + "vpunpcklqdq zmm12, zmm28, zmm30", + "vpunpckhqdq zmm13, zmm28, zmm30", + "vpunpcklqdq zmm14, zmm29, zmm31", + "vpunpckhqdq zmm15, zmm29, zmm31", + // Finally, write out each 128-bit group, unaligned. + "vmovdqu32 xmmword ptr [r9 + 0 * 16], xmm0", + "vmovdqu32 xmmword ptr [r9 + 1 * 16], xmm4", + "vmovdqu32 xmmword ptr [r9 + 2 * 16], xmm8", + "vmovdqu32 xmmword ptr [r9 + 3 * 16], xmm12", + "vmovdqu32 xmmword ptr [r9 + 4 * 16], xmm1", + "vmovdqu32 xmmword ptr [r9 + 5 * 16], xmm5", + "vmovdqu32 xmmword ptr [r9 + 6 * 16], xmm9", + "vmovdqu32 xmmword ptr [r9 + 7 * 16], xmm13", + "vmovdqu32 xmmword ptr [r9 + 8 * 16], xmm2", + "vmovdqu32 xmmword ptr [r9 + 9 * 16], xmm6", + "vmovdqu32 xmmword ptr [r9 + 10 * 16], xmm10", + "vmovdqu32 xmmword ptr [r9 + 11 * 16], xmm14", + "vmovdqu32 xmmword ptr [r9 + 12 * 16], xmm3", + "vmovdqu32 xmmword ptr [r9 + 13 * 16], xmm7", + "vmovdqu32 xmmword ptr [r9 + 14 * 16], xmm11", + "vmovdqu32 xmmword ptr [r9 + 15 * 16], xmm15", + "vextracti32x4 xmmword ptr [r9 + 16 * 16], zmm0, 1", + "vextracti32x4 xmmword ptr [r9 + 17 * 16], zmm4, 1", + "vextracti32x4 xmmword ptr [r9 + 18 * 16], zmm8, 1", + "vextracti32x4 xmmword ptr [r9 + 19 * 16], zmm12, 1", + "vextracti32x4 xmmword ptr [r9 + 20 * 16], zmm1, 1", + "vextracti32x4 xmmword ptr [r9 + 21 * 16], zmm5, 1", + "vextracti32x4 xmmword ptr [r9 + 22 * 16], zmm9, 1", + "vextracti32x4 xmmword ptr [r9 + 23 * 16], zmm13, 1", + "vextracti32x4 xmmword ptr [r9 + 24 * 16], zmm2, 1", + "vextracti32x4 xmmword ptr [r9 + 25 * 16], zmm6, 1", + "vextracti32x4 xmmword ptr [r9 + 26 * 16], zmm10, 1", + "vextracti32x4 xmmword ptr [r9 + 27 * 16], zmm14, 1", + "vextracti32x4 xmmword ptr [r9 + 28 * 16], zmm3, 1", + "vextracti32x4 xmmword ptr [r9 + 29 * 16], zmm7, 1", + "vextracti32x4 xmmword ptr [r9 + 30 * 16], zmm11, 1", + "vextracti32x4 xmmword ptr [r9 + 31 * 16], zmm15, 1", + "vextracti32x4 xmmword ptr [r9 + 32 * 16], zmm0, 2", + "vextracti32x4 xmmword ptr [r9 + 33 * 16], zmm4, 2", + "vextracti32x4 xmmword ptr [r9 + 34 * 16], zmm8, 2", + "vextracti32x4 xmmword ptr [r9 + 35 * 16], zmm12, 2", + "vextracti32x4 xmmword ptr [r9 + 36 * 16], zmm1, 2", + "vextracti32x4 xmmword ptr [r9 + 37 * 16], zmm5, 2", + "vextracti32x4 xmmword ptr [r9 + 38 * 16], zmm9, 2", + "vextracti32x4 xmmword ptr [r9 + 39 * 16], zmm13, 2", + "vextracti32x4 xmmword ptr [r9 + 40 * 16], zmm2, 2", + "vextracti32x4 xmmword ptr [r9 + 41 * 16], zmm6, 2", + "vextracti32x4 xmmword ptr [r9 + 42 * 16], zmm10, 2", + "vextracti32x4 xmmword ptr [r9 + 43 * 16], zmm14, 2", + "vextracti32x4 xmmword ptr [r9 + 44 * 16], zmm3, 2", + "vextracti32x4 xmmword ptr [r9 + 45 * 16], zmm7, 2", + "vextracti32x4 xmmword ptr [r9 + 46 * 16], zmm11, 2", + "vextracti32x4 xmmword ptr [r9 + 47 * 16], zmm15, 2", + "vextracti32x4 xmmword ptr [r9 + 48 * 16], zmm0, 3", + "vextracti32x4 xmmword ptr [r9 + 49 * 16], zmm4, 3", + "vextracti32x4 xmmword ptr [r9 + 50 * 16], zmm8, 3", + "vextracti32x4 xmmword ptr [r9 + 51 * 16], zmm12, 3", + "vextracti32x4 xmmword ptr [r9 + 52 * 16], zmm1, 3", + "vextracti32x4 xmmword ptr [r9 + 53 * 16], zmm5, 3", + "vextracti32x4 xmmword ptr [r9 + 54 * 16], zmm9, 3", + "vextracti32x4 xmmword ptr [r9 + 55 * 16], zmm13, 3", + "vextracti32x4 xmmword ptr [r9 + 56 * 16], zmm2, 3", + "vextracti32x4 xmmword ptr [r9 + 57 * 16], zmm6, 3", + "vextracti32x4 xmmword ptr [r9 + 58 * 16], zmm10, 3", + "vextracti32x4 xmmword ptr [r9 + 59 * 16], zmm14, 3", + "vextracti32x4 xmmword ptr [r9 + 60 * 16], zmm3, 3", + "vextracti32x4 xmmword ptr [r9 + 61 * 16], zmm7, 3", + "vextracti32x4 xmmword ptr [r9 + 62 * 16], zmm11, 3", + "vextracti32x4 xmmword ptr [r9 + 63 * 16], zmm15, 3", + "vzeroupper", + "ret", ); #[repr(C, align(64))] @@ -1233,6 +1420,38 @@ pub unsafe fn parents16( ); } +pub unsafe fn xof_stream16( + message_words: &[u32; 16], + cv_words: &[u32; 8], + counter: u64, + flags: u32, + out_ptr: *mut [u8; 16 * 64], +) { + // Prepare the counter vectors, the low words and high words. + let mut counter_vectors = [Words16([0; 16]); 2]; + for i in 0..16 { + counter_vectors[0].0[i] = (counter + i as u64) as u32; + counter_vectors[1].0[i] = ((counter + i as u64) >> 32) as u32; + } + asm!( + "call blake3_avx512_xof_stream_16", + inout("rdi") message_words => _, + inout("rsi") cv_words => _, + inout("rdx") &counter_vectors => _, + out("ecx") _, + inout("r8d") flags => _, + inout("r9") out_ptr => _, + out("zmm0") _, out("zmm1") _, out("zmm2") _, out("zmm3") _, + out("zmm4") _, out("zmm5") _, out("zmm6") _, out("zmm7") _, + out("zmm8") _, out("zmm9") _, out("zmm10") _, out("zmm11") _, + out("zmm12") _, out("zmm13") _, out("zmm14") _, out("zmm15") _, + out("zmm16") _, out("zmm17") _, out("zmm18") _, out("zmm19") _, + out("zmm20") _, out("zmm21") _, out("zmm22") _, out("zmm23") _, + out("zmm24") _, out("zmm25") _, out("zmm26") _, out("zmm27") _, + out("zmm28") _, out("zmm29") _, out("zmm30") _, out("zmm31") _, + ); +} + #[test] fn test_chunks16() { let mut message = [0u8; 16 * CHUNK_LEN]; @@ -1341,3 +1560,25 @@ fn test_parents16() { } assert_eq!(expected_out, found_out_transposed); } + +#[test] +fn test_xof_stream16() { + let mut block = [0; 64]; + let mut key = [0; 32]; + crate::test::paint_test_input(&mut block); + crate::test::paint_test_input(&mut key); + let mut expected = [0; 1024]; + crate::Hasher::new_keyed(&key) + .update(&block) + .finalize_xof() + .fill(&mut expected); + + let block_words = crate::platform::words_from_le_bytes_64(&block); + let key_words = crate::platform::words_from_le_bytes_32(&key); + let flags = crate::KEYED_HASH | crate::CHUNK_START | crate::CHUNK_END; + let mut found = [0; 1024]; + unsafe { + xof_stream16(&block_words, &key_words, 0, flags as u32, &mut found); + } + assert_eq!(expected, found); +} |
