diff options
| author | Jack O'Connor <[email protected]> | 2022-03-09 12:05:33 -0500 |
|---|---|---|
| committer | Jack O'Connor <[email protected]> | 2022-03-09 12:19:14 -0500 |
| commit | 09c2b9141c81e5afc0720ce9c4937856e0dbfdb6 (patch) | |
| tree | 441af5f8f73bb26bcf84609674a459ee88a6dced | |
| parent | 506ae0b0fe255c00c69c3ca6a6388e0a20eebe40 (diff) | |
broadcast the block length and domain flags inside blake3_avx512_kernel_16
blake3_avx512_xof_stream_16 was also incorrectly hardcoding a block
length of 64. The block length parameter is the *input* block length,
which is independent of the output block length. (The output block
length is not a compression function parameter.)
| -rw-r--r-- | benches/bench.rs | 12 | ||||
| -rw-r--r-- | src/kernel.rs | 70 |
2 files changed, 51 insertions, 31 deletions
diff --git a/benches/bench.rs b/benches/bench.rs index d82f3ce..9b16bee 100644 --- a/benches/bench.rs +++ b/benches/bench.rs @@ -579,14 +579,22 @@ fn bench_xof_kernel(b: &mut Bencher) { b.bytes = output.len() as u64; let message_words = [0; 16]; let key_words = [0; 8]; + let counter = 0; + let block_length = 0; let flags = 1 | 2 | 16; // CHUNK_START | CHUNK_END | KEYED_HASH b.iter(|| unsafe { - blake3::kernel::xof_stream16(&message_words, &key_words, 0, flags, &mut output); + blake3::kernel::xof_stream16( + &message_words, + &key_words, + counter, + block_length, + flags, + &mut output, + ); }); // Double check that this output is reasonable. let mut expected = [0; 16 * 64]; blake3::Hasher::new_keyed(&[0; 32]) - .update(&[0; 64]) .finalize_xof() .fill(&mut expected); assert_eq!(expected, output); diff --git a/src/kernel.rs b/src/kernel.rs index ae5712e..49ac3f6 100644 --- a/src/kernel.rs +++ b/src/kernel.rs @@ -5,6 +5,8 @@ global_asm!( // -------------------------------------------------------------------------------------------- // blake3_avx512_kernel_16 // + // ecx: block length + // r8d: domain flags // zmm0-zmm7: transposed input CV (which may be the key or the IV) // zmm12: transposed lower order counter words // zmm13: transposed higher order counter words @@ -12,9 +14,11 @@ global_asm!( // zmm15: transposed flag words // zmm16-zmm31: transposed message vectors // - // This routine overwrites zmm8-zmm11 (the third row of the state) with IV bytes, executes all - // 7 rounds of compression, and performs the XOR of the upper half of the state into the lower - // half (but not the feed-forward). The result is left in zmm0-zmm7. + // This routine overwrites zmm8-zmm11 (the third row of the state) with IV bytes, broadcasts + // the block length into zmm14, and broadcasts the domain flags into zmm15. This completes the + // transposed state in zmm0-zmm15. It then executes all 7 rounds of compression, and performs + // the XOR of the upper half of the state into the lower half (but not the feed-forward). The + // result is left in zmm0-zmm7. // -------------------------------------------------------------------------------------------- ".p2align 6", "BLAKE3_IV0_16:", @@ -35,6 +39,10 @@ global_asm!( "vmovdqa32 zmm9, zmmword ptr [BLAKE3_IV1_16 + rip]", "vmovdqa32 zmm10, zmmword ptr [BLAKE3_IV2_16 + rip]", "vmovdqa32 zmm11, zmmword ptr [BLAKE3_IV3_16 + rip]", + // broadcast the block length + "vpbroadcastd zmm14, ecx", + // broadcast the domain flags + "vpbroadcastd zmm15, r8d", // round 1 "vpaddd zmm0, zmm0, zmm16", "vpaddd zmm1, zmm1, zmm18", @@ -844,7 +852,7 @@ global_asm!( // rdi: pointer to first message block in rdi, subsequent blocks offset by 1024 bytes each // rsi: [unused] // rdx: pointer to two 64-byte aligned vectors, counter-low followed by counter-high - // ecx: block len (always 64) + // ecx: block len // r8d: flags (other than CHUNK_START and CHUNK_END) // // This routine loads and transposes message words, populates the rest of the state registers, @@ -974,12 +982,9 @@ global_asm!( "vpunpckhqdq zmm29, zmm30, zmm13", "vpunpcklqdq zmm30, zmm31, zmm14", "vpunpckhqdq zmm31, zmm31, zmm14", - // Initialize fourth row of the state, part of which we just used as scratch space during - // transposition. + // Load the low and high counter words. "vmovdqa32 zmm12, zmmword ptr [rdx + 64 * 0]", // counter low "vmovdqa32 zmm13, zmmword ptr [rdx + 64 * 1]", // counter high - "vpbroadcastd zmm14, ecx", // block length (always 64) - "vpbroadcastd zmm15, r8d", // flags // Run the kernel and then exit. "call blake3_avx512_kernel_16", "ret", @@ -1010,7 +1015,7 @@ global_asm!( "vpbroadcastd zmm5, dword ptr [rsi + 5 * 4]", "vpbroadcastd zmm6, dword ptr [rsi + 6 * 4]", "vpbroadcastd zmm7, dword ptr [rsi + 7 * 4]", - // ecx is the block length parameter for blake3_avx512_blocks_16. It is always 64. + // The block length is always 64 here. "mov ecx, 64", // Set the CHUNK_START flag. "or r8d, 1", @@ -1052,7 +1057,7 @@ global_asm!( // Compress the last block. "add rdi, 64", "call blake3_avx512_blocks_16", - // Write the output and exit. + // Write the output in transposed form and exit. "vmovdqa32 zmmword ptr [r9 + 0 * 64], zmm0", "vmovdqa32 zmmword ptr [r9 + 1 * 64], zmm1", "vmovdqa32 zmmword ptr [r9 + 2 * 64], zmm2", @@ -1154,14 +1159,13 @@ global_asm!( "vpbroadcastd zmm5, dword ptr [rdx + 5 * 4]", "vpbroadcastd zmm6, dword ptr [rdx + 6 * 4]", "vpbroadcastd zmm7, dword ptr [rdx + 7 * 4]", - // Initialize the fourth row of the state. - "xor ecx, ecx", // zero - "vpbroadcastd zmm12, ecx", // counter low (always 0) - "vpbroadcastd zmm13, ecx", // counter high (always 0) + // Initialize the low and high counter words. + "vpxorq zmm12, zmm12, zmm12", // counter low, always zero for parents + "vpxorq zmm13, zmm13, zmm13", // counter high, always zero for parents + // The block length is always 64 for parents. "mov ecx, 64", - "vpbroadcastd zmm14, ecx", // block length (always 64) - "or r8d, 4", // set the PARENT flag - "vpbroadcastd zmm15, r8d", // flags + // Set the PARENT flag. + "or r8d, 4", // Run the kernel. "call blake3_avx512_kernel_16", // Write the output and exit. @@ -1183,7 +1187,7 @@ global_asm!( // rdi: pointer to the 16-word message block, 4-byte aligned // rsi: pointer to the 8-word input CV, 4-byte aligned // rdx: pointer to two 64-byte aligned vectors, counter-low followed by counter-high - // ecx: [clobbered] + // ecx: block length // r8d: flags (other than ROOT) // r9: out pointer to 16x64=1024 bytes, unaligned // @@ -1200,13 +1204,11 @@ global_asm!( "vpbroadcastd zmm5, dword ptr [rsi + 5 * 4]", "vpbroadcastd zmm6, dword ptr [rsi + 6 * 4]", "vpbroadcastd zmm7, dword ptr [rsi + 7 * 4]", - // Initialize zmm12-zmm15, fourth row of the state. + // Load the low and high counter words. "vmovdqa32 zmm12, zmmword ptr [rdx + 64 * 0]", // counter low "vmovdqa32 zmm13, zmmword ptr [rdx + 64 * 1]", // counter high - "mov ecx, 64", - "vpbroadcastd zmm14, ecx", // block length (always 64) - "or r8d, 8", // set the ROOT flag - "vpbroadcastd zmm15, r8d", // flags + // Set the ROOT flag. + "or r8d, 8", // Broadcast the message block into zmm16-zmm31 "vpbroadcastd zmm16, dword ptr [rdi + 0 * 4]", "vpbroadcastd zmm17, dword ptr [rdi + 1 * 4]", @@ -1428,6 +1430,7 @@ pub unsafe fn xof_stream16( message_words: &[u32; 16], cv_words: &[u32; 8], counter: u64, + block_len: u32, flags: u32, out_ptr: *mut [u8; 16 * 64], ) { @@ -1442,7 +1445,7 @@ pub unsafe fn xof_stream16( inout("rdi") message_words => _, inout("rsi") cv_words => _, inout("rdx") &counter_vectors => _, - out("ecx") _, + inout("ecx") block_len => _, inout("r8d") flags => _, inout("r9") out_ptr => _, out("zmm0") _, out("zmm1") _, out("zmm2") _, out("zmm3") _, @@ -1567,22 +1570,31 @@ fn test_parents16() { #[test] fn test_xof_stream16() { - let mut block = [0; 64]; + let mut padded_block = [0; 64]; + let block_len = 53; + let block = &mut padded_block[..block_len]; let mut key = [0; 32]; - crate::test::paint_test_input(&mut block); + crate::test::paint_test_input(block); crate::test::paint_test_input(&mut key); let mut expected = [0; 1024]; crate::Hasher::new_keyed(&key) - .update(&block) + .update(block) .finalize_xof() .fill(&mut expected); - let block_words = crate::platform::words_from_le_bytes_64(&block); + let block_words = crate::platform::words_from_le_bytes_64(&padded_block); let key_words = crate::platform::words_from_le_bytes_32(&key); let flags = crate::KEYED_HASH | crate::CHUNK_START | crate::CHUNK_END; let mut found = [0; 1024]; unsafe { - xof_stream16(&block_words, &key_words, 0, flags as u32, &mut found); + xof_stream16( + &block_words, + &key_words, + 0, + block_len as u32, + flags as u32, + &mut found, + ); } assert_eq!(expected, found); } |
