aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJack O'Connor <[email protected]>2023-05-28 19:36:32 +0200
committerJack O'Connor <[email protected]>2024-08-15 16:02:10 -0700
commite74172acc35f599497243946e2f3cb3aa3f83f40 (patch)
tree7c417e5e4edc4795a9e0d528929adcf09d9e069e
parent80b83effbd50425939483e5503b186db4dac4d9d (diff)
integrate xof_many with the Rust implementation and with Rust and C tests
-rw-r--r--c/blake3.c3
-rw-r--r--c/blake3_c_rust_bindings/src/lib.rs18
-rw-r--r--c/blake3_c_rust_bindings/src/test.rs75
-rw-r--r--c/blake3_impl.h11
-rw-r--r--src/ffi_avx512.rs37
-rw-r--r--src/lib.rs54
-rw-r--r--src/platform.rs20
-rw-r--r--src/portable.rs20
-rw-r--r--src/test.rs70
9 files changed, 291 insertions, 17 deletions
diff --git a/c/blake3.c b/c/blake3.c
index b6768d2..7e6d01e 100644
--- a/c/blake3.c
+++ b/c/blake3.c
@@ -88,6 +88,9 @@ INLINE void output_chaining_value(const output_t *self, uint8_t cv[32]) {
INLINE void output_root_bytes(const output_t *self, uint64_t seek, uint8_t *out,
size_t out_len) {
+ if (out_len == 0) {
+ return;
+ }
uint64_t output_block_counter = seek / 64;
size_t offset_within_block = seek % 64;
uint8_t wide_buf[64];
diff --git a/c/blake3_c_rust_bindings/src/lib.rs b/c/blake3_c_rust_bindings/src/lib.rs
index 41e4938..ac7880a 100644
--- a/c/blake3_c_rust_bindings/src/lib.rs
+++ b/c/blake3_c_rust_bindings/src/lib.rs
@@ -177,6 +177,15 @@ pub mod ffi {
flags_end: u8,
out: *mut u8,
);
+ pub fn blake3_xof_many_portable(
+ cv: *const u32,
+ block: *const u8,
+ block_len: u8,
+ counter: u64,
+ flags: u8,
+ out: *mut u8,
+ outblocks: usize,
+ );
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
@@ -282,6 +291,15 @@ pub mod ffi {
flags_end: u8,
out: *mut u8,
);
+ pub fn blake3_xof_many_avx512(
+ cv: *const u32,
+ block: *const u8,
+ block_len: u8,
+ counter: u64,
+ flags: u8,
+ out: *mut u8,
+ outblocks: usize,
+ );
}
}
diff --git a/c/blake3_c_rust_bindings/src/test.rs b/c/blake3_c_rust_bindings/src/test.rs
index 0730d93..e6baff1 100644
--- a/c/blake3_c_rust_bindings/src/test.rs
+++ b/c/blake3_c_rust_bindings/src/test.rs
@@ -359,6 +359,81 @@ fn test_hash_many_neon() {
test_hash_many_fn(crate::ffi::neon::blake3_hash_many_neon);
}
+type XofManyFunction = unsafe extern "C" fn(
+ cv: *const u32,
+ block: *const u8,
+ block_len: u8,
+ counter: u64,
+ flags: u8,
+ out: *mut u8,
+ outblocks: usize,
+);
+
+// A shared helper function for platform-specific tests.
+pub fn test_xof_many_fn(xof_many_function: XofManyFunction) {
+ // Test a few different initial counter values.
+ // - 0: The base case.
+ // - u32::MAX: The low word of the counter overflows for all inputs except the first.
+ // - i32::MAX: *No* overflow. But carry bugs in tricky SIMD code can screw this up, if you XOR
+ // when you're supposed to ANDNOT...
+ let initial_counters = [0, u32::MAX as u64, i32::MAX as u64];
+ for counter in initial_counters {
+ #[cfg(feature = "std")]
+ dbg!(counter);
+
+ let mut block = [0; BLOCK_LEN];
+ let block_len = 42;
+ crate::test::paint_test_input(&mut block[..block_len]);
+ let cv = [40, 41, 42, 43, 44, 45, 46, 47];
+ let flags = KEYED_HASH;
+ // 31 (16 + 8 + 4 + 2 + 1) outputs
+ const OUTPUT_SIZE: usize = 31 * BLOCK_LEN;
+
+ let mut portable_out = [0u8; OUTPUT_SIZE];
+ unsafe {
+ crate::ffi::blake3_xof_many_portable(
+ cv.as_ptr(),
+ block.as_ptr(),
+ block_len as u8,
+ counter,
+ flags,
+ portable_out.as_mut_ptr(),
+ OUTPUT_SIZE / BLOCK_LEN,
+ );
+ }
+
+ let mut test_out = [0u8; OUTPUT_SIZE];
+ unsafe {
+ xof_many_function(
+ cv.as_ptr(),
+ block.as_ptr(),
+ block_len as u8,
+ counter,
+ flags,
+ test_out.as_mut_ptr(),
+ OUTPUT_SIZE / BLOCK_LEN,
+ );
+ }
+
+ assert_eq!(portable_out, test_out);
+ }
+}
+
+// Testing the portable implementation against itself is circular, but why not.
+#[test]
+fn test_xof_many_portable() {
+ test_xof_many_fn(crate::ffi::blake3_xof_many_portable);
+}
+
+#[test]
+#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
+fn test_xof_many_avx512() {
+ if !crate::avx512_detected() {
+ return;
+ }
+ test_xof_many_fn(crate::ffi::x86::blake3_xof_many_avx512);
+}
+
#[test]
fn test_compare_reference_impl() {
const OUT: usize = 303; // more than 64, not a multiple of 4
diff --git a/c/blake3_impl.h b/c/blake3_impl.h
index b3abce2..66a1707 100644
--- a/c/blake3_impl.h
+++ b/c/blake3_impl.h
@@ -282,17 +282,16 @@ void blake3_compress_xof_avx512(const uint32_t cv[8],
uint8_t block_len, uint64_t counter,
uint8_t flags, uint8_t out[64]);
-
-void blake3_xof_many_avx512(const uint32_t cv[8],
- const uint8_t block[BLAKE3_BLOCK_LEN],
- uint8_t block_len, uint64_t counter, uint8_t flags,
- uint8_t* out, size_t outblocks);
-
void blake3_hash_many_avx512(const uint8_t *const *inputs, size_t num_inputs,
size_t blocks, const uint32_t key[8],
uint64_t counter, bool increment_counter,
uint8_t flags, uint8_t flags_start,
uint8_t flags_end, uint8_t *out);
+
+void blake3_xof_many_avx512(const uint32_t cv[8],
+ const uint8_t block[BLAKE3_BLOCK_LEN],
+ uint8_t block_len, uint64_t counter, uint8_t flags,
+ uint8_t* out, size_t outblocks);
#endif
#endif
diff --git a/src/ffi_avx512.rs b/src/ffi_avx512.rs
index 884f481..d34e53c 100644
--- a/src/ffi_avx512.rs
+++ b/src/ffi_avx512.rs
@@ -60,6 +60,26 @@ pub unsafe fn hash_many<const N: usize>(
)
}
+// Unsafe because this may only be called on platforms supporting AVX-512.
+pub unsafe fn xof_many(
+ cv: &CVWords,
+ block: &[u8; BLOCK_LEN],
+ block_len: u8,
+ counter: u64,
+ flags: u8,
+ out: &mut [u8],
+) {
+ ffi::blake3_xof_many_avx512(
+ cv.as_ptr(),
+ block.as_ptr(),
+ block_len,
+ counter,
+ flags,
+ out.as_mut_ptr(),
+ out.len() / BLOCK_LEN,
+ );
+}
+
pub mod ffi {
extern "C" {
pub fn blake3_compress_in_place_avx512(
@@ -89,6 +109,15 @@ pub mod ffi {
flags_end: u8,
out: *mut u8,
);
+ pub fn blake3_xof_many_avx512(
+ cv: *const u32,
+ block: *const u8,
+ block_len: u8,
+ counter: u64,
+ flags: u8,
+ out: *mut u8,
+ outblocks: usize,
+ );
}
}
@@ -111,4 +140,12 @@ mod test {
}
crate::test::test_hash_many_fn(hash_many, hash_many);
}
+
+ #[test]
+ fn test_xof_many() {
+ if !crate::platform::avx512_detected() {
+ return;
+ }
+ crate::test::test_xof_many_fn(xof_many);
+ }
}
diff --git a/src/lib.rs b/src/lib.rs
index 066b925..7239edb 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -1634,6 +1634,21 @@ impl OutputReader {
}
}
+ // It would be more natural if this helper took &mut &mut [u8], but that doesn't seem to be allowed.
+ fn fill_one_block<'a>(&mut self, mut buf: &'a mut [u8]) -> &'a mut [u8] {
+ let block: [u8; BLOCK_LEN] = self.inner.root_output_block();
+ let output_bytes = &block[self.position_within_block as usize..];
+ let take = cmp::min(buf.len(), output_bytes.len());
+ buf[..take].copy_from_slice(&output_bytes[..take]);
+ buf = &mut buf[take..];
+ self.position_within_block += take as u8;
+ if self.position_within_block == BLOCK_LEN as u8 {
+ self.inner.counter += 1;
+ self.position_within_block = 0;
+ }
+ buf
+ }
+
/// Fill a buffer with output bytes and advance the position of the
/// `OutputReader`. This is equivalent to [`Read::read`], except that it
/// doesn't return a `Result`. Both methods always fill the entire buffer.
@@ -1650,17 +1665,34 @@ impl OutputReader {
///
/// [`Read::read`]: #method.read
pub fn fill(&mut self, mut buf: &mut [u8]) {
- while !buf.is_empty() {
- let block: [u8; BLOCK_LEN] = self.inner.root_output_block();
- let output_bytes = &block[self.position_within_block as usize..];
- let take = cmp::min(buf.len(), output_bytes.len());
- buf[..take].copy_from_slice(&output_bytes[..take]);
- buf = &mut buf[take..];
- self.position_within_block += take as u8;
- if self.position_within_block == BLOCK_LEN as u8 {
- self.inner.counter += 1;
- self.position_within_block = 0;
- }
+ if buf.is_empty() {
+ return;
+ }
+
+ // If we're partway through a block, try to get to a block boundary.
+ if self.position_within_block != 0 {
+ buf = self.fill_one_block(buf);
+ }
+
+ let full_blocks = buf.len() / BLOCK_LEN;
+ if full_blocks > 0 {
+ debug_assert_eq!(0, self.position_within_block);
+ self.inner.platform.xof_many(
+ &self.inner.input_chaining_value,
+ &self.inner.block,
+ self.inner.block_len,
+ self.inner.counter,
+ self.inner.flags | ROOT,
+ buf,
+ );
+ self.inner.counter += full_blocks as u64;
+ buf = &mut buf[full_blocks * BLOCK_LEN..];
+ }
+
+ if !buf.is_empty() {
+ debug_assert!(buf.len() < BLOCK_LEN);
+ buf = self.fill_one_block(buf);
+ debug_assert!(buf.is_empty());
}
}
diff --git a/src/platform.rs b/src/platform.rs
index 79bc9a3..1858add 100644
--- a/src/platform.rs
+++ b/src/platform.rs
@@ -277,6 +277,26 @@ impl Platform {
}
}
+ pub fn xof_many(
+ &self,
+ cv: &CVWords,
+ block: &[u8; BLOCK_LEN],
+ block_len: u8,
+ counter: u64,
+ flags: u8,
+ out: &mut [u8],
+ ) {
+ match self {
+ // Safe because detect() checked for platform support.
+ #[cfg(blake3_avx512_ffi)]
+ #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
+ Platform::AVX512 => unsafe {
+ crate::avx512::xof_many(cv, block, block_len, counter, flags, out)
+ },
+ _ => crate::portable::xof_many(cv, block, block_len, counter, flags, out),
+ }
+ }
+
// Explicit platform constructors, for benchmarks.
pub fn portable() -> Self {
diff --git a/src/portable.rs b/src/portable.rs
index 7af6828..822cd0d 100644
--- a/src/portable.rs
+++ b/src/portable.rs
@@ -177,6 +177,20 @@ pub fn hash_many<const N: usize>(
}
}
+pub fn xof_many(
+ cv: &CVWords,
+ block: &[u8; BLOCK_LEN],
+ block_len: u8,
+ mut counter: u64,
+ flags: u8,
+ out: &mut [u8],
+) {
+ for out_block in out.chunks_exact_mut(64) {
+ out_block.copy_from_slice(&compress_xof(cv, block, block_len, counter, flags));
+ counter += 1;
+ }
+}
+
#[cfg(test)]
pub mod test {
use super::*;
@@ -195,4 +209,10 @@ pub mod test {
fn test_hash_many() {
crate::test::test_hash_many_fn(hash_many, hash_many);
}
+
+ // Ditto.
+ #[test]
+ fn test_xof_many() {
+ crate::test::test_xof_many_fn(xof_many);
+ }
}
diff --git a/src/test.rs b/src/test.rs
index b716e1b..39dfe46 100644
--- a/src/test.rs
+++ b/src/test.rs
@@ -206,6 +206,54 @@ pub fn test_hash_many_fn(
}
}
+type XofManyFunction = unsafe fn(
+ cv: &CVWords,
+ block: &[u8; BLOCK_LEN],
+ block_len: u8,
+ counter: u64,
+ flags: u8,
+ out: &mut [u8],
+);
+
+// A shared helper function for platform-specific tests.
+pub fn test_xof_many_fn(xof_many_function: XofManyFunction) {
+ // Test a few different initial counter values.
+ // - 0: The base case.
+ // - u32::MAX: The low word of the counter overflows for all inputs except the first.
+ // - i32::MAX: *No* overflow. But carry bugs in tricky SIMD code can screw this up, if you XOR
+ // when you're supposed to ANDNOT...
+ let initial_counters = [0, u32::MAX as u64, i32::MAX as u64];
+ for counter in initial_counters {
+ #[cfg(feature = "std")]
+ dbg!(counter);
+
+ let mut block = [0; BLOCK_LEN];
+ let block_len = 42;
+ crate::test::paint_test_input(&mut block[..block_len]);
+ let cv = [40, 41, 42, 43, 44, 45, 46, 47];
+ let flags = crate::KEYED_HASH;
+ // 31 (16 + 8 + 4 + 2 + 1) outputs
+ const OUTPUT_SIZE: usize = 31 * BLOCK_LEN;
+
+ let mut portable_out = [0u8; OUTPUT_SIZE];
+ crate::portable::xof_many(
+ &cv,
+ &block,
+ block_len as u8,
+ counter,
+ flags,
+ &mut portable_out,
+ );
+
+ let mut test_out = [0u8; OUTPUT_SIZE];
+ unsafe {
+ xof_many_function(&cv, &block, block_len as u8, counter, flags, &mut test_out);
+ }
+
+ assert_eq!(portable_out, test_out);
+ }
+}
+
#[test]
fn test_key_bytes_equal_key_words() {
assert_eq!(
@@ -373,6 +421,28 @@ fn test_compare_reference_impl() {
}
}
+#[test]
+fn test_xof_partial_blocks() {
+ const OUT_LEN: usize = 6 * BLOCK_LEN;
+ let mut reference_out = [0u8; OUT_LEN];
+ reference_impl::Hasher::new().finalize(&mut reference_out);
+
+ let mut all_at_once_out = [0u8; OUT_LEN];
+ crate::Hasher::new()
+ .finalize_xof()
+ .fill(&mut all_at_once_out);
+ assert_eq!(reference_out, all_at_once_out);
+
+ let mut partial_out = [0u8; OUT_LEN];
+ let partial_start = 32;
+ let partial_end = OUT_LEN - 32;
+ let mut xof = crate::Hasher::new().finalize_xof();
+ xof.fill(&mut partial_out[..partial_start]);
+ xof.fill(&mut partial_out[partial_start..partial_end]);
+ xof.fill(&mut partial_out[partial_end..]);
+ assert_eq!(reference_out, partial_out);
+}
+
fn reference_hash(input: &[u8]) -> crate::Hash {
let mut hasher = reference_impl::Hasher::new();
hasher.update(input);