aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJack O'Connor <[email protected]>2020-02-03 11:35:50 -0500
committerJack O'Connor <[email protected]>2020-02-06 15:07:15 -0500
commitfc219f4f8d92f721d6444bb7420d42d88ee4b43c (patch)
tree0601cbc527e4c681e4847f959c099052c0993d9b
parent24071db3463f29a6ad6173e3aea62b0f1497b5bc (diff)
Hasher::update_with_join
This is a new interface that allows the caller to provide a multi-threading implementation. It's defined in terms of a new `Join` trait, for which we provide two implementations, `SerialJoin` and `RayonJoin`. This lets the caller control when multi-threading is used, rather than the previous all-or-nothing design of the "rayon" feature. Although existing callers should keep working, this is a compatibility break, because callers who were relying on automatic multi-threading before will now be single-threaded. Thus the next release of this crate will need to be version 0.2. See https://github.com/BLAKE3-team/BLAKE3/issues/25 and https://github.com/BLAKE3-team/BLAKE3/issues/54.
-rw-r--r--Cargo.toml4
-rw-r--r--b3sum/src/main.rs52
-rw-r--r--benches/bench.rs82
-rw-r--r--src/join.rs114
-rw-r--r--src/lib.rs112
-rw-r--r--src/test.rs63
6 files changed, 377 insertions, 50 deletions
diff --git a/Cargo.toml b/Cargo.toml
index 3324d7f..4d8e7cf 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -21,6 +21,10 @@ c_avx512 = []
c_neon = []
std = ["digest/std"]
+[package.metadata.docs.rs]
+# Document blake3::join::RayonJoin on docs.rs.
+features = ["rayon"]
+
[dependencies]
arrayref = "0.3.5"
arrayvec = { version = "0.5.1", default-features = false, features = ["array-sizes-33-128"] }
diff --git a/b3sum/src/main.rs b/b3sum/src/main.rs
index beb8c38..246f24d 100644
--- a/b3sum/src/main.rs
+++ b/b3sum/src/main.rs
@@ -3,6 +3,7 @@ use clap::{App, Arg};
use std::cmp;
use std::convert::TryInto;
use std::fs::File;
+use std::io;
use std::io::prelude::*;
const FILE_ARG: &str = "file";
@@ -57,21 +58,37 @@ fn clap_parse_argv() -> clap::ArgMatches<'static> {
.get_matches()
}
+// A 16 KiB buffer is enough to take advantage of all the SIMD instruction sets
+// that we support, but `std::io::copy` currently uses 8 KiB. Most platforms
+// can support at least 64 KiB, and there's some performance benefit to using
+// bigger reads, so that's what we use here.
+fn copy_wide(mut reader: impl Read, hasher: &mut blake3::Hasher) -> io::Result<u64> {
+ let mut buffer = [0; 65536];
+ let mut total = 0;
+ loop {
+ match reader.read(&mut buffer) {
+ Ok(0) => return Ok(total),
+ Ok(n) => {
+ hasher.update(&buffer[..n]);
+ total += n as u64;
+ }
+ Err(ref e) if e.kind() == io::ErrorKind::Interrupted => continue,
+ Err(e) => return Err(e),
+ }
+ }
+}
+
// The slow path, for inputs that we can't memmap.
-fn hash_reader(
- base_hasher: &blake3::Hasher,
- mut reader: impl Read,
-) -> Result<blake3::OutputReader> {
+fn hash_reader(base_hasher: &blake3::Hasher, reader: impl Read) -> Result<blake3::OutputReader> {
let mut hasher = base_hasher.clone();
- // TODO: This is a narrow copy, so it might not take advantage of SIMD or
- // threads. With a larger buffer size, most of that performance can be
- // recovered. However, this requires some platform-specific tuning, based
- // on both the SIMD degree and the number of cores. A double-buffering
- // strategy is also helpful, where a dedicated background thread reads
- // input into one buffer while another thread is calling update() on a
- // second buffer. Since this is the slow path anyway, do the simple thing
- // for now.
- std::io::copy(&mut reader, &mut hasher)?;
+ // This is currently all single-threaded. Doing multi-threaded hashing
+ // without memory mapping is tricky, since all your worker threads have to
+ // stop every time you refill the buffer, and that ends up being a lot of
+ // overhead. To solve that, we need a more complicated double-buffering
+ // strategy where a background thread fills one buffer while the worker
+ // threads are hashing the other one. We might implement that in the
+ // future, but since this is the slow path anyway, it's not high priority.
+ copy_wide(reader, &mut hasher)?;
Ok(hasher.finalize_xof())
}
@@ -114,7 +131,14 @@ fn maybe_hash_memmap(
#[cfg(feature = "memmap")]
{
if let Some(map) = maybe_memmap_file(_file)? {
- return Ok(Some(_base_hasher.clone().update(&map).finalize_xof()));
+ // Memory mapping worked. Use Rayon-based multi-threading to split
+ // up the whole file across many worker threads.
+ return Ok(Some(
+ _base_hasher
+ .clone()
+ .update_with_join::<blake3::join::RayonJoin>(&map)
+ .finalize_xof(),
+ ));
}
}
Ok(None)
diff --git a/benches/bench.rs b/benches/bench.rs
index 77e1210..0d73970 100644
--- a/benches/bench.rs
+++ b/benches/bench.rs
@@ -417,3 +417,85 @@ fn bench_reference_0512_kib(b: &mut Bencher) {
fn bench_reference_1024_kib(b: &mut Bencher) {
bench_reference(b, 1024 * KIB);
}
+
+#[cfg(feature = "rayon")]
+fn bench_rayon(b: &mut Bencher, len: usize) {
+ let mut input = RandomInput::new(b, len);
+ b.iter(|| {
+ blake3::Hasher::new()
+ .update_with_join::<blake3::join::RayonJoin>(input.get())
+ .finalize()
+ });
+}
+
+#[bench]
+#[cfg(feature = "rayon")]
+fn bench_rayon_0001_block(b: &mut Bencher) {
+ bench_rayon(b, BLOCK_LEN);
+}
+
+#[bench]
+#[cfg(feature = "rayon")]
+fn bench_rayon_0001_kib(b: &mut Bencher) {
+ bench_rayon(b, 1 * KIB);
+}
+
+#[bench]
+#[cfg(feature = "rayon")]
+fn bench_rayon_0002_kib(b: &mut Bencher) {
+ bench_rayon(b, 2 * KIB);
+}
+
+#[bench]
+#[cfg(feature = "rayon")]
+fn bench_rayon_0004_kib(b: &mut Bencher) {
+ bench_rayon(b, 4 * KIB);
+}
+
+#[bench]
+#[cfg(feature = "rayon")]
+fn bench_rayon_0008_kib(b: &mut Bencher) {
+ bench_rayon(b, 8 * KIB);
+}
+
+#[bench]
+#[cfg(feature = "rayon")]
+fn bench_rayon_0016_kib(b: &mut Bencher) {
+ bench_rayon(b, 16 * KIB);
+}
+
+#[bench]
+#[cfg(feature = "rayon")]
+fn bench_rayon_0032_kib(b: &mut Bencher) {
+ bench_rayon(b, 32 * KIB);
+}
+
+#[bench]
+#[cfg(feature = "rayon")]
+fn bench_rayon_0064_kib(b: &mut Bencher) {
+ bench_rayon(b, 64 * KIB);
+}
+
+#[bench]
+#[cfg(feature = "rayon")]
+fn bench_rayon_0128_kib(b: &mut Bencher) {
+ bench_rayon(b, 128 * KIB);
+}
+
+#[bench]
+#[cfg(feature = "rayon")]
+fn bench_rayon_0256_kib(b: &mut Bencher) {
+ bench_rayon(b, 256 * KIB);
+}
+
+#[bench]
+#[cfg(feature = "rayon")]
+fn bench_rayon_0512_kib(b: &mut Bencher) {
+ bench_rayon(b, 512 * KIB);
+}
+
+#[bench]
+#[cfg(feature = "rayon")]
+fn bench_rayon_1024_kib(b: &mut Bencher) {
+ bench_rayon(b, 1024 * KIB);
+}
diff --git a/src/join.rs b/src/join.rs
new file mode 100644
index 0000000..8442172
--- /dev/null
+++ b/src/join.rs
@@ -0,0 +1,114 @@
+//! The multi-threading abstractions used by [`Hasher::update_with_join`].
+//!
+//! Different implementations of the `Join` trait determine whether
+//! [`Hasher::update_with_join`] performs multi-threading on sufficiently large
+//! inputs. The `SerialJoin` implementation is single-threaded, and the
+//! `RayonJoin` implementation (gated by the `rayon` feature) is
+//! multi-threaded. Interfaces other than [`Hasher::update_with_join`], like
+//! [`hash`] and [`Hasher::update`], always use `SerialJoin` internally.
+//!
+//! The `Join` trait is an almost exact copy of the [`rayon::join`] API, and
+//! `RayonJoin` is the only non-trivial implementation provided. The only
+//! difference between the function signature in the `Join` trait and the
+//! underlying one in Rayon, is that the trait method includes two length
+//! parameters. This gives an implementation the option of e.g. setting a
+//! subtree size threshold below which it keeps splits on the same thread.
+//! However, neither of the two provided implementations currently makes use of
+//! those parameters. Note that in Rayon, the very first `join` call is more
+//! expensive than subsequent calls, because it moves work from the calling
+//! thread into the thread pool. That makes a coarse-grained input length
+//! threshold in the caller more effective than a fine-grained subtree size
+//! threshold after the implementation has already started recursing.
+//!
+//! # Example
+//!
+//! ```
+//! // Hash a large input using multi-threading. Note that multi-threading
+//! // comes with some overhead, and it can actually hurt performance for small
+//! // inputs. The meaning of "small" varies, however, depending on the
+//! // platform and the number of threads. (On x86_64, the cutoff tends to be
+//! // around 128 KiB.) You should benchmark your own use case to see whether
+//! // multi-threading helps.
+//! # #[cfg(feature = "rayon")]
+//! # {
+//! # fn some_large_input() -> &'static [u8] { b"foo" }
+//! let input: &[u8] = some_large_input();
+//! let mut hasher = blake3::Hasher::new();
+//! hasher.update_with_join::<blake3::join::RayonJoin>(input);
+//! let hash = hasher.finalize();
+//! # }
+//! ```
+//!
+//! [`Hasher::update_with_join`]: ../struct.Hasher.html#method.update_with_join
+//! [`Hasher::update`]: ../struct.Hasher.html#method.update
+//! [`hash`]: ../fn.hash.html
+//! [`rayon::join`]: https://docs.rs/rayon/1.3.0/rayon/fn.join.html
+
+/// The trait that abstracts over single-threaded and multi-threaded recursion.
+pub trait Join {
+ fn join<A, B, RA, RB>(oper_a: A, oper_b: B, len_a: usize, len_b: usize) -> (RA, RB)
+ where
+ A: FnOnce() -> RA + Send,
+ B: FnOnce() -> RB + Send,
+ RA: Send,
+ RB: Send;
+}
+
+/// The trivial, serial implementation of `Join`. The left and right sides are
+/// executed one after the other, on the calling thread. The standalone hashing
+/// functions and the `Hasher::update` method use this implementation
+/// internally.
+pub enum SerialJoin {}
+
+impl Join for SerialJoin {
+ #[inline]
+ fn join<A, B, RA, RB>(oper_a: A, oper_b: B, _len_a: usize, _len_b: usize) -> (RA, RB)
+ where
+ A: FnOnce() -> RA + Send,
+ B: FnOnce() -> RB + Send,
+ RA: Send,
+ RB: Send,
+ {
+ (oper_a(), oper_b())
+ }
+}
+
+/// The Rayon-based implementation of `Join`. The left and right sides are
+/// executed on the Rayon thread pool, potentially in parallel. This
+/// implementation is gated by the `rayon` feature, which is off by default.
+#[cfg(feature = "rayon")]
+pub enum RayonJoin {}
+
+#[cfg(feature = "rayon")]
+impl Join for RayonJoin {
+ #[inline]
+ fn join<A, B, RA, RB>(oper_a: A, oper_b: B, _len_a: usize, _len_b: usize) -> (RA, RB)
+ where
+ A: FnOnce() -> RA + Send,
+ B: FnOnce() -> RB + Send,
+ RA: Send,
+ RB: Send,
+ {
+ rayon::join(oper_a, oper_b)
+ }
+}
+
+#[cfg(test)]
+mod test {
+ use super::*;
+
+ #[test]
+ fn test_serial_join() {
+ let oper_a = || 1 + 1;
+ let oper_b = || 2 + 2;
+ assert_eq!((2, 4), SerialJoin::join(oper_a, oper_b, 3, 4));
+ }
+
+ #[test]
+ #[cfg(feature = "rayon")]
+ fn test_rayon_join() {
+ let oper_a = || 1 + 1;
+ let oper_b = || 2 + 2;
+ assert_eq!((2, 4), RayonJoin::join(oper_a, oper_b, 3, 4));
+ }
+}
diff --git a/src/lib.rs b/src/lib.rs
index 996a865..7fa3510 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -60,10 +60,13 @@ pub mod sse41;
pub mod traits;
+pub mod join;
+
use arrayref::{array_mut_ref, array_ref};
use arrayvec::{ArrayString, ArrayVec};
use core::cmp;
use core::fmt;
+use join::{Join, SerialJoin};
use platform::{Platform, MAX_SIMD_DEGREE, MAX_SIMD_DEGREE_OR_2};
/// The number of bytes in a [`Hash`](struct.Hash.html), 32.
@@ -414,25 +417,6 @@ fn left_len(content_len: usize) -> usize {
largest_power_of_two_leq(full_chunks) * CHUNK_LEN
}
-// Recurse in parallel with rayon::join() if the "rayon" feature is active.
-// Rayon uses a global thread pool and a work-stealing algorithm to hand the
-// right side off to another thread, if idle threads are available. If the
-// "rayon" feature is disabled, just make ordinary function calls for the left
-// and the right.
-#[inline]
-fn join<A, B, RA, RB>(oper_a: A, oper_b: B) -> (RA, RB)
-where
- A: FnOnce() -> RA + Send,
- B: FnOnce() -> RB + Send,
- RA: Send,
- RB: Send,
-{
- #[cfg(feature = "rayon")]
- return rayon::join(oper_a, oper_b);
- #[cfg(not(feature = "rayon"))]
- return (oper_a(), oper_b());
-}
-
// Use SIMD parallelism to hash up to MAX_SIMD_DEGREE chunks at the same time
// on a single thread. Write out the chunk chaining values and return the
// number of chunks hashed. These chunks are never the root and never empty;
@@ -541,7 +525,7 @@ fn compress_parents_parallel(
// Why not just have the caller split the input on the first update(), instead
// of implementing this special rule? Because we don't want to limit SIMD or
// multi-threading parallelism for that update().
-fn compress_subtree_wide(
+fn compress_subtree_wide<J: Join>(
input: &[u8],
key: &CVWords,
chunk_counter: u64,
@@ -578,9 +562,11 @@ fn compress_subtree_wide(
let (left_out, right_out) = cv_array.split_at_mut(degree * OUT_LEN);
// Recurse! This uses multiple threads if the "rayon" feature is enabled.
- let (left_n, right_n) = join(
- || compress_subtree_wide(left, key, chunk_counter, flags, platform, left_out),
- || compress_subtree_wide(right, key, right_chunk_counter, flags, platform, right_out),
+ let (left_n, right_n) = J::join(
+ || compress_subtree_wide::<J>(left, key, chunk_counter, flags, platform, left_out),
+ || compress_subtree_wide::<J>(right, key, right_chunk_counter, flags, platform, right_out),
+ left.len(),
+ right.len(),
);
// The special case again. If simd_degree=1, then we'll have left_n=1 and
@@ -614,7 +600,7 @@ fn compress_subtree_wide(
//
// As with compress_subtree_wide(), this function is not used on inputs of 1
// chunk or less. That's a different codepath.
-fn compress_subtree_to_parent_node(
+fn compress_subtree_to_parent_node<J: Join>(
input: &[u8],
key: &CVWords,
chunk_counter: u64,
@@ -624,7 +610,7 @@ fn compress_subtree_to_parent_node(
debug_assert!(input.len() > CHUNK_LEN);
let mut cv_array = [0; 2 * MAX_SIMD_DEGREE_OR_2 * OUT_LEN];
let mut num_cvs =
- compress_subtree_wide(input, &key, chunk_counter, flags, platform, &mut cv_array);
+ compress_subtree_wide::<J>(input, &key, chunk_counter, flags, platform, &mut cv_array);
debug_assert!(num_cvs >= 2);
// If MAX_SIMD_DEGREE is greater than 2 and there's enough input,
@@ -641,6 +627,7 @@ fn compress_subtree_to_parent_node(
// Hash a complete input all at once. Unlike compress_subtree_wide() and
// compress_subtree_to_parent_node(), this function handles the 1 chunk case.
+// Note that this we use SerialJoin here, so this is always single-threaded.
fn hash_all_at_once(input: &[u8], key: &CVWords, flags: u8) -> Output {
let platform = Platform::detect();
@@ -655,7 +642,7 @@ fn hash_all_at_once(input: &[u8], key: &CVWords, flags: u8) -> Output {
// compress_subtree_to_parent_node().
Output {
input_chaining_value: *key,
- block: compress_subtree_to_parent_node(input, key, 0, flags, platform),
+ block: compress_subtree_to_parent_node::<SerialJoin>(input, key, 0, flags, platform),
block_len: BLOCK_LEN as u8,
counter: 0,
flags: flags | PARENT,
@@ -665,9 +652,13 @@ fn hash_all_at_once(input: &[u8], key: &CVWords, flags: u8) -> Output {
/// The default hash function.
///
-/// For an incremental version that accepts multiple writes, see [`Hasher`].
+/// For an incremental version that accepts multiple writes, see [`Hasher::update`].
///
-/// [`Hasher`]: struct.Hasher.html
+/// This function is always single-threaded. For multi-threading support, see
+/// [`Hasher::update_with_join`].
+///
+/// [`Hasher::update`]: struct.Hasher.html#method.update
+/// [`Hasher::update_with_join`]: struct.Hasher.html#method.update_with_join
pub fn hash(input: &[u8]) -> Hash {
hash_all_at_once(input, IV, 0).root_hash()
}
@@ -679,6 +670,11 @@ pub fn hash(input: &[u8]) -> Hash {
/// In that use case, the constant-time equality checking provided by
/// [`Hash`](struct.Hash.html) is almost always a security requirement, and
/// callers need to be careful not to compare MACs as raw bytes.
+///
+/// This function is always single-threaded. For multi-threading support, see
+/// [`Hasher::update_with_join`].
+///
+/// [`Hasher::update_with_join`]: struct.Hasher.html#method.update_with_join
pub fn keyed_hash(key: &[u8; KEY_LEN], input: &[u8]) -> Hash {
let key_words = platform::words_from_le_bytes_32(key);
hash_all_at_once(input, &key_words, KEYED_HASH).root_hash()
@@ -710,9 +706,13 @@ pub fn keyed_hash(key: &[u8; KEY_LEN], input: &[u8]) -> Hash {
/// [Argon2]. Password hashes are entirely different from generic hash
/// functions, with opposite design requirements.
///
+/// This function is always single-threaded. For multi-threading support, see
+/// [`Hasher::update_with_join`].
+///
/// [`Hasher::new_derive_key`]: struct.Hasher.html#method.new_derive_key
/// [`Hasher::finalize_xof`]: struct.Hasher.html#method.finalize_xof
/// [Argon2]: https://en.wikipedia.org/wiki/Argon2
+/// [`Hasher::update_with_join`]: struct.Hasher.html#method.update_with_join
pub fn derive_key(context: &str, key_material: &[u8], output: &mut [u8]) {
let context_key = hash_all_at_once(context.as_bytes(), IV, DERIVE_KEY_CONTEXT).root_hash();
let context_key_words = platform::words_from_le_bytes_32(context_key.as_bytes());
@@ -877,15 +877,55 @@ impl Hasher {
/// Add input bytes to the hash state. You can call this any number of
/// times.
///
- /// Note that the degree of SIMD and multi-threading parallelism that
- /// `Hasher` can use is limited by the size of this input buffer. The 8 KiB
- /// buffer currently used by [`std::io::copy`] is enough to leverage AVX2,
- /// for example, but not enough to leverage AVX-512. If multi-threading is
- /// enabled (the `rayon` feature), the optimal input buffer size will vary
- /// considerably across different CPUs, and it may be several mebibytes.
+ /// This method is always single-threaded. For multi-threading support, see
+ /// `update_with_join` below.
+ ///
+ /// Note that the degree of SIMD parallelism that `update` can use is
+ /// limited by the size of this input buffer. The 8 KiB buffer currently
+ /// used by [`std::io::copy`] is enough to leverage AVX2, for example, but
+ /// not enough to leverage AVX-512. A 16 KiB buffer is large enough to
+ /// leverage all currently supported SIMD instruction sets.
///
/// [`std::io::copy`]: https://doc.rust-lang.org/std/io/fn.copy.html
- pub fn update(&mut self, mut input: &[u8]) -> &mut Self {
+ pub fn update(&mut self, input: &[u8]) -> &mut Self {
+ self.update_with_join::<SerialJoin>(input)
+ }
+
+ /// Add input bytes to the hash state, as with `update`, but potentially
+ /// using multi-threading. See the example below, and the
+ /// [`join`](join/index.html) module for a more detailed explanation.
+ ///
+ /// To get any performance benefit from multi-threading, the input buffer
+ /// size needs to be very large. As a rule of thumb on x86_64, there is no
+ /// benefit to multi-threading inputs less than 128 KiB. Other platforms
+ /// have different thresholds, and in general you need to benchmark your
+ /// specific use case. Where possible, memory mapping an entire input file
+ /// is recommended, to take maximum advantage of multi-threading without
+ /// needing to tune a specific buffer size. Where memory mapping is not
+ /// possible, good multi-threading performance requires doing IO on a
+ /// background thread, to avoid sleeping all your worker threads while the
+ /// input buffer is (serially) refilled. This is quite complicated compared
+ /// to memory mapping.
+ ///
+ /// # Example
+ ///
+ /// ```
+ /// // Hash a large input using multi-threading. Note that multi-threading
+ /// // comes with some overhead, and it can actually hurt performance for small
+ /// // inputs. The meaning of "small" varies, however, depending on the
+ /// // platform and the number of threads. (On x86_64, the cutoff tends to be
+ /// // around 128 KiB.) You should benchmark your own use case to see whether
+ /// // multi-threading helps.
+ /// # #[cfg(feature = "rayon")]
+ /// # {
+ /// # fn some_large_input() -> &'static [u8] { b"foo" }
+ /// let input: &[u8] = some_large_input();
+ /// let mut hasher = blake3::Hasher::new();
+ /// hasher.update_with_join::<blake3::join::RayonJoin>(input);
+ /// let hash = hasher.finalize();
+ /// # }
+ /// ```
+ pub fn update_with_join<J: Join>(&mut self, mut input: &[u8]) -> &mut Self {
// If we have some partial chunk bytes in the internal chunk_state, we
// need to finish that chunk first.
if self.chunk_state.len() > 0 {
@@ -963,7 +1003,7 @@ impl Hasher {
} else {
// This is the high-performance happy path, though getting here
// depends on the caller giving us a long enough input.
- let cv_pair = compress_subtree_to_parent_node(
+ let cv_pair = compress_subtree_to_parent_node::<J>(
&input[..subtree_len],
&self.key,
self.chunk_state.chunk_counter,
diff --git a/src/test.rs b/src/test.rs
index 83fd3ae..bc6f136 100644
--- a/src/test.rs
+++ b/src/test.rs
@@ -1,6 +1,7 @@
use crate::{CVBytes, CVWords, IncrementCounter, BLOCK_LEN, CHUNK_LEN, OUT_LEN};
use arrayref::array_ref;
use arrayvec::ArrayVec;
+use core::sync::atomic::{AtomicUsize, Ordering};
use core::usize;
use rand::prelude::*;
@@ -469,3 +470,65 @@ fn test_reset() {
hasher.update(&[42; CHUNK_LEN + 3]);
assert_eq!(hasher.finalize(), crate::hash(&[42; CHUNK_LEN + 3]));
}
+
+#[test]
+#[cfg(feature = "rayon")]
+fn test_update_with_rayon_join() {
+ let mut input = [0; TEST_CASES_MAX];
+ paint_test_input(&mut input);
+ let rayon_hash = crate::Hasher::new()
+ .update_with_join::<crate::join::RayonJoin>(&input)
+ .finalize();
+ assert_eq!(crate::hash(&input), rayon_hash);
+}
+
+// Test that the length values given to Join::join are what they're supposed to
+// be.
+#[test]
+fn test_join_lengths() {
+ // Use static atomics to let us safely get a couple of values in and out of
+ // CustomJoin. This avoids depending on std, though it assumes that this
+ // thread will only run once in the lifetime of the runner process.
+ static SINGLE_THREAD_LEN: AtomicUsize = AtomicUsize::new(0);
+ static CUSTOM_JOIN_CALLS: AtomicUsize = AtomicUsize::new(0);
+
+ // Use an input that's exactly (simd_degree * CHUNK_LEN) + 1. That should
+ // guarantee that compress_subtree_wide does exactly one split, with the
+ // last byte on the right side. Note that it we used
+ // Hasher::update_with_join, we would end up buffering that last byte,
+ // rather than splitting and joining it.
+ let single_thread_len = crate::platform::Platform::detect().simd_degree() * CHUNK_LEN;
+ SINGLE_THREAD_LEN.store(single_thread_len, Ordering::SeqCst);
+ let mut input_buf = [0; 2 * crate::platform::MAX_SIMD_DEGREE * CHUNK_LEN];
+ paint_test_input(&mut input_buf);
+ let input = &input_buf[..single_thread_len + 1];
+
+ enum CustomJoin {}
+
+ impl crate::join::Join for CustomJoin {
+ fn join<A, B, RA, RB>(oper_a: A, oper_b: B, len_a: usize, len_b: usize) -> (RA, RB)
+ where
+ A: FnOnce() -> RA + Send,
+ B: FnOnce() -> RB + Send,
+ RA: Send,
+ RB: Send,
+ {
+ let prev_calls = CUSTOM_JOIN_CALLS.fetch_add(1, Ordering::SeqCst);
+ assert_eq!(prev_calls, 0);
+ assert_eq!(len_a, SINGLE_THREAD_LEN.load(Ordering::SeqCst));
+ assert_eq!(len_b, 1);
+ (oper_a(), oper_b())
+ }
+ }
+
+ let mut out_buf = [0; crate::platform::MAX_SIMD_DEGREE_OR_2 * CHUNK_LEN];
+ crate::compress_subtree_wide::<CustomJoin>(
+ input,
+ crate::IV,
+ 0,
+ 0,
+ crate::platform::Platform::detect(),
+ &mut out_buf,
+ );
+ assert_eq!(CUSTOM_JOIN_CALLS.load(Ordering::SeqCst), 1);
+}