aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--rust/guts/Cargo.toml18
-rw-r--r--rust/guts/readme.md62
-rw-r--r--rust/guts/src/lib.rs956
-rw-r--r--rust/guts/src/portable.rs262
-rw-r--r--rust/guts/src/test.rs523
5 files changed, 1821 insertions, 0 deletions
diff --git a/rust/guts/Cargo.toml b/rust/guts/Cargo.toml
new file mode 100644
index 0000000..ebcf77f
--- /dev/null
+++ b/rust/guts/Cargo.toml
@@ -0,0 +1,18 @@
+[package]
+name = "blake3_guts"
+version = "0.0.0"
+authors = ["Jack O'Connor <[email protected]>", "Samuel Neves"]
+description = "low-level building blocks for the BLAKE3 hash function"
+repository = "https://github.com/BLAKE3-team/BLAKE3"
+license = "CC0-1.0 OR Apache-2.0"
+documentation = "https://docs.rs/blake3_guts"
+readme = "readme.md"
+edition = "2021"
+
+[dev-dependencies]
+hex = "0.4.3"
+reference_impl = { path = "../../reference_impl" }
+
+[features]
+default = ["std"]
+std = []
diff --git a/rust/guts/readme.md b/rust/guts/readme.md
new file mode 100644
index 0000000..a1adbf1
--- /dev/null
+++ b/rust/guts/readme.md
@@ -0,0 +1,62 @@
+# The BLAKE3 Guts API
+
+## Introduction
+
+This crate contains low-level, high-performance, platform-specific
+implementations of the BLAKE3 compression function. This API is complicated and
+unsafe, and this crate will never have a stable release. For the standard
+BLAKE3 hash function, see the [`blake3`](https://crates.io/crates/blake3)
+crate, which depends on this one.
+
+The most important ingredient in a high-performance implementation of BLAKE3 is
+parallelism. The BLAKE3 tree structure lets us hash different parts of the tree
+in parallel, and modern computers have a _lot_ of parallelism to offer.
+Sometimes that means using multiple threads running on multiple cores, but
+multithreading isn't appropriate for all applications, and it's not the usual
+default for library APIs. More commonly, BLAKE3 implementations use SIMD
+instructions ("Single Instruction Multiple Data") to improve the performance of
+a single thread. When we do use multithreading, the performance benefits
+multiply.
+
+The tricky thing about SIMD is that each instruction set works differently.
+Instead of writing portable code once and letting the compiler do most of the
+optimization work, we need to write platform-specific implementations, and
+sometimes more than one per platform. We maintain *four* different
+implementations on x86 alone (targeting SSE2, SSE4.1, AVX2, and AVX-512), in
+addition to ARM NEON and the RISC-V vector extensions. In the future we might
+add ARM SVE2.
+
+All of that means a lot of duplicated logic and maintenance. So while the main
+goal of this API is high performance, it's also important to keep the API as
+small and simple as possible. Higher level details like the "CV stack", input
+buffering, and multithreading are handled by portable code in the main `blake3`
+crate. These are just building blocks.
+
+## The private API
+
+This is the API that each platform reimplements. It's completely `unsafe`,
+inputs and outputs are allowed to alias, and bounds checking is the caller's
+responsibility.
+
+- `degree`
+- `compress`
+- `hash_chunks`
+- `hash_parents`
+- `xof`
+- `xof_xor`
+- `universal_hash`
+
+## The public API
+
+This is the API that this crate exposes to callers, i.e. to the main `blake3`
+crate. It's a thin, portable layer on top of the private API above. The Rust
+version of this API is memory-safe.
+
+- `degree`
+- `compress`
+- `hash_chunks`
+- `hash_parents`
+- `reduce_parents`
+- `xof`
+- `xof_xor`
+- `universal_hash`
diff --git a/rust/guts/src/lib.rs b/rust/guts/src/lib.rs
new file mode 100644
index 0000000..67f7a05
--- /dev/null
+++ b/rust/guts/src/lib.rs
@@ -0,0 +1,956 @@
+use core::cmp;
+use core::marker::PhantomData;
+use core::mem;
+use core::ptr;
+use core::sync::atomic::{AtomicPtr, Ordering::Relaxed};
+
+pub mod portable;
+
+#[cfg(test)]
+mod test;
+
+pub const OUT_LEN: usize = 32;
+pub const BLOCK_LEN: usize = 64;
+pub const CHUNK_LEN: usize = 1024;
+pub const WORD_LEN: usize = 4;
+pub const UNIVERSAL_HASH_LEN: usize = 16;
+
+pub const CHUNK_START: u32 = 1 << 0;
+pub const CHUNK_END: u32 = 1 << 1;
+pub const PARENT: u32 = 1 << 2;
+pub const ROOT: u32 = 1 << 3;
+pub const KEYED_HASH: u32 = 1 << 4;
+pub const DERIVE_KEY_CONTEXT: u32 = 1 << 5;
+pub const DERIVE_KEY_MATERIAL: u32 = 1 << 6;
+
+pub const IV: CVWords = [
+ 0x6A09E667, 0xBB67AE85, 0x3C6EF372, 0xA54FF53A, 0x510E527F, 0x9B05688C, 0x1F83D9AB, 0x5BE0CD19,
+];
+pub const IV_BYTES: CVBytes = le_bytes_from_words_32(&IV);
+
+pub const MSG_SCHEDULE: [[usize; 16]; 7] = [
+ [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15],
+ [2, 6, 3, 10, 7, 0, 4, 13, 1, 11, 12, 5, 9, 14, 15, 8],
+ [3, 4, 10, 12, 13, 2, 7, 14, 6, 5, 9, 0, 11, 15, 8, 1],
+ [10, 7, 12, 9, 14, 3, 13, 15, 4, 0, 11, 2, 5, 8, 1, 6],
+ [12, 13, 9, 11, 15, 10, 14, 8, 7, 2, 5, 3, 0, 1, 6, 4],
+ [9, 14, 11, 5, 8, 12, 15, 1, 13, 3, 0, 10, 2, 6, 4, 7],
+ [11, 15, 5, 0, 1, 9, 8, 6, 14, 10, 2, 12, 3, 4, 7, 13],
+];
+
+// never less than 2
+pub const MAX_SIMD_DEGREE: usize = 2;
+
+pub type CVBytes = [u8; 32];
+pub type CVWords = [u32; 8];
+pub type BlockBytes = [u8; 64];
+pub type BlockWords = [u32; 16];
+
+pub static DETECTED_IMPL: Implementation = Implementation::new(
+ degree_init,
+ compress_init,
+ hash_chunks_init,
+ hash_parents_init,
+ xof_init,
+ xof_xor_init,
+ universal_hash_init,
+);
+
+fn detect() -> Implementation {
+ portable::implementation()
+}
+
+fn init_detected_impl() {
+ let detected = detect();
+
+ DETECTED_IMPL
+ .degree_ptr
+ .store(detected.degree_ptr.load(Relaxed), Relaxed);
+ DETECTED_IMPL
+ .compress_ptr
+ .store(detected.compress_ptr.load(Relaxed), Relaxed);
+ DETECTED_IMPL
+ .hash_chunks_ptr
+ .store(detected.hash_chunks_ptr.load(Relaxed), Relaxed);
+ DETECTED_IMPL
+ .hash_parents_ptr
+ .store(detected.hash_parents_ptr.load(Relaxed), Relaxed);
+ DETECTED_IMPL
+ .xof_ptr
+ .store(detected.xof_ptr.load(Relaxed), Relaxed);
+ DETECTED_IMPL
+ .xof_xor_ptr
+ .store(detected.xof_xor_ptr.load(Relaxed), Relaxed);
+ DETECTED_IMPL
+ .universal_hash_ptr
+ .store(detected.universal_hash_ptr.load(Relaxed), Relaxed);
+}
+
+pub struct Implementation {
+ degree_ptr: AtomicPtr<()>,
+ compress_ptr: AtomicPtr<()>,
+ hash_chunks_ptr: AtomicPtr<()>,
+ hash_parents_ptr: AtomicPtr<()>,
+ xof_ptr: AtomicPtr<()>,
+ xof_xor_ptr: AtomicPtr<()>,
+ universal_hash_ptr: AtomicPtr<()>,
+}
+
+impl Implementation {
+ const fn new(
+ degree_fn: DegreeFn,
+ compress_fn: CompressFn,
+ hash_chunks_fn: HashChunksFn,
+ hash_parents_fn: HashParentsFn,
+ xof_fn: XofFn,
+ xof_xor_fn: XofFn,
+ universal_hash_fn: UniversalHashFn,
+ ) -> Self {
+ Self {
+ degree_ptr: AtomicPtr::new(degree_fn as *mut ()),
+ compress_ptr: AtomicPtr::new(compress_fn as *mut ()),
+ hash_chunks_ptr: AtomicPtr::new(hash_chunks_fn as *mut ()),
+ hash_parents_ptr: AtomicPtr::new(hash_parents_fn as *mut ()),
+ xof_ptr: AtomicPtr::new(xof_fn as *mut ()),
+ xof_xor_ptr: AtomicPtr::new(xof_xor_fn as *mut ()),
+ universal_hash_ptr: AtomicPtr::new(universal_hash_fn as *mut ()),
+ }
+ }
+
+ #[inline]
+ fn degree_fn(&self) -> DegreeFn {
+ unsafe { mem::transmute(self.degree_ptr.load(Relaxed)) }
+ }
+
+ #[inline]
+ pub fn degree(&self) -> usize {
+ let degree = unsafe { self.degree_fn()() };
+ debug_assert!(degree >= 2);
+ debug_assert!(degree <= MAX_SIMD_DEGREE);
+ debug_assert_eq!(1, degree.count_ones(), "power of 2");
+ degree
+ }
+
+ #[inline]
+ pub fn split_transposed_vectors<'v>(
+ &self,
+ vectors: &'v mut TransposedVectors,
+ ) -> (TransposedSplit<'v>, TransposedSplit<'v>) {
+ unsafe { vectors.split(self.degree()) }
+ }
+
+ #[inline]
+ fn compress_fn(&self) -> CompressFn {
+ unsafe { mem::transmute(self.compress_ptr.load(Relaxed)) }
+ }
+
+ #[inline]
+ pub fn compress(
+ &self,
+ block: &BlockBytes,
+ block_len: u32,
+ cv: &CVBytes,
+ counter: u64,
+ flags: u32,
+ ) -> CVBytes {
+ let mut out = [0u8; 32];
+ unsafe {
+ self.compress_fn()(block, block_len, cv, counter, flags, &mut out);
+ }
+ out
+ }
+
+ // The contract for HashChunksFn doesn't require the implementation to support single-chunk
+ // inputs. Instead we handle that case here by calling compress in a loop.
+ #[inline]
+ fn hash_one_chunk(
+ &self,
+ mut input: &[u8],
+ key: &CVBytes,
+ counter: u64,
+ mut flags: u32,
+ output: TransposedSplit,
+ ) {
+ debug_assert!(input.len() <= CHUNK_LEN);
+ let mut cv = *key;
+ flags |= CHUNK_START;
+ while input.len() > BLOCK_LEN {
+ cv = self.compress(
+ input[..BLOCK_LEN].try_into().unwrap(),
+ BLOCK_LEN as u32,
+ &cv,
+ counter,
+ flags,
+ );
+ input = &input[BLOCK_LEN..];
+ flags &= !CHUNK_START;
+ }
+ let mut final_block = [0u8; BLOCK_LEN];
+ final_block[..input.len()].copy_from_slice(input);
+ cv = self.compress(
+ &final_block,
+ input.len() as u32,
+ &cv,
+ counter,
+ flags | CHUNK_END,
+ );
+ unsafe {
+ write_transposed_cv(&words_from_le_bytes_32(&cv), output.ptr);
+ }
+ }
+
+ #[inline]
+ fn hash_chunks_fn(&self) -> HashChunksFn {
+ unsafe { mem::transmute(self.hash_chunks_ptr.load(Relaxed)) }
+ }
+
+ #[inline]
+ pub fn hash_chunks(
+ &self,
+ input: &[u8],
+ key: &CVBytes,
+ counter: u64,
+ flags: u32,
+ transposed_output: TransposedSplit,
+ ) -> usize {
+ debug_assert!(input.len() <= self.degree() * CHUNK_LEN);
+ if input.len() <= CHUNK_LEN {
+ // The underlying hash_chunks_fn isn't required to support this case. Instead we handle
+ // it by calling compress_fn in a loop. But note that we still don't support root
+ // finalization or the empty input here.
+ self.hash_one_chunk(input, key, counter, flags, transposed_output);
+ return 1;
+ }
+ // SAFETY: If the caller passes in more than MAX_SIMD_DEGREE * CHUNK_LEN bytes, silently
+ // ignore the remainder. This makes it impossible to write out of bounds in a properly
+ // constructed TransposedSplit.
+ let len = cmp::min(input.len(), MAX_SIMD_DEGREE * CHUNK_LEN);
+ unsafe {
+ self.hash_chunks_fn()(
+ input.as_ptr(),
+ len,
+ key,
+ counter,
+ flags,
+ transposed_output.ptr,
+ );
+ }
+ if input.len() % CHUNK_LEN == 0 {
+ input.len() / CHUNK_LEN
+ } else {
+ (input.len() / CHUNK_LEN) + 1
+ }
+ }
+
+ #[inline]
+ fn hash_parents_fn(&self) -> HashParentsFn {
+ unsafe { mem::transmute(self.hash_parents_ptr.load(Relaxed)) }
+ }
+
+ #[inline]
+ pub fn hash_parents(
+ &self,
+ transposed_input: &TransposedVectors,
+ mut num_cvs: usize,
+ key: &CVBytes,
+ flags: u32,
+ transposed_output: TransposedSplit,
+ ) -> usize {
+ debug_assert!(num_cvs <= 2 * MAX_SIMD_DEGREE);
+ // SAFETY: Cap num_cvs at 2 * MAX_SIMD_DEGREE, to guarantee no out-of-bounds accesses.
+ num_cvs = cmp::min(num_cvs, 2 * MAX_SIMD_DEGREE);
+ let mut odd_cv = [0u32; 8];
+ if num_cvs % 2 == 1 {
+ unsafe {
+ odd_cv = read_transposed_cv(transposed_input.as_ptr().add(num_cvs - 1));
+ }
+ }
+ let num_parents = num_cvs / 2;
+ unsafe {
+ self.hash_parents_fn()(
+ transposed_input.as_ptr(),
+ num_parents,
+ key,
+ flags | PARENT,
+ transposed_output.ptr,
+ );
+ }
+ if num_cvs % 2 == 1 {
+ unsafe {
+ write_transposed_cv(&odd_cv, transposed_output.ptr.add(num_parents));
+ }
+ num_parents + 1
+ } else {
+ num_parents
+ }
+ }
+
+ #[inline]
+ pub fn reduce_parents(
+ &self,
+ transposed_in_out: &mut TransposedVectors,
+ mut num_cvs: usize,
+ key: &CVBytes,
+ flags: u32,
+ ) -> usize {
+ debug_assert!(num_cvs <= 2 * MAX_SIMD_DEGREE);
+ // SAFETY: Cap num_cvs at 2 * MAX_SIMD_DEGREE, to guarantee no out-of-bounds accesses.
+ num_cvs = cmp::min(num_cvs, 2 * MAX_SIMD_DEGREE);
+ let in_out_ptr = transposed_in_out.as_mut_ptr();
+ let mut odd_cv = [0u32; 8];
+ if num_cvs % 2 == 1 {
+ unsafe {
+ odd_cv = read_transposed_cv(in_out_ptr.add(num_cvs - 1));
+ }
+ }
+ let num_parents = num_cvs / 2;
+ unsafe {
+ self.hash_parents_fn()(in_out_ptr, num_parents, key, flags | PARENT, in_out_ptr);
+ }
+ if num_cvs % 2 == 1 {
+ unsafe {
+ write_transposed_cv(&odd_cv, in_out_ptr.add(num_parents));
+ }
+ num_parents + 1
+ } else {
+ num_parents
+ }
+ }
+
+ #[inline]
+ fn xof_fn(&self) -> XofFn {
+ unsafe { mem::transmute(self.xof_ptr.load(Relaxed)) }
+ }
+
+ #[inline]
+ pub fn xof(
+ &self,
+ block: &BlockBytes,
+ block_len: u32,
+ cv: &CVBytes,
+ mut counter: u64,
+ flags: u32,
+ mut out: &mut [u8],
+ ) {
+ let degree = self.degree();
+ let simd_len = degree * BLOCK_LEN;
+ while !out.is_empty() {
+ let take = cmp::min(simd_len, out.len());
+ unsafe {
+ self.xof_fn()(
+ block,
+ block_len,
+ cv,
+ counter,
+ flags | ROOT,
+ out.as_mut_ptr(),
+ take,
+ );
+ }
+ out = &mut out[take..];
+ counter += degree as u64;
+ }
+ }
+
+ #[inline]
+ fn xof_xor_fn(&self) -> XofFn {
+ unsafe { mem::transmute(self.xof_xor_ptr.load(Relaxed)) }
+ }
+
+ #[inline]
+ pub fn xof_xor(
+ &self,
+ block: &BlockBytes,
+ block_len: u32,
+ cv: &CVBytes,
+ mut counter: u64,
+ flags: u32,
+ mut out: &mut [u8],
+ ) {
+ let degree = self.degree();
+ let simd_len = degree * BLOCK_LEN;
+ while !out.is_empty() {
+ let take = cmp::min(simd_len, out.len());
+ unsafe {
+ self.xof_xor_fn()(
+ block,
+ block_len,
+ cv,
+ counter,
+ flags | ROOT,
+ out.as_mut_ptr(),
+ take,
+ );
+ }
+ out = &mut out[take..];
+ counter += degree as u64;
+ }
+ }
+
+ #[inline]
+ fn universal_hash_fn(&self) -> UniversalHashFn {
+ unsafe { mem::transmute(self.universal_hash_ptr.load(Relaxed)) }
+ }
+
+ #[inline]
+ pub fn universal_hash(&self, mut input: &[u8], key: &CVBytes, mut counter: u64) -> [u8; 16] {
+ let degree = self.degree();
+ let simd_len = degree * BLOCK_LEN;
+ let mut ret = [0u8; 16];
+ while !input.is_empty() {
+ let take = cmp::min(simd_len, input.len());
+ let mut output = [0u8; 16];
+ unsafe {
+ self.universal_hash_fn()(input.as_ptr(), take, key, counter, &mut output);
+ }
+ input = &input[take..];
+ counter += degree as u64;
+ for byte_index in 0..16 {
+ ret[byte_index] ^= output[byte_index];
+ }
+ }
+ ret
+ }
+}
+
+impl Clone for Implementation {
+ fn clone(&self) -> Self {
+ Self {
+ degree_ptr: AtomicPtr::new(self.degree_ptr.load(Relaxed)),
+ compress_ptr: AtomicPtr::new(self.compress_ptr.load(Relaxed)),
+ hash_chunks_ptr: AtomicPtr::new(self.hash_chunks_ptr.load(Relaxed)),
+ hash_parents_ptr: AtomicPtr::new(self.hash_parents_ptr.load(Relaxed)),
+ xof_ptr: AtomicPtr::new(self.xof_ptr.load(Relaxed)),
+ xof_xor_ptr: AtomicPtr::new(self.xof_xor_ptr.load(Relaxed)),
+ universal_hash_ptr: AtomicPtr::new(self.universal_hash_ptr.load(Relaxed)),
+ }
+ }
+}
+
+// never less than 2
+type DegreeFn = unsafe extern "C" fn() -> usize;
+
+unsafe extern "C" fn degree_init() -> usize {
+ init_detected_impl();
+ DETECTED_IMPL.degree_fn()()
+}
+
+type CompressFn = unsafe extern "C" fn(
+ block: *const BlockBytes, // zero padded to 64 bytes
+ block_len: u32,
+ cv: *const CVBytes,
+ counter: u64,
+ flags: u32,
+ out: *mut CVBytes, // may overlap the input
+);
+
+unsafe extern "C" fn compress_init(
+ block: *const BlockBytes,
+ block_len: u32,
+ cv: *const CVBytes,
+ counter: u64,
+ flags: u32,
+ out: *mut CVBytes,
+) {
+ init_detected_impl();
+ DETECTED_IMPL.compress_fn()(block, block_len, cv, counter, flags, out);
+}
+
+type CompressXofFn = unsafe extern "C" fn(
+ block: *const BlockBytes, // zero padded to 64 bytes
+ block_len: u32,
+ cv: *const CVBytes,
+ counter: u64,
+ flags: u32,
+ out: *mut BlockBytes, // may overlap the input
+);
+
+type HashChunksFn = unsafe extern "C" fn(
+ input: *const u8,
+ input_len: usize,
+ key: *const CVBytes,
+ counter: u64,
+ flags: u32,
+ transposed_output: *mut u32,
+);
+
+unsafe extern "C" fn hash_chunks_init(
+ input: *const u8,
+ input_len: usize,
+ key: *const CVBytes,
+ counter: u64,
+ flags: u32,
+ transposed_output: *mut u32,
+) {
+ init_detected_impl();
+ DETECTED_IMPL.hash_chunks_fn()(input, input_len, key, counter, flags, transposed_output);
+}
+
+type HashParentsFn = unsafe extern "C" fn(
+ transposed_input: *const u32,
+ num_parents: usize,
+ key: *const CVBytes,
+ flags: u32,
+ transposed_output: *mut u32, // may overlap the input
+);
+
+unsafe extern "C" fn hash_parents_init(
+ transposed_input: *const u32,
+ num_parents: usize,
+ key: *const CVBytes,
+ flags: u32,
+ transposed_output: *mut u32,
+) {
+ init_detected_impl();
+ DETECTED_IMPL.hash_parents_fn()(transposed_input, num_parents, key, flags, transposed_output);
+}
+
+// This signature covers both xof() and xof_xor().
+type XofFn = unsafe extern "C" fn(
+ block: *const BlockBytes, // zero padded to 64 bytes
+ block_len: u32,
+ cv: *const CVBytes,
+ counter: u64,
+ flags: u32,
+ out: *mut u8,
+ out_len: usize,
+);
+
+unsafe extern "C" fn xof_init(
+ block: *const BlockBytes,
+ block_len: u32,
+ cv: *const CVBytes,
+ counter: u64,
+ flags: u32,
+ out: *mut u8,
+ out_len: usize,
+) {
+ init_detected_impl();
+ DETECTED_IMPL.xof_fn()(block, block_len, cv, counter, flags, out, out_len);
+}
+
+unsafe extern "C" fn xof_xor_init(
+ block: *const BlockBytes,
+ block_len: u32,
+ cv: *const CVBytes,
+ counter: u64,
+ flags: u32,
+ out: *mut u8,
+ out_len: usize,
+) {
+ init_detected_impl();
+ DETECTED_IMPL.xof_xor_fn()(block, block_len, cv, counter, flags, out, out_len);
+}
+
+type UniversalHashFn = unsafe extern "C" fn(
+ input: *const u8,
+ input_len: usize,
+ key: *const CVBytes,
+ counter: u64,
+ out: *mut [u8; 16],
+);
+
+unsafe extern "C" fn universal_hash_init(
+ input: *const u8,
+ input_len: usize,
+ key: *const CVBytes,
+ counter: u64,
+ out: *mut [u8; 16],
+) {
+ init_detected_impl();
+ DETECTED_IMPL.universal_hash_fn()(input, input_len, key, counter, out);
+}
+
+// The implicit degree of this implementation is MAX_SIMD_DEGREE.
+#[inline(always)]
+unsafe fn hash_chunks_using_compress(
+ compress: CompressFn,
+ mut input: *const u8,
+ mut input_len: usize,
+ key: *const CVBytes,
+ mut counter: u64,
+ flags: u32,
+ mut transposed_output: *mut u32,
+) {
+ debug_assert!(input_len > 0);
+ debug_assert!(input_len <= MAX_SIMD_DEGREE * CHUNK_LEN);
+ input_len = cmp::min(input_len, MAX_SIMD_DEGREE * CHUNK_LEN);
+ while input_len > 0 {
+ let mut chunk_len = cmp::min(input_len, CHUNK_LEN);
+ input_len -= chunk_len;
+ // We only use 8 words of the CV, but compress returns 16.
+ let mut cv = *key;
+ let cv_ptr: *mut CVBytes = &mut cv;
+ let mut chunk_flags = flags | CHUNK_START;
+ while chunk_len > BLOCK_LEN {
+ compress(
+ input as *const BlockBytes,
+ BLOCK_LEN as u32,
+ cv_ptr,
+ counter,
+ chunk_flags,
+ cv_ptr,
+ );
+ input = input.add(BLOCK_LEN);
+ chunk_len -= BLOCK_LEN;
+ chunk_flags &= !CHUNK_START;
+ }
+ let mut last_block = [0u8; BLOCK_LEN];
+ ptr::copy_nonoverlapping(input, last_block.as_mut_ptr(), chunk_len);
+ input = input.add(chunk_len);
+ compress(
+ &last_block,
+ chunk_len as u32,
+ cv_ptr,
+ counter,
+ chunk_flags | CHUNK_END,
+ cv_ptr,
+ );
+ let cv_words = words_from_le_bytes_32(&cv);
+ for word_index in 0..8 {
+ transposed_output
+ .add(word_index * TRANSPOSED_STRIDE)
+ .write(cv_words[word_index]);
+ }
+ transposed_output = transposed_output.add(1);
+ counter += 1;
+ }
+}
+
+// The implicit degree of this implementation is MAX_SIMD_DEGREE.
+#[inline(always)]
+unsafe fn hash_parents_using_compress(
+ compress: CompressFn,
+ mut transposed_input: *const u32,
+ mut num_parents: usize,
+ key: *const CVBytes,
+ flags: u32,
+ mut transposed_output: *mut u32, // may overlap the input
+) {
+ debug_assert!(num_parents > 0);
+ debug_assert!(num_parents <= MAX_SIMD_DEGREE);
+ while num_parents > 0 {
+ let mut block_bytes = [0u8; 64];
+ for word_index in 0..8 {
+ let left_child_word = transposed_input.add(word_index * TRANSPOSED_STRIDE).read();
+ block_bytes[WORD_LEN * word_index..][..WORD_LEN]
+ .copy_from_slice(&left_child_word.to_le_bytes());
+ let right_child_word = transposed_input
+ .add(word_index * TRANSPOSED_STRIDE + 1)
+ .read();
+ block_bytes[WORD_LEN * (word_index + 8)..][..WORD_LEN]
+ .copy_from_slice(&right_child_word.to_le_bytes());
+ }
+ let mut cv = [0u8; 32];
+ compress(&block_bytes, BLOCK_LEN as u32, key, 0, flags, &mut cv);
+ let cv_words = words_from_le_bytes_32(&cv);
+ for word_index in 0..8 {
+ transposed_output
+ .add(word_index * TRANSPOSED_STRIDE)
+ .write(cv_words[word_index]);
+ }
+ transposed_input = transposed_input.add(2);
+ transposed_output = transposed_output.add(1);
+ num_parents -= 1;
+ }
+}
+
+#[inline(always)]
+unsafe fn xof_using_compress_xof(
+ compress_xof: CompressXofFn,
+ block: *const BlockBytes,
+ block_len: u32,
+ cv: *const CVBytes,
+ mut counter: u64,
+ flags: u32,
+ mut out: *mut u8,
+ mut out_len: usize,
+) {
+ debug_assert!(out_len <= MAX_SIMD_DEGREE * BLOCK_LEN);
+ while out_len > 0 {
+ let mut block_output = [0u8; 64];
+ compress_xof(block, block_len, cv, counter, flags, &mut block_output);
+ let take = cmp::min(out_len, BLOCK_LEN);
+ ptr::copy_nonoverlapping(block_output.as_ptr(), out, take);
+ out = out.add(take);
+ out_len -= take;
+ counter += 1;
+ }
+}
+
+#[inline(always)]
+unsafe fn xof_xor_using_compress_xof(
+ compress_xof: CompressXofFn,
+ block: *const BlockBytes,
+ block_len: u32,
+ cv: *const CVBytes,
+ mut counter: u64,
+ flags: u32,
+ mut out: *mut u8,
+ mut out_len: usize,
+) {
+ debug_assert!(out_len <= MAX_SIMD_DEGREE * BLOCK_LEN);
+ while out_len > 0 {
+ let mut block_output = [0u8; 64];
+ compress_xof(block, block_len, cv, counter, flags, &mut block_output);
+ let take = cmp::min(out_len, BLOCK_LEN);
+ for i in 0..take {
+ *out.add(i) ^= block_output[i];
+ }
+ out = out.add(take);
+ out_len -= take;
+ counter += 1;
+ }
+}
+
+#[inline(always)]
+unsafe fn universal_hash_using_compress(
+ compress: CompressFn,
+ mut input: *const u8,
+ mut input_len: usize,
+ key: *const CVBytes,
+ mut counter: u64,
+ out: *mut [u8; 16],
+) {
+ let flags = KEYED_HASH | CHUNK_START | CHUNK_END | ROOT;
+ let mut result = [0u8; 16];
+ while input_len > 0 {
+ let block_len = cmp::min(input_len, BLOCK_LEN);
+ let mut block = [0u8; BLOCK_LEN];
+ ptr::copy_nonoverlapping(input, block.as_mut_ptr(), block_len);
+ let mut block_output = [0u8; 32];
+ compress(
+ &block,
+ block_len as u32,
+ key,
+ counter,
+ flags,
+ &mut block_output,
+ );
+ for i in 0..16 {
+ result[i] ^= block_output[i];
+ }
+ input = input.add(block_len);
+ input_len -= block_len;
+ counter += 1;
+ }
+ *out = result;
+}
+
+// this is in units of *words*, for pointer operations on *const/*mut u32
+const TRANSPOSED_STRIDE: usize = 2 * MAX_SIMD_DEGREE;
+
+#[cfg_attr(any(target_arch = "x86", target_arch = "x86_64"), repr(C, align(64)))]
+#[derive(Clone, Debug, PartialEq, Eq)]
+pub struct TransposedVectors([[u32; 2 * MAX_SIMD_DEGREE]; 8]);
+
+impl TransposedVectors {
+ pub fn new() -> Self {
+ Self([[0; 2 * MAX_SIMD_DEGREE]; 8])
+ }
+
+ pub fn extract_cv(&self, cv_index: usize) -> CVBytes {
+ let mut words = [0u32; 8];
+ for word_index in 0..8 {
+ words[word_index] = self.0[word_index][cv_index];
+ }
+ le_bytes_from_words_32(&words)
+ }
+
+ pub fn extract_parent_node(&self, parent_index: usize) -> BlockBytes {
+ let mut bytes = [0u8; 64];
+ bytes[..32].copy_from_slice(&self.extract_cv(parent_index / 2));
+ bytes[32..].copy_from_slice(&self.extract_cv(parent_index / 2 + 1));
+ bytes
+ }
+
+ fn as_ptr(&self) -> *const u32 {
+ self.0[0].as_ptr()
+ }
+
+ fn as_mut_ptr(&mut self) -> *mut u32 {
+ self.0[0].as_mut_ptr()
+ }
+
+ // SAFETY: This function is just pointer arithmetic, but callers assume that it's safe (not
+ // necessarily correct) to write up to `degree` words to either side of the split, possibly
+ // from different threads.
+ unsafe fn split(&mut self, degree: usize) -> (TransposedSplit, TransposedSplit) {
+ debug_assert!(degree > 0);
+ debug_assert!(degree <= MAX_SIMD_DEGREE);
+ debug_assert_eq!(degree.count_ones(), 1, "power of 2");
+ let ptr = self.as_mut_ptr();
+ let left = TransposedSplit {
+ ptr,
+ phantom_data: PhantomData,
+ };
+ let right = TransposedSplit {
+ ptr: ptr.wrapping_add(degree),
+ phantom_data: PhantomData,
+ };
+ (left, right)
+ }
+}
+
+pub struct TransposedSplit<'vectors> {
+ ptr: *mut u32,
+ phantom_data: PhantomData<&'vectors mut u32>,
+}
+
+unsafe impl<'vectors> Send for TransposedSplit<'vectors> {}
+unsafe impl<'vectors> Sync for TransposedSplit<'vectors> {}
+
+unsafe fn read_transposed_cv(src: *const u32) -> CVWords {
+ let mut cv = [0u32; 8];
+ for word_index in 0..8 {
+ let offset_words = word_index * TRANSPOSED_STRIDE;
+ cv[word_index] = src.add(offset_words).read();
+ }
+ cv
+}
+
+unsafe fn write_transposed_cv(cv: &CVWords, dest: *mut u32) {
+ for word_index in 0..8 {
+ let offset_words = word_index * TRANSPOSED_STRIDE;
+ dest.add(offset_words).write(cv[word_index]);
+ }
+}
+
+#[inline(always)]
+pub const fn le_bytes_from_words_32(words: &CVWords) -> CVBytes {
+ let mut bytes = [0u8; 32];
+ // This loop is super verbose because currently that's what it takes to be const.
+ let mut word_index = 0;
+ while word_index < bytes.len() / WORD_LEN {
+ let word_bytes = words[word_index].to_le_bytes();
+ let mut byte_index = 0;
+ while byte_index < WORD_LEN {
+ bytes[word_index * WORD_LEN + byte_index] = word_bytes[byte_index];
+ byte_index += 1;
+ }
+ word_index += 1;
+ }
+ bytes
+}
+
+#[inline(always)]
+pub const fn le_bytes_from_words_64(words: &BlockWords) -> BlockBytes {
+ let mut bytes = [0u8; 64];
+ // This loop is super verbose because currently that's what it takes to be const.
+ let mut word_index = 0;
+ while word_index < bytes.len() / WORD_LEN {
+ let word_bytes = words[word_index].to_le_bytes();
+ let mut byte_index = 0;
+ while byte_index < WORD_LEN {
+ bytes[word_index * WORD_LEN + byte_index] = word_bytes[byte_index];
+ byte_index += 1;
+ }
+ word_index += 1;
+ }
+ bytes
+}
+
+#[inline(always)]
+pub const fn words_from_le_bytes_32(bytes: &CVBytes) -> CVWords {
+ let mut words = [0u32; 8];
+ // This loop is super verbose because currently that's what it takes to be const.
+ let mut word_index = 0;
+ while word_index < words.len() {
+ let mut word_bytes = [0u8; WORD_LEN];
+ let mut byte_index = 0;
+ while byte_index < WORD_LEN {
+ word_bytes[byte_index] = bytes[word_index * WORD_LEN + byte_index];
+ byte_index += 1;
+ }
+ words[word_index] = u32::from_le_bytes(word_bytes);
+ word_index += 1;
+ }
+ words
+}
+
+#[inline(always)]
+pub const fn words_from_le_bytes_64(bytes: &BlockBytes) -> BlockWords {
+ let mut words = [0u32; 16];
+ // This loop is super verbose because currently that's what it takes to be const.
+ let mut word_index = 0;
+ while word_index < words.len() {
+ let mut word_bytes = [0u8; WORD_LEN];
+ let mut byte_index = 0;
+ while byte_index < WORD_LEN {
+ word_bytes[byte_index] = bytes[word_index * WORD_LEN + byte_index];
+ byte_index += 1;
+ }
+ words[word_index] = u32::from_le_bytes(word_bytes);
+ word_index += 1;
+ }
+ words
+}
+
+#[test]
+fn test_byte_word_round_trips() {
+ let cv = *b"This is 32 LE bytes/eight words.";
+ assert_eq!(cv, le_bytes_from_words_32(&words_from_le_bytes_32(&cv)));
+ let block = *b"This is sixty-four little-endian bytes, or sixteen 32-bit words.";
+ assert_eq!(
+ block,
+ le_bytes_from_words_64(&words_from_le_bytes_64(&block)),
+ );
+}
+
+// The largest power of two less than or equal to `n`, used for left_len()
+// immediately below, and also directly in Hasher::update().
+pub fn largest_power_of_two_leq(n: usize) -> usize {
+ ((n / 2) + 1).next_power_of_two()
+}
+
+#[test]
+fn test_largest_power_of_two_leq() {
+ let input_output = &[
+ // The zero case is nonsensical, but it does work.
+ (0, 1),
+ (1, 1),
+ (2, 2),
+ (3, 2),
+ (4, 4),
+ (5, 4),
+ (6, 4),
+ (7, 4),
+ (8, 8),
+ // the largest possible usize
+ (usize::MAX, (usize::MAX >> 1) + 1),
+ ];
+ for &(input, output) in input_output {
+ assert_eq!(
+ output,
+ crate::largest_power_of_two_leq(input),
+ "wrong output for n={}",
+ input
+ );
+ }
+}
+
+// Given some input larger than one chunk, return the number of bytes that
+// should go in the left subtree. This is the largest power-of-2 number of
+// chunks that leaves at least 1 byte for the right subtree.
+pub fn left_len(content_len: usize) -> usize {
+ debug_assert!(content_len > CHUNK_LEN);
+ // Subtract 1 to reserve at least one byte for the right side.
+ let full_chunks = (content_len - 1) / CHUNK_LEN;
+ largest_power_of_two_leq(full_chunks) * CHUNK_LEN
+}
+
+#[test]
+fn test_left_len() {
+ let input_output = &[
+ (CHUNK_LEN + 1, CHUNK_LEN),
+ (2 * CHUNK_LEN - 1, CHUNK_LEN),
+ (2 * CHUNK_LEN, CHUNK_LEN),
+ (2 * CHUNK_LEN + 1, 2 * CHUNK_LEN),
+ (4 * CHUNK_LEN - 1, 2 * CHUNK_LEN),
+ (4 * CHUNK_LEN, 2 * CHUNK_LEN),
+ (4 * CHUNK_LEN + 1, 4 * CHUNK_LEN),
+ ];
+ for &(input, output) in input_output {
+ assert_eq!(left_len(input), output);
+ }
+}
diff --git a/rust/guts/src/portable.rs b/rust/guts/src/portable.rs
new file mode 100644
index 0000000..d597644
--- /dev/null
+++ b/rust/guts/src/portable.rs
@@ -0,0 +1,262 @@
+use crate::{
+ le_bytes_from_words_32, le_bytes_from_words_64, words_from_le_bytes_32, words_from_le_bytes_64,
+ BlockBytes, BlockWords, CVBytes, CVWords, Implementation, IV, MAX_SIMD_DEGREE, MSG_SCHEDULE,
+};
+
+const DEGREE: usize = MAX_SIMD_DEGREE;
+
+unsafe extern "C" fn degree() -> usize {
+ DEGREE
+}
+
+#[inline(always)]
+fn g(state: &mut BlockWords, a: usize, b: usize, c: usize, d: usize, x: u32, y: u32) {
+ state[a] = state[a].wrapping_add(state[b]).wrapping_add(x);
+ state[d] = (state[d] ^ state[a]).rotate_right(16);
+ state[c] = state[c].wrapping_add(state[d]);
+ state[b] = (state[b] ^ state[c]).rotate_right(12);
+ state[a] = state[a].wrapping_add(state[b]).wrapping_add(y);
+ state[d] = (state[d] ^ state[a]).rotate_right(8);
+ state[c] = state[c].wrapping_add(state[d]);
+ state[b] = (state[b] ^ state[c]).rotate_right(7);
+}
+
+#[inline(always)]
+fn round(state: &mut [u32; 16], msg: &BlockWords, round: usize) {
+ // Select the message schedule based on the round.
+ let schedule = MSG_SCHEDULE[round];
+
+ // Mix the columns.
+ g(state, 0, 4, 8, 12, msg[schedule[0]], msg[schedule[1]]);
+ g(state, 1, 5, 9, 13, msg[schedule[2]], msg[schedule[3]]);
+ g(state, 2, 6, 10, 14, msg[schedule[4]], msg[schedule[5]]);
+ g(state, 3, 7, 11, 15, msg[schedule[6]], msg[schedule[7]]);
+
+ // Mix the diagonals.
+ g(state, 0, 5, 10, 15, msg[schedule[8]], msg[schedule[9]]);
+ g(state, 1, 6, 11, 12, msg[schedule[10]], msg[schedule[11]]);
+ g(state, 2, 7, 8, 13, msg[schedule[12]], msg[schedule[13]]);
+ g(state, 3, 4, 9, 14, msg[schedule[14]], msg[schedule[15]]);
+}
+
+#[inline(always)]
+fn compress_inner(
+ block_words: &BlockWords,
+ block_len: u32,
+ cv_words: &CVWords,
+ counter: u64,
+ flags: u32,
+) -> [u32; 16] {
+ let mut state = [
+ cv_words[0],
+ cv_words[1],
+ cv_words[2],
+ cv_words[3],
+ cv_words[4],
+ cv_words[5],
+ cv_words[6],
+ cv_words[7],
+ IV[0],
+ IV[1],
+ IV[2],
+ IV[3],
+ counter as u32,
+ (counter >> 32) as u32,
+ block_len as u32,
+ flags as u32,
+ ];
+ for round_number in 0..7 {
+ round(&mut state, &block_words, round_number);
+ }
+ state
+}
+
+pub(crate) unsafe extern "C" fn compress(
+ block: *const BlockBytes,
+ block_len: u32,
+ cv: *const CVBytes,
+ counter: u64,
+ flags: u32,
+ out: *mut CVBytes,
+) {
+ let block_words = words_from_le_bytes_64(&*block);
+ let cv_words = words_from_le_bytes_32(&*cv);
+ let mut state = compress_inner(&block_words, block_len, &cv_words, counter, flags);
+ for word_index in 0..8 {
+ state[word_index] ^= state[word_index + 8];
+ }
+ *out = le_bytes_from_words_32(state[..8].try_into().unwrap());
+}
+
+pub(crate) unsafe extern "C" fn compress_xof(
+ block: *const BlockBytes,
+ block_len: u32,
+ cv: *const CVBytes,
+ counter: u64,
+ flags: u32,
+ out: *mut BlockBytes,
+) {
+ let block_words = words_from_le_bytes_64(&*block);
+ let cv_words = words_from_le_bytes_32(&*cv);
+ let mut state = compress_inner(&block_words, block_len, &cv_words, counter, flags);
+ for word_index in 0..8 {
+ state[word_index] ^= state[word_index + 8];
+ state[word_index + 8] ^= cv_words[word_index];
+ }
+ *out = le_bytes_from_words_64(&state);
+}
+
+pub(crate) unsafe extern "C" fn hash_chunks(
+ input: *const u8,
+ input_len: usize,
+ key: *const CVBytes,
+ counter: u64,
+ flags: u32,
+ transposed_output: *mut u32,
+) {
+ crate::hash_chunks_using_compress(
+ compress,
+ input,
+ input_len,
+ key,
+ counter,
+ flags,
+ transposed_output,
+ )
+}
+
+pub(crate) unsafe extern "C" fn hash_parents(
+ transposed_input: *const u32,
+ num_parents: usize,
+ key: *const CVBytes,
+ flags: u32,
+ transposed_output: *mut u32, // may overlap the input
+) {
+ crate::hash_parents_using_compress(
+ compress,
+ transposed_input,
+ num_parents,
+ key,
+ flags,
+ transposed_output,
+ )
+}
+
+pub(crate) unsafe extern "C" fn xof(
+ block: *const BlockBytes,
+ block_len: u32,
+ cv: *const CVBytes,
+ counter: u64,
+ flags: u32,
+ out: *mut u8,
+ out_len: usize,
+) {
+ crate::xof_using_compress_xof(
+ compress_xof,
+ block,
+ block_len,
+ cv,
+ counter,
+ flags,
+ out,
+ out_len,
+ )
+}
+
+pub(crate) unsafe extern "C" fn xof_xor(
+ block: *const BlockBytes,
+ block_len: u32,
+ cv: *const CVBytes,
+ counter: u64,
+ flags: u32,
+ out: *mut u8,
+ out_len: usize,
+) {
+ crate::xof_xor_using_compress_xof(
+ compress_xof,
+ block,
+ block_len,
+ cv,
+ counter,
+ flags,
+ out,
+ out_len,
+ )
+}
+
+pub(crate) unsafe extern "C" fn universal_hash(
+ input: *const u8,
+ input_len: usize,
+ key: *const CVBytes,
+ counter: u64,
+ out: *mut [u8; 16],
+) {
+ crate::universal_hash_using_compress(compress, input, input_len, key, counter, out)
+}
+
+pub fn implementation() -> Implementation {
+ Implementation::new(
+ degree,
+ compress,
+ hash_chunks,
+ hash_parents,
+ xof,
+ xof_xor,
+ universal_hash,
+ )
+}
+
+#[cfg(test)]
+mod test {
+ use super::*;
+
+ // This is circular but do it anyway.
+ #[test]
+ fn test_compress_vs_portable() {
+ crate::test::test_compress_vs_portable(&implementation());
+ }
+
+ #[test]
+ fn test_compress_vs_reference() {
+ crate::test::test_compress_vs_reference(&implementation());
+ }
+
+ // This is circular but do it anyway.
+ #[test]
+ fn test_hash_chunks_vs_portable() {
+ crate::test::test_hash_chunks_vs_portable(&implementation());
+ }
+
+ // This is circular but do it anyway.
+ #[test]
+ fn test_hash_parents_vs_portable() {
+ crate::test::test_hash_parents_vs_portable(&implementation());
+ }
+
+ #[test]
+ fn test_chunks_and_parents_vs_reference() {
+ crate::test::test_chunks_and_parents_vs_reference(&implementation());
+ }
+
+ // This is circular but do it anyway.
+ #[test]
+ fn test_xof_vs_portable() {
+ crate::test::test_xof_vs_portable(&implementation());
+ }
+
+ #[test]
+ fn test_xof_vs_reference() {
+ crate::test::test_xof_vs_reference(&implementation());
+ }
+
+ // This is circular but do it anyway.
+ #[test]
+ fn test_universal_hash_vs_portable() {
+ crate::test::test_universal_hash_vs_portable(&implementation());
+ }
+
+ #[test]
+ fn test_universal_hash_vs_reference() {
+ crate::test::test_universal_hash_vs_reference(&implementation());
+ }
+}
diff --git a/rust/guts/src/test.rs b/rust/guts/src/test.rs
new file mode 100644
index 0000000..83bd790
--- /dev/null
+++ b/rust/guts/src/test.rs
@@ -0,0 +1,523 @@
+use crate::*;
+
+pub const TEST_KEY: CVBytes = *b"whats the Elvish word for friend";
+
+// Test a few different initial counter values.
+// - 0: The base case.
+// - i32::MAX: *No* overflow. But carry bugs in tricky SIMD code can screw this up, if you XOR when
+// you're supposed to ANDNOT.
+// - u32::MAX: The low word of the counter overflows for all inputs except the first.
+// - (42 << 32) + u32::MAX: Same but with a non-zero value in the high word.
+const INITIAL_COUNTERS: [u64; 4] = [
+ 0,
+ i32::MAX as u64,
+ u32::MAX as u64,
+ (42u64 << 32) + u32::MAX as u64,
+];
+
+const BLOCK_LENGTHS: [usize; 4] = [0, 1, 63, 64];
+
+pub fn paint_test_input(buf: &mut [u8]) {
+ for (i, b) in buf.iter_mut().enumerate() {
+ *b = (i % 251) as u8;
+ }
+}
+
+pub fn test_compress_vs_portable(test_impl: &Implementation) {
+ for block_len in BLOCK_LENGTHS {
+ dbg!(block_len);
+ let mut block = [0; BLOCK_LEN];
+ paint_test_input(&mut block[..block_len]);
+ for counter in INITIAL_COUNTERS {
+ dbg!(counter);
+ let portable_cv = portable::implementation().compress(
+ &block,
+ block_len as u32,
+ &TEST_KEY,
+ counter,
+ KEYED_HASH,
+ );
+
+ let test_cv =
+ test_impl.compress(&block, block_len as u32, &TEST_KEY, counter, KEYED_HASH);
+
+ assert_eq!(portable_cv, test_cv);
+ }
+ }
+}
+
+pub fn test_compress_vs_reference(test_impl: &Implementation) {
+ for block_len in BLOCK_LENGTHS {
+ dbg!(block_len);
+ let mut block = [0; BLOCK_LEN];
+ paint_test_input(&mut block[..block_len]);
+
+ let mut ref_hasher = reference_impl::Hasher::new_keyed(&TEST_KEY);
+ ref_hasher.update(&block[..block_len]);
+ let mut ref_hash = [0u8; 32];
+ ref_hasher.finalize(&mut ref_hash);
+
+ let test_cv = test_impl.compress(
+ &block,
+ block_len as u32,
+ &TEST_KEY,
+ 0,
+ CHUNK_START | CHUNK_END | ROOT | KEYED_HASH,
+ );
+
+ assert_eq!(ref_hash, test_cv);
+ }
+}
+
+fn check_transposed_eq(output_a: &TransposedVectors, output_b: &TransposedVectors) {
+ if output_a == output_b {
+ return;
+ }
+ for cv_index in 0..2 * MAX_SIMD_DEGREE {
+ let cv_a = output_a.extract_cv(cv_index);
+ let cv_b = output_b.extract_cv(cv_index);
+ if cv_a == [0; 32] && cv_b == [0; 32] {
+ println!("CV {cv_index:2} empty");
+ } else if cv_a == cv_b {
+ println!("CV {cv_index:2} matches");
+ } else {
+ println!("CV {cv_index:2} mismatch:");
+ println!(" {}", hex::encode(cv_a));
+ println!(" {}", hex::encode(cv_b));
+ }
+ }
+ panic!("transposed outputs are not equal");
+}
+
+pub fn test_hash_chunks_vs_portable(test_impl: &Implementation) {
+ assert!(test_impl.degree() <= MAX_SIMD_DEGREE);
+ dbg!(test_impl.degree() * CHUNK_LEN);
+ // Allocate 4 extra bytes of padding so we can make aligned slices.
+ let mut input_buf = [0u8; 2 * 2 * MAX_SIMD_DEGREE * CHUNK_LEN + 4];
+ let mut input_slice = &mut input_buf[..];
+ // Make sure the start of the input is word-aligned.
+ while input_slice.as_ptr() as usize % 4 != 0 {
+ input_slice = &mut input_slice[1..];
+ }
+ let (aligned_input, mut unaligned_input) =
+ input_slice.split_at_mut(2 * MAX_SIMD_DEGREE * CHUNK_LEN);
+ unaligned_input = &mut unaligned_input[1..][..2 * MAX_SIMD_DEGREE * CHUNK_LEN];
+ assert_eq!(aligned_input.as_ptr() as usize % 4, 0);
+ assert_eq!(unaligned_input.as_ptr() as usize % 4, 1);
+ paint_test_input(aligned_input);
+ paint_test_input(unaligned_input);
+ // Try just below, equal to, and just above every whole number of chunks.
+ let mut input_2_lengths = Vec::new();
+ let mut next_len = 2 * CHUNK_LEN;
+ loop {
+ // 95 is one whole block plus one interesting part of another
+ input_2_lengths.push(next_len - 95);
+ input_2_lengths.push(next_len);
+ if next_len == test_impl.degree() * CHUNK_LEN {
+ break;
+ }
+ input_2_lengths.push(next_len + 95);
+ next_len += CHUNK_LEN;
+ }
+ for input_2_len in input_2_lengths {
+ dbg!(input_2_len);
+ let aligned_input1 = &aligned_input[..test_impl.degree() * CHUNK_LEN];
+ let aligned_input2 = &aligned_input[test_impl.degree() * CHUNK_LEN..][..input_2_len];
+ let unaligned_input1 = &unaligned_input[..test_impl.degree() * CHUNK_LEN];
+ let unaligned_input2 = &unaligned_input[test_impl.degree() * CHUNK_LEN..][..input_2_len];
+ for initial_counter in INITIAL_COUNTERS {
+ dbg!(initial_counter);
+ // Make two calls, to test the output_column parameter.
+ let mut portable_output = TransposedVectors::new();
+ let (portable_left, portable_right) =
+ test_impl.split_transposed_vectors(&mut portable_output);
+ portable::implementation().hash_chunks(
+ aligned_input1,
+ &IV_BYTES,
+ initial_counter,
+ 0,
+ portable_left,
+ );
+ portable::implementation().hash_chunks(
+ aligned_input2,
+ &TEST_KEY,
+ initial_counter + test_impl.degree() as u64,
+ KEYED_HASH,
+ portable_right,
+ );
+
+ let mut test_output = TransposedVectors::new();
+ let (test_left, test_right) = test_impl.split_transposed_vectors(&mut test_output);
+ test_impl.hash_chunks(aligned_input1, &IV_BYTES, initial_counter, 0, test_left);
+ test_impl.hash_chunks(
+ aligned_input2,
+ &TEST_KEY,
+ initial_counter + test_impl.degree() as u64,
+ KEYED_HASH,
+ test_right,
+ );
+ check_transposed_eq(&portable_output, &test_output);
+
+ // Do the same thing with unaligned input.
+ let mut unaligned_test_output = TransposedVectors::new();
+ let (unaligned_left, unaligned_right) =
+ test_impl.split_transposed_vectors(&mut unaligned_test_output);
+ test_impl.hash_chunks(
+ unaligned_input1,
+ &IV_BYTES,
+ initial_counter,
+ 0,
+ unaligned_left,
+ );
+ test_impl.hash_chunks(
+ unaligned_input2,
+ &TEST_KEY,
+ initial_counter + test_impl.degree() as u64,
+ KEYED_HASH,
+ unaligned_right,
+ );
+ check_transposed_eq(&portable_output, &unaligned_test_output);
+ }
+ }
+}
+
+fn painted_transposed_input() -> TransposedVectors {
+ let mut vectors = TransposedVectors::new();
+ let mut val = 0;
+ for col in 0..2 * MAX_SIMD_DEGREE {
+ for row in 0..8 {
+ vectors.0[row][col] = val;
+ val += 1;
+ }
+ }
+ vectors
+}
+
+pub fn test_hash_parents_vs_portable(test_impl: &Implementation) {
+ assert!(test_impl.degree() <= MAX_SIMD_DEGREE);
+ let input = painted_transposed_input();
+ for num_parents in 2..=(test_impl.degree() / 2) {
+ dbg!(num_parents);
+ let mut portable_output = TransposedVectors::new();
+ let (portable_left, portable_right) =
+ test_impl.split_transposed_vectors(&mut portable_output);
+ portable::implementation().hash_parents(
+ &input,
+ 2 * num_parents, // num_cvs
+ &IV_BYTES,
+ 0,
+ portable_left,
+ );
+ portable::implementation().hash_parents(
+ &input,
+ 2 * num_parents, // num_cvs
+ &TEST_KEY,
+ KEYED_HASH,
+ portable_right,
+ );
+
+ let mut test_output = TransposedVectors::new();
+ let (test_left, test_right) = test_impl.split_transposed_vectors(&mut test_output);
+ test_impl.hash_parents(
+ &input,
+ 2 * num_parents, // num_cvs
+ &IV_BYTES,
+ 0,
+ test_left,
+ );
+ test_impl.hash_parents(
+ &input,
+ 2 * num_parents, // num_cvs
+ &TEST_KEY,
+ KEYED_HASH,
+ test_right,
+ );
+
+ check_transposed_eq(&portable_output, &test_output);
+ }
+}
+
+fn hash_with_chunks_and_parents_recurse(
+ test_impl: &Implementation,
+ input: &[u8],
+ counter: u64,
+ output: TransposedSplit,
+) -> usize {
+ assert!(input.len() > 0);
+ if input.len() <= test_impl.degree() * CHUNK_LEN {
+ return test_impl.hash_chunks(input, &IV_BYTES, counter, 0, output);
+ }
+ let (left_input, right_input) = input.split_at(left_len(input.len()));
+ let mut child_output = TransposedVectors::new();
+ let (left_output, right_output) = test_impl.split_transposed_vectors(&mut child_output);
+ let mut children =
+ hash_with_chunks_and_parents_recurse(test_impl, left_input, counter, left_output);
+ assert_eq!(children, test_impl.degree());
+ children += hash_with_chunks_and_parents_recurse(
+ test_impl,
+ right_input,
+ counter + (left_input.len() / CHUNK_LEN) as u64,
+ right_output,
+ );
+ test_impl.hash_parents(&child_output, children, &IV_BYTES, PARENT, output)
+}
+
+// Note: This test implementation doesn't support the 1-chunk-or-less case.
+fn root_hash_with_chunks_and_parents(test_impl: &Implementation, input: &[u8]) -> CVBytes {
+ // TODO: handle the 1-chunk case?
+ assert!(input.len() > CHUNK_LEN);
+ let mut cvs = TransposedVectors::new();
+ // The right half of these vectors are never used.
+ let (cvs_left, _) = test_impl.split_transposed_vectors(&mut cvs);
+ let mut num_cvs = hash_with_chunks_and_parents_recurse(test_impl, input, 0, cvs_left);
+ while num_cvs > 2 {
+ num_cvs = test_impl.reduce_parents(&mut cvs, num_cvs, &IV_BYTES, 0);
+ }
+ test_impl.compress(
+ &cvs.extract_parent_node(0),
+ BLOCK_LEN as u32,
+ &IV_BYTES,
+ 0,
+ PARENT | ROOT,
+ )
+}
+
+pub fn test_chunks_and_parents_vs_reference(test_impl: &Implementation) {
+ assert_eq!(test_impl.degree().count_ones(), 1, "power of 2");
+ const MAX_INPUT_LEN: usize = 2 * MAX_SIMD_DEGREE * CHUNK_LEN;
+ let mut input_buf = [0u8; MAX_INPUT_LEN];
+ paint_test_input(&mut input_buf);
+ // Try just below, equal to, and just above every whole number of chunks, except that
+ // root_hash_with_chunks_and_parents doesn't support the 1-chunk-or-less case.
+ let mut test_lengths = vec![CHUNK_LEN + 1];
+ let mut next_len = 2 * CHUNK_LEN;
+ loop {
+ // 95 is one whole block plus one interesting part of another
+ test_lengths.push(next_len - 95);
+ test_lengths.push(next_len);
+ if next_len == MAX_INPUT_LEN {
+ break;
+ }
+ test_lengths.push(next_len + 95);
+ next_len += CHUNK_LEN;
+ }
+ for test_len in test_lengths {
+ dbg!(test_len);
+ let input = &input_buf[..test_len];
+
+ let mut ref_hasher = reference_impl::Hasher::new();
+ ref_hasher.update(&input);
+ let mut ref_hash = [0u8; 32];
+ ref_hasher.finalize(&mut ref_hash);
+
+ let test_hash = root_hash_with_chunks_and_parents(test_impl, input);
+
+ assert_eq!(ref_hash, test_hash);
+ }
+}
+
+pub fn test_xof_vs_portable(test_impl: &Implementation) {
+ let flags = CHUNK_START | CHUNK_END | KEYED_HASH;
+ for counter in INITIAL_COUNTERS {
+ dbg!(counter);
+ for input_len in [0, 1, BLOCK_LEN] {
+ dbg!(input_len);
+ let mut input_block = [0u8; BLOCK_LEN];
+ for byte_index in 0..input_len {
+ input_block[byte_index] = byte_index as u8 + 42;
+ }
+ // Try equal to and partway through every whole number of output blocks.
+ const MAX_OUTPUT_LEN: usize = 2 * MAX_SIMD_DEGREE * BLOCK_LEN;
+ let mut output_lengths = Vec::new();
+ let mut next_len = 0;
+ loop {
+ output_lengths.push(next_len);
+ if next_len == MAX_OUTPUT_LEN {
+ break;
+ }
+ output_lengths.push(next_len + 31);
+ next_len += BLOCK_LEN;
+ }
+ for output_len in output_lengths {
+ dbg!(output_len);
+ let mut portable_output = [0xff; MAX_OUTPUT_LEN];
+ portable::implementation().xof(
+ &input_block,
+ input_len as u32,
+ &TEST_KEY,
+ counter,
+ flags,
+ &mut portable_output[..output_len],
+ );
+ let mut test_output = [0xff; MAX_OUTPUT_LEN];
+ test_impl.xof(
+ &input_block,
+ input_len as u32,
+ &TEST_KEY,
+ counter,
+ flags,
+ &mut test_output[..output_len],
+ );
+ assert_eq!(portable_output, test_output);
+
+ // Double check that the implementation didn't overwrite.
+ assert!(test_output[output_len..].iter().all(|&b| b == 0xff));
+
+ // The first XOR cancels out the output.
+ test_impl.xof_xor(
+ &input_block,
+ input_len as u32,
+ &TEST_KEY,
+ counter,
+ flags,
+ &mut test_output[..output_len],
+ );
+ assert!(test_output[..output_len].iter().all(|&b| b == 0));
+ assert!(test_output[output_len..].iter().all(|&b| b == 0xff));
+
+ // The second XOR restores out the output.
+ test_impl.xof_xor(
+ &input_block,
+ input_len as u32,
+ &TEST_KEY,
+ counter,
+ flags,
+ &mut test_output[..output_len],
+ );
+ assert_eq!(portable_output, test_output);
+ assert!(test_output[output_len..].iter().all(|&b| b == 0xff));
+ }
+ }
+ }
+}
+
+pub fn test_xof_vs_reference(test_impl: &Implementation) {
+ let input = b"hello world";
+ let mut input_block = [0; BLOCK_LEN];
+ input_block[..input.len()].copy_from_slice(input);
+
+ const MAX_OUTPUT_LEN: usize = 2 * MAX_SIMD_DEGREE * BLOCK_LEN;
+ let mut ref_output = [0; MAX_OUTPUT_LEN];
+ let mut ref_hasher = reference_impl::Hasher::new_keyed(&TEST_KEY);
+ ref_hasher.update(input);
+ ref_hasher.finalize(&mut ref_output);
+
+ // Try equal to and partway through every whole number of output blocks.
+ let mut output_lengths = vec![0, 1, 31];
+ let mut next_len = BLOCK_LEN;
+ loop {
+ output_lengths.push(next_len);
+ if next_len == MAX_OUTPUT_LEN {
+ break;
+ }
+ output_lengths.push(next_len + 31);
+ next_len += BLOCK_LEN;
+ }
+
+ for output_len in output_lengths {
+ dbg!(output_len);
+ let mut test_output = [0; MAX_OUTPUT_LEN];
+ test_impl.xof(
+ &input_block,
+ input.len() as u32,
+ &TEST_KEY,
+ 0,
+ KEYED_HASH | CHUNK_START | CHUNK_END,
+ &mut test_output[..output_len],
+ );
+ assert_eq!(ref_output[..output_len], test_output[..output_len]);
+
+ // Double check that the implementation didn't overwrite.
+ assert!(test_output[output_len..].iter().all(|&b| b == 0));
+
+ // Do it again starting from block 1.
+ if output_len >= BLOCK_LEN {
+ test_impl.xof(
+ &input_block,
+ input.len() as u32,
+ &TEST_KEY,
+ 1,
+ KEYED_HASH | CHUNK_START | CHUNK_END,
+ &mut test_output[..output_len - BLOCK_LEN],
+ );
+ assert_eq!(
+ ref_output[BLOCK_LEN..output_len],
+ test_output[..output_len - BLOCK_LEN],
+ );
+ }
+ }
+}
+
+pub fn test_universal_hash_vs_portable(test_impl: &Implementation) {
+ const MAX_INPUT_LEN: usize = 2 * MAX_SIMD_DEGREE * BLOCK_LEN;
+ let mut input_buf = [0; MAX_INPUT_LEN];
+ paint_test_input(&mut input_buf);
+ // Try equal to and partway through every whole number of input blocks.
+ let mut input_lengths = vec![0, 1, 31];
+ let mut next_len = BLOCK_LEN;
+ loop {
+ input_lengths.push(next_len);
+ if next_len == MAX_INPUT_LEN {
+ break;
+ }
+ input_lengths.push(next_len + 31);
+ next_len += BLOCK_LEN;
+ }
+ for input_len in input_lengths {
+ dbg!(input_len);
+ for counter in INITIAL_COUNTERS {
+ dbg!(counter);
+ let portable_output = portable::implementation().universal_hash(
+ &input_buf[..input_len],
+ &TEST_KEY,
+ counter,
+ );
+ let test_output = test_impl.universal_hash(&input_buf[..input_len], &TEST_KEY, counter);
+ assert_eq!(portable_output, test_output);
+ }
+ }
+}
+
+fn reference_impl_universal_hash(input: &[u8], key: &CVBytes) -> [u8; UNIVERSAL_HASH_LEN] {
+ // The reference_impl doesn't support XOF seeking, so we have to materialize an entire extended
+ // output to seek to a block.
+ const MAX_BLOCKS: usize = 2 * MAX_SIMD_DEGREE;
+ assert!(input.len() / BLOCK_LEN <= MAX_BLOCKS);
+ let mut output_buffer: [u8; BLOCK_LEN * MAX_BLOCKS] = [0u8; BLOCK_LEN * MAX_BLOCKS];
+ let mut result = [0u8; UNIVERSAL_HASH_LEN];
+ let mut block_start = 0;
+ while block_start < input.len() {
+ let block_len = cmp::min(input.len() - block_start, BLOCK_LEN);
+ let mut ref_hasher = reference_impl::Hasher::new_keyed(key);
+ ref_hasher.update(&input[block_start..block_start + block_len]);
+ ref_hasher.finalize(&mut output_buffer[..block_start + UNIVERSAL_HASH_LEN]);
+ for byte_index in 0..UNIVERSAL_HASH_LEN {
+ result[byte_index] ^= output_buffer[block_start + byte_index];
+ }
+ block_start += BLOCK_LEN;
+ }
+ result
+}
+
+pub fn test_universal_hash_vs_reference(test_impl: &Implementation) {
+ const MAX_INPUT_LEN: usize = 2 * MAX_SIMD_DEGREE * BLOCK_LEN;
+ let mut input_buf = [0; MAX_INPUT_LEN];
+ paint_test_input(&mut input_buf);
+ // Try equal to and partway through every whole number of input blocks.
+ let mut input_lengths = vec![0, 1, 31];
+ let mut next_len = BLOCK_LEN;
+ loop {
+ input_lengths.push(next_len);
+ if next_len == MAX_INPUT_LEN {
+ break;
+ }
+ input_lengths.push(next_len + 31);
+ next_len += BLOCK_LEN;
+ }
+ for input_len in input_lengths {
+ dbg!(input_len);
+ let ref_output = reference_impl_universal_hash(&input_buf[..input_len], &TEST_KEY);
+ let test_output = test_impl.universal_hash(&input_buf[..input_len], &TEST_KEY, 0);
+ assert_eq!(ref_output, test_output);
+ }
+}