diff options
| author | Jack O'Connor <[email protected]> | 2020-05-13 15:33:17 -0400 |
|---|---|---|
| committer | Jack O'Connor <[email protected]> | 2020-05-13 18:23:59 -0400 |
| commit | c5c07bb337d0af7522666d05308aaf24eef3709c (patch) | |
| tree | f0e34666df9bcd7a358cb5f2a3e4aeb9dee507aa /b3sum | |
| parent | 5030c0f1c345b8ccf40ecd7cf540d1bce562895c (diff) | |
refactor b3sum to support --check
This is an overall cleanup of everything that b3sum is doing, especially
file opening and memory mapping, which makes it easier for the regular
hashing mode to share code with the checking mode.
Diffstat (limited to 'b3sum')
| -rw-r--r-- | b3sum/src/main.rs | 553 | ||||
| -rw-r--r-- | b3sum/src/unit_tests.rs | 93 | ||||
| -rw-r--r-- | b3sum/tests/cli_tests.rs | 149 |
3 files changed, 561 insertions, 234 deletions
diff --git a/b3sum/src/main.rs b/b3sum/src/main.rs index 04ed6b5..ebcb928 100644 --- a/b3sum/src/main.rs +++ b/b3sum/src/main.rs @@ -1,17 +1,17 @@ use anyhow::{bail, ensure, Context, Result}; use clap::{App, Arg}; -use std::borrow::Cow; use std::cmp; use std::convert::TryInto; -use std::ffi::OsStr; use std::fs::File; use std::io; use std::io::prelude::*; -use std::path::Path; +use std::path::{Path, PathBuf}; #[cfg(test)] mod unit_tests; +const NAME: &str = "b3sum"; + const FILE_ARG: &str = "file"; const DERIVE_KEY_ARG: &str = "derive-key"; const KEYED_ARG: &str = "keyed"; @@ -20,70 +20,219 @@ const NO_MMAP_ARG: &str = "no-mmap"; const NO_NAMES_ARG: &str = "no-names"; const NUM_THREADS_ARG: &str = "num-threads"; const RAW_ARG: &str = "raw"; +const CHECK_ARG: &str = "check"; -fn clap_parse_argv() -> clap::ArgMatches<'static> { - App::new("b3sum") - .version(env!("CARGO_PKG_VERSION")) - .arg(Arg::with_name(FILE_ARG).multiple(true)) - .arg( - Arg::with_name(LENGTH_ARG) - .long(LENGTH_ARG) - .short("l") - .takes_value(true) - .value_name("LEN") - .help( - "The number of output bytes, prior to hex\n\ - encoding (default 32)", - ), - ) - .arg( - Arg::with_name(NUM_THREADS_ARG) - .long(NUM_THREADS_ARG) - .takes_value(true) - .value_name("NUM") - .help( - "The maximum number of threads to use. By\n\ - default, this is the number of logical cores.\n\ - If this flag is omitted, or if its value is 0,\n\ - RAYON_NUM_THREADS is also respected.", - ), - ) - .arg( - Arg::with_name(KEYED_ARG) - .long(KEYED_ARG) - .requires(FILE_ARG) - .help( - "Uses the keyed mode. The secret key is read from standard\n\ - input, and it must be exactly 32 raw bytes.", - ), - ) - .arg( - Arg::with_name(DERIVE_KEY_ARG) - .long(DERIVE_KEY_ARG) - .conflicts_with(KEYED_ARG) - .takes_value(true) - .value_name("CONTEXT") - .help( - "Uses the key derivation mode, with the given\n\ - context string. Cannot be used with --keyed.", - ), - ) - .arg( - Arg::with_name(NO_MMAP_ARG) - .long(NO_MMAP_ARG) - .help("Disables memory mapping"), - ) - .arg( - Arg::with_name(NO_NAMES_ARG) - .long(NO_NAMES_ARG) - .help("Omits filenames in the output"), - ) - .arg(Arg::with_name(RAW_ARG).long(RAW_ARG).help( - "Writes raw output bytes to stdout, rather than hex.\n\ - --no-names is implied. In this case, only a single\n\ - input is allowed.", - )) - .get_matches() +struct Args { + inner: clap::ArgMatches<'static>, + file_args: Vec<PathBuf>, + base_hasher: blake3::Hasher, +} + +impl Args { + fn parse() -> Result<Self> { + let inner = App::new("b3sum") + .version(env!("CARGO_PKG_VERSION")) + .arg(Arg::with_name(FILE_ARG).multiple(true)) + .arg( + Arg::with_name(LENGTH_ARG) + .long(LENGTH_ARG) + .short("l") + .takes_value(true) + .value_name("LEN") + .help( + "The number of output bytes, prior to hex\n\ + encoding (default 32)", + ), + ) + .arg( + Arg::with_name(NUM_THREADS_ARG) + .long(NUM_THREADS_ARG) + .takes_value(true) + .value_name("NUM") + .help( + "The maximum number of threads to use. By\n\ + default, this is the number of logical cores.\n\ + If this flag is omitted, or if its value is 0,\n\ + RAYON_NUM_THREADS is also respected.", + ), + ) + .arg( + Arg::with_name(KEYED_ARG) + .long(KEYED_ARG) + .requires(FILE_ARG) + .help( + "Uses the keyed mode. The secret key is read from standard\n\ + input, and it must be exactly 32 raw bytes.", + ), + ) + .arg( + Arg::with_name(DERIVE_KEY_ARG) + .long(DERIVE_KEY_ARG) + .conflicts_with(KEYED_ARG) + .takes_value(true) + .value_name("CONTEXT") + .help( + "Uses the key derivation mode, with the given\n\ + context string. Cannot be used with --keyed.", + ), + ) + .arg( + Arg::with_name(NO_MMAP_ARG) + .long(NO_MMAP_ARG) + .help("Disables memory mapping"), + ) + .arg( + Arg::with_name(NO_NAMES_ARG) + .long(NO_NAMES_ARG) + .help("Omits filenames in the output"), + ) + .arg(Arg::with_name(RAW_ARG).long(RAW_ARG).help( + "Writes raw output bytes to stdout, rather than hex.\n\ + --no-names is implied. In this case, only a single\n\ + input is allowed.", + )) + .arg( + Arg::with_name(CHECK_ARG) + .long(CHECK_ARG) + .short("c") + .conflicts_with(DERIVE_KEY_ARG) + .conflicts_with(KEYED_ARG) + .conflicts_with(LENGTH_ARG) + .conflicts_with(RAW_ARG) + .conflicts_with(NO_NAMES_ARG) + .help("Reads BLAKE3 sums from the [file]s and checks them"), + ) + .get_matches(); + let file_args = if let Some(iter) = inner.values_of_os(FILE_ARG) { + iter.map(|s| s.into()).collect() + } else { + vec!["-".into()] + }; + if inner.is_present(RAW_ARG) && file_args.len() > 1 { + bail!("Only one filename can be provided when using --raw"); + } + let base_hasher = if inner.is_present(KEYED_ARG) { + // In keyed mode, since stdin is used for the key, we can't handle + // `-` arguments. Input::open handles that case below. + blake3::Hasher::new_keyed(&read_key_from_stdin()?) + } else if let Some(context) = inner.value_of(DERIVE_KEY_ARG) { + blake3::Hasher::new_derive_key(context) + } else { + blake3::Hasher::new() + }; + Ok(Self { + inner, + file_args, + base_hasher, + }) + } + + fn num_threads(&self) -> Result<Option<usize>> { + if let Some(num_threads_str) = self.inner.value_of(NUM_THREADS_ARG) { + Ok(Some( + num_threads_str + .parse() + .context("Failed to parse num threads.")?, + )) + } else { + Ok(None) + } + } + + fn check(&self) -> bool { + self.inner.is_present(CHECK_ARG) + } + + fn raw(&self) -> bool { + self.inner.is_present(RAW_ARG) + } + + fn no_mmap(&self) -> bool { + self.inner.is_present(NO_MMAP_ARG) + } + + fn no_names(&self) -> bool { + self.inner.is_present(NO_NAMES_ARG) + } + + fn len(&self) -> Result<u64> { + if let Some(length) = self.inner.value_of(LENGTH_ARG) { + length.parse::<u64>().context("Failed to parse length.") + } else { + Ok(blake3::OUT_LEN as u64) + } + } + + fn keyed(&self) -> bool { + self.inner.is_present(KEYED_ARG) + } +} + +enum Input { + Mmap(io::Cursor<memmap::Mmap>), + File(File), + Stdin, +} + +impl Input { + // Open an input file, using mmap if appropriate. "-" means stdin. Note + // that this convention applies both to command line arguments, and to + // filepaths that appear in a checkfile. + fn open(path: &Path, args: &Args) -> Result<Self> { + if path == Path::new("-") { + if args.keyed() { + bail!("Cannot open `-` in keyed mode"); + } + return Ok(Self::Stdin); + } + let file = File::open(path)?; + if !args.no_mmap() { + if let Some(mmap) = maybe_memmap_file(&file)? { + return Ok(Self::Mmap(io::Cursor::new(mmap))); + } + } + Ok(Self::File(file)) + } + + fn hash(&mut self, args: &Args) -> Result<blake3::OutputReader> { + let mut hasher = args.base_hasher.clone(); + match self { + // The fast path: If we mmapped the file successfully, hash using + // multiple threads. This doesn't work on stdin, or on some files, + // and it can also be disabled with --no-mmap. + Self::Mmap(cursor) => { + hasher.update_with_join::<blake3::join::RayonJoin>(cursor.get_ref()); + } + // The slower paths, for stdin or files we didn't/couldn't mmap. + // This is currently all single-threaded. Doing multi-threaded + // hashing without memory mapping is tricky, since all your worker + // threads have to stop every time you refill the buffer, and that + // ends up being a lot of overhead. To solve that, we need a more + // complicated double-buffering strategy where a background thread + // fills one buffer while the worker threads are hashing the other + // one. We might implement that in the future, but since this is + // the slow path anyway, it's not high priority. + Self::File(file) => { + copy_wide(file, &mut hasher)?; + } + Self::Stdin => { + let stdin = io::stdin(); + let lock = stdin.lock(); + copy_wide(lock, &mut hasher)?; + } + } + Ok(hasher.finalize_xof()) + } +} + +impl Read for Input { + fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> { + match self { + Self::Mmap(cursor) => cursor.read(buf), + Self::File(file) => file.read(buf), + Self::Stdin => io::stdin().read(buf), + } + } } // A 16 KiB buffer is enough to take advantage of all the SIMD instruction sets @@ -106,20 +255,9 @@ fn copy_wide(mut reader: impl Read, hasher: &mut blake3::Hasher) -> io::Result<u } } -// The slow path, for inputs that we can't memmap. -fn hash_reader(base_hasher: &blake3::Hasher, reader: impl Read) -> Result<blake3::OutputReader> { - let mut hasher = base_hasher.clone(); - // This is currently all single-threaded. Doing multi-threaded hashing - // without memory mapping is tricky, since all your worker threads have to - // stop every time you refill the buffer, and that ends up being a lot of - // overhead. To solve that, we need a more complicated double-buffering - // strategy where a background thread fills one buffer while the worker - // threads are hashing the other one. We might implement that in the - // future, but since this is the slow path anyway, it's not high priority. - copy_wide(reader, &mut hasher)?; - Ok(hasher.finalize_xof()) -} - +// Mmap a file, if it looks like a good idea. Return None in cases where we +// know mmap will fail, or if the file is short enough that mmapping isn't +// worth it. However, if we do try to mmap and it fails, return the error. fn maybe_memmap_file(file: &File) -> Result<Option<memmap::Mmap>> { let metadata = file.metadata()?; let file_size = metadata.len(); @@ -149,27 +287,9 @@ fn maybe_memmap_file(file: &File) -> Result<Option<memmap::Mmap>> { }) } -// The fast path: Try to hash a file by mem-mapping it first. This is faster if -// it works, but it's not always possible. -fn maybe_hash_memmap( - _base_hasher: &blake3::Hasher, - _file: &File, -) -> Result<Option<blake3::OutputReader>> { - if let Some(map) = maybe_memmap_file(_file)? { - // Memory mapping worked. Use Rayon-based multi-threading to split - // up the whole file across many worker threads. - return Ok(Some( - _base_hasher - .clone() - .update_with_join::<blake3::join::RayonJoin>(&map) - .finalize_xof(), - )); - } - Ok(None) -} - -fn write_hex_output(mut output: blake3::OutputReader, mut len: u64) -> Result<()> { +fn write_hex_output(mut output: blake3::OutputReader, args: &Args) -> Result<()> { // Encoding multiples of the block size is most efficient. + let mut len = args.len()?; let mut block = [0; blake3::BLOCK_LEN]; while len > 0 { output.fill(&mut block); @@ -181,8 +301,8 @@ fn write_hex_output(mut output: blake3::OutputReader, mut len: u64) -> Result<() Ok(()) } -fn write_raw_output(output: blake3::OutputReader, len: u64) -> Result<()> { - let mut output = output.take(len); +fn write_raw_output(output: blake3::OutputReader, args: &Args) -> Result<()> { + let mut output = output.take(args.len()?); let stdout = std::io::stdout(); let mut handler = stdout.lock(); std::io::copy(&mut output, &mut handler)?; @@ -190,22 +310,6 @@ fn write_raw_output(output: blake3::OutputReader, len: u64) -> Result<()> { Ok(()) } -// Errors from this function get handled by the file loop and printed per-file. -fn hash_file( - base_hasher: &blake3::Hasher, - filepath: &std::ffi::OsStr, - mmap_disabled: bool, -) -> Result<blake3::OutputReader> { - let file = File::open(filepath)?; - if !mmap_disabled { - if let Some(output) = maybe_hash_memmap(&base_hasher, &file)? { - return Ok(output); // the fast path - } - } - // the slow path - hash_reader(&base_hasher, file) -} - fn read_key_from_stdin() -> Result<[u8; blake3::KEY_LEN]> { let mut bytes = Vec::with_capacity(blake3::KEY_LEN + 1); let n = std::io::stdin() @@ -227,12 +331,12 @@ fn read_key_from_stdin() -> Result<[u8; blake3::KEY_LEN]> { struct FilepathString { filepath_string: String, - has_escapes: bool, + is_escaped: bool, } // returns (string, did_escape) -fn filepath_to_string(filepath_osstr: &OsStr) -> FilepathString { - let unicode_cow = filepath_osstr.to_string_lossy(); +fn filepath_to_string(filepath: &Path) -> FilepathString { + let unicode_cow = filepath.to_string_lossy(); let mut filepath_string = unicode_cow.to_string(); // If we're on Windows, normalize backslashes to forward slashes. This // avoids a lot of ugly escaping in the common case, and it makes @@ -243,14 +347,14 @@ fn filepath_to_string(filepath_osstr: &OsStr) -> FilepathString { if cfg!(windows) { filepath_string = filepath_string.replace('\\', "/"); } - let mut has_escapes = false; + let mut is_escaped = false; if filepath_string.contains('\\') || filepath_string.contains('\n') { filepath_string = filepath_string.replace('\\', "\\\\").replace('\n', "\\n"); - has_escapes = true; + is_escaped = true; } FilepathString { filepath_string, - has_escapes, + is_escaped, } } @@ -263,7 +367,7 @@ fn hex_half_byte(c: char) -> Result<u8> { if 'a' <= c && c <= 'f' { return Ok(c as u8 - 'a' as u8 + 10); } - bail!("b3sum: Invalid hex"); + bail!("Invalid hex"); } // The `check` command is a security tool. That means it's much better for a @@ -271,17 +375,17 @@ fn hex_half_byte(c: char) -> Result<u8> { // to ever succeed when it shouldn't (a false positive). By forbidding certain // characters in checked filepaths, we avoid a class of false positives where // two different filepaths can get confused with each other. -fn check_for_invalid_characters(path: &str) -> Result<()> { +fn check_for_invalid_characters(utf8_path: &str) -> Result<()> { // Null characters in paths should never happen, but they can result in a // path getting silently truncated on Unix. - if path.contains('\0') { - bail!("b3sum: Null character in path"); + if utf8_path.contains('\0') { + bail!("Null character in path"); } // Because we convert invalid UTF-8 sequences in paths to the Unicode // replacement character, multiple different invalid paths can map to the // same UTF-8 string. - if path.contains('�') { - bail!("b3sum: Unicode replacement character in path"); + if utf8_path.contains('�') { + bail!("Unicode replacement character in path"); } // We normalize all Windows backslashes to forward slashes in our output, // so the only natural way to get a backslash in a checkfile on Windows is @@ -289,8 +393,8 @@ fn check_for_invalid_characters(path: &str) -> Result<()> { // doctor it by hand.) To avoid confusing this with a directory separator, // we forbid backslashes entirely on Windows. Note that this check comes // after unescaping has been done. - if cfg!(windows) && path.contains('\\') { - bail!("b3sum: Backslash in path"); + if cfg!(windows) && utf8_path.contains('\\') { + bail!("Backslash in path"); } Ok(()) } @@ -298,13 +402,13 @@ fn check_for_invalid_characters(path: &str) -> Result<()> { fn unescape(mut path: &str) -> Result<String> { let mut unescaped = String::with_capacity(2 * path.len()); while let Some(i) = path.find('\\') { - ensure!(i < path.len() - 1, "b3sum: Invalid backslash escape"); + ensure!(i < path.len() - 1, "Invalid backslash escape"); unescaped.push_str(&path[..i]); match path[i + 1..].chars().next().unwrap() { // Anything other than a recognized escape sequence is an error. 'n' => unescaped.push_str("\n"), '\\' => unescaped.push_str("\\"), - _ => bail!("b3sum: Invalid backslash escape"), + _ => bail!("Invalid backslash escape"), } path = &path[i + 2..]; } @@ -312,7 +416,15 @@ fn unescape(mut path: &str) -> Result<String> { Ok(unescaped) } -fn parse_check_line(mut line: &str) -> Result<(blake3::Hash, Cow<str>)> { +#[derive(Debug)] +struct ParsedCheckLine { + file_string: String, + is_escaped: bool, + file_path: PathBuf, + expected_hash: blake3::Hash, +} + +fn parse_check_line(mut line: &str) -> Result<ParsedCheckLine> { // Trim off the trailing newline, if any. line = line.trim_end_matches('\n'); // If there's a backslash at the front of the line, that means we need to @@ -320,11 +432,11 @@ fn parse_check_line(mut line: &str) -> Result<(blake3::Hash, Cow<str>)> { let first = if let Some(c) = line.chars().next() { c } else { - bail!("b3sum: Empty line"); + bail!("Empty line"); }; - let mut escaped = false; + let mut is_escaped = false; if first == '\\' { - escaped = true; + is_escaped = true; line = &line[1..]; } // The front of the line must be a hash of the usual length, followed by @@ -333,12 +445,12 @@ fn parse_check_line(mut line: &str) -> Result<(blake3::Hash, Cow<str>)> { let hash_hex_len = 2 * blake3::OUT_LEN; let num_spaces = 2; let prefix_len = hash_hex_len + num_spaces; - ensure!(line.len() > prefix_len, "b3sum: Short line"); + ensure!(line.len() > prefix_len, "Short line"); ensure!( line.chars().take(prefix_len).all(|c| c.is_ascii()), - "b3sum: Non-ASCII prefix" + "Non-ASCII prefix" ); - ensure!(&line[hash_hex_len..][..2] == " ", "b3sum: Invalid space"); + ensure!(&line[hash_hex_len..][..2] == " ", "Invalid space"); // Decode the hash hex. let mut hash_bytes = [0; blake3::OUT_LEN]; let mut hex_chars = line[..hash_hex_len].chars(); @@ -347,89 +459,110 @@ fn parse_check_line(mut line: &str) -> Result<(blake3::Hash, Cow<str>)> { let low_char = hex_chars.next().unwrap(); *byte = 16 * hex_half_byte(high_char)? + hex_half_byte(low_char)?; } - let hash: blake3::Hash = hash_bytes.into(); - let path_str = &line[prefix_len..]; - let path_cow: Cow<str> = if escaped { + let expected_hash: blake3::Hash = hash_bytes.into(); + let file_string = line[prefix_len..].to_string(); + let file_path_string = if is_escaped { // If we detected a backslash at the start of the line earlier, now we // need to unescape backslashes and newlines. - Cow::Owned(unescape(path_str)?) + unescape(&file_string)? } else { - Cow::Borrowed(path_str) + file_string.clone().into() }; - check_for_invalid_characters(&path_cow)?; - Ok((hash, path_cow)) + check_for_invalid_characters(&file_path_string)?; + Ok(ParsedCheckLine { + file_string, + is_escaped, + file_path: file_path_string.into(), + expected_hash, + }) +} + +fn hash_one_input(path: &Path, args: &Args) -> Result<()> { + let mut input = Input::open(path, args)?; + let output = input.hash(args)?; + if args.raw() { + write_raw_output(output, args)?; + return Ok(()); + } + if args.no_names() { + write_hex_output(output, args)?; + println!(); + return Ok(()); + } + let FilepathString { + filepath_string, + is_escaped, + } = filepath_to_string(path); + if is_escaped { + print!("\\"); + } + write_hex_output(output, args)?; + println!(" {}", filepath_string); + Ok(()) +} + +fn check_one_checkfile(path: &Path, args: &Args, some_file_failed: &mut bool) -> Result<()> { + let checkfile_input = Input::open(path, args)?; + let mut bufreader = io::BufReader::new(checkfile_input); + let mut line = String::new(); + loop { + line.clear(); + let n = bufreader.read_line(&mut line)?; + if n == 0 { + return Ok(()); + } + let ParsedCheckLine { + file_string, + is_escaped, + file_path, + expected_hash, + } = parse_check_line(&line)?; + let mut hash_input = Input::open(&file_path, args)?; + let mut found_hash_bytes = [0; blake3::OUT_LEN]; + let mut hash_output = hash_input.hash(args)?; + hash_output.fill(&mut found_hash_bytes); + let found_hash: blake3::Hash = found_hash_bytes.into(); + if is_escaped { + print!("\\"); + } + print!("{}: ", file_string); + // This is a constant-time comparison. + if expected_hash == found_hash { + println!("OK"); + } else { + *some_file_failed = true; + println!("FAILED"); + } + } } fn main() -> Result<()> { - let args = clap_parse_argv(); - let len = if let Some(length) = args.value_of(LENGTH_ARG) { - length.parse::<u64>().context("Failed to parse length.")? - } else { - blake3::OUT_LEN as u64 - }; - let base_hasher = if args.is_present(KEYED_ARG) { - blake3::Hasher::new_keyed(&read_key_from_stdin()?) - } else if let Some(context) = args.value_of(DERIVE_KEY_ARG) { - blake3::Hasher::new_derive_key(context) - } else { - blake3::Hasher::new() - }; - let mmap_disabled = args.is_present(NO_MMAP_ARG); - let print_names = !args.is_present(NO_NAMES_ARG); - let raw_output = args.is_present(RAW_ARG); + let args = Args::parse()?; let mut thread_pool_builder = rayon::ThreadPoolBuilder::new(); - if let Some(num_threads_str) = args.value_of(NUM_THREADS_ARG) { - let num_threads: usize = num_threads_str - .parse() - .context("Failed to parse num threads.")?; + if let Some(num_threads) = args.num_threads()? { thread_pool_builder = thread_pool_builder.num_threads(num_threads); } - let thread_pool = thread_pool_builder.build()?; thread_pool.install(|| { - let mut did_error = false; - if let Some(files) = args.values_of_os(FILE_ARG) { - if raw_output && files.len() > 1 { - bail!("b3sum: Only one filename can be provided when using --raw"); - } - for filepath_osstr in files { - let FilepathString { - filepath_string, - has_escapes, - } = filepath_to_string(filepath_osstr); - match hash_file(&base_hasher, filepath_osstr, mmap_disabled) { - Ok(output) => { - if raw_output { - write_raw_output(output, len)?; - } else { - if has_escapes { - print!("\\"); - } - write_hex_output(output, len)?; - if print_names { - println!(" {}", filepath_string); - } else { - println!(); - } - } - } - Err(e) => { - did_error = true; - eprintln!("b3sum: {}: {}", filepath_string, e); - } - } - } - } else { - let stdin = std::io::stdin(); - let stdin = stdin.lock(); - let output = hash_reader(&base_hasher, stdin)?; - if raw_output { - write_raw_output(output, len)?; + let mut some_file_failed = false; + // Note that file_args automatically includes `-` if nothing is given. + for path in &args.file_args { + if args.check() { + // Errors encountered in checking (that is, any failure other + // than "bad checksum") bring down the whole process. + check_one_checkfile(path, &args, &mut some_file_failed)?; } else { - write_hex_output(output, len)?; - println!(); + // Errors encountered in hashing are tolerated and printed to + // stderr. This allows e.g. `b3sum *` to print errors for + // non-files and keep going. However, if we encounter any + // errors we'll still return non-zero at the end. + let result = hash_one_input(path, &args); + if let Err(e) = result { + some_file_failed = true; + eprintln!("{}: {}", NAME, e); + } } } - std::process::exit(if did_error { 1 } else { 0 }); + std::process::exit(if some_file_failed { 1 } else { 0 }); }) } diff --git a/b3sum/src/unit_tests.rs b/b3sum/src/unit_tests.rs index f65ed67..1fa1a17 100644 --- a/b3sum/src/unit_tests.rs +++ b/b3sum/src/unit_tests.rs @@ -7,68 +7,117 @@ fn test_parse_check_line() { // ========================= // the basic case - let (h, p) = crate::parse_check_line( + let crate::ParsedCheckLine { + file_string, + is_escaped, + file_path, + expected_hash, + } = crate::parse_check_line( "0909090909090909090909090909090909090909090909090909090909090909 foo", ) .unwrap(); - assert_eq!(h, blake3::Hash::from([0x09; 32])); - assert_eq!(p, "foo"); + assert_eq!(expected_hash, blake3::Hash::from([0x09; 32])); + assert!(!is_escaped); + assert_eq!(file_string, "foo"); + assert_eq!(file_path, Path::new("foo")); // regular whitespace - let (h, p) = crate::parse_check_line( + let crate::ParsedCheckLine { + file_string, + is_escaped, + file_path, + expected_hash, + } = crate::parse_check_line( "fafafafafafafafafafafafafafafafafafafafafafafafafafafafafafafafa fo \to\n\n\n", ) .unwrap(); - assert_eq!(h, blake3::Hash::from([0xfa; 32])); - assert_eq!(p, "fo \to"); + assert_eq!(expected_hash, blake3::Hash::from([0xfa; 32])); + assert!(!is_escaped); + assert_eq!(file_string, "fo \to"); + assert_eq!(file_path, Path::new("fo \to")); // path is one space - let (h, p) = crate::parse_check_line( + let crate::ParsedCheckLine { + file_string, + is_escaped, + file_path, + expected_hash, + } = crate::parse_check_line( "4242424242424242424242424242424242424242424242424242424242424242 ", ) .unwrap(); - assert_eq!(h, blake3::Hash::from([0x42; 32])); - assert_eq!(p, " "); + assert_eq!(expected_hash, blake3::Hash::from([0x42; 32])); + assert!(!is_escaped); + assert_eq!(file_string, " "); + assert_eq!(file_path, Path::new(" ")); // *Unescaped* backslashes. Note that this line does *not* start with a // backslash, so something like "\" + "n" is interpreted as *two* // characters. We forbid all backslashes on Windows, so this test is // Unix-only. if cfg!(not(windows)) { - let (h, p) = crate::parse_check_line( + let crate::ParsedCheckLine { + file_string, + is_escaped, + file_path, + expected_hash, + } = crate::parse_check_line( "4343434343434343434343434343434343434343434343434343434343434343 fo\\a\\no", ) .unwrap(); - assert_eq!(h, blake3::Hash::from([0x43; 32])); - assert_eq!(p, "fo\\a\\no"); + assert_eq!(expected_hash, blake3::Hash::from([0x43; 32])); + assert!(!is_escaped); + assert_eq!(file_string, "fo\\a\\no"); + assert_eq!(file_path, Path::new("fo\\a\\no")); } // escaped newline - let (h, p) = crate::parse_check_line( + let crate::ParsedCheckLine { + file_string, + is_escaped, + file_path, + expected_hash, + } = crate::parse_check_line( "\\4444444444444444444444444444444444444444444444444444444444444444 fo\\n\\no", ) .unwrap(); - assert_eq!(h, blake3::Hash::from([0x44; 32])); - assert_eq!(p, "fo\n\no"); + assert_eq!(expected_hash, blake3::Hash::from([0x44; 32])); + assert!(is_escaped); + assert_eq!(file_string, "fo\\n\\no"); + assert_eq!(file_path, Path::new("fo\n\no")); // Escaped newline and backslash. Again because backslash is not allowed on // Windows, this test is Unix-only. if cfg!(not(windows)) { - let (h, p) = crate::parse_check_line( + let crate::ParsedCheckLine { + file_string, + is_escaped, + file_path, + expected_hash, + } = crate::parse_check_line( "\\4545454545454545454545454545454545454545454545454545454545454545 fo\\n\\\\o", ) .unwrap(); - assert_eq!(h, blake3::Hash::from([0x45; 32])); - assert_eq!(p, "fo\n\\o"); + assert_eq!(expected_hash, blake3::Hash::from([0x45; 32])); + assert!(is_escaped); + assert_eq!(file_string, "fo\\n\\\\o"); + assert_eq!(file_path, Path::new("fo\n\\o")); } // non-ASCII path - let (h, p) = crate::parse_check_line( - "\\4646464646464646464646464646464646464646464646464646464646464646 否认", + let crate::ParsedCheckLine { + file_string, + is_escaped, + file_path, + expected_hash, + } = crate::parse_check_line( + "4646464646464646464646464646464646464646464646464646464646464646 否认", ) .unwrap(); - assert_eq!(h, blake3::Hash::from([0x46; 32])); - assert_eq!(p, "否认"); + assert_eq!(expected_hash, blake3::Hash::from([0x46; 32])); + assert!(!is_escaped); + assert_eq!(file_string, "否认"); + assert_eq!(file_path, Path::new("否认")); // ========================= // ===== Failure Cases ===== diff --git a/b3sum/tests/cli_tests.rs b/b3sum/tests/cli_tests.rs index fd01a70..749d941 100644 --- a/b3sum/tests/cli_tests.rs +++ b/b3sum/tests/cli_tests.rs @@ -10,7 +10,7 @@ pub fn b3sum_exe() -> PathBuf { #[test] fn test_hash_one() { - let expected = blake3::hash(b"foo").to_hex(); + let expected = format!("{} -", blake3::hash(b"foo").to_hex()); let output = cmd!(b3sum_exe()).stdin_bytes("foo").read().unwrap(); assert_eq!(&*expected, output); } @@ -62,7 +62,7 @@ fn test_hash_length() { .update(b"foo") .finalize_xof() .fill(&mut buf); - let expected = hex::encode(&buf[..]); + let expected = format!("{} -", hex::encode(&buf[..])); let output = cmd!(b3sum_exe(), "--length=100") .stdin_bytes("foo") .read() @@ -301,3 +301,148 @@ fn test_invalid_unicode_on_windows() { println!(); assert_eq!(expected, output); } + +#[test] +fn test_check() { + // Make a directory full of files, and make sure the b3sum output in that + // directory is what we expect. + let a_hash = blake3::hash(b"a").to_hex(); + let b_hash = blake3::hash(b"b").to_hex(); + let cd_hash = blake3::hash(b"cd").to_hex(); + let dir = tempfile::tempdir().unwrap(); + fs::write(dir.path().join("a"), b"a").unwrap(); + fs::write(dir.path().join("b"), b"b").unwrap(); + fs::create_dir(dir.path().join("c")).unwrap(); + fs::write(dir.path().join("c/d"), b"cd").unwrap(); + let output = cmd!(b3sum_exe(), "a", "b", "c/d") + .dir(dir.path()) + .stdout_capture() + .stderr_capture() + .run() + .unwrap(); + let stdout = std::str::from_utf8(&output.stdout).unwrap(); + let stderr = std::str::from_utf8(&output.stderr).unwrap(); + let expected_checkfile = format!( + "{} a\n\ + {} b\n\ + {} c/d\n", + a_hash, b_hash, cd_hash, + ); + assert_eq!(expected_checkfile, stdout); + assert_eq!("", stderr); + + // Now use the output we just validated as a checkfile, passed to stdin. + let output = cmd!(b3sum_exe(), "--check") + .stdin_bytes(expected_checkfile.as_bytes()) + .dir(dir.path()) + .stdout_capture() + .stderr_capture() + .run() + .unwrap(); + let stdout = std::str::from_utf8(&output.stdout).unwrap(); + let stderr = std::str::from_utf8(&output.stderr).unwrap(); + let expected_check_output = "\ + a: OK\n\ + b: OK\n\ + c/d: OK\n"; + assert_eq!(expected_check_output, stdout); + assert_eq!("", stderr); + + // Now pass the same checkfile twice on the command line just for fun. + let checkfile_path = dir.path().join("checkfile"); + fs::write(&checkfile_path, &expected_checkfile).unwrap(); + let output = cmd!(b3sum_exe(), "--check", &checkfile_path, &checkfile_path) + .dir(dir.path()) + .stdout_capture() + .stderr_capture() + .run() + .unwrap(); + let stdout = std::str::from_utf8(&output.stdout).unwrap(); + let stderr = std::str::from_utf8(&output.stderr).unwrap(); + let mut double_check_output = String::new(); + double_check_output.push_str(&expected_check_output); + double_check_output.push_str(&expected_check_output); + assert_eq!(double_check_output, stdout); + assert_eq!("", stderr); + + // Finally, corrupt one of the files and check again. + fs::write(dir.path().join("b"), b"CORRUPTION").unwrap(); + let output = cmd!(b3sum_exe(), "--check", &checkfile_path) + .dir(dir.path()) + .stdout_capture() + .stderr_capture() + .unchecked() + .run() + .unwrap(); + let stdout = std::str::from_utf8(&output.stdout).unwrap(); + let stderr = std::str::from_utf8(&output.stderr).unwrap(); + let expected_check_failure = "\ + a: OK\n\ + b: FAILED\n\ + c/d: OK\n"; + assert!(!output.status.success()); + assert_eq!(expected_check_failure, stdout); + assert_eq!("", stderr); +} + +#[test] +fn test_check_invalid_characters() { + // Check that a null character in the path fails. + let output = cmd!(b3sum_exe(), "--check") + .stdin_bytes("0000000000000000000000000000000000000000000000000000000000000000 \0") + .stdout_capture() + .stderr_capture() + .unchecked() + .run() + .unwrap(); + let stdout = std::str::from_utf8(&output.stdout).unwrap(); + let stderr = std::str::from_utf8(&output.stderr).unwrap(); + assert!(!output.status.success()); + assert_eq!("", stdout); + assert_eq!("Error: Null character in path\n", stderr); + + // Check that a Unicode replacement character in the path fails. + let output = cmd!(b3sum_exe(), "--check") + .stdin_bytes("0000000000000000000000000000000000000000000000000000000000000000 �") + .stdout_capture() + .stderr_capture() + .unchecked() + .run() + .unwrap(); + let stdout = std::str::from_utf8(&output.stdout).unwrap(); + let stderr = std::str::from_utf8(&output.stderr).unwrap(); + assert!(!output.status.success()); + assert_eq!("", stdout); + assert_eq!("Error: Unicode replacement character in path\n", stderr); + + // Check that an invalid escape sequence in the path fails. + let output = cmd!(b3sum_exe(), "--check") + .stdin_bytes("\\0000000000000000000000000000000000000000000000000000000000000000 \\a") + .stdout_capture() + .stderr_capture() + .unchecked() + .run() + .unwrap(); + let stdout = std::str::from_utf8(&output.stdout).unwrap(); + let stderr = std::str::from_utf8(&output.stderr).unwrap(); + assert!(!output.status.success()); + assert_eq!("", stdout); + assert_eq!("Error: Invalid backslash escape\n", stderr); + + // Windows also forbids literal backslashes. Check for that if and only if + // we're on Windows. + if cfg!(windows) { + let output = cmd!(b3sum_exe(), "--check") + .stdin_bytes("0000000000000000000000000000000000000000000000000000000000000000 \\") + .stdout_capture() + .stderr_capture() + .unchecked() + .run() + .unwrap(); + let stdout = std::str::from_utf8(&output.stdout).unwrap(); + let stderr = std::str::from_utf8(&output.stderr).unwrap(); + assert!(!output.status.success()); + assert_eq!("", stdout); + assert_eq!("Error: Backslash in path\n", stderr); + } +} |
