From c90c752f21d0df805e65a1b0ed0eae286b88b55e Mon Sep 17 00:00:00 2001 From: Alekos Filini Date: Sat, 8 Aug 2020 12:06:40 +0200 Subject: [PATCH] [wallet] Add `force_non_witness_utxo()` to TxBuilder --- src/blockchain/electrum.rs | 4 ++-- src/blockchain/esplora.rs | 4 ++-- src/blockchain/mod.rs | 2 +- src/cli.rs | 7 +++++-- src/wallet/mod.rs | 25 ++++++++++++++++--------- src/wallet/tx_builder.rs | 17 ++++++++++++++--- 6 files changed, 40 insertions(+), 19 deletions(-) diff --git a/src/blockchain/electrum.rs b/src/blockchain/electrum.rs index 5ac86151..8913feb8 100644 --- a/src/blockchain/electrum.rs +++ b/src/blockchain/electrum.rs @@ -68,7 +68,7 @@ impl OnlineBlockchain for ElectrumBlockchain { .map(|_| ())?) } - fn get_height(&self) -> Result { + fn get_height(&self) -> Result { // TODO: unsubscribe when added to the client, or is there a better call to use here? Ok(self @@ -76,7 +76,7 @@ impl OnlineBlockchain for ElectrumBlockchain { .as_ref() .ok_or(Error::OfflineClient)? .block_headers_subscribe() - .map(|data| data.height)?) + .map(|data| data.height as u32)?) } fn estimate_fee(&self, target: usize) -> Result { diff --git a/src/blockchain/esplora.rs b/src/blockchain/esplora.rs index 53dcb5b3..10eb5cb4 100644 --- a/src/blockchain/esplora.rs +++ b/src/blockchain/esplora.rs @@ -93,7 +93,7 @@ impl OnlineBlockchain for EsploraBlockchain { ._broadcast(tx))?) } - fn get_height(&self) -> Result { + fn get_height(&self) -> Result { Ok(await_or_block!(self .0 .as_ref() @@ -153,7 +153,7 @@ impl UrlClient { Ok(()) } - async fn _get_height(&self) -> Result { + async fn _get_height(&self) -> Result { let req = self .client .get(&format!("{}/api/blocks/tip/height", self.url)) diff --git a/src/blockchain/mod.rs b/src/blockchain/mod.rs index 08e46694..da267335 100644 --- a/src/blockchain/mod.rs +++ b/src/blockchain/mod.rs @@ -64,7 +64,7 @@ pub trait OnlineBlockchain: Blockchain { fn get_tx(&self, txid: &Txid) -> Result, Error>; fn broadcast(&self, tx: &Transaction) -> Result<(), Error>; - fn get_height(&self) -> Result; + fn get_height(&self) -> Result; fn estimate_fee(&self, target: usize) -> Result; } diff --git a/src/cli.rs b/src/cli.rs index 48f5e42b..164ae4a2 100644 --- a/src/cli.rs +++ b/src/cli.rs @@ -326,8 +326,11 @@ where .map(|s| parse_addressee(s)) .collect::, _>>() .map_err(|s| Error::Generic(s))?; - let mut tx_builder = - TxBuilder::from_addressees(addressees).send_all(sub_matches.is_present("send_all")); + let mut tx_builder = TxBuilder::from_addressees(addressees); + + if sub_matches.is_present("send_all") { + tx_builder = tx_builder.send_all(); + } if let Some(fee_rate) = sub_matches.value_of("fee_rate") { let fee_rate = f32::from_str(fee_rate).map_err(|s| Error::Generic(s.to_string()))?; diff --git a/src/wallet/mod.rs b/src/wallet/mod.rs index b6cd8532..cd2bfbc8 100644 --- a/src/wallet/mod.rs +++ b/src/wallet/mod.rs @@ -1,4 +1,5 @@ use std::cell::RefCell; +use std::collections::HashMap; use std::collections::{BTreeMap, HashSet}; use std::ops::DerefMut; use std::str::FromStr; @@ -236,6 +237,12 @@ where fee_amount, )?; let (mut txin, prev_script_pubkeys): (Vec<_>, Vec<_>) = txin.into_iter().unzip(); + // map that allows us to lookup the prev_script_pubkey for a given previous_output + let prev_script_pubkeys = txin + .iter() + .zip(prev_script_pubkeys.into_iter()) + .map(|(txin, script)| (txin.previous_output, script)) + .collect::>(); txin.iter_mut().for_each(|i| i.sequence = n_sequence); tx.input = txin; @@ -285,12 +292,13 @@ where let mut psbt = PSBT::from_unsigned_tx(tx)?; // add metadata for the inputs - for ((psbt_input, prev_script), input) in psbt + for (psbt_input, input) in psbt .inputs .iter_mut() - .zip(prev_script_pubkeys.into_iter()) .zip(psbt.global.unsigned_tx.input.iter()) { + let prev_script = prev_script_pubkeys.get(&input.previous_output).unwrap(); + // Add sighash, default is obviously "ALL" psbt_input.sighash_type = builder.sighash.or(Some(SigHashType::All)); @@ -317,7 +325,8 @@ where if derived_descriptor.is_witness() { psbt_input.witness_utxo = Some(prev_tx.output[prev_output.vout as usize].clone()); - } else { + } + if !derived_descriptor.is_witness() || builder.force_non_witness_utxo { psbt_input.non_witness_utxo = Some(prev_tx); } } @@ -535,7 +544,6 @@ where n, input.previous_output, create_height, current_height ); - // TODO: use height once we sync headers let satisfier = PSBTSatisfier::new(&psbt.inputs[n], false, create_height, current_height); @@ -778,17 +786,16 @@ where )) } + pub fn client(&self) -> &B { + &self.client + } + #[maybe_async] pub fn broadcast(&self, tx: Transaction) -> Result { maybe_await!(self.client.broadcast(&tx))?; Ok(tx.txid()) } - - #[maybe_async] - pub fn estimate_fee(&self, target: usize) -> Result { - Ok(maybe_await!(self.client.estimate_fee(target))?) - } } #[cfg(test)] diff --git a/src/wallet/tx_builder.rs b/src/wallet/tx_builder.rs index 21fc932e..ea34fbcf 100644 --- a/src/wallet/tx_builder.rs +++ b/src/wallet/tx_builder.rs @@ -7,7 +7,6 @@ use super::coin_selection::{CoinSelectionAlgorithm, DefaultCoinSelectionAlgorith use super::utils::FeeRate; use crate::types::UTXO; -// TODO: add a flag to ignore change outputs (make them unspendable) #[derive(Debug, Default)] pub struct TxBuilder { pub(crate) addressees: Vec<(Address, u64)>, @@ -22,6 +21,7 @@ pub struct TxBuilder { pub(crate) rbf: Option, pub(crate) version: Version, pub(crate) change_policy: ChangeSpendPolicy, + pub(crate) force_non_witness_utxo: bool, pub(crate) coin_selection: Cs, } @@ -46,8 +46,8 @@ impl TxBuilder { self } - pub fn send_all(mut self, send_all: bool) -> Self { - self.send_all = send_all; + pub fn send_all(mut self) -> Self { + self.send_all = true; self } @@ -122,6 +122,16 @@ impl TxBuilder { self } + pub fn change_policy(mut self, change_policy: ChangeSpendPolicy) -> Self { + self.change_policy = change_policy; + self + } + + pub fn force_non_witness_utxo(mut self) -> Self { + self.force_non_witness_utxo = true; + self + } + pub fn coin_selection(self, coin_selection: P) -> TxBuilder

{ TxBuilder { addressees: self.addressees, @@ -136,6 +146,7 @@ impl TxBuilder { rbf: self.rbf, version: self.version, change_policy: self.change_policy, + force_non_witness_utxo: self.force_non_witness_utxo, coin_selection, } }