Fix vbytes and fee rate code

It was just pointed out that we are calculating the virtual bytes
incorrectly by forgetting to take the ceiling after division by 4 [1]

Add helper functions to encapsulate all weight unit -> virtual byte
calculations including fee to and from fee rate. This makes the code
easier to read, easier to write, and gives us a better chance that bugs
like this will be easier to see.

As an added bonus we can also stop using f32 values for fee amount,
which is by definition an amount in sats so should be a u64. This
removes a bunch of casts and the need for epsilon comparisons and just
deep down feels nice :)

[1] https://github.com/bitcoindevkit/bdk/pull/386#discussion_r670882678
This commit is contained in:
Tobin Harding
2021-07-16 15:14:20 +10:00
parent 474620e6a5
commit 2986fce7c6
3 changed files with 102 additions and 78 deletions

View File

@@ -543,7 +543,7 @@ where
});
}
}
(FeeRate::from_sat_per_vb(0.0), *fee as f32)
(FeeRate::from_sat_per_vb(0.0), *fee)
}
FeePolicy::FeeRate(rate) => {
if let Some(previous_fee) = params.bumping_fee {
@@ -554,7 +554,7 @@ where
});
}
}
(*rate, 0.0)
(*rate, 0)
}
};
@@ -573,8 +573,7 @@ where
let mut outgoing: u64 = 0;
let mut received: u64 = 0;
let calc_fee_bytes = |wu| (wu as f32) * fee_rate.as_sat_vb() / 4.0;
fee_amount += calc_fee_bytes(tx.get_weight());
fee_amount += fee_rate.fee_wu(tx.get_weight());
let recipients = params.recipients.iter().map(|(r, v)| (r, *v));
@@ -591,7 +590,7 @@ where
script_pubkey: script_pubkey.clone(),
value,
};
fee_amount += calc_fee_bytes(serialize(&new_out).len() * 4);
fee_amount += fee_rate.fee_vb(serialize(&new_out).len());
tx.output.push(new_out);
@@ -649,9 +648,8 @@ where
}
};
fee_amount += calc_fee_bytes(serialize(&drain_output).len() * 4);
fee_amount += fee_rate.fee_vb(serialize(&drain_output).len());
let mut fee_amount = fee_amount.ceil() as u64;
let drain_val = (coin_selection.selected_amount() - outgoing).saturating_sub(fee_amount);
if tx.output.is_empty() {
@@ -754,8 +752,10 @@ where
return Err(Error::IrreplaceableTransaction);
}
let vbytes = tx.get_weight().vbytes();
let feerate = details.fee.ok_or(Error::FeeRateUnavailable)? as f32 / vbytes;
let feerate = FeeRate::from_wu(
details.fee.ok_or(Error::FeeRateUnavailable)?,
tx.get_weight(),
);
// remove the inputs from the tx and process them
let original_txin = tx.input.drain(..).collect::<Vec<_>>();
@@ -832,7 +832,7 @@ where
utxos: original_utxos,
bumping_fee: Some(tx_builder::PreviousFee {
absolute: details.fee.ok_or(Error::FeeRateUnavailable)?,
rate: feerate,
rate: feerate.as_sat_vb(),
}),
..Default::default()
};
@@ -1548,18 +1548,6 @@ where
}
}
/// Trait implemented by types that can be used to measure weight units.
pub trait Vbytes {
/// Convert weight units to virtual bytes.
fn vbytes(self) -> f32;
}
impl Vbytes for usize {
fn vbytes(self) -> f32 {
self as f32 / 4.0
}
}
#[cfg(test)]
pub(crate) mod test {
use std::str::FromStr;
@@ -1746,13 +1734,13 @@ pub(crate) mod test {
dust_change = true;
)*
let tx_fee_rate = $fees as f32 / (tx.get_weight().vbytes());
let fee_rate = $fee_rate.as_sat_vb();
let tx_fee_rate = FeeRate::from_wu($fees, tx.get_weight());
let fee_rate = $fee_rate;
if !dust_change {
assert!((tx_fee_rate - fee_rate).abs() < 0.5, "Expected fee rate of {}, the tx has {}", fee_rate, tx_fee_rate);
assert!((tx_fee_rate - fee_rate).as_sat_vb().abs() < 0.5, "Expected fee rate of {:?}, the tx has {:?}", fee_rate, tx_fee_rate);
} else {
assert!(tx_fee_rate >= fee_rate, "Expected fee rate of at least {}, the tx has {}", fee_rate, tx_fee_rate);
assert!(tx_fee_rate >= fee_rate, "Expected fee rate of at least {:?}, the tx has {:?}", fee_rate, tx_fee_rate);
}
});
}