From fd34956c2980295e1f0cc32340fb2b99a6b245ea Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BF=97=E5=AE=87?= Date: Thu, 29 Sep 2022 13:06:03 +0800 Subject: [PATCH] `get_checksum_bytes` now checks input data for checksum If `exclude_hash` is set, we split the input data, and if a checksum already existed within the original data, we check the calculated checksum against the original checksum. Additionally, the implementation of `IntoWalletDescriptor` for `&str` has been refactored for clarity. --- src/descriptor/checksum.rs | 22 +++++++++++++++++++--- src/descriptor/mod.rs | 21 +++++++++------------ src/wallet/mod.rs | 13 ++++--------- 3 files changed, 32 insertions(+), 24 deletions(-) diff --git a/src/descriptor/checksum.rs b/src/descriptor/checksum.rs index 5ed1151b..8dfdac49 100644 --- a/src/descriptor/checksum.rs +++ b/src/descriptor/checksum.rs @@ -41,12 +41,21 @@ fn poly_mod(mut c: u64, val: u64) -> u64 { c } -/// Computes the checksum bytes of a descriptor -pub fn get_checksum_bytes(desc: &str) -> Result<[u8; 8], DescriptorError> { +/// Computes the checksum bytes of a descriptor. +/// `exclude_hash = true` ignores all data after the first '#' (inclusive). +pub fn get_checksum_bytes(mut desc: &str, exclude_hash: bool) -> Result<[u8; 8], DescriptorError> { let mut c = 1; let mut cls = 0; let mut clscount = 0; + let mut original_checksum = None; + if exclude_hash { + if let Some(split) = desc.split_once('#') { + desc = split.0; + original_checksum = Some(split.1); + } + } + for ch in desc.as_bytes() { let pos = INPUT_CHARSET .iter() @@ -72,13 +81,20 @@ pub fn get_checksum_bytes(desc: &str) -> Result<[u8; 8], DescriptorError> { checksum[j] = CHECKSUM_CHARSET[((c >> (5 * (7 - j))) & 31) as usize]; } + // if input data already had a checksum, check calculated checksum against original checksum + if let Some(original_checksum) = original_checksum { + if original_checksum.as_bytes() != &checksum { + return Err(DescriptorError::InvalidDescriptorChecksum); + } + } + Ok(checksum) } /// Compute the checksum of a descriptor pub fn get_checksum(desc: &str) -> Result { // unsafe is okay here as the checksum only uses bytes in `CHECKSUM_CHARSET` - get_checksum_bytes(desc).map(|b| unsafe { String::from_utf8_unchecked(b.to_vec()) }) + get_checksum_bytes(desc, true).map(|b| unsafe { String::from_utf8_unchecked(b.to_vec()) }) } #[cfg(test)] diff --git a/src/descriptor/mod.rs b/src/descriptor/mod.rs index 802ccd19..7c51d27f 100644 --- a/src/descriptor/mod.rs +++ b/src/descriptor/mod.rs @@ -40,6 +40,7 @@ pub mod policy; pub mod template; pub use self::checksum::get_checksum; +use self::checksum::get_checksum_bytes; pub use self::derived::{AsDerived, DerivedDescriptorKey}; pub use self::error::Error as DescriptorError; pub use self::policy::Policy; @@ -84,19 +85,15 @@ impl IntoWalletDescriptor for &str { secp: &SecpCtx, network: Network, ) -> Result<(ExtendedDescriptor, KeyMap), DescriptorError> { - let descriptor = if self.contains('#') { - let parts: Vec<&str> = self.splitn(2, '#').collect(); - if !get_checksum(parts[0]) - .ok() - .map(|computed| computed == parts[1]) - .unwrap_or(false) - { - return Err(DescriptorError::InvalidDescriptorChecksum); + let descriptor = match self.split_once('#') { + Some((desc, original_checksum)) => { + let checksum = get_checksum_bytes(desc, false)?; + if original_checksum.as_bytes() != &checksum { + return Err(DescriptorError::InvalidDescriptorChecksum); + } + desc } - - parts[0] - } else { - self + None => self, }; ExtendedDescriptor::parse_descriptor(secp, descriptor)? diff --git a/src/wallet/mod.rs b/src/wallet/mod.rs index 2e3d9fdf..776e1740 100644 --- a/src/wallet/mod.rs +++ b/src/wallet/mod.rs @@ -1943,15 +1943,10 @@ pub(crate) mod test { let (wallet, _, _) = get_funded_wallet(get_test_wpkh()); let checksum = wallet.descriptor_checksum(KeychainKind::External); assert_eq!(checksum.len(), 8); - - let raw_descriptor = wallet - .descriptor - .to_string() - .split_once('#') - .unwrap() - .0 - .to_string(); - assert_eq!(get_checksum(&raw_descriptor).unwrap(), checksum); + assert_eq!( + get_checksum(&wallet.descriptor.to_string()).unwrap(), + checksum + ); } #[test]