diff --git a/Cargo.toml b/Cargo.toml index 8f417044..8db2878b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,6 +11,7 @@ bitcoin = { version = "0.23", features = ["use-serde"] } miniscript = { version = "1.0" } serde = { version = "^1.0", features = ["derive"] } serde_json = { version = "^1.0" } +rand = "^0.7" # Optional dependencies sled = { version = "0.31.0", optional = true } diff --git a/src/cli.rs b/src/cli.rs index c946f22e..735c0af9 100644 --- a/src/cli.rs +++ b/src/cli.rs @@ -326,21 +326,19 @@ where .map(|s| parse_addressee(s)) .collect::, _>>() .map_err(|s| Error::Generic(s))?; - let mut tx_builder = TxBuilder::from_addressees(addressees); + let mut tx_builder = + TxBuilder::from_addressees(addressees).send_all(sub_matches.is_present("send_all")); - if sub_matches.is_present("send_all") { - tx_builder.send_all(); - } if let Some(fee_rate) = sub_matches.value_of("fee_rate") { let fee_rate = f32::from_str(fee_rate).map_err(|s| Error::Generic(s.to_string()))?; - tx_builder.fee_rate(fee_rate); + tx_builder = tx_builder.fee_rate(fee_rate); } if let Some(utxos) = sub_matches.values_of("utxos") { let utxos = utxos .map(|i| parse_outpoint(i)) .collect::, _>>() .map_err(|s| Error::Generic(s.to_string()))?; - tx_builder.utxos(utxos); + tx_builder = tx_builder.utxos(utxos); } if let Some(unspendable) = sub_matches.values_of("unspendable") { @@ -348,15 +346,15 @@ where .map(|i| parse_outpoint(i)) .collect::, _>>() .map_err(|s| Error::Generic(s.to_string()))?; - tx_builder.unspendable(unspendable); + tx_builder = tx_builder.unspendable(unspendable); } if let Some(policy) = sub_matches.value_of("policy") { let policy = serde_json::from_str::>>(&policy) .map_err(|s| Error::Generic(s.to_string()))?; - tx_builder.policy_path(policy); + tx_builder = tx_builder.policy_path(policy); } - let result = wallet.create_tx(&tx_builder)?; + let result = wallet.create_tx(tx_builder)?; Ok(Some(format!( "{:#?}\nPSBT: {}", result.1, diff --git a/src/wallet/coin_selection.rs b/src/wallet/coin_selection.rs new file mode 100644 index 00000000..fc7211e1 --- /dev/null +++ b/src/wallet/coin_selection.rs @@ -0,0 +1,184 @@ +use bitcoin::consensus::encode::serialize; +use bitcoin::{Script, TxIn}; + +use crate::error::Error; +use crate::types::UTXO; + +pub type DefaultCoinSelectionAlgorithm = DumbCoinSelection; + +#[derive(Debug)] +pub struct CoinSelectionResult { + pub txin: Vec<(TxIn, Script)>, + pub total_amount: u64, + pub fee_amount: f32, +} + +pub trait CoinSelectionAlgorithm: std::fmt::Debug { + fn coin_select( + &self, + utxos: Vec, + use_all_utxos: bool, + fee_rate: f32, + outgoing_amount: u64, + input_witness_weight: usize, + fee_amount: f32, + ) -> Result; +} + +#[derive(Debug, Default)] +pub struct DumbCoinSelection; + +impl CoinSelectionAlgorithm for DumbCoinSelection { + fn coin_select( + &self, + mut utxos: Vec, + use_all_utxos: bool, + fee_rate: f32, + outgoing_amount: u64, + input_witness_weight: usize, + mut fee_amount: f32, + ) -> Result { + let mut txin = Vec::new(); + let calc_fee_bytes = |wu| (wu as f32) * fee_rate / 4.0; + + log::debug!( + "outgoing_amount = `{}`, fee_amount = `{}`, fee_rate = `{}`", + outgoing_amount, + fee_amount, + fee_rate + ); + + // sort so that we pick them starting from the larger. + utxos.sort_by(|a, b| a.txout.value.partial_cmp(&b.txout.value).unwrap()); + + let mut total_amount: u64 = 0; + while use_all_utxos || total_amount < outgoing_amount + (fee_amount.ceil() as u64) { + let utxo = match utxos.pop() { + Some(utxo) => utxo, + None if total_amount < outgoing_amount + (fee_amount.ceil() as u64) => { + return Err(Error::InsufficientFunds) + } + None if use_all_utxos => break, + None => return Err(Error::InsufficientFunds), + }; + + let new_in = TxIn { + previous_output: utxo.outpoint, + script_sig: Script::default(), + sequence: 0, // Let the caller choose the right nSequence + witness: vec![], + }; + fee_amount += calc_fee_bytes(serialize(&new_in).len() * 4 + input_witness_weight); + log::debug!( + "Selected {}, updated fee_amount = `{}`", + new_in.previous_output, + fee_amount + ); + + txin.push((new_in, utxo.txout.script_pubkey)); + total_amount += utxo.txout.value; + } + + Ok(CoinSelectionResult { + txin, + fee_amount, + total_amount, + }) + } +} + +#[cfg(test)] +mod test { + use std::str::FromStr; + + use bitcoin::{OutPoint, Script, TxOut}; + + use super::*; + use crate::types::*; + + const P2WPKH_WITNESS_SIZE: usize = 73 + 33 + 2; + + fn get_test_utxos() -> Vec { + vec![ + UTXO { + outpoint: OutPoint::from_str( + "ebd9813ecebc57ff8f30797de7c205e3c7498ca950ea4341ee51a685ff2fa30a:0", + ) + .unwrap(), + txout: TxOut { + value: 100_000, + script_pubkey: Script::new(), + }, + }, + UTXO { + outpoint: OutPoint::from_str( + "65d92ddff6b6dc72c89624a6491997714b90f6004f928d875bc0fd53f264fa85:0", + ) + .unwrap(), + txout: TxOut { + value: 200_000, + script_pubkey: Script::new(), + }, + }, + ] + } + + #[test] + fn test_dumb_coin_selection_success() { + let utxos = get_test_utxos(); + + let result = DumbCoinSelection + .coin_select(utxos, false, 1.0, 250_000, P2WPKH_WITNESS_SIZE, 50.0) + .unwrap(); + + assert_eq!(result.txin.len(), 2); + assert_eq!(result.total_amount, 300_000); + assert_eq!(result.fee_amount, 186.0); + } + + #[test] + fn test_dumb_coin_selection_use_all() { + let utxos = get_test_utxos(); + + let result = DumbCoinSelection + .coin_select(utxos, true, 1.0, 20_000, P2WPKH_WITNESS_SIZE, 50.0) + .unwrap(); + + assert_eq!(result.txin.len(), 2); + assert_eq!(result.total_amount, 300_000); + assert_eq!(result.fee_amount, 186.0); + } + + #[test] + fn test_dumb_coin_selection_use_only_necessary() { + let utxos = get_test_utxos(); + + let result = DumbCoinSelection + .coin_select(utxos, false, 1.0, 20_000, P2WPKH_WITNESS_SIZE, 50.0) + .unwrap(); + + assert_eq!(result.txin.len(), 1); + assert_eq!(result.total_amount, 200_000); + assert_eq!(result.fee_amount, 118.0); + } + + #[test] + #[should_panic(expected = "InsufficientFunds")] + fn test_dumb_coin_selection_insufficient_funds() { + let utxos = get_test_utxos(); + + DumbCoinSelection + .coin_select(utxos, false, 1.0, 500_000, P2WPKH_WITNESS_SIZE, 50.0) + .unwrap(); + } + + #[test] + #[should_panic(expected = "InsufficientFunds")] + fn test_dumb_coin_selection_insufficient_funds_high_fees() { + let utxos = get_test_utxos(); + + DumbCoinSelection + .coin_select(utxos, false, 1000.0, 250_000, P2WPKH_WITNESS_SIZE, 50.0) + .unwrap(); + } +} diff --git a/src/wallet/mod.rs b/src/wallet/mod.rs index 55aa9fbf..f59424d1 100644 --- a/src/wallet/mod.rs +++ b/src/wallet/mod.rs @@ -8,7 +8,7 @@ use bitcoin::blockdata::script::Builder; use bitcoin::consensus::encode::serialize; use bitcoin::util::psbt::PartiallySignedTransaction as PSBT; use bitcoin::{ - Address, Network, OutPoint, PublicKey, Script, SigHashType, Transaction, TxIn, TxOut, Txid, + Address, Network, OutPoint, PublicKey, Script, SigHashType, Transaction, TxOut, Txid, }; use miniscript::BitcoinSig; @@ -16,6 +16,7 @@ use miniscript::BitcoinSig; #[allow(unused_imports)] use log::{debug, error, info, trace}; +pub mod coin_selection; pub mod time; pub mod tx_builder; pub mod utils; @@ -123,8 +124,11 @@ where .fold(0, |sum, i| sum + i.txout.value)) } - // TODO: add a flag to ignore change in coin selection - pub fn create_tx(&self, builder: &TxBuilder) -> Result<(PSBT, TransactionDetails), Error> { + pub fn create_tx( + &self, + builder: TxBuilder, + ) -> Result<(PSBT, TransactionDetails), Error> { + // TODO: fetch both internal and external policies let policy = self.descriptor.extract_policy()?.unwrap(); if policy.requires_path() && builder.policy_path.is_none() { return Err(Error::SpendingPolicyRequired); @@ -146,12 +150,12 @@ where } // we keep it as a float while we accumulate it, and only round it at the end - let mut fee_val: f32 = 0.0; + let mut fee_amount: f32 = 0.0; let mut outgoing: u64 = 0; let mut received: u64 = 0; let calc_fee_bytes = |wu| (wu as f32) * fee_rate / 4.0; - fee_val += calc_fee_bytes(tx.get_weight()); + fee_amount += calc_fee_bytes(tx.get_weight()); for (index, (address, satoshi)) in builder.addressees.iter().enumerate() { let value = match builder.send_all { @@ -170,35 +174,45 @@ where script_pubkey: address.script_pubkey(), value, }; - fee_val += calc_fee_bytes(serialize(&new_out).len() * 4); + fee_amount += calc_fee_bytes(serialize(&new_out).len() * 4); tx.output.push(new_out); outgoing += value; } - // TODO: assumes same weight to spend external and internal - let input_witness_weight = self.descriptor.max_satisfaction_weight(); + // TODO: use the right weight instead of the maximum, and only fall-back to it if the + // script is unknown in the database + let input_witness_weight = std::cmp::max( + self.get_descriptor_for(ScriptType::Internal) + .max_satisfaction_weight(), + self.get_descriptor_for(ScriptType::External) + .max_satisfaction_weight(), + ); let (available_utxos, use_all_utxos) = self.get_available_utxos(&builder.utxos, &builder.unspendable, builder.send_all)?; - let (mut inputs, paths, selected_amount, mut fee_val) = self.coin_select( + let coin_selection::CoinSelectionResult { + txin, + total_amount, + mut fee_amount, + } = builder.coin_selection.coin_select( available_utxos, use_all_utxos, fee_rate, outgoing, input_witness_weight, - fee_val, + fee_amount, )?; - let n_sequence = if let Some(csv) = requirements.csv { - csv - } else if requirements.timelock.is_some() { - 0xFFFFFFFE - } else { - 0xFFFFFFFF + let (mut txin, prev_script_pubkeys): (Vec<_>, Vec<_>) = txin.into_iter().unzip(); + + let n_sequence = match requirements.csv { + Some(csv) => csv, + _ if requirements.timelock.is_some() => 0xFFFFFFFE, + _ => 0xFFFFFFFF, }; - inputs.iter_mut().for_each(|i| i.sequence = n_sequence); - tx.input.append(&mut inputs); + txin.iter_mut().for_each(|i| i.sequence = n_sequence); + tx.input = txin; // prepare the change output let change_output = match builder.send_all { @@ -211,12 +225,12 @@ where }; // take the change into account for fees - fee_val += calc_fee_bytes(serialize(&change_output).len() * 4); + fee_amount += calc_fee_bytes(serialize(&change_output).len() * 4); Some(change_output) } }; - let change_val = selected_amount - outgoing - (fee_val.ceil() as u64); + let change_val = total_amount - outgoing - (fee_amount.ceil() as u64); if !builder.send_all && !change_val.is_dust() { let mut change_output = change_output.unwrap(); change_output.value = change_val; @@ -225,7 +239,7 @@ where tx.output.push(change_output); } else if builder.send_all && !change_val.is_dust() { // set the outgoing value to whatever we've put in - outgoing = selected_amount; + outgoing = total_amount; // there's only one output, send everything to it tx.output[0].value = change_val; @@ -238,58 +252,57 @@ where return Err(Error::InsufficientFunds); // TODO: or OutputBelowDustLimit? } - // TODO: shuffle the outputs + if builder.shuffle_outputs.unwrap_or(true) { + use rand::seq::SliceRandom; + + let mut rng = rand::thread_rng(); + tx.output.shuffle(&mut rng); + } let txid = tx.txid(); let mut psbt = PSBT::from_unsigned_tx(tx)?; // add metadata for the inputs - for ((psbt_input, (script_type, child)), input) in psbt + for ((psbt_input, prev_script), input) in psbt .inputs .iter_mut() - .zip(paths.into_iter()) + .zip(prev_script_pubkeys.into_iter()) .zip(psbt.global.unsigned_tx.input.iter()) { - let desc = self.get_descriptor_for(script_type); - psbt_input.hd_keypaths = desc.get_hd_keypaths(child).unwrap(); - let derived_descriptor = desc.derive(child).unwrap(); + // Add sighash, default is obviously "ALL" + psbt_input.sighash_type = builder.sighash.or(Some(SigHashType::All)); + + // Try to find the prev_script in our db to figure out if this is internal or external, + // and the derivation index + let (script_type, child) = match self + .database + .borrow() + .get_path_from_script_pubkey(&prev_script)? + { + Some(x) => x, + None => continue, + }; + + let desc = self.get_descriptor_for(script_type); + psbt_input.hd_keypaths = desc.get_hd_keypaths(child)?; + let derived_descriptor = desc.derive(child)?; - // TODO: figure out what do redeem_script and witness_script mean psbt_input.redeem_script = derived_descriptor.psbt_redeem_script(); psbt_input.witness_script = derived_descriptor.psbt_witness_script(); let prev_output = input.previous_output; - let prev_tx = self - .database - .borrow() - .get_raw_tx(&prev_output.txid)? - .unwrap(); // TODO: remove unwrap - - if derived_descriptor.is_witness() { - psbt_input.witness_utxo = Some(prev_tx.output[prev_output.vout as usize].clone()); - } else { - psbt_input.non_witness_utxo = Some(prev_tx); - }; - - // we always sign with SIGHASH_ALL - psbt_input.sighash_type = Some(SigHashType::All); - } - - for (psbt_output, tx_output) in psbt - .outputs - .iter_mut() - .zip(psbt.global.unsigned_tx.output.iter()) - { - if let Some((script_type, child)) = self - .database - .borrow() - .get_path_from_script_pubkey(&tx_output.script_pubkey)? - { - let desc = self.get_descriptor_for(script_type); - psbt_output.hd_keypaths = desc.get_hd_keypaths(child)?; + if let Some(prev_tx) = self.database.borrow().get_raw_tx(&prev_output.txid)? { + if derived_descriptor.is_witness() { + psbt_input.witness_utxo = + Some(prev_tx.output[prev_output.vout as usize].clone()); + } else { + psbt_input.non_witness_utxo = Some(prev_tx); + } } } + self.add_hd_keypaths(&mut psbt)?; + let transaction_details = TransactionDetails { transaction: None, txid, @@ -600,61 +613,6 @@ where } } - fn coin_select( - &self, - mut utxos: Vec, - use_all_utxos: bool, - fee_rate: f32, - outgoing: u64, - input_witness_weight: usize, - mut fee_val: f32, - ) -> Result<(Vec, Vec<(ScriptType, u32)>, u64, f32), Error> { - let mut answer = Vec::new(); - let mut deriv_indexes = Vec::new(); - let calc_fee_bytes = |wu| (wu as f32) * fee_rate / 4.0; - - debug!( - "coin select: outgoing = `{}`, fee_val = `{}`, fee_rate = `{}`", - outgoing, fee_val, fee_rate - ); - - // sort so that we pick them starting from the larger. TODO: proper coin selection - utxos.sort_by(|a, b| a.txout.value.partial_cmp(&b.txout.value).unwrap()); - - let mut selected_amount: u64 = 0; - while use_all_utxos || selected_amount < outgoing + (fee_val.ceil() as u64) { - let utxo = match utxos.pop() { - Some(utxo) => utxo, - None if selected_amount < outgoing + (fee_val.ceil() as u64) => { - return Err(Error::InsufficientFunds) - } - None if use_all_utxos => break, - None => return Err(Error::InsufficientFunds), - }; - - let new_in = TxIn { - previous_output: utxo.outpoint, - script_sig: Script::default(), - sequence: 0xFFFFFFFD, // TODO: change according to rbf/csv - witness: vec![], - }; - fee_val += calc_fee_bytes(serialize(&new_in).len() * 4 + input_witness_weight); - debug!("coin select new fee_val = `{}`", fee_val); - - answer.push(new_in); - selected_amount += utxo.txout.value; - - let child = self - .database - .borrow() - .get_path_from_script_pubkey(&utxo.txout.script_pubkey)? - .unwrap(); // TODO: remove unrwap - deriv_indexes.push(child); - } - - Ok((answer, deriv_indexes, selected_amount, fee_val)) - } - fn add_hd_keypaths(&self, psbt: &mut PSBT) -> Result<(), Error> { let mut input_utxos = Vec::with_capacity(psbt.inputs.len()); for n in 0..psbt.inputs.len() { diff --git a/src/wallet/tx_builder.rs b/src/wallet/tx_builder.rs index 964b69ca..dab488fe 100644 --- a/src/wallet/tx_builder.rs +++ b/src/wallet/tx_builder.rs @@ -1,71 +1,105 @@ use std::collections::BTreeMap; -use bitcoin::{Address, OutPoint}; +use bitcoin::{Address, OutPoint, SigHashType}; +use super::coin_selection::{CoinSelectionAlgorithm, DefaultCoinSelectionAlgorithm}; + +// TODO: add a flag to ignore change outputs (make them unspendable) #[derive(Debug, Default)] -pub struct TxBuilder { +pub struct TxBuilder { pub(crate) addressees: Vec<(Address, u64)>, pub(crate) send_all: bool, pub(crate) fee_perkb: Option, pub(crate) policy_path: Option>>, pub(crate) utxos: Option>, pub(crate) unspendable: Option>, + pub(crate) sighash: Option, + pub(crate) shuffle_outputs: Option, + pub(crate) coin_selection: Cs, } -impl TxBuilder { - pub fn new() -> TxBuilder { - TxBuilder::default() +impl TxBuilder { + pub fn new() -> Self { + Self::default() } - pub fn from_addressees(addressees: Vec<(Address, u64)>) -> TxBuilder { - let mut tx_builder = TxBuilder::default(); - tx_builder.addressees = addressees; + pub fn from_addressees(addressees: Vec<(Address, u64)>) -> Self { + Self::default().set_addressees(addressees) + } +} - tx_builder +impl TxBuilder { + pub fn set_addressees(mut self, addressees: Vec<(Address, u64)>) -> Self { + self.addressees = addressees; + self } - pub fn add_addressee(&mut self, address: Address, amount: u64) -> &mut TxBuilder { + pub fn add_addressee(mut self, address: Address, amount: u64) -> Self { self.addressees.push((address, amount)); self } - pub fn send_all(&mut self) -> &mut TxBuilder { - self.send_all = true; + pub fn send_all(mut self, send_all: bool) -> Self { + self.send_all = send_all; self } - pub fn fee_rate(&mut self, satoshi_per_vbyte: f32) -> &mut TxBuilder { + pub fn fee_rate(mut self, satoshi_per_vbyte: f32) -> Self { self.fee_perkb = Some(satoshi_per_vbyte * 1e3); self } - pub fn fee_rate_perkb(&mut self, satoshi_per_kb: f32) -> &mut TxBuilder { + pub fn fee_rate_perkb(mut self, satoshi_per_kb: f32) -> Self { self.fee_perkb = Some(satoshi_per_kb); self } - pub fn policy_path(&mut self, policy_path: BTreeMap>) -> &mut TxBuilder { + pub fn policy_path(mut self, policy_path: BTreeMap>) -> Self { self.policy_path = Some(policy_path); self } - pub fn utxos(&mut self, utxos: Vec) -> &mut TxBuilder { + pub fn utxos(mut self, utxos: Vec) -> Self { self.utxos = Some(utxos); self } - pub fn add_utxo(&mut self, utxo: OutPoint) -> &mut TxBuilder { + pub fn add_utxo(mut self, utxo: OutPoint) -> Self { self.utxos.get_or_insert(vec![]).push(utxo); self } - pub fn unspendable(&mut self, unspendable: Vec) -> &mut TxBuilder { + pub fn unspendable(mut self, unspendable: Vec) -> Self { self.unspendable = Some(unspendable); self } - pub fn add_unspendable(&mut self, unspendable: OutPoint) -> &mut TxBuilder { + pub fn add_unspendable(mut self, unspendable: OutPoint) -> Self { self.unspendable.get_or_insert(vec![]).push(unspendable); self } + + pub fn sighash(mut self, sighash: SigHashType) -> Self { + self.sighash = Some(sighash); + self + } + + pub fn do_not_shuffle_outputs(mut self) -> Self { + self.shuffle_outputs = Some(false); + self + } + + pub fn coin_selection(self, coin_selection: P) -> TxBuilder

{ + TxBuilder { + addressees: self.addressees, + send_all: self.send_all, + fee_perkb: self.fee_perkb, + policy_path: self.policy_path, + utxos: self.utxos, + unspendable: self.unspendable, + sighash: self.sighash, + shuffle_outputs: self.shuffle_outputs, + coin_selection, + } + } }