Merge commit 'refs/pull/228/head' of github.com:bitcoindevkit/bdk

This commit is contained in:
Alekos Filini
2020-12-15 14:33:59 +01:00
19 changed files with 334 additions and 345 deletions

View File

@@ -124,7 +124,7 @@ where
) -> Result<Self, Error> {
let (descriptor, keymap) = descriptor.to_wallet_descriptor(network)?;
database.check_descriptor_checksum(
ScriptType::External,
KeychainKind::External,
get_checksum(&descriptor.to_string())?.as_bytes(),
)?;
let signers = Arc::new(SignersContainer::from(keymap));
@@ -132,7 +132,7 @@ where
Some(desc) => {
let (change_descriptor, change_keymap) = desc.to_wallet_descriptor(network)?;
database.check_descriptor_checksum(
ScriptType::Internal,
KeychainKind::Internal,
get_checksum(&change_descriptor.to_string())?.as_bytes(),
)?;
@@ -166,7 +166,7 @@ where
/// Return a newly generated address using the external descriptor
pub fn get_new_address(&self) -> Result<Address, Error> {
let index = self.fetch_and_increment_index(ScriptType::External)?;
let index = self.fetch_and_increment_index(KeychainKind::External)?;
let deriv_ctx = descriptor_to_pk_ctx(&self.secp);
self.descriptor
@@ -215,14 +215,14 @@ where
/// See [the `signer` module](signer) for an example.
pub fn add_signer(
&mut self,
script_type: ScriptType,
keychain: KeychainKind,
id: SignerId,
ordering: SignerOrdering,
signer: Arc<dyn Signer>,
) {
let signers = match script_type {
ScriptType::External => Arc::make_mut(&mut self.signers),
ScriptType::Internal => Arc::make_mut(&mut self.change_signers),
let signers = match keychain {
KeychainKind::External => Arc::make_mut(&mut self.signers),
KeychainKind::Internal => Arc::make_mut(&mut self.change_signers),
};
signers.add_external(id, ordering, signer);
@@ -278,7 +278,7 @@ where
&& external_policy.requires_path()
&& builder.external_policy_path.is_none()
{
return Err(Error::SpendingPolicyRequired(ScriptType::External));
return Err(Error::SpendingPolicyRequired(KeychainKind::External));
};
// Same for the internal_policy path, if present
if let Some(internal_policy) = &internal_policy {
@@ -286,7 +286,7 @@ where
&& internal_policy.requires_path()
&& builder.internal_policy_path.is_none()
{
return Err(Error::SpendingPolicyRequired(ScriptType::Internal));
return Err(Error::SpendingPolicyRequired(KeychainKind::Internal));
};
}
@@ -600,18 +600,17 @@ where
None => {
let mut change_output = None;
for (index, txout) in tx.output.iter().enumerate() {
// look for an output that we know and that has the right ScriptType. We use
// `get_descriptor_for` to find what's the ScriptType for `Internal`
// look for an output that we know and that has the right KeychainKind. We use
// `get_descriptor_for` to find what's the KeychainKind for `Internal`
// addresses really is, because if there's no change_descriptor it's actually equal
// to "External"
let (_, change_type) =
self.get_descriptor_for_script_type(ScriptType::Internal);
let (_, change_type) = self.get_descriptor_for_keychain(KeychainKind::Internal);
match self
.database
.borrow()
.get_path_from_script_pubkey(&txout.script_pubkey)?
{
Some((script_type, _)) if script_type == change_type => {
Some((keychain, _)) if keychain == change_type => {
change_output = Some(index);
break;
}
@@ -657,31 +656,31 @@ where
.get_previous_output(&txin.previous_output)?
.ok_or(Error::UnknownUTXO)?;
let (weight, script_type) = match self
let (weight, keychain) = match self
.database
.borrow()
.get_path_from_script_pubkey(&txout.script_pubkey)?
{
Some((script_type, _)) => (
self.get_descriptor_for_script_type(script_type)
Some((keychain, _)) => (
self.get_descriptor_for_keychain(keychain)
.0
.max_satisfaction_weight(deriv_ctx)
.unwrap(),
script_type,
keychain,
),
None => {
// estimate the weight based on the scriptsig/witness size present in the
// original transaction
let weight =
serialize(&txin.script_sig).len() * 4 + serialize(&txin.witness).len();
(weight, ScriptType::External)
(weight, KeychainKind::External)
}
};
let utxo = UTXO {
outpoint: txin.previous_output,
txout,
script_type,
keychain,
};
Ok((utxo, weight))
@@ -853,13 +852,13 @@ where
}
/// Return the spending policies for the wallet's descriptor
pub fn policies(&self, script_type: ScriptType) -> Result<Option<Policy>, Error> {
match (script_type, self.change_descriptor.as_ref()) {
(ScriptType::External, _) => {
pub fn policies(&self, keychain: KeychainKind) -> Result<Option<Policy>, Error> {
match (keychain, self.change_descriptor.as_ref()) {
(KeychainKind::External, _) => {
Ok(self.descriptor.extract_policy(&self.signers, &self.secp)?)
}
(ScriptType::Internal, None) => Ok(None),
(ScriptType::Internal, Some(desc)) => {
(KeychainKind::Internal, None) => Ok(None),
(KeychainKind::Internal, Some(desc)) => {
Ok(desc.extract_policy(&self.change_signers, &self.secp)?)
}
}
@@ -871,12 +870,12 @@ where
/// This can be used to build a watch-only version of a wallet
pub fn public_descriptor(
&self,
script_type: ScriptType,
keychain: KeychainKind,
) -> Result<Option<ExtendedDescriptor>, Error> {
match (script_type, self.change_descriptor.as_ref()) {
(ScriptType::External, _) => Ok(Some(self.descriptor.clone())),
(ScriptType::Internal, None) => Ok(None),
(ScriptType::Internal, Some(desc)) => Ok(Some(desc.clone())),
match (keychain, self.change_descriptor.as_ref()) {
(KeychainKind::External, _) => Ok(Some(self.descriptor.clone())),
(KeychainKind::Internal, None) => Ok(None),
(KeychainKind::Internal, Some(desc)) => Ok(Some(desc.clone())),
}
}
@@ -909,7 +908,7 @@ where
);
// - Try to derive the descriptor by looking at the txout. If it's in our database, we
// know exactly which `script_type` to use, and which derivation index it is
// know exactly which `keychain` to use, and which derivation index it is
// - If that fails, try to derive it by looking at the psbt input: the complete logic
// is in `src/descriptor/mod.rs`, but it will basically look at `hd_keypaths`,
// `redeem_script` and `witness_script` to determine the right derivation
@@ -970,16 +969,16 @@ where
// Internals
fn get_descriptor_for_script_type(
fn get_descriptor_for_keychain(
&self,
script_type: ScriptType,
) -> (&ExtendedDescriptor, ScriptType) {
match script_type {
ScriptType::Internal if self.change_descriptor.is_some() => (
keychain: KeychainKind,
) -> (&ExtendedDescriptor, KeychainKind) {
match keychain {
KeychainKind::Internal if self.change_descriptor.is_some() => (
self.change_descriptor.as_ref().unwrap(),
ScriptType::Internal,
KeychainKind::Internal,
),
_ => (&self.descriptor, ScriptType::External),
_ => (&self.descriptor, KeychainKind::External),
}
}
@@ -988,38 +987,35 @@ where
.database
.borrow()
.get_path_from_script_pubkey(&txout.script_pubkey)?
.map(|(script_type, child)| (self.get_descriptor_for_script_type(script_type).0, child))
.map(|(keychain, child)| (self.get_descriptor_for_keychain(keychain).0, child))
.map(|(desc, child)| desc.derive(ChildNumber::from_normal_idx(child).unwrap())))
}
fn get_change_address(&self) -> Result<Script, Error> {
let deriv_ctx = descriptor_to_pk_ctx(&self.secp);
let (desc, script_type) = self.get_descriptor_for_script_type(ScriptType::Internal);
let index = self.fetch_and_increment_index(script_type)?;
let (desc, keychain) = self.get_descriptor_for_keychain(KeychainKind::Internal);
let index = self.fetch_and_increment_index(keychain)?;
Ok(desc
.derive(ChildNumber::from_normal_idx(index)?)
.script_pubkey(deriv_ctx))
}
fn fetch_and_increment_index(&self, script_type: ScriptType) -> Result<u32, Error> {
let (descriptor, script_type) = self.get_descriptor_for_script_type(script_type);
fn fetch_and_increment_index(&self, keychain: KeychainKind) -> Result<u32, Error> {
let (descriptor, keychain) = self.get_descriptor_for_keychain(keychain);
let index = match descriptor.is_fixed() {
true => 0,
false => self
.database
.borrow_mut()
.increment_last_index(script_type)?,
false => self.database.borrow_mut().increment_last_index(keychain)?,
};
if self
.database
.borrow()
.get_script_pubkey_from_path(script_type, index)?
.get_script_pubkey_from_path(keychain, index)?
.is_none()
{
self.cache_addresses(script_type, index, CACHE_ADDR_BATCH_SIZE)?;
self.cache_addresses(keychain, index, CACHE_ADDR_BATCH_SIZE)?;
}
let deriv_ctx = descriptor_to_pk_ctx(&self.secp);
@@ -1029,7 +1025,7 @@ where
.derive(ChildNumber::from_normal_idx(index)?)
.script_pubkey(deriv_ctx);
for validator in &self.address_validators {
validator.validate(script_type, &hd_keypaths, &script)?;
validator.validate(keychain, &hd_keypaths, &script)?;
}
Ok(index)
@@ -1037,11 +1033,11 @@ where
fn cache_addresses(
&self,
script_type: ScriptType,
keychain: KeychainKind,
from: u32,
mut count: u32,
) -> Result<(), Error> {
let (descriptor, script_type) = self.get_descriptor_for_script_type(script_type);
let (descriptor, keychain) = self.get_descriptor_for_keychain(keychain);
if descriptor.is_fixed() {
if from > 0 {
return Ok(());
@@ -1060,7 +1056,7 @@ where
&descriptor
.derive(ChildNumber::from_normal_idx(i)?)
.script_pubkey(deriv_ctx),
script_type,
keychain,
i,
)?;
}
@@ -1083,10 +1079,10 @@ where
.list_unspent()?
.into_iter()
.map(|utxo| {
let script_type = utxo.script_type;
let keychain = utxo.keychain;
(
utxo,
self.get_descriptor_for_script_type(script_type)
self.get_descriptor_for_keychain(keychain)
.0
.max_satisfaction_weight(deriv_ctx)
.unwrap(),
@@ -1230,7 +1226,7 @@ where
// Try to find the prev_script in our db to figure out if this is internal or external,
// and the derivation index
let (script_type, child) = match self
let (keychain, child) = match self
.database
.borrow()
.get_path_from_script_pubkey(&utxo.txout.script_pubkey)?
@@ -1239,7 +1235,7 @@ where
None => continue,
};
let (desc, _) = self.get_descriptor_for_script_type(script_type);
let (desc, _) = self.get_descriptor_for_keychain(keychain);
psbt_input.hd_keypaths = desc.get_hd_keypaths(child, &self.secp)?;
let derived_descriptor = desc.derive(ChildNumber::from_normal_idx(child)?);
@@ -1267,12 +1263,12 @@ where
.iter_mut()
.zip(psbt.global.unsigned_tx.output.iter())
{
if let Some((script_type, child)) = self
if let Some((keychain, child)) = self
.database
.borrow()
.get_path_from_script_pubkey(&tx_output.script_pubkey)?
{
let (desc, _) = self.get_descriptor_for_script_type(script_type);
let (desc, _) = self.get_descriptor_for_keychain(keychain);
psbt_output.hd_keypaths = desc.get_hd_keypaths(child, &self.secp)?;
if builder.include_output_redeem_witness_script {
let derived_descriptor = desc.derive(ChildNumber::from_normal_idx(child)?);
@@ -1294,15 +1290,15 @@ where
// try to add hd_keypaths if we've already seen the output
for (psbt_input, out) in psbt.inputs.iter_mut().zip(input_utxos.iter()) {
if let Some(out) = out {
if let Some((script_type, child)) = self
if let Some((keychain, child)) = self
.database
.borrow()
.get_path_from_script_pubkey(&out.script_pubkey)?
{
debug!("Found descriptor {:?}/{}", script_type, child);
debug!("Found descriptor {:?}/{}", keychain, child);
// merge hd_keypaths
let (desc, _) = self.get_descriptor_for_script_type(script_type);
let (desc, _) = self.get_descriptor_for_keychain(keychain);
let mut hd_keypaths = desc.get_hd_keypaths(child, &self.secp)?;
psbt_input.hd_keypaths.append(&mut hd_keypaths);
}
@@ -1353,11 +1349,11 @@ where
if self
.database
.borrow()
.get_script_pubkey_from_path(ScriptType::External, max_address.saturating_sub(1))?
.get_script_pubkey_from_path(KeychainKind::External, max_address.saturating_sub(1))?
.is_none()
{
run_setup = true;
self.cache_addresses(ScriptType::External, 0, max_address)?;
self.cache_addresses(KeychainKind::External, 0, max_address)?;
}
if let Some(change_descriptor) = &self.change_descriptor {
@@ -1369,11 +1365,11 @@ where
if self
.database
.borrow()
.get_script_pubkey_from_path(ScriptType::Internal, max_address.saturating_sub(1))?
.get_script_pubkey_from_path(KeychainKind::Internal, max_address.saturating_sub(1))?
.is_none()
{
run_setup = true;
self.cache_addresses(ScriptType::Internal, 0, max_address)?;
self.cache_addresses(KeychainKind::Internal, 0, max_address)?;
}
}
@@ -1425,7 +1421,7 @@ mod test {
use crate::database::memory::MemoryDatabase;
use crate::database::Database;
use crate::types::ScriptType;
use crate::types::KeychainKind;
use super::*;
@@ -1452,13 +1448,13 @@ mod test {
assert!(wallet
.database
.borrow_mut()
.get_script_pubkey_from_path(ScriptType::External, 0)
.get_script_pubkey_from_path(KeychainKind::External, 0)
.unwrap()
.is_some());
assert!(wallet
.database
.borrow_mut()
.get_script_pubkey_from_path(ScriptType::Internal, 0)
.get_script_pubkey_from_path(KeychainKind::Internal, 0)
.unwrap()
.is_none());
}
@@ -1480,13 +1476,13 @@ mod test {
assert!(wallet
.database
.borrow_mut()
.get_script_pubkey_from_path(ScriptType::External, CACHE_ADDR_BATCH_SIZE - 1)
.get_script_pubkey_from_path(KeychainKind::External, CACHE_ADDR_BATCH_SIZE - 1)
.unwrap()
.is_some());
assert!(wallet
.database
.borrow_mut()
.get_script_pubkey_from_path(ScriptType::External, CACHE_ADDR_BATCH_SIZE)
.get_script_pubkey_from_path(KeychainKind::External, CACHE_ADDR_BATCH_SIZE)
.unwrap()
.is_none());
}
@@ -1503,7 +1499,7 @@ mod test {
assert!(wallet
.database
.borrow_mut()
.get_script_pubkey_from_path(ScriptType::External, CACHE_ADDR_BATCH_SIZE - 1)
.get_script_pubkey_from_path(KeychainKind::External, CACHE_ADDR_BATCH_SIZE - 1)
.unwrap()
.is_some());
@@ -1514,7 +1510,7 @@ mod test {
assert!(wallet
.database
.borrow_mut()
.get_script_pubkey_from_path(ScriptType::External, CACHE_ADDR_BATCH_SIZE * 2 - 1)
.get_script_pubkey_from_path(KeychainKind::External, CACHE_ADDR_BATCH_SIZE * 2 - 1)
.unwrap()
.is_some());
}
@@ -2311,7 +2307,7 @@ mod test {
fn test_create_tx_policy_path_no_csv() {
let (wallet, _, _) = get_funded_wallet(get_test_a_or_b_plus_csv());
let external_policy = wallet.policies(ScriptType::External).unwrap().unwrap();
let external_policy = wallet.policies(KeychainKind::External).unwrap().unwrap();
let root_id = external_policy.id;
// child #0 is just the key "A"
let path = vec![(root_id, vec![0])].into_iter().collect();
@@ -2320,7 +2316,7 @@ mod test {
let (psbt, _) = wallet
.create_tx(
TxBuilder::with_recipients(vec![(addr.script_pubkey(), 30_000)])
.policy_path(path, ScriptType::External),
.policy_path(path, KeychainKind::External),
)
.unwrap();
@@ -2331,7 +2327,7 @@ mod test {
fn test_create_tx_policy_path_use_csv() {
let (wallet, _, _) = get_funded_wallet(get_test_a_or_b_plus_csv());
let external_policy = wallet.policies(ScriptType::External).unwrap().unwrap();
let external_policy = wallet.policies(KeychainKind::External).unwrap().unwrap();
let root_id = external_policy.id;
// child #1 is or(pk(B),older(144))
let path = vec![(root_id, vec![1])].into_iter().collect();
@@ -2340,7 +2336,7 @@ mod test {
let (psbt, _) = wallet
.create_tx(
TxBuilder::with_recipients(vec![(addr.script_pubkey(), 30_000)])
.policy_path(path, ScriptType::External),
.policy_path(path, KeychainKind::External),
)
.unwrap();