From 7493cdc5fa3d5c2a7a12763be7af699525ba6337 Mon Sep 17 00:00:00 2001 From: Andrew Gunnerson Date: Mon, 18 Sep 2023 00:12:09 -0400 Subject: [PATCH] Stopping spewing Arc everywhere for the cancel signal There's no reason all these functions need to care about the ownership of the cancel signal. Signed-off-by: Andrew Gunnerson --- avbroot/src/boot.rs | 12 ++++++------ avbroot/src/cli/args.rs | 4 ++-- avbroot/src/cli/avb.rs | 6 +++--- avbroot/src/cli/ota.rs | 22 +++++++++++----------- avbroot/src/format/avb.rs | 15 ++++++--------- avbroot/src/format/ota.rs | 7 ++----- avbroot/src/format/payload.rs | 10 +++++----- avbroot/src/stream.rs | 13 +++++-------- e2e/src/download.rs | 9 ++++----- e2e/src/main.rs | 30 +++++++++++++----------------- 10 files changed, 57 insertions(+), 71 deletions(-) diff --git a/avbroot/src/boot.rs b/avbroot/src/boot.rs index 46bcf1c..50e9414 100644 --- a/avbroot/src/boot.rs +++ b/avbroot/src/boot.rs @@ -11,7 +11,7 @@ use std::{ num::ParseIntError, ops::Range, path::{Path, PathBuf}, - sync::{atomic::AtomicBool, Arc}, + sync::atomic::AtomicBool, }; use regex::bytes::Regex; @@ -89,7 +89,7 @@ fn save_ramdisk(entries: &[CpioEntryNew], format: CompressedFormat) -> Result) -> Result<()>; + fn patch(&self, boot_image: &mut BootImage, cancel_signal: &AtomicBool) -> Result<()>; } /// Root a boot image with Magisk. @@ -254,7 +254,7 @@ impl MagiskRootPatcher { } impl BootImagePatcher for MagiskRootPatcher { - fn patch(&self, boot_image: &mut BootImage, cancel_signal: &Arc) -> Result<()> { + fn patch(&self, boot_image: &mut BootImage, cancel_signal: &AtomicBool) -> Result<()> { let zip_reader = File::open(&self.apk_path)?; let mut zip = ZipArchive::new(BufReader::new(zip_reader))?; @@ -466,7 +466,7 @@ impl OtaCertPatcher { } impl BootImagePatcher for OtaCertPatcher { - fn patch(&self, boot_image: &mut BootImage, _cancel_signal: &Arc) -> Result<()> { + fn patch(&self, boot_image: &mut BootImage, _cancel_signal: &AtomicBool) -> Result<()> { let patched_any = match boot_image { BootImage::V0Through2(b) => self.patch_ramdisk(&mut b.ramdisk)?, BootImage::V3Through4(b) => self.patch_ramdisk(&mut b.ramdisk)?, @@ -557,7 +557,7 @@ impl PrepatchedImagePatcher { } impl BootImagePatcher for PrepatchedImagePatcher { - fn patch(&self, boot_image: &mut BootImage, _cancel_signal: &Arc) -> Result<()> { + fn patch(&self, boot_image: &mut BootImage, _cancel_signal: &AtomicBool) -> Result<()> { let prepatched_image = { let raw_reader = File::open(&self.prepatched)?; BootImage::from_reader(BufReader::new(raw_reader))? @@ -735,7 +735,7 @@ pub fn patch_boot( writer: impl Write + Seek, key: &RsaPrivateKey, patchers: &[Box], - cancel_signal: &Arc, + cancel_signal: &AtomicBool, ) -> Result<()> { let (mut header, footer, image_size) = avb::load_image(&mut reader)?; let Some(footer) = footer else { diff --git a/avbroot/src/cli/args.rs b/avbroot/src/cli/args.rs index 32d30ae..7c75baa 100644 --- a/avbroot/src/cli/args.rs +++ b/avbroot/src/cli/args.rs @@ -3,7 +3,7 @@ * SPDX-License-Identifier: GPL-3.0-only */ -use std::sync::{atomic::AtomicBool, Arc}; +use std::sync::atomic::AtomicBool; use anyhow::Result; use clap::{Parser, Subcommand}; @@ -34,7 +34,7 @@ pub struct Cli { pub command: Command, } -pub fn main(cancel_signal: &Arc) -> Result<()> { +pub fn main(cancel_signal: &AtomicBool) -> Result<()> { let cli = Cli::parse(); match cli.command { diff --git a/avbroot/src/cli/avb.rs b/avbroot/src/cli/avb.rs index 4d5008c..5f48684 100644 --- a/avbroot/src/cli/avb.rs +++ b/avbroot/src/cli/avb.rs @@ -10,7 +10,7 @@ use std::{ io::{self, BufReader}, path::{Path, PathBuf}, str, - sync::{atomic::AtomicBool, Arc}, + sync::atomic::AtomicBool, }; use anyhow::{anyhow, bail, Context, Result}; @@ -107,7 +107,7 @@ pub fn verify_headers( pub fn verify_descriptors( directory: &Path, descriptors: &HashMap, - cancel_signal: &Arc, + cancel_signal: &AtomicBool, ) -> Result<()> { descriptors .par_iter() @@ -149,7 +149,7 @@ pub fn verify_descriptors( .collect() } -pub fn avb_main(cli: &AvbCli, cancel_signal: &Arc) -> Result<()> { +pub fn avb_main(cli: &AvbCli, cancel_signal: &AtomicBool) -> Result<()> { match &cli.command { AvbCommand::Dump(c) => { let raw_reader = File::open(&c.input) diff --git a/avbroot/src/cli/ota.rs b/avbroot/src/cli/ota.rs index 8bdbda3..97512d6 100644 --- a/avbroot/src/cli/ota.rs +++ b/avbroot/src/cli/ota.rs @@ -11,7 +11,7 @@ use std::{ fs::{self, File}, io::{self, BufReader, BufWriter, Cursor, Read, Seek, SeekFrom, Write}, path::{Path, PathBuf}, - sync::{atomic::AtomicBool, Arc, Mutex}, + sync::{atomic::AtomicBool, Mutex}, time::Instant, }; @@ -143,7 +143,7 @@ fn open_input_streams( required_images: &HashMap, external_images: &HashMap, header: &PayloadHeader, - cancel_signal: &Arc, + cancel_signal: &AtomicBool, ) -> Result>> { let mut input_streams = HashMap::>::new(); @@ -185,7 +185,7 @@ fn patch_boot_images( root_patcher: Option>, key_avb: &RsaPrivateKey, cert_ota: &Certificate, - cancel_signal: &Arc, + cancel_signal: &AtomicBool, ) -> Result<()> { let mut boot_patchers = HashMap::<&str, Vec>>::new(); boot_patchers @@ -426,7 +426,7 @@ fn compress_image( mut stream: &mut Box, header: &Mutex, block_size: u32, - cancel_signal: &Arc, + cancel_signal: &AtomicBool, ) -> Result<()> { stream.rewind()?; @@ -460,7 +460,7 @@ fn patch_ota_payload( key_avb: &RsaPrivateKey, key_ota: &RsaPrivateKey, cert_ota: &Certificate, - cancel_signal: &Arc, + cancel_signal: &AtomicBool, ) -> Result<(String, u64)> { let header = PayloadHeader::from_reader(open_payload()?).context("Failed to load OTA payload header")?; @@ -630,7 +630,7 @@ fn patch_ota_zip( key_avb: &RsaPrivateKey, key_ota: &RsaPrivateKey, cert_ota: &Certificate, - cancel_signal: &Arc, + cancel_signal: &AtomicBool, ) -> Result<(OtaMetadata, u64)> { let mut missing = BTreeSet::from([ ota::PATH_METADATA_PB, @@ -797,7 +797,7 @@ fn extract_ota_zip( payload_size: u64, header: &PayloadHeader, images: &BTreeSet, - cancel_signal: &Arc, + cancel_signal: &AtomicBool, ) -> Result<()> { for name in images { if Path::new(name).file_name() != Some(OsStr::new(name)) { @@ -843,7 +843,7 @@ fn extract_ota_zip( Ok(()) } -pub fn patch_subcommand(cli: &PatchCli, cancel_signal: &Arc) -> Result<()> { +pub fn patch_subcommand(cli: &PatchCli, cancel_signal: &AtomicBool) -> Result<()> { let output = cli.output.as_ref().map_or_else( || { let mut s = cli.input.clone().into_os_string(); @@ -1008,7 +1008,7 @@ pub fn patch_subcommand(cli: &PatchCli, cancel_signal: &Arc) -> Resu Ok(()) } -pub fn extract_subcommand(cli: &ExtractCli, cancel_signal: &Arc) -> Result<()> { +pub fn extract_subcommand(cli: &ExtractCli, cancel_signal: &AtomicBool) -> Result<()> { let raw_reader = File::open(&cli.input) .map(PSeekFile::new) .with_context(|| format!("Failed to open for reading: {:?}", cli.input))?; @@ -1067,7 +1067,7 @@ pub fn extract_subcommand(cli: &ExtractCli, cancel_signal: &Arc) -> Ok(()) } -pub fn verify_subcommand(cli: &VerifyCli, cancel_signal: &Arc) -> Result<()> { +pub fn verify_subcommand(cli: &VerifyCli, cancel_signal: &AtomicBool) -> Result<()> { let raw_reader = File::open(&cli.input) .map(PSeekFile::new) .with_context(|| format!("Failed to open for reading: {:?}", cli.input))?; @@ -1184,7 +1184,7 @@ pub fn verify_subcommand(cli: &VerifyCli, cancel_signal: &Arc) -> Re Ok(()) } -pub fn ota_main(cli: &OtaCli, cancel_signal: &Arc) -> Result<()> { +pub fn ota_main(cli: &OtaCli, cancel_signal: &AtomicBool) -> Result<()> { match &cli.command { OtaCommand::Patch(c) => patch_subcommand(c, cancel_signal), OtaCommand::Extract(c) => extract_subcommand(c, cancel_signal), diff --git a/avbroot/src/format/avb.rs b/avbroot/src/format/avb.rs index 122fe96..5c40ea6 100644 --- a/avbroot/src/format/avb.rs +++ b/avbroot/src/format/avb.rs @@ -7,10 +7,7 @@ use std::{ cmp, fmt, io::{self, Cursor, Read, Seek, SeekFrom, Write}, str, - sync::{ - atomic::{AtomicBool, Ordering}, - Arc, - }, + sync::atomic::{AtomicBool, Ordering}, }; use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt}; @@ -337,7 +334,7 @@ impl HashtreeDescriptor { block_size: u32, algorithm: &'static Algorithm, salt: &[u8], - cancel_signal: &Arc, + cancel_signal: &AtomicBool, ) -> io::Result> { // Each digest must be a power of 2. let digest_padding = algorithm.output_len.next_power_of_two() - algorithm.output_len; @@ -384,7 +381,7 @@ impl HashtreeDescriptor { block_size: u32, algorithm: &'static Algorithm, salt: &[u8], - cancel_signal: &Arc, + cancel_signal: &AtomicBool, ) -> io::Result> { assert!( image_size > block_size as u64, @@ -418,7 +415,7 @@ impl HashtreeDescriptor { block_size: u32, algorithm: &'static Algorithm, salt: &[u8], - cancel_signal: &Arc, + cancel_signal: &AtomicBool, ) -> io::Result<(Vec, Vec)> { // Small files are hashed directly, exactly like a hash descriptor. if image_size <= u64::from(block_size) { @@ -486,7 +483,7 @@ impl HashtreeDescriptor { pub fn verify( &self, open_input: impl Fn() -> io::Result> + Sync, - cancel_signal: &Arc, + cancel_signal: &AtomicBool, ) -> Result<()> { let algorithm = match self.hash_algorithm.as_str() { "sha256" => &ring::digest::SHA256, @@ -683,7 +680,7 @@ impl fmt::Debug for HashDescriptor { impl HashDescriptor { /// Verify the root hash against the input reader. - pub fn verify(&self, reader: impl Read, cancel_signal: &Arc) -> Result<()> { + pub fn verify(&self, reader: impl Read, cancel_signal: &AtomicBool) -> Result<()> { let algorithm = match self.hash_algorithm.as_str() { "sha256" => &ring::digest::SHA256, "sha512" => &ring::digest::SHA512, diff --git a/avbroot/src/format/ota.rs b/avbroot/src/format/ota.rs index 193b595..0360d08 100644 --- a/avbroot/src/format/ota.rs +++ b/avbroot/src/format/ota.rs @@ -7,7 +7,7 @@ use std::{ collections::BTreeMap, io::{self, Cursor, Read, Seek, SeekFrom, Write}, iter, - sync::{atomic::AtomicBool, Arc}, + sync::atomic::AtomicBool, }; use cms::signed_data::SignedData; @@ -452,10 +452,7 @@ fn parse_ota_sig(mut reader: impl Read + Seek) -> Result<(SignedData, u64)> { /// CMS signed attributes are intentionally not supported because AOSP recovery /// does not support them either. It expects the CMS [`SignedData`] structure to /// be used for nothing more than a raw signature transport mechanism. -pub fn verify_ota( - mut reader: impl Read + Seek, - cancel_signal: &Arc, -) -> Result { +pub fn verify_ota(mut reader: impl Read + Seek, cancel_signal: &AtomicBool) -> Result { let (sd, hashed_size) = parse_ota_sig(&mut reader)?; // Make sure the certificate in the CMS structure matches the otacert zip diff --git a/avbroot/src/format/payload.rs b/avbroot/src/format/payload.rs index 1a241d8..295d840 100644 --- a/avbroot/src/format/payload.rs +++ b/avbroot/src/format/payload.rs @@ -6,7 +6,7 @@ use std::{ collections::{HashMap, HashSet}, io::{self, Cursor, Read, Seek, SeekFrom, Write}, - sync::{atomic::AtomicBool, Arc}, + sync::atomic::AtomicBool, }; use base64::engine::general_purpose::STANDARD; @@ -619,7 +619,7 @@ pub fn verify_payload( mut reader: impl Read + Seek, cert: &Certificate, properties_raw: &str, - cancel_signal: &Arc, + cancel_signal: &AtomicBool, ) -> Result<()> { let header = PayloadHeader::from_reader(&mut reader)?; reader.rewind()?; @@ -753,7 +753,7 @@ pub fn apply_operation( block_size: u32, blob_offset: u64, op: &InstallOperation, - cancel_signal: &Arc, + cancel_signal: &AtomicBool, ) -> Result<()> { for extent in &op.dst_extents { let start_block = extent @@ -859,7 +859,7 @@ pub fn extract_image_to_memory( open_payload: impl Fn() -> io::Result> + Sync, header: &PayloadHeader, partition_name: &str, - cancel_signal: &Arc, + cancel_signal: &AtomicBool, ) -> Result { let partition = header .manifest @@ -900,7 +900,7 @@ pub fn extract_images<'a>( open_output: impl Fn(&str) -> io::Result> + Sync, header: &PayloadHeader, partition_names: impl IntoIterator, - cancel_signal: &Arc, + cancel_signal: &AtomicBool, ) -> Result<()> { let mut remaining = partition_names.into_iter().collect::>(); // We parallelize at the operation level or else one thread might get stuck diff --git a/avbroot/src/stream.rs b/avbroot/src/stream.rs index 745f11b..459c5e8 100644 --- a/avbroot/src/stream.rs +++ b/avbroot/src/stream.rs @@ -569,7 +569,7 @@ pub fn copy_n_inspect( mut writer: impl Write, mut size: u64, mut inspect: impl FnMut(&[u8]), - cancel_signal: &Arc, + cancel_signal: &AtomicBool, ) -> io::Result<()> { let mut buf = [0u8; 16384]; @@ -599,7 +599,7 @@ pub fn copy_n( reader: impl Read, writer: impl Write, size: u64, - cancel_signal: &Arc, + cancel_signal: &AtomicBool, ) -> io::Result<()> { copy_n_inspect(reader, writer, size, |_| {}, cancel_signal) } @@ -610,7 +610,7 @@ pub fn copy_n( pub fn copy( mut reader: impl Read, mut writer: impl Write, - cancel_signal: &Arc, + cancel_signal: &AtomicBool, ) -> io::Result { let mut buf = [0u8; 16384]; let mut copied = 0; @@ -640,10 +640,7 @@ pub fn copy( mod tests { use std::{ io::{self, Cursor, Read, Seek, SeekFrom, Write}, - sync::{ - atomic::{AtomicBool, Ordering}, - Arc, - }, + sync::atomic::{AtomicBool, Ordering}, }; use ring::digest::Context; @@ -869,7 +866,7 @@ mod tests { #[test] fn copy() { - let cancel_signal = Arc::new(AtomicBool::new(false)); + let cancel_signal = AtomicBool::new(false); let mut reader = Cursor::new(b"foobar"); let mut writer = Cursor::new([0u8; 6]); diff --git a/e2e/src/download.rs b/e2e/src/download.rs index 0bd921c..f1b922d 100644 --- a/e2e/src/download.rs +++ b/e2e/src/download.rs @@ -12,7 +12,6 @@ use std::{ sync::{ atomic::{AtomicBool, Ordering}, mpsc::{self, Sender}, - Arc, }, thread::{self, ThreadId}, time::{Duration, Instant}, @@ -140,7 +139,7 @@ fn download_range( mut file: PSeekFile, initial_range: Range, channel: Sender, - cancel_signal: &Arc, + cancel_signal: &AtomicBool, ) -> Result<()> { assert!(initial_range.start < initial_range.end); @@ -215,7 +214,7 @@ fn download_thread( file: PSeekFile, initial_range: Range, channel: mpsc::Sender, - cancel_signal: &Arc, + cancel_signal: &AtomicBool, ) { let result = download_range(url, file, initial_range, channel.clone(), cancel_signal); @@ -255,7 +254,7 @@ fn download_ranges( display: &mut dyn ProgressDisplay, max_threads: usize, max_errors: u8, - cancel_signal: &Arc, + cancel_signal: &AtomicBool, ) -> Result>> { let file_size = get_content_length(url)?; @@ -433,7 +432,7 @@ pub fn download( display: &mut dyn ProgressDisplay, max_tasks: usize, max_errors: u8, - cancel_signal: &Arc, + cancel_signal: &AtomicBool, ) -> Result<()> { let state_path = state_path(output); let ranges = match read_state(&state_path)? { diff --git a/e2e/src/main.rs b/e2e/src/main.rs index fa33936..a5ddf88 100644 --- a/e2e/src/main.rs +++ b/e2e/src/main.rs @@ -90,7 +90,7 @@ fn exclusion_to_inclusion(holes: &[Range], file_range: Range) -> Resul fn strip_image( input: &Path, output: &Path, - cancel_signal: &Arc, + cancel_signal: &AtomicBool, ) -> Result<(Vec>, [u8; 32])> { println!("Stripping {input:?} to {output:?}"); @@ -194,7 +194,7 @@ fn url_filename(url: &str) -> Result<&str> { .ok_or_else(|| anyhow!("Failed to determine filename from URL: {url}")) } -fn hash_file(path: &Path, cancel_signal: &Arc) -> Result<[u8; 32]> { +fn hash_file(path: &Path, cancel_signal: &AtomicBool) -> Result<[u8; 32]> { println!("Calculating hash of {path:?}"); let raw_reader = @@ -211,7 +211,7 @@ fn hash_file(path: &Path, cancel_signal: &Arc) -> Result<[u8; 32]> { Ok(digest.as_ref().try_into().unwrap()) } -fn verify_hash(path: &Path, sha256: &[u8; 32], cancel_signal: &Arc) -> Result<()> { +fn verify_hash(path: &Path, sha256: &[u8; 32], cancel_signal: &AtomicBool) -> Result<()> { let digest = hash_file(path, cancel_signal)?; if sha256 != digest.as_ref() { @@ -239,7 +239,7 @@ fn download_file( sections: Option<&[Range]>, path_is_dir: bool, validate: Validate, - cancel_signal: &Arc, + cancel_signal: &AtomicBool, ) -> Result { let path = if path_is_dir { path.join(url_filename(url)?) @@ -285,7 +285,7 @@ fn download_magisk( config: &Config, work_dir: &Path, revalidate: bool, - cancel_signal: &Arc, + cancel_signal: &AtomicBool, ) -> Result { download_file( &work_dir.join("magisk"), @@ -308,7 +308,7 @@ fn download_image( work_dir: &Path, stripped: bool, revalidate: bool, - cancel_signal: &Arc, + cancel_signal: &AtomicBool, ) -> Result { let info = &config.device[device]; let mut path = work_dir.join(device); @@ -396,7 +396,7 @@ fn patch_image( input_file: &Path, output_file: &Path, extra_args: &[OsString], - cancel_signal: &Arc, + cancel_signal: &AtomicBool, ) -> Result<()> { println!("Patching {input_file:?}"); @@ -426,11 +426,7 @@ fn patch_image( Ok(()) } -fn extract_image( - input_file: &Path, - output_dir: &Path, - cancel_signal: &Arc, -) -> Result<()> { +fn extract_image(input_file: &Path, output_dir: &Path, cancel_signal: &AtomicBool) -> Result<()> { println!("Extracting AVB partitions from {input_file:?}"); let cli = ExtractCli::try_parse_from([ @@ -445,7 +441,7 @@ fn extract_image( Ok(()) } -fn verify_image(input_file: &Path, cancel_signal: &Arc) -> Result<()> { +fn verify_image(input_file: &Path, cancel_signal: &AtomicBool) -> Result<()> { println!("Verifying signatures in {input_file:?}"); let (_temp_key_dir, _, key_args) = test_keys()?; @@ -510,7 +506,7 @@ fn filter_devices<'a>(config: &'a Config, cli: &'a DeviceGroup) -> Result) -> Result<()> { +fn strip_subcommand(cli: &StripCli, cancel_signal: &AtomicBool) -> Result<()> { let (sections, sha256) = strip_image(&cli.input, &cli.output, cancel_signal)?; println!("Preserved sections:"); @@ -523,7 +519,7 @@ fn strip_subcommand(cli: &StripCli, cancel_signal: &Arc) -> Result<( Ok(()) } -fn add_subcommand(cli: &AddCli, cancel_signal: &Arc) -> Result<()> { +fn add_subcommand(cli: &AddCli, cancel_signal: &AtomicBool) -> Result<()> { let (config, mut document) = config::load_config(&cli.config.config)?; let image_dir = cli.config.work_dir.join(&cli.device); @@ -637,7 +633,7 @@ fn add_subcommand(cli: &AddCli, cancel_signal: &Arc) -> Result<()> { Ok(()) } -fn download_subcommand(cli: &DownloadCli, cancel_signal: &Arc) -> Result<()> { +fn download_subcommand(cli: &DownloadCli, cancel_signal: &AtomicBool) -> Result<()> { let (config, _) = config::load_config(&cli.config.config)?; let devices = filter_devices(&config, &cli.device)?; @@ -668,7 +664,7 @@ fn download_subcommand(cli: &DownloadCli, cancel_signal: &Arc) -> Re Ok(()) } -fn test_subcommand(cli: &TestCli, cancel_signal: &Arc) -> Result<()> { +fn test_subcommand(cli: &TestCli, cancel_signal: &AtomicBool) -> Result<()> { let (config, _) = config::load_config(&cli.config.config)?; let devices = filter_devices(&config, &cli.device)?;