diff --git a/CHANGELOG.md b/CHANGELOG.md index ddf2b300..60d0241f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Wallet - Added an option that must be explicitly enabled to allow signing using non-`SIGHASH_ALL` sighashes (#350) +#### Changed +`get_address` now returns an `AddressInfo` struct that includes the index and derefs to `Address`. ## [v0.7.0] - [v0.6.0] diff --git a/src/testutils/blockchain_tests.rs b/src/testutils/blockchain_tests.rs index 30999990..3f7402cd 100644 --- a/src/testutils/blockchain_tests.rs +++ b/src/testutils/blockchain_tests.rs @@ -815,7 +815,7 @@ macro_rules! bdk_blockchain_tests { #[serial] fn test_sync_receive_coinbase() { let (wallet, _, mut test_client) = init_single_sig(); - let wallet_addr = wallet.get_address(New).unwrap(); + let wallet_addr = wallet.get_address(New).unwrap().address; wallet.sync(noop_progress(), None).unwrap(); assert_eq!(wallet.get_balance().unwrap(), 0); diff --git a/src/wallet/mod.rs b/src/wallet/mod.rs index 8f4b85f9..e1c33ddc 100644 --- a/src/wallet/mod.rs +++ b/src/wallet/mod.rs @@ -16,6 +16,7 @@ use std::cell::RefCell; use std::collections::HashMap; use std::collections::{BTreeMap, HashSet}; +use std::fmt; use std::ops::{Deref, DerefMut}; use std::sync::Arc; @@ -196,24 +197,52 @@ pub enum AddressIndex { Reset(u32), } +/// A derived address and the index it was found at +/// For convenience this automatically derefs to `Address` +#[derive(Debug, PartialEq)] +pub struct AddressInfo { + /// Child index of this address + pub index: u32, + /// Address + pub address: Address, +} + +impl Deref for AddressInfo { + type Target = Address; + + fn deref(&self) -> &Self::Target { + &self.address + } +} + +impl fmt::Display for AddressInfo { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.address) + } +} + // offline actions, always available impl Wallet where D: BatchDatabase, { // Return a newly derived address using the external descriptor - fn get_new_address(&self) -> Result { + fn get_new_address(&self) -> Result<(Address, u32), Error> { let incremented_index = self.fetch_and_increment_index(KeychainKind::External)?; - self.descriptor + let address_result = self + .descriptor .as_derived(incremented_index, &self.secp) - .address(self.network) + .address(self.network); + + address_result + .map(|address| (address, incremented_index)) .map_err(|_| Error::ScriptDoesntHaveAddressForm) } // Return the the last previously derived address if it has not been used in a received // transaction. Otherwise return a new address using [`Wallet::get_new_address`]. - fn get_unused_address(&self) -> Result { + fn get_unused_address(&self) -> Result<(Address, u32), Error> { let current_index = self.fetch_index(KeychainKind::External)?; let derived_key = self.descriptor.as_derived(current_index, &self.secp); @@ -232,39 +261,44 @@ where } else { derived_key .address(self.network) + .map(|address| (address, current_index)) .map_err(|_| Error::ScriptDoesntHaveAddressForm) } } // Return derived address for the external descriptor at a specific index - fn peek_address(&self, index: u32) -> Result { + fn peek_address(&self, index: u32) -> Result<(Address, u32), Error> { self.descriptor .as_derived(index, &self.secp) .address(self.network) + .map(|address| (address, index)) .map_err(|_| Error::ScriptDoesntHaveAddressForm) } // Return derived address for the external descriptor at a specific index and reset current // address index - fn reset_address(&self, index: u32) -> Result { + fn reset_address(&self, index: u32) -> Result<(Address, u32), Error> { self.set_index(KeychainKind::External, index)?; self.descriptor .as_derived(index, &self.secp) .address(self.network) + .map(|address| (address, index)) .map_err(|_| Error::ScriptDoesntHaveAddressForm) } /// Return a derived address using the external descriptor, see [`AddressIndex`] for /// available address index selection strategies. If none of the keys in the descriptor are derivable /// (ie. does not end with /*) then the same address will always be returned for any [`AddressIndex`]. - pub fn get_address(&self, address_index: AddressIndex) -> Result { - match address_index { + pub fn get_address(&self, address_index: AddressIndex) -> Result { + let result = match address_index { AddressIndex::New => self.get_new_address(), AddressIndex::LastUnused => self.get_unused_address(), AddressIndex::Peek(index) => self.peek_address(index), AddressIndex::Reset(index) => self.reset_address(index), - } + }; + + result.map(|(address, index)| AddressInfo { index, address }) } /// Return whether or not a `script` is part of this wallet (either internal or external) @@ -3867,4 +3901,65 @@ pub(crate) mod test { "tb1qzntf2mqex4ehwkjlfdyy3ewdlk08qkvkvrz7x2" ); } + + #[test] + fn test_returns_index_and_address() { + let db = MemoryDatabase::new(); + let wallet = Wallet::new_offline("wpkh(tpubEBr4i6yk5nf5DAaJpsi9N2pPYBeJ7fZ5Z9rmN4977iYLCGco1VyjB9tvvuvYtfZzjD5A8igzgw3HeWeeKFmanHYqksqZXYXGsw5zjnj7KM9/*)", + None, Network::Testnet, db).unwrap(); + + // new index 0 + assert_eq!( + wallet.get_address(New).unwrap(), + AddressInfo { + index: 0, + address: Address::from_str("tb1q6yn66vajcctph75pvylgkksgpp6nq04ppwct9a").unwrap(), + } + ); + + // new index 1 + assert_eq!( + wallet.get_address(New).unwrap(), + AddressInfo { + index: 1, + address: Address::from_str("tb1q4er7kxx6sssz3q7qp7zsqsdx4erceahhax77d7").unwrap() + } + ); + + // peek index 25 + assert_eq!( + wallet.get_address(Peek(25)).unwrap(), + AddressInfo { + index: 25, + address: Address::from_str("tb1qsp7qu0knx3sl6536dzs0703u2w2ag6ppl9d0c2").unwrap() + } + ); + + // new index 2 + assert_eq!( + wallet.get_address(New).unwrap(), + AddressInfo { + index: 2, + address: Address::from_str("tb1qzntf2mqex4ehwkjlfdyy3ewdlk08qkvkvrz7x2").unwrap() + } + ); + + // reset index 1 again + assert_eq!( + wallet.get_address(Reset(1)).unwrap(), + AddressInfo { + index: 1, + address: Address::from_str("tb1q4er7kxx6sssz3q7qp7zsqsdx4erceahhax77d7").unwrap() + } + ); + + // new index 2 again + assert_eq!( + wallet.get_address(New).unwrap(), + AddressInfo { + index: 2, + address: Address::from_str("tb1qzntf2mqex4ehwkjlfdyy3ewdlk08qkvkvrz7x2").unwrap() + } + ); + } }