diff --git a/src/wallet/mod.rs b/src/wallet/mod.rs index fe1255a5..fe0f2b9b 100644 --- a/src/wallet/mod.rs +++ b/src/wallet/mod.rs @@ -135,9 +135,36 @@ where policy.get_requirements(builder.policy_path.as_ref().unwrap_or(&BTreeMap::new()))?; debug!("requirements: {:?}", requirements); + let version = match builder.version { + tx_builder::Version(0) => return Err(Error::Generic("Invalid version `0`".into())), + tx_builder::Version(1) if requirements.csv.is_some() => { + return Err(Error::Generic( + "TxBuilder requested version `1`, but at least `2` is needed to use OP_CSV" + .into(), + )) + } + tx_builder::Version(x) => x, + }; + + let lock_time = match builder.locktime { + None => requirements.timelock.unwrap_or(0), + Some(x) if requirements.timelock.is_none() => x, + Some(x) if requirements.timelock.unwrap() <= x => x, + Some(x) => return Err(Error::Generic(format!("TxBuilder requested timelock of `{}`, but at least `{}` is required to spend from this script", x, requirements.timelock.unwrap()))) + }; + + let n_sequence = match (builder.rbf, requirements.csv) { + (None, Some(csv)) => csv, + (Some(rbf), Some(csv)) if rbf < csv => return Err(Error::Generic(format!("Cannot enable RBF with nSequence `{}`, since at least `{}` is required to spend with OP_CSV", rbf, csv))), + (None, _) if requirements.timelock.is_some() => 0xFFFFFFFE, + (Some(rbf), _) if rbf >= 0xFFFFFFFE => return Err(Error::Generic("Cannot enable RBF with anumber >= 0xFFFFFFFE".into())), + (Some(rbf), _) => rbf, + (None, _) => 0xFFFFFFFF, + }; + let mut tx = Transaction { - version: 2, - lock_time: requirements.timelock.unwrap_or(0), + version, + lock_time, input: vec![], output: vec![], }; @@ -206,11 +233,6 @@ where )?; let (mut txin, prev_script_pubkeys): (Vec<_>, Vec<_>) = txin.into_iter().unzip(); - let n_sequence = match requirements.csv { - Some(csv) => csv, - _ if requirements.timelock.is_some() => 0xFFFFFFFE, - _ => 0xFFFFFFFF, - }; txin.iter_mut().for_each(|i| i.sequence = n_sequence); tx.input = txin; diff --git a/src/wallet/tx_builder.rs b/src/wallet/tx_builder.rs index 74c69468..1b43fbdb 100644 --- a/src/wallet/tx_builder.rs +++ b/src/wallet/tx_builder.rs @@ -18,6 +18,8 @@ pub struct TxBuilder { pub(crate) sighash: Option, pub(crate) ordering: TxOrdering, pub(crate) locktime: Option, + pub(crate) rbf: Option, + pub(crate) version: Version, pub(crate) coin_selection: Cs, } @@ -92,6 +94,20 @@ impl TxBuilder { self } + pub fn enable_rbf(self) -> Self { + self.enable_rbf_with_sequence(0xFFFFFFFD) + } + + pub fn enable_rbf_with_sequence(mut self, nsequence: u32) -> Self { + self.rbf = Some(nsequence); + self + } + + pub fn version(mut self, version: u32) -> Self { + self.version = Version(version); + self + } + pub fn coin_selection(self, coin_selection: P) -> TxBuilder

{ TxBuilder { addressees: self.addressees, @@ -103,6 +119,8 @@ impl TxBuilder { sighash: self.sighash, ordering: self.ordering, locktime: self.locktime, + rbf: self.rbf, + version: self.version, coin_selection, } } @@ -148,6 +166,16 @@ impl TxOrdering { } } +// Helper type that wraps u32 and has a default value of 1 +#[derive(Debug)] +pub(crate) struct Version(pub(crate) u32); + +impl Default for Version { + fn default() -> Self { + Version(1) + } +} + #[cfg(test)] mod test { const ORDERING_TEST_TX: &'static str = "0200000003c26f3eb7932f7acddc5ddd26602b77e7516079b03090a16e2c2f54\