From 991db281702535110998f0fd574e652f3ad2e6b1 Mon Sep 17 00:00:00 2001 From: Alekos Filini Date: Mon, 17 Aug 2020 12:10:51 +0200 Subject: [PATCH] [wallet] Add explicit ordering for the signers --- src/wallet/mod.rs | 18 +++++---- src/wallet/signer.rs | 87 +++++++++++++++++++++++++++++++++++++++----- 2 files changed, 89 insertions(+), 16 deletions(-) diff --git a/src/wallet/mod.rs b/src/wallet/mod.rs index 46a93b86..57d822ad 100644 --- a/src/wallet/mod.rs +++ b/src/wallet/mod.rs @@ -24,7 +24,7 @@ pub mod tx_builder; pub mod utils; use address_validator::AddressValidator; -use signer::{Signer, SignerId, SignersContainer}; +use signer::{Signer, SignerId, SignerOrdering, SignersContainer}; use tx_builder::TxBuilder; use utils::{After, FeeRate, IsDust, Older}; @@ -142,6 +142,7 @@ where &mut self, script_type: ScriptType, id: SignerId, + ordering: SignerOrdering, signer: Arc>, ) { let signers = match script_type { @@ -149,7 +150,7 @@ where ScriptType::Internal => Arc::make_mut(&mut self.change_signers), }; - signers.add_external(id, signer); + signers.add_external(id, ordering, signer); } pub fn add_address_validator(&mut self, validator: Arc>) { @@ -575,15 +576,18 @@ where Ok((psbt, details)) } - // TODO: define an enum for signing errors pub fn sign(&self, mut psbt: PSBT, assume_height: Option) -> Result<(PSBT, bool), Error> { // this helps us doing our job later self.add_input_hd_keypaths(&mut psbt)?; - for index in 0..psbt.inputs.len() { - self.signers.sign(&mut psbt, index)?; - if self.change_descriptor.is_some() { - self.change_signers.sign(&mut psbt, index)?; + for signer in self + .signers + .signers() + .iter() + .chain(self.change_signers.signers().iter()) + { + for index in 0..psbt.inputs.len() { + signer.sign(&mut psbt, index)?; } } diff --git a/src/wallet/signer.rs b/src/wallet/signer.rs index 687bc8c6..d94589c1 100644 --- a/src/wallet/signer.rs +++ b/src/wallet/signer.rs @@ -1,6 +1,8 @@ use std::any::Any; -use std::collections::HashMap; +use std::cmp::Ordering; +use std::collections::BTreeMap; use std::fmt; +use std::ops::Bound::Included; use std::sync::Arc; use bitcoin::blockdata::opcodes; @@ -150,9 +152,35 @@ impl Signer for PrivateKey { } } +#[derive(Debug, Clone, PartialOrd, PartialEq, Ord, Eq)] +pub struct SignerOrdering(pub usize); + +impl std::default::Default for SignerOrdering { + fn default() -> Self { + SignerOrdering(100) + } +} + +#[derive(Debug, Clone)] +struct SignersContainerKey { + id: SignerId, + ordering: SignerOrdering, +} + +impl From<(SignerId, SignerOrdering)> for SignersContainerKey { + fn from(tuple: (SignerId, SignerOrdering)) -> Self { + SignersContainerKey { + id: tuple.0, + ordering: tuple.1, + } + } +} + /// Container for multiple signers #[derive(Debug, Default, Clone)] -pub struct SignersContainer(HashMap, Arc>>); +pub struct SignersContainer( + BTreeMap, Arc>>, +); impl SignersContainer { pub fn as_key_map(&self) -> KeyMap { @@ -190,10 +218,12 @@ impl From for SignersContainer { .public_key(&Secp256k1::signing_only()) .to_pubkeyhash(), ), + SignerOrdering::default(), Arc::new(Box::new(private_key)), ), DescriptorSecretKey::XPrv(xprv) => container.add_external( SignerId::from(xprv.root_fingerprint()), + SignerOrdering::default(), Arc::new(Box::new(xprv)), ), }; @@ -206,7 +236,7 @@ impl From for SignersContainer { impl SignersContainer { /// Default constructor pub fn new() -> Self { - SignersContainer(HashMap::new()) + SignersContainer(Default::default()) } /// Adds an external signer to the container for the specified id. Optionally returns the @@ -214,24 +244,43 @@ impl SignersContainer { pub fn add_external( &mut self, id: SignerId, + ordering: SignerOrdering, signer: Arc>, ) -> Option>> { - self.0.insert(id, signer) + self.0.insert((id, ordering).into(), signer) } /// Removes a signer from the container and returns it - pub fn remove(&mut self, id: SignerId) -> Option>> { - self.0.remove(&id) + pub fn remove( + &mut self, + id: SignerId, + ordering: SignerOrdering, + ) -> Option>> { + self.0.remove(&(id, ordering).into()) } /// Returns the list of identifiers of all the signers in the container pub fn ids(&self) -> Vec<&SignerId> { - self.0.keys().collect() + self.0 + .keys() + .map(|SignersContainerKey { id, .. }| id) + .collect() } - /// Finds the signer with a given id in the container + /// Returns the list of signers in the container, sorted by lowest to highest `ordering` + pub fn signers(&self) -> Vec<&Arc>> { + self.0.values().collect() + } + + /// Finds the signer with lowest ordering for a given id in the container. pub fn find(&self, id: SignerId) -> Option<&Arc>> { - self.0.get(&id) + self.0 + .range(( + Included(&(id.clone(), SignerOrdering(0)).into()), + Included(&(id, SignerOrdering(usize::MAX)).into()), + )) + .map(|(_, v)| v) + .nth(0) } } @@ -327,3 +376,23 @@ impl ComputeSighash for Segwitv0 { )) } } + +impl PartialOrd for SignersContainerKey { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for SignersContainerKey { + fn cmp(&self, other: &Self) -> Ordering { + self.ordering.cmp(&other.ordering) + } +} + +impl PartialEq for SignersContainerKey { + fn eq(&self, other: &Self) -> bool { + self.ordering == other.ordering + } +} + +impl Eq for SignersContainerKey {}