aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJack O'Connor <[email protected]>2022-03-09 00:17:59 -0500
committerJack O'Connor <[email protected]>2022-03-09 00:29:37 -0500
commit4c929ddac1ac3d39a1285a1527fd916d7934d7ad (patch)
treeea8b714a9e8f119c7e0328659c6dd334bd206241
parent5d4655920151a941b997bc0c59a86d493f7e3548 (diff)
blake3_avx512_xof_stream_16
-rw-r--r--benches/bench.rs19
-rw-r--r--src/kernel.rs245
2 files changed, 262 insertions, 2 deletions
diff --git a/benches/bench.rs b/benches/bench.rs
index 4aa62aa..d82f3ce 100644
--- a/benches/bench.rs
+++ b/benches/bench.rs
@@ -572,3 +572,22 @@ fn bench_two_updates(b: &mut Bencher) {
hasher.finalize()
});
}
+
+#[bench]
+fn bench_xof_kernel(b: &mut Bencher) {
+ let mut output = [0; 16 * 64];
+ b.bytes = output.len() as u64;
+ let message_words = [0; 16];
+ let key_words = [0; 8];
+ let flags = 1 | 2 | 16; // CHUNK_START | CHUNK_END | KEYED_HASH
+ b.iter(|| unsafe {
+ blake3::kernel::xof_stream16(&message_words, &key_words, 0, flags, &mut output);
+ });
+ // Double check that this output is reasonable.
+ let mut expected = [0; 16 * 64];
+ blake3::Hasher::new_keyed(&[0; 32])
+ .update(&[0; 64])
+ .finalize_xof()
+ .fill(&mut expected);
+ assert_eq!(expected, output);
+}
diff --git a/src/kernel.rs b/src/kernel.rs
index 3f5498d..eeab5e7 100644
--- a/src/kernel.rs
+++ b/src/kernel.rs
@@ -970,7 +970,7 @@ global_asm!(
//
// zmm0-zmm31: [clobbered]
// rdi: pointer to 16 contiguous chunks of 1024 bytes each, unaligned
- // rsi: pointer to the 32-byte key, unaligned
+ // rsi: pointer to the 8-word key, 4-byte aligned
// rdx: pointer to two 64-byte aligned vectors, counter-low followed by counter-high
// ecx: [clobbered]
// r8d: flags (other than CHUNK_START and CHUNK_END)
@@ -1051,7 +1051,7 @@ global_asm!(
// zmm0-zmm31: [clobbered]
// rdi: pointer to the left child CVs, 8 transposed state vectors, 64-byte aligned
// rsi: pointer to the right child CVs, 8 transposed state vectors, 64-byte aligned
- // rdx: pointer to the 32-byte key, unaligned
+ // rdx: pointer to the 8-word key, 4-byte aligned
// ecx: [clobbered]
// r8d: flags (other than PARENT)
// r9: out pointer to 8x64 bytes, 64-byte aligned
@@ -1160,6 +1160,193 @@ global_asm!(
"vmovdqa32 zmmword ptr [r9 + 7 * 64], zmm7",
"vzeroupper",
"ret",
+ //
+ // --------------------------------------------------------------------------------------------
+ // blake3_avx512_xof_stream_16
+ //
+ // zmm0-zmm31: [clobbered]
+ // rdi: pointer to the 16-word message block, 4-byte aligned
+ // rsi: pointer to the 8-word input CV, 4-byte aligned
+ // rdx: pointer to two 64-byte aligned vectors, counter-low followed by counter-high
+ // ecx: [clobbered]
+ // r8d: flags (other than ROOT)
+ // r9: out pointer to 16x64=1024 bytes, unaligned
+ //
+ // This routine performs the root compression for 16 consecutive output blocks and writes 1024
+ // bytes of output to the out pointer.
+ // --------------------------------------------------------------------------------------------
+ "blake3_avx512_xof_stream_16:",
+ // Broadcast the input CV into zmm0-zmm7, the first two rows of the state.
+ "vpbroadcastd zmm0, dword ptr [rsi + 0 * 4]",
+ "vpbroadcastd zmm1, dword ptr [rsi + 1 * 4]",
+ "vpbroadcastd zmm2, dword ptr [rsi + 2 * 4]",
+ "vpbroadcastd zmm3, dword ptr [rsi + 3 * 4]",
+ "vpbroadcastd zmm4, dword ptr [rsi + 4 * 4]",
+ "vpbroadcastd zmm5, dword ptr [rsi + 5 * 4]",
+ "vpbroadcastd zmm6, dword ptr [rsi + 6 * 4]",
+ "vpbroadcastd zmm7, dword ptr [rsi + 7 * 4]",
+ // Initialize zmm8-zmm15, the third and fourth rows of the state.
+ "vmovdqa32 zmm8, zmmword ptr [BLAKE3_IV0_16 + rip]", // IV constants
+ "vmovdqa32 zmm9, zmmword ptr [BLAKE3_IV1_16 + rip]",
+ "vmovdqa32 zmm10, zmmword ptr [BLAKE3_IV2_16 + rip]",
+ "vmovdqa32 zmm11, zmmword ptr [BLAKE3_IV3_16 + rip]",
+ "vmovdqa32 zmm12, zmmword ptr [rdx + 64 * 0]", // counter low
+ "vmovdqa32 zmm13, zmmword ptr [rdx + 64 * 1]", // counter high
+ "mov ecx, 64",
+ "vpbroadcastd zmm14, ecx", // block length (always 64)
+ "or r8d, 8", // set the ROOT flag
+ "vpbroadcastd zmm15, r8d", // flags
+ // Broadcast the message block into zmm16-zmm31
+ "vpbroadcastd zmm16, dword ptr [rdi + 0 * 4]",
+ "vpbroadcastd zmm17, dword ptr [rdi + 1 * 4]",
+ "vpbroadcastd zmm18, dword ptr [rdi + 2 * 4]",
+ "vpbroadcastd zmm19, dword ptr [rdi + 3 * 4]",
+ "vpbroadcastd zmm20, dword ptr [rdi + 4 * 4]",
+ "vpbroadcastd zmm21, dword ptr [rdi + 5 * 4]",
+ "vpbroadcastd zmm22, dword ptr [rdi + 6 * 4]",
+ "vpbroadcastd zmm23, dword ptr [rdi + 7 * 4]",
+ "vpbroadcastd zmm24, dword ptr [rdi + 8 * 4]",
+ "vpbroadcastd zmm25, dword ptr [rdi + 9 * 4]",
+ "vpbroadcastd zmm26, dword ptr [rdi + 10 * 4]",
+ "vpbroadcastd zmm27, dword ptr [rdi + 11 * 4]",
+ "vpbroadcastd zmm28, dword ptr [rdi + 12 * 4]",
+ "vpbroadcastd zmm29, dword ptr [rdi + 13 * 4]",
+ "vpbroadcastd zmm30, dword ptr [rdi + 14 * 4]",
+ "vpbroadcastd zmm31, dword ptr [rdi + 15 * 4]",
+ // Run the kernel.
+ "call blake3_avx512_kernel_16",
+ // Re-broadcast the input CV and feed it forward into the second half of the state.
+ "vpbroadcastd zmm16, dword ptr [rsi + 0 * 4]",
+ "vpxord zmm8, zmm8, zmm16",
+ "vpbroadcastd zmm17, dword ptr [rsi + 1 * 4]",
+ "vpxord zmm9, zmm9, zmm17",
+ "vpbroadcastd zmm18, dword ptr [rsi + 2 * 4]",
+ "vpxord zmm10, zmm10, zmm18",
+ "vpbroadcastd zmm19, dword ptr [rsi + 3 * 4]",
+ "vpxord zmm11, zmm11, zmm19",
+ "vpbroadcastd zmm20, dword ptr [rsi + 4 * 4]",
+ "vpxord zmm12, zmm12, zmm20",
+ "vpbroadcastd zmm21, dword ptr [rsi + 5 * 4]",
+ "vpxord zmm13, zmm13, zmm21",
+ "vpbroadcastd zmm22, dword ptr [rsi + 6 * 4]",
+ "vpxord zmm14, zmm14, zmm22",
+ "vpbroadcastd zmm23, dword ptr [rsi + 7 * 4]",
+ "vpxord zmm15, zmm15, zmm23",
+ // zmm0-zmm15 now contain the final extended state vectors, transposed. We need to un-transpose
+ // them before we write them out. As with blake3_avx512_blocks_16, we prefer to avoid expensive
+ // operations across 128-bit lanes, so we do a couple of interleaving passes and then write out
+ // 128 bits at a time.
+ //
+ // First, interleave 32-bit words. Use zmm16-zmm31 to hold the intermediate results. This
+ // takes the input vectors like:
+ //
+ // a0, b0, c0, d0, e0, f0, g0, h0, i0, j0, k0, l0, m0, n0, o0, p0
+ //
+ // And produces vectors like:
+ //
+ // a0, a1, b0, b1, e0, e1, g0, g1, i0, i1, k0, k1, m0, m1, o0, o1
+ "vpunpckldq zmm16, zmm0, zmm1",
+ "vpunpckhdq zmm17, zmm0, zmm1",
+ "vpunpckldq zmm18, zmm2, zmm3",
+ "vpunpckhdq zmm19, zmm2, zmm3",
+ "vpunpckldq zmm20, zmm4, zmm5",
+ "vpunpckhdq zmm21, zmm4, zmm5",
+ "vpunpckldq zmm22, zmm6, zmm7",
+ "vpunpckhdq zmm23, zmm6, zmm7",
+ "vpunpckldq zmm24, zmm8, zmm9",
+ "vpunpckhdq zmm25, zmm8, zmm9",
+ "vpunpckldq zmm26, zmm10, zmm11",
+ "vpunpckhdq zmm27, zmm10, zmm11",
+ "vpunpckldq zmm28, zmm12, zmm13",
+ "vpunpckhdq zmm29, zmm12, zmm13",
+ "vpunpckldq zmm30, zmm14, zmm15",
+ "vpunpckhdq zmm31, zmm14, zmm15",
+ // Then interleave 64-bit words back into zmm0-zmm15, producing vectors like:
+ //
+ // a0, a1, a2, a3, e0, e1, e2, e3, i0, i1, i2, i3, m0, m1, m2, m3
+ "vpunpcklqdq zmm0, zmm16, zmm18",
+ "vpunpckhqdq zmm1, zmm16, zmm18",
+ "vpunpcklqdq zmm2, zmm17, zmm19",
+ "vpunpckhqdq zmm3, zmm17, zmm19",
+ "vpunpcklqdq zmm4, zmm20, zmm22",
+ "vpunpckhqdq zmm5, zmm20, zmm22",
+ "vpunpcklqdq zmm6, zmm21, zmm23",
+ "vpunpckhqdq zmm7, zmm21, zmm23",
+ "vpunpcklqdq zmm8, zmm24, zmm26",
+ "vpunpckhqdq zmm9, zmm24, zmm26",
+ "vpunpcklqdq zmm10, zmm25, zmm27",
+ "vpunpckhqdq zmm11, zmm25, zmm27",
+ "vpunpcklqdq zmm12, zmm28, zmm30",
+ "vpunpckhqdq zmm13, zmm28, zmm30",
+ "vpunpcklqdq zmm14, zmm29, zmm31",
+ "vpunpckhqdq zmm15, zmm29, zmm31",
+ // Finally, write out each 128-bit group, unaligned.
+ "vmovdqu32 xmmword ptr [r9 + 0 * 16], xmm0",
+ "vmovdqu32 xmmword ptr [r9 + 1 * 16], xmm4",
+ "vmovdqu32 xmmword ptr [r9 + 2 * 16], xmm8",
+ "vmovdqu32 xmmword ptr [r9 + 3 * 16], xmm12",
+ "vmovdqu32 xmmword ptr [r9 + 4 * 16], xmm1",
+ "vmovdqu32 xmmword ptr [r9 + 5 * 16], xmm5",
+ "vmovdqu32 xmmword ptr [r9 + 6 * 16], xmm9",
+ "vmovdqu32 xmmword ptr [r9 + 7 * 16], xmm13",
+ "vmovdqu32 xmmword ptr [r9 + 8 * 16], xmm2",
+ "vmovdqu32 xmmword ptr [r9 + 9 * 16], xmm6",
+ "vmovdqu32 xmmword ptr [r9 + 10 * 16], xmm10",
+ "vmovdqu32 xmmword ptr [r9 + 11 * 16], xmm14",
+ "vmovdqu32 xmmword ptr [r9 + 12 * 16], xmm3",
+ "vmovdqu32 xmmword ptr [r9 + 13 * 16], xmm7",
+ "vmovdqu32 xmmword ptr [r9 + 14 * 16], xmm11",
+ "vmovdqu32 xmmword ptr [r9 + 15 * 16], xmm15",
+ "vextracti32x4 xmmword ptr [r9 + 16 * 16], zmm0, 1",
+ "vextracti32x4 xmmword ptr [r9 + 17 * 16], zmm4, 1",
+ "vextracti32x4 xmmword ptr [r9 + 18 * 16], zmm8, 1",
+ "vextracti32x4 xmmword ptr [r9 + 19 * 16], zmm12, 1",
+ "vextracti32x4 xmmword ptr [r9 + 20 * 16], zmm1, 1",
+ "vextracti32x4 xmmword ptr [r9 + 21 * 16], zmm5, 1",
+ "vextracti32x4 xmmword ptr [r9 + 22 * 16], zmm9, 1",
+ "vextracti32x4 xmmword ptr [r9 + 23 * 16], zmm13, 1",
+ "vextracti32x4 xmmword ptr [r9 + 24 * 16], zmm2, 1",
+ "vextracti32x4 xmmword ptr [r9 + 25 * 16], zmm6, 1",
+ "vextracti32x4 xmmword ptr [r9 + 26 * 16], zmm10, 1",
+ "vextracti32x4 xmmword ptr [r9 + 27 * 16], zmm14, 1",
+ "vextracti32x4 xmmword ptr [r9 + 28 * 16], zmm3, 1",
+ "vextracti32x4 xmmword ptr [r9 + 29 * 16], zmm7, 1",
+ "vextracti32x4 xmmword ptr [r9 + 30 * 16], zmm11, 1",
+ "vextracti32x4 xmmword ptr [r9 + 31 * 16], zmm15, 1",
+ "vextracti32x4 xmmword ptr [r9 + 32 * 16], zmm0, 2",
+ "vextracti32x4 xmmword ptr [r9 + 33 * 16], zmm4, 2",
+ "vextracti32x4 xmmword ptr [r9 + 34 * 16], zmm8, 2",
+ "vextracti32x4 xmmword ptr [r9 + 35 * 16], zmm12, 2",
+ "vextracti32x4 xmmword ptr [r9 + 36 * 16], zmm1, 2",
+ "vextracti32x4 xmmword ptr [r9 + 37 * 16], zmm5, 2",
+ "vextracti32x4 xmmword ptr [r9 + 38 * 16], zmm9, 2",
+ "vextracti32x4 xmmword ptr [r9 + 39 * 16], zmm13, 2",
+ "vextracti32x4 xmmword ptr [r9 + 40 * 16], zmm2, 2",
+ "vextracti32x4 xmmword ptr [r9 + 41 * 16], zmm6, 2",
+ "vextracti32x4 xmmword ptr [r9 + 42 * 16], zmm10, 2",
+ "vextracti32x4 xmmword ptr [r9 + 43 * 16], zmm14, 2",
+ "vextracti32x4 xmmword ptr [r9 + 44 * 16], zmm3, 2",
+ "vextracti32x4 xmmword ptr [r9 + 45 * 16], zmm7, 2",
+ "vextracti32x4 xmmword ptr [r9 + 46 * 16], zmm11, 2",
+ "vextracti32x4 xmmword ptr [r9 + 47 * 16], zmm15, 2",
+ "vextracti32x4 xmmword ptr [r9 + 48 * 16], zmm0, 3",
+ "vextracti32x4 xmmword ptr [r9 + 49 * 16], zmm4, 3",
+ "vextracti32x4 xmmword ptr [r9 + 50 * 16], zmm8, 3",
+ "vextracti32x4 xmmword ptr [r9 + 51 * 16], zmm12, 3",
+ "vextracti32x4 xmmword ptr [r9 + 52 * 16], zmm1, 3",
+ "vextracti32x4 xmmword ptr [r9 + 53 * 16], zmm5, 3",
+ "vextracti32x4 xmmword ptr [r9 + 54 * 16], zmm9, 3",
+ "vextracti32x4 xmmword ptr [r9 + 55 * 16], zmm13, 3",
+ "vextracti32x4 xmmword ptr [r9 + 56 * 16], zmm2, 3",
+ "vextracti32x4 xmmword ptr [r9 + 57 * 16], zmm6, 3",
+ "vextracti32x4 xmmword ptr [r9 + 58 * 16], zmm10, 3",
+ "vextracti32x4 xmmword ptr [r9 + 59 * 16], zmm14, 3",
+ "vextracti32x4 xmmword ptr [r9 + 60 * 16], zmm3, 3",
+ "vextracti32x4 xmmword ptr [r9 + 61 * 16], zmm7, 3",
+ "vextracti32x4 xmmword ptr [r9 + 62 * 16], zmm11, 3",
+ "vextracti32x4 xmmword ptr [r9 + 63 * 16], zmm15, 3",
+ "vzeroupper",
+ "ret",
);
#[repr(C, align(64))]
@@ -1233,6 +1420,38 @@ pub unsafe fn parents16(
);
}
+pub unsafe fn xof_stream16(
+ message_words: &[u32; 16],
+ cv_words: &[u32; 8],
+ counter: u64,
+ flags: u32,
+ out_ptr: *mut [u8; 16 * 64],
+) {
+ // Prepare the counter vectors, the low words and high words.
+ let mut counter_vectors = [Words16([0; 16]); 2];
+ for i in 0..16 {
+ counter_vectors[0].0[i] = (counter + i as u64) as u32;
+ counter_vectors[1].0[i] = ((counter + i as u64) >> 32) as u32;
+ }
+ asm!(
+ "call blake3_avx512_xof_stream_16",
+ inout("rdi") message_words => _,
+ inout("rsi") cv_words => _,
+ inout("rdx") &counter_vectors => _,
+ out("ecx") _,
+ inout("r8d") flags => _,
+ inout("r9") out_ptr => _,
+ out("zmm0") _, out("zmm1") _, out("zmm2") _, out("zmm3") _,
+ out("zmm4") _, out("zmm5") _, out("zmm6") _, out("zmm7") _,
+ out("zmm8") _, out("zmm9") _, out("zmm10") _, out("zmm11") _,
+ out("zmm12") _, out("zmm13") _, out("zmm14") _, out("zmm15") _,
+ out("zmm16") _, out("zmm17") _, out("zmm18") _, out("zmm19") _,
+ out("zmm20") _, out("zmm21") _, out("zmm22") _, out("zmm23") _,
+ out("zmm24") _, out("zmm25") _, out("zmm26") _, out("zmm27") _,
+ out("zmm28") _, out("zmm29") _, out("zmm30") _, out("zmm31") _,
+ );
+}
+
#[test]
fn test_chunks16() {
let mut message = [0u8; 16 * CHUNK_LEN];
@@ -1341,3 +1560,25 @@ fn test_parents16() {
}
assert_eq!(expected_out, found_out_transposed);
}
+
+#[test]
+fn test_xof_stream16() {
+ let mut block = [0; 64];
+ let mut key = [0; 32];
+ crate::test::paint_test_input(&mut block);
+ crate::test::paint_test_input(&mut key);
+ let mut expected = [0; 1024];
+ crate::Hasher::new_keyed(&key)
+ .update(&block)
+ .finalize_xof()
+ .fill(&mut expected);
+
+ let block_words = crate::platform::words_from_le_bytes_64(&block);
+ let key_words = crate::platform::words_from_le_bytes_32(&key);
+ let flags = crate::KEYED_HASH | crate::CHUNK_START | crate::CHUNK_END;
+ let mut found = [0; 1024];
+ unsafe {
+ xof_stream16(&block_words, &key_words, 0, flags as u32, &mut found);
+ }
+ assert_eq!(expected, found);
+}