aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJack O'Connor <[email protected]>2022-03-09 12:05:33 -0500
committerJack O'Connor <[email protected]>2022-03-09 12:19:14 -0500
commit09c2b9141c81e5afc0720ce9c4937856e0dbfdb6 (patch)
tree441af5f8f73bb26bcf84609674a459ee88a6dced
parent506ae0b0fe255c00c69c3ca6a6388e0a20eebe40 (diff)
broadcast the block length and domain flags inside blake3_avx512_kernel_16
blake3_avx512_xof_stream_16 was also incorrectly hardcoding a block length of 64. The block length parameter is the *input* block length, which is independent of the output block length. (The output block length is not a compression function parameter.)
-rw-r--r--benches/bench.rs12
-rw-r--r--src/kernel.rs70
2 files changed, 51 insertions, 31 deletions
diff --git a/benches/bench.rs b/benches/bench.rs
index d82f3ce..9b16bee 100644
--- a/benches/bench.rs
+++ b/benches/bench.rs
@@ -579,14 +579,22 @@ fn bench_xof_kernel(b: &mut Bencher) {
b.bytes = output.len() as u64;
let message_words = [0; 16];
let key_words = [0; 8];
+ let counter = 0;
+ let block_length = 0;
let flags = 1 | 2 | 16; // CHUNK_START | CHUNK_END | KEYED_HASH
b.iter(|| unsafe {
- blake3::kernel::xof_stream16(&message_words, &key_words, 0, flags, &mut output);
+ blake3::kernel::xof_stream16(
+ &message_words,
+ &key_words,
+ counter,
+ block_length,
+ flags,
+ &mut output,
+ );
});
// Double check that this output is reasonable.
let mut expected = [0; 16 * 64];
blake3::Hasher::new_keyed(&[0; 32])
- .update(&[0; 64])
.finalize_xof()
.fill(&mut expected);
assert_eq!(expected, output);
diff --git a/src/kernel.rs b/src/kernel.rs
index ae5712e..49ac3f6 100644
--- a/src/kernel.rs
+++ b/src/kernel.rs
@@ -5,6 +5,8 @@ global_asm!(
// --------------------------------------------------------------------------------------------
// blake3_avx512_kernel_16
//
+ // ecx: block length
+ // r8d: domain flags
// zmm0-zmm7: transposed input CV (which may be the key or the IV)
// zmm12: transposed lower order counter words
// zmm13: transposed higher order counter words
@@ -12,9 +14,11 @@ global_asm!(
// zmm15: transposed flag words
// zmm16-zmm31: transposed message vectors
//
- // This routine overwrites zmm8-zmm11 (the third row of the state) with IV bytes, executes all
- // 7 rounds of compression, and performs the XOR of the upper half of the state into the lower
- // half (but not the feed-forward). The result is left in zmm0-zmm7.
+ // This routine overwrites zmm8-zmm11 (the third row of the state) with IV bytes, broadcasts
+ // the block length into zmm14, and broadcasts the domain flags into zmm15. This completes the
+ // transposed state in zmm0-zmm15. It then executes all 7 rounds of compression, and performs
+ // the XOR of the upper half of the state into the lower half (but not the feed-forward). The
+ // result is left in zmm0-zmm7.
// --------------------------------------------------------------------------------------------
".p2align 6",
"BLAKE3_IV0_16:",
@@ -35,6 +39,10 @@ global_asm!(
"vmovdqa32 zmm9, zmmword ptr [BLAKE3_IV1_16 + rip]",
"vmovdqa32 zmm10, zmmword ptr [BLAKE3_IV2_16 + rip]",
"vmovdqa32 zmm11, zmmword ptr [BLAKE3_IV3_16 + rip]",
+ // broadcast the block length
+ "vpbroadcastd zmm14, ecx",
+ // broadcast the domain flags
+ "vpbroadcastd zmm15, r8d",
// round 1
"vpaddd zmm0, zmm0, zmm16",
"vpaddd zmm1, zmm1, zmm18",
@@ -844,7 +852,7 @@ global_asm!(
// rdi: pointer to first message block in rdi, subsequent blocks offset by 1024 bytes each
// rsi: [unused]
// rdx: pointer to two 64-byte aligned vectors, counter-low followed by counter-high
- // ecx: block len (always 64)
+ // ecx: block len
// r8d: flags (other than CHUNK_START and CHUNK_END)
//
// This routine loads and transposes message words, populates the rest of the state registers,
@@ -974,12 +982,9 @@ global_asm!(
"vpunpckhqdq zmm29, zmm30, zmm13",
"vpunpcklqdq zmm30, zmm31, zmm14",
"vpunpckhqdq zmm31, zmm31, zmm14",
- // Initialize fourth row of the state, part of which we just used as scratch space during
- // transposition.
+ // Load the low and high counter words.
"vmovdqa32 zmm12, zmmword ptr [rdx + 64 * 0]", // counter low
"vmovdqa32 zmm13, zmmword ptr [rdx + 64 * 1]", // counter high
- "vpbroadcastd zmm14, ecx", // block length (always 64)
- "vpbroadcastd zmm15, r8d", // flags
// Run the kernel and then exit.
"call blake3_avx512_kernel_16",
"ret",
@@ -1010,7 +1015,7 @@ global_asm!(
"vpbroadcastd zmm5, dword ptr [rsi + 5 * 4]",
"vpbroadcastd zmm6, dword ptr [rsi + 6 * 4]",
"vpbroadcastd zmm7, dword ptr [rsi + 7 * 4]",
- // ecx is the block length parameter for blake3_avx512_blocks_16. It is always 64.
+ // The block length is always 64 here.
"mov ecx, 64",
// Set the CHUNK_START flag.
"or r8d, 1",
@@ -1052,7 +1057,7 @@ global_asm!(
// Compress the last block.
"add rdi, 64",
"call blake3_avx512_blocks_16",
- // Write the output and exit.
+ // Write the output in transposed form and exit.
"vmovdqa32 zmmword ptr [r9 + 0 * 64], zmm0",
"vmovdqa32 zmmword ptr [r9 + 1 * 64], zmm1",
"vmovdqa32 zmmword ptr [r9 + 2 * 64], zmm2",
@@ -1154,14 +1159,13 @@ global_asm!(
"vpbroadcastd zmm5, dword ptr [rdx + 5 * 4]",
"vpbroadcastd zmm6, dword ptr [rdx + 6 * 4]",
"vpbroadcastd zmm7, dword ptr [rdx + 7 * 4]",
- // Initialize the fourth row of the state.
- "xor ecx, ecx", // zero
- "vpbroadcastd zmm12, ecx", // counter low (always 0)
- "vpbroadcastd zmm13, ecx", // counter high (always 0)
+ // Initialize the low and high counter words.
+ "vpxorq zmm12, zmm12, zmm12", // counter low, always zero for parents
+ "vpxorq zmm13, zmm13, zmm13", // counter high, always zero for parents
+ // The block length is always 64 for parents.
"mov ecx, 64",
- "vpbroadcastd zmm14, ecx", // block length (always 64)
- "or r8d, 4", // set the PARENT flag
- "vpbroadcastd zmm15, r8d", // flags
+ // Set the PARENT flag.
+ "or r8d, 4",
// Run the kernel.
"call blake3_avx512_kernel_16",
// Write the output and exit.
@@ -1183,7 +1187,7 @@ global_asm!(
// rdi: pointer to the 16-word message block, 4-byte aligned
// rsi: pointer to the 8-word input CV, 4-byte aligned
// rdx: pointer to two 64-byte aligned vectors, counter-low followed by counter-high
- // ecx: [clobbered]
+ // ecx: block length
// r8d: flags (other than ROOT)
// r9: out pointer to 16x64=1024 bytes, unaligned
//
@@ -1200,13 +1204,11 @@ global_asm!(
"vpbroadcastd zmm5, dword ptr [rsi + 5 * 4]",
"vpbroadcastd zmm6, dword ptr [rsi + 6 * 4]",
"vpbroadcastd zmm7, dword ptr [rsi + 7 * 4]",
- // Initialize zmm12-zmm15, fourth row of the state.
+ // Load the low and high counter words.
"vmovdqa32 zmm12, zmmword ptr [rdx + 64 * 0]", // counter low
"vmovdqa32 zmm13, zmmword ptr [rdx + 64 * 1]", // counter high
- "mov ecx, 64",
- "vpbroadcastd zmm14, ecx", // block length (always 64)
- "or r8d, 8", // set the ROOT flag
- "vpbroadcastd zmm15, r8d", // flags
+ // Set the ROOT flag.
+ "or r8d, 8",
// Broadcast the message block into zmm16-zmm31
"vpbroadcastd zmm16, dword ptr [rdi + 0 * 4]",
"vpbroadcastd zmm17, dword ptr [rdi + 1 * 4]",
@@ -1428,6 +1430,7 @@ pub unsafe fn xof_stream16(
message_words: &[u32; 16],
cv_words: &[u32; 8],
counter: u64,
+ block_len: u32,
flags: u32,
out_ptr: *mut [u8; 16 * 64],
) {
@@ -1442,7 +1445,7 @@ pub unsafe fn xof_stream16(
inout("rdi") message_words => _,
inout("rsi") cv_words => _,
inout("rdx") &counter_vectors => _,
- out("ecx") _,
+ inout("ecx") block_len => _,
inout("r8d") flags => _,
inout("r9") out_ptr => _,
out("zmm0") _, out("zmm1") _, out("zmm2") _, out("zmm3") _,
@@ -1567,22 +1570,31 @@ fn test_parents16() {
#[test]
fn test_xof_stream16() {
- let mut block = [0; 64];
+ let mut padded_block = [0; 64];
+ let block_len = 53;
+ let block = &mut padded_block[..block_len];
let mut key = [0; 32];
- crate::test::paint_test_input(&mut block);
+ crate::test::paint_test_input(block);
crate::test::paint_test_input(&mut key);
let mut expected = [0; 1024];
crate::Hasher::new_keyed(&key)
- .update(&block)
+ .update(block)
.finalize_xof()
.fill(&mut expected);
- let block_words = crate::platform::words_from_le_bytes_64(&block);
+ let block_words = crate::platform::words_from_le_bytes_64(&padded_block);
let key_words = crate::platform::words_from_le_bytes_32(&key);
let flags = crate::KEYED_HASH | crate::CHUNK_START | crate::CHUNK_END;
let mut found = [0; 1024];
unsafe {
- xof_stream16(&block_words, &key_words, 0, flags as u32, &mut found);
+ xof_stream16(
+ &block_words,
+ &key_words,
+ 0,
+ block_len as u32,
+ flags as u32,
+ &mut found,
+ );
}
assert_eq!(expected, found);
}