diff --git a/CHANGELOG.md b/CHANGELOG.md index 319fb628..cad538d5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Removed default verification from `wallet::sync`. sync-time verification is added in `script_sync` and is activated by `verify` feature flag. - `verify` flag removed from `TransactionDetails`. +- Add `get_internal_address` to allow you to get internal addresses just as you get external addresses. +- added `ensure_addresses_cached` to `Wallet` to let offline wallets load and cache addresses in their database ### Sync API change diff --git a/src/testutils/blockchain_tests.rs b/src/testutils/blockchain_tests.rs index b54da105..6ac91e88 100644 --- a/src/testutils/blockchain_tests.rs +++ b/src/testutils/blockchain_tests.rs @@ -817,7 +817,7 @@ macro_rules! bdk_blockchain_tests { let mut builder = wallet.build_fee_bump(details.txid).unwrap(); builder.fee_rate(FeeRate::from_sat_per_vb(2.1)); - let (mut new_psbt, new_details) = builder.finish().unwrap(); + let (mut new_psbt, new_details) = builder.finish().expect("fee bump tx"); let finalized = wallet.sign(&mut new_psbt, Default::default()).unwrap(); assert!(finalized, "Cannot finalize transaction"); blockchain.broadcast(&new_psbt.extract_tx()).unwrap(); diff --git a/src/wallet/mod.rs b/src/wallet/mod.rs index 4006ccb2..97f7bbb8 100644 --- a/src/wallet/mod.rs +++ b/src/wallet/mod.rs @@ -236,11 +236,11 @@ where } // Return a newly derived address using the external descriptor - fn get_new_address(&self) -> Result { - let incremented_index = self.fetch_and_increment_index(KeychainKind::External)?; + fn get_new_address(&self, keychain: KeychainKind) -> Result { + let incremented_index = self.fetch_and_increment_index(keychain)?; let address_result = self - .descriptor + .get_descriptor_for_keychain(keychain) .as_derived(incremented_index, &self.secp) .address(self.network); @@ -252,12 +252,14 @@ where .map_err(|_| Error::ScriptDoesntHaveAddressForm) } - // Return the the last previously derived address if it has not been used in a received - // transaction. Otherwise return a new address using [`Wallet::get_new_address`]. - fn get_unused_address(&self) -> Result { - let current_index = self.fetch_index(KeychainKind::External)?; + // Return the the last previously derived address for `keychain` if it has not been used in a + // received transaction. Otherwise return a new address using [`Wallet::get_new_address`]. + fn get_unused_address(&self, keychain: KeychainKind) -> Result { + let current_index = self.fetch_index(keychain)?; - let derived_key = self.descriptor.as_derived(current_index, &self.secp); + let derived_key = self + .get_descriptor_for_keychain(keychain) + .as_derived(current_index, &self.secp); let script_pubkey = derived_key.script_pubkey(); @@ -269,7 +271,7 @@ where .any(|o| o.script_pubkey == script_pubkey); if found_used { - self.get_new_address() + self.get_new_address(keychain) } else { derived_key .address(self.network) @@ -281,21 +283,21 @@ where } } - // Return derived address for the external descriptor at a specific index - fn peek_address(&self, index: u32) -> Result { - self.descriptor + // Return derived address for the descriptor of given [`KeychainKind`] at a specific index + fn peek_address(&self, index: u32, keychain: KeychainKind) -> Result { + self.get_descriptor_for_keychain(keychain) .as_derived(index, &self.secp) .address(self.network) .map(|address| AddressInfo { index, address }) .map_err(|_| Error::ScriptDoesntHaveAddressForm) } - // Return derived address for the external descriptor at a specific index and reset current + // Return derived address for `keychain` at a specific index and reset current // address index - fn reset_address(&self, index: u32) -> Result { - self.set_index(KeychainKind::External, index)?; + fn reset_address(&self, index: u32, keychain: KeychainKind) -> Result { + self.set_index(keychain, index)?; - self.descriptor + self.get_descriptor_for_keychain(keychain) .as_derived(index, &self.secp) .address(self.network) .map(|address| AddressInfo { index, address }) @@ -306,14 +308,77 @@ where /// available address index selection strategies. If none of the keys in the descriptor are derivable /// (ie. does not end with /*) then the same address will always be returned for any [`AddressIndex`]. pub fn get_address(&self, address_index: AddressIndex) -> Result { + self._get_address(address_index, KeychainKind::External) + } + + /// Return a derived address using the internal (change) descriptor. + /// + /// If the wallet doesn't have an internal descriptor it will use the external descriptor. + /// + /// see [`AddressIndex`] for available address index selection strategies. If none of the keys + /// in the descriptor are derivable (ie. does not end with /*) then the same address will always + /// be returned for any [`AddressIndex`]. + pub fn get_internal_address(&self, address_index: AddressIndex) -> Result { + self._get_address(address_index, KeychainKind::Internal) + } + + fn _get_address( + &self, + address_index: AddressIndex, + keychain: KeychainKind, + ) -> Result { match address_index { - AddressIndex::New => self.get_new_address(), - AddressIndex::LastUnused => self.get_unused_address(), - AddressIndex::Peek(index) => self.peek_address(index), - AddressIndex::Reset(index) => self.reset_address(index), + AddressIndex::New => self.get_new_address(keychain), + AddressIndex::LastUnused => self.get_unused_address(keychain), + AddressIndex::Peek(index) => self.peek_address(index, keychain), + AddressIndex::Reset(index) => self.reset_address(index, keychain), } } + /// Ensures that there are at least `max_addresses` addresses cached in the database if the + /// descriptor is derivable, or 1 address if it is not. + /// Will return `Ok(true)` if there are new addresses generated (either external or internal), + /// and `Ok(false)` if all the required addresses are already cached. This function is useful to + /// explicitly cache addresses in a wallet to do things like check [`Wallet::is_mine`] on + /// transaction output scripts. + pub fn ensure_addresses_cached(&self, max_addresses: u32) -> Result { + let mut new_addresses_cached = false; + let max_address = match self.descriptor.is_deriveable() { + false => 0, + true => max_addresses, + }; + debug!("max_address {}", max_address); + if self + .database + .borrow() + .get_script_pubkey_from_path(KeychainKind::External, max_address.saturating_sub(1))? + .is_none() + { + debug!("caching external addresses"); + new_addresses_cached = true; + self.cache_addresses(KeychainKind::External, 0, max_address)?; + } + + if let Some(change_descriptor) = &self.change_descriptor { + let max_address = match change_descriptor.is_deriveable() { + false => 0, + true => max_addresses, + }; + + if self + .database + .borrow() + .get_script_pubkey_from_path(KeychainKind::Internal, max_address.saturating_sub(1))? + .is_none() + { + debug!("caching internal addresses"); + new_addresses_cached = true; + self.cache_addresses(KeychainKind::Internal, 0, max_address)?; + } + } + Ok(new_addresses_cached) + } + /// Return whether or not a `script` is part of this wallet (either internal or external) pub fn is_mine(&self, script: &Script) -> Result { self.database.borrow().is_mine(script) @@ -660,7 +725,10 @@ where let mut drain_output = { let script_pubkey = match params.drain_to { Some(ref drain_recipient) => drain_recipient.clone(), - None => self.get_change_address()?, + None => self + .get_internal_address(AddressIndex::New)? + .address + .script_pubkey(), }; TxOut { @@ -1093,13 +1161,6 @@ where .map(|(desc, child)| desc.as_derived(child, &self.secp))) } - fn get_change_address(&self) -> Result { - let (desc, keychain) = self._get_descriptor_for_keychain(KeychainKind::Internal); - let index = self.fetch_and_increment_index(keychain)?; - - Ok(desc.as_derived(index, &self.secp).script_pubkey()) - } - fn fetch_and_increment_index(&self, keychain: KeychainKind) -> Result { let (descriptor, keychain) = self._get_descriptor_for_keychain(keychain); let index = match descriptor.is_deriveable() { @@ -1463,46 +1524,14 @@ where ) -> Result<(), Error> { debug!("Begin sync..."); - let mut run_setup = false; let SyncOptions { max_addresses, progress, } = sync_opts; let progress = progress.unwrap_or_else(|| Box::new(NoopProgress)); - let max_address = match self.descriptor.is_deriveable() { - false => 0, - true => max_addresses.unwrap_or(CACHE_ADDR_BATCH_SIZE), - }; - debug!("max_address {}", max_address); - if self - .database - .borrow() - .get_script_pubkey_from_path(KeychainKind::External, max_address.saturating_sub(1))? - .is_none() - { - debug!("caching external addresses"); - run_setup = true; - self.cache_addresses(KeychainKind::External, 0, max_address)?; - } - - if let Some(change_descriptor) = &self.change_descriptor { - let max_address = match change_descriptor.is_deriveable() { - false => 0, - true => max_addresses.unwrap_or(CACHE_ADDR_BATCH_SIZE), - }; - - if self - .database - .borrow() - .get_script_pubkey_from_path(KeychainKind::Internal, max_address.saturating_sub(1))? - .is_none() - { - debug!("caching internal addresses"); - run_setup = true; - self.cache_addresses(KeychainKind::Internal, 0, max_address)?; - } - } + let run_setup = + self.ensure_addresses_cached(max_addresses.unwrap_or(CACHE_ADDR_BATCH_SIZE))?; debug!("run_setup: {}", run_setup); // TODO: what if i generate an address first and cache some addresses? @@ -3953,6 +3982,48 @@ pub(crate) mod test { builder.add_recipient(addr.script_pubkey(), 45_000); builder.finish().unwrap(); } + + #[test] + fn test_get_address() { + use crate::descriptor::template::Bip84; + let key = bitcoin::util::bip32::ExtendedPrivKey::from_str("tprv8ZgxMBicQKsPcx5nBGsR63Pe8KnRUqmbJNENAfGftF3yuXoMMoVJJcYeUw5eVkm9WBPjWYt6HMWYJNesB5HaNVBaFc1M6dRjWSYnmewUMYy").unwrap(); + let wallet = Wallet::new( + Bip84(key, KeychainKind::External), + Some(Bip84(key, KeychainKind::Internal)), + Network::Regtest, + MemoryDatabase::default(), + ) + .unwrap(); + + assert_eq!( + wallet.get_address(AddressIndex::New).unwrap().address, + Address::from_str("bcrt1qkmvk2nadgplmd57ztld8nf8v2yxkzmdvwtjf8s").unwrap() + ); + assert_eq!( + wallet + .get_internal_address(AddressIndex::New) + .unwrap() + .address, + Address::from_str("bcrt1qtrwtz00wxl69e5xex7amy4xzlxkaefg3gfdkxa").unwrap() + ); + + let wallet = Wallet::new( + Bip84(key, KeychainKind::External), + None, + Network::Regtest, + MemoryDatabase::default(), + ) + .unwrap(); + + assert_eq!( + wallet + .get_internal_address(AddressIndex::New) + .unwrap() + .address, + Address::from_str("bcrt1qkmvk2nadgplmd57ztld8nf8v2yxkzmdvwtjf8s").unwrap(), + "when there's no internal descriptor it should just use external" + ); + } } /// Deterministically generate a unique name given the descriptors defining the wallet