[wallet] Abstract coin selection in a separate trait

This commit is contained in:
Alekos Filini
2020-08-06 16:56:41 +02:00
parent 499e579824
commit 7a23b2b558
5 changed files with 314 additions and 139 deletions

View File

@@ -8,7 +8,7 @@ use bitcoin::blockdata::script::Builder;
use bitcoin::consensus::encode::serialize;
use bitcoin::util::psbt::PartiallySignedTransaction as PSBT;
use bitcoin::{
Address, Network, OutPoint, PublicKey, Script, SigHashType, Transaction, TxIn, TxOut, Txid,
Address, Network, OutPoint, PublicKey, Script, SigHashType, Transaction, TxOut, Txid,
};
use miniscript::BitcoinSig;
@@ -16,6 +16,7 @@ use miniscript::BitcoinSig;
#[allow(unused_imports)]
use log::{debug, error, info, trace};
pub mod coin_selection;
pub mod time;
pub mod tx_builder;
pub mod utils;
@@ -123,8 +124,11 @@ where
.fold(0, |sum, i| sum + i.txout.value))
}
// TODO: add a flag to ignore change in coin selection
pub fn create_tx(&self, builder: &TxBuilder) -> Result<(PSBT, TransactionDetails), Error> {
pub fn create_tx<Cs: coin_selection::CoinSelectionAlgorithm>(
&self,
builder: TxBuilder<Cs>,
) -> Result<(PSBT, TransactionDetails), Error> {
// TODO: fetch both internal and external policies
let policy = self.descriptor.extract_policy()?.unwrap();
if policy.requires_path() && builder.policy_path.is_none() {
return Err(Error::SpendingPolicyRequired);
@@ -146,12 +150,12 @@ where
}
// we keep it as a float while we accumulate it, and only round it at the end
let mut fee_val: f32 = 0.0;
let mut fee_amount: f32 = 0.0;
let mut outgoing: u64 = 0;
let mut received: u64 = 0;
let calc_fee_bytes = |wu| (wu as f32) * fee_rate / 4.0;
fee_val += calc_fee_bytes(tx.get_weight());
fee_amount += calc_fee_bytes(tx.get_weight());
for (index, (address, satoshi)) in builder.addressees.iter().enumerate() {
let value = match builder.send_all {
@@ -170,35 +174,45 @@ where
script_pubkey: address.script_pubkey(),
value,
};
fee_val += calc_fee_bytes(serialize(&new_out).len() * 4);
fee_amount += calc_fee_bytes(serialize(&new_out).len() * 4);
tx.output.push(new_out);
outgoing += value;
}
// TODO: assumes same weight to spend external and internal
let input_witness_weight = self.descriptor.max_satisfaction_weight();
// TODO: use the right weight instead of the maximum, and only fall-back to it if the
// script is unknown in the database
let input_witness_weight = std::cmp::max(
self.get_descriptor_for(ScriptType::Internal)
.max_satisfaction_weight(),
self.get_descriptor_for(ScriptType::External)
.max_satisfaction_weight(),
);
let (available_utxos, use_all_utxos) =
self.get_available_utxos(&builder.utxos, &builder.unspendable, builder.send_all)?;
let (mut inputs, paths, selected_amount, mut fee_val) = self.coin_select(
let coin_selection::CoinSelectionResult {
txin,
total_amount,
mut fee_amount,
} = builder.coin_selection.coin_select(
available_utxos,
use_all_utxos,
fee_rate,
outgoing,
input_witness_weight,
fee_val,
fee_amount,
)?;
let n_sequence = if let Some(csv) = requirements.csv {
csv
} else if requirements.timelock.is_some() {
0xFFFFFFFE
} else {
0xFFFFFFFF
let (mut txin, prev_script_pubkeys): (Vec<_>, Vec<_>) = txin.into_iter().unzip();
let n_sequence = match requirements.csv {
Some(csv) => csv,
_ if requirements.timelock.is_some() => 0xFFFFFFFE,
_ => 0xFFFFFFFF,
};
inputs.iter_mut().for_each(|i| i.sequence = n_sequence);
tx.input.append(&mut inputs);
txin.iter_mut().for_each(|i| i.sequence = n_sequence);
tx.input = txin;
// prepare the change output
let change_output = match builder.send_all {
@@ -211,12 +225,12 @@ where
};
// take the change into account for fees
fee_val += calc_fee_bytes(serialize(&change_output).len() * 4);
fee_amount += calc_fee_bytes(serialize(&change_output).len() * 4);
Some(change_output)
}
};
let change_val = selected_amount - outgoing - (fee_val.ceil() as u64);
let change_val = total_amount - outgoing - (fee_amount.ceil() as u64);
if !builder.send_all && !change_val.is_dust() {
let mut change_output = change_output.unwrap();
change_output.value = change_val;
@@ -225,7 +239,7 @@ where
tx.output.push(change_output);
} else if builder.send_all && !change_val.is_dust() {
// set the outgoing value to whatever we've put in
outgoing = selected_amount;
outgoing = total_amount;
// there's only one output, send everything to it
tx.output[0].value = change_val;
@@ -238,58 +252,57 @@ where
return Err(Error::InsufficientFunds); // TODO: or OutputBelowDustLimit?
}
// TODO: shuffle the outputs
if builder.shuffle_outputs.unwrap_or(true) {
use rand::seq::SliceRandom;
let mut rng = rand::thread_rng();
tx.output.shuffle(&mut rng);
}
let txid = tx.txid();
let mut psbt = PSBT::from_unsigned_tx(tx)?;
// add metadata for the inputs
for ((psbt_input, (script_type, child)), input) in psbt
for ((psbt_input, prev_script), input) in psbt
.inputs
.iter_mut()
.zip(paths.into_iter())
.zip(prev_script_pubkeys.into_iter())
.zip(psbt.global.unsigned_tx.input.iter())
{
let desc = self.get_descriptor_for(script_type);
psbt_input.hd_keypaths = desc.get_hd_keypaths(child).unwrap();
let derived_descriptor = desc.derive(child).unwrap();
// Add sighash, default is obviously "ALL"
psbt_input.sighash_type = builder.sighash.or(Some(SigHashType::All));
// Try to find the prev_script in our db to figure out if this is internal or external,
// and the derivation index
let (script_type, child) = match self
.database
.borrow()
.get_path_from_script_pubkey(&prev_script)?
{
Some(x) => x,
None => continue,
};
let desc = self.get_descriptor_for(script_type);
psbt_input.hd_keypaths = desc.get_hd_keypaths(child)?;
let derived_descriptor = desc.derive(child)?;
// TODO: figure out what do redeem_script and witness_script mean
psbt_input.redeem_script = derived_descriptor.psbt_redeem_script();
psbt_input.witness_script = derived_descriptor.psbt_witness_script();
let prev_output = input.previous_output;
let prev_tx = self
.database
.borrow()
.get_raw_tx(&prev_output.txid)?
.unwrap(); // TODO: remove unwrap
if derived_descriptor.is_witness() {
psbt_input.witness_utxo = Some(prev_tx.output[prev_output.vout as usize].clone());
} else {
psbt_input.non_witness_utxo = Some(prev_tx);
};
// we always sign with SIGHASH_ALL
psbt_input.sighash_type = Some(SigHashType::All);
}
for (psbt_output, tx_output) in psbt
.outputs
.iter_mut()
.zip(psbt.global.unsigned_tx.output.iter())
{
if let Some((script_type, child)) = self
.database
.borrow()
.get_path_from_script_pubkey(&tx_output.script_pubkey)?
{
let desc = self.get_descriptor_for(script_type);
psbt_output.hd_keypaths = desc.get_hd_keypaths(child)?;
if let Some(prev_tx) = self.database.borrow().get_raw_tx(&prev_output.txid)? {
if derived_descriptor.is_witness() {
psbt_input.witness_utxo =
Some(prev_tx.output[prev_output.vout as usize].clone());
} else {
psbt_input.non_witness_utxo = Some(prev_tx);
}
}
}
self.add_hd_keypaths(&mut psbt)?;
let transaction_details = TransactionDetails {
transaction: None,
txid,
@@ -600,61 +613,6 @@ where
}
}
fn coin_select(
&self,
mut utxos: Vec<UTXO>,
use_all_utxos: bool,
fee_rate: f32,
outgoing: u64,
input_witness_weight: usize,
mut fee_val: f32,
) -> Result<(Vec<TxIn>, Vec<(ScriptType, u32)>, u64, f32), Error> {
let mut answer = Vec::new();
let mut deriv_indexes = Vec::new();
let calc_fee_bytes = |wu| (wu as f32) * fee_rate / 4.0;
debug!(
"coin select: outgoing = `{}`, fee_val = `{}`, fee_rate = `{}`",
outgoing, fee_val, fee_rate
);
// sort so that we pick them starting from the larger. TODO: proper coin selection
utxos.sort_by(|a, b| a.txout.value.partial_cmp(&b.txout.value).unwrap());
let mut selected_amount: u64 = 0;
while use_all_utxos || selected_amount < outgoing + (fee_val.ceil() as u64) {
let utxo = match utxos.pop() {
Some(utxo) => utxo,
None if selected_amount < outgoing + (fee_val.ceil() as u64) => {
return Err(Error::InsufficientFunds)
}
None if use_all_utxos => break,
None => return Err(Error::InsufficientFunds),
};
let new_in = TxIn {
previous_output: utxo.outpoint,
script_sig: Script::default(),
sequence: 0xFFFFFFFD, // TODO: change according to rbf/csv
witness: vec![],
};
fee_val += calc_fee_bytes(serialize(&new_in).len() * 4 + input_witness_weight);
debug!("coin select new fee_val = `{}`", fee_val);
answer.push(new_in);
selected_amount += utxo.txout.value;
let child = self
.database
.borrow()
.get_path_from_script_pubkey(&utxo.txout.script_pubkey)?
.unwrap(); // TODO: remove unrwap
deriv_indexes.push(child);
}
Ok((answer, deriv_indexes, selected_amount, fee_val))
}
fn add_hd_keypaths(&self, psbt: &mut PSBT) -> Result<(), Error> {
let mut input_utxos = Vec::with_capacity(psbt.inputs.len());
for n in 0..psbt.inputs.len() {