From b67bbeb202051380fed51781bbe46841e80449c5 Mon Sep 17 00:00:00 2001 From: Alekos Filini Date: Thu, 6 Aug 2020 18:11:07 +0200 Subject: [PATCH] [wallet] Refill the address pool whenever necessary --- src/cli.rs | 2 +- src/wallet/mod.rs | 296 ++++++++++++++++++++++++++++++++-------------- 2 files changed, 209 insertions(+), 89 deletions(-) diff --git a/src/cli.rs b/src/cli.rs index 735c0af9..d4e45789 100644 --- a/src/cli.rs +++ b/src/cli.rs @@ -291,7 +291,7 @@ where if let Some(_sub_matches) = matches.subcommand_matches("get_new_address") { Ok(Some(format!("{}", wallet.get_new_address()?))) } else if let Some(_sub_matches) = matches.subcommand_matches("sync") { - maybe_await!(wallet.sync(None, None))?; + maybe_await!(wallet.sync(None))?; Ok(None) } else if let Some(_sub_matches) = matches.subcommand_matches("list_unspent") { let mut res = String::new(); diff --git a/src/wallet/mod.rs b/src/wallet/mod.rs index f59424d1..7948a672 100644 --- a/src/wallet/mod.rs +++ b/src/wallet/mod.rs @@ -22,8 +22,7 @@ pub mod tx_builder; pub mod utils; pub use tx_builder::TxBuilder; - -use self::utils::IsDust; +use utils::IsDust; use crate::blockchain::{noop_progress, Blockchain, OfflineBlockchain, OnlineBlockchain}; use crate::database::{BatchDatabase, BatchOperations, DatabaseUtils}; @@ -33,6 +32,8 @@ use crate::psbt::{utils::PSBTUtils, PSBTSatisfier, PSBTSigner}; use crate::signer::Signer; use crate::types::*; +const CACHE_ADDR_BATCH_SIZE: u32 = 100; + pub type OfflineWallet = Wallet; pub struct Wallet { @@ -93,11 +94,7 @@ where } pub fn get_new_address(&self) -> Result { - let index = self - .database - .borrow_mut() - .increment_last_index(ScriptType::External)?; - // TODO: refill the address pool if index is close to the last cached addr + let index = self.fetch_and_increment_index(ScriptType::External)?; self.descriptor .derive(index)? @@ -185,8 +182,10 @@ where // script is unknown in the database let input_witness_weight = std::cmp::max( self.get_descriptor_for(ScriptType::Internal) + .0 .max_satisfaction_weight(), self.get_descriptor_for(ScriptType::External) + .0 .max_satisfaction_weight(), ); @@ -283,7 +282,7 @@ where None => continue, }; - let desc = self.get_descriptor_for(script_type); + let (desc, _) = self.get_descriptor_for(script_type); psbt_input.hd_keypaths = desc.get_hd_keypaths(child)?; let derived_descriptor = desc.derive(child)?; @@ -537,10 +536,13 @@ where // Internals - fn get_descriptor_for(&self, script_type: ScriptType) -> &ExtendedDescriptor { + fn get_descriptor_for(&self, script_type: ScriptType) -> (&ExtendedDescriptor, ScriptType) { let desc = match script_type { - ScriptType::External => &self.descriptor, - ScriptType::Internal => &self.change_descriptor.as_ref().unwrap_or(&self.descriptor), + ScriptType::Internal if self.change_descriptor.is_some() => ( + self.change_descriptor.as_ref().unwrap(), + ScriptType::Internal, + ), + _ => (&self.descriptor, ScriptType::External), }; desc @@ -557,24 +559,72 @@ where } fn get_change_address(&self) -> Result { - let (desc, script_type) = if self.change_descriptor.is_none() { - (&self.descriptor, ScriptType::External) - } else { - ( - self.change_descriptor.as_ref().unwrap(), - ScriptType::Internal, - ) - }; - - // TODO: refill the address pool if index is close to the last cached addr - let index = self - .database - .borrow_mut() - .increment_last_index(script_type)?; + let (desc, script_type) = self.get_descriptor_for(ScriptType::Internal); + let index = self.fetch_and_increment_index(script_type)?; Ok(desc.derive(index)?.script_pubkey()) } + fn fetch_and_increment_index(&self, script_type: ScriptType) -> Result { + let (descriptor, script_type) = self.get_descriptor_for(script_type); + let index = match descriptor.is_fixed() { + true => 0, + false => self + .database + .borrow_mut() + .increment_last_index(script_type)?, + }; + + if self + .database + .borrow() + .get_script_pubkey_from_path(script_type, index)? + .is_none() + { + self.cache_addresses(script_type, index, CACHE_ADDR_BATCH_SIZE)?; + } + + Ok(index) + } + + fn cache_addresses( + &self, + script_type: ScriptType, + from: u32, + mut count: u32, + ) -> Result<(), Error> { + let (descriptor, script_type) = self.get_descriptor_for(script_type); + if descriptor.is_fixed() { + if from > 0 { + return Ok(()); + } + + count = 1; + } + + let mut address_batch = self.database.borrow().begin_batch(); + + let start_time = time::Instant::new(); + for i in from..(from + count) { + address_batch.set_script_pubkey( + &descriptor.derive(i)?.script_pubkey(), + script_type, + i, + )?; + } + + info!( + "Derivation of {} addresses from {} took {} ms", + count, + from, + start_time.elapsed().as_millis() + ); + + self.database.borrow_mut().commit_batch(address_batch)?; + + Ok(()) + } + fn get_available_utxos( &self, utxo: &Option>, @@ -621,25 +671,19 @@ where // try to add hd_keypaths if we've already seen the output for (psbt_input, out) in psbt.inputs.iter_mut().zip(input_utxos.iter()) { - debug!("searching hd_keypaths for out: {:?}", out); - if let Some(out) = out { - let option_path = self + if let Some((script_type, child)) = self .database .borrow() - .get_path_from_script_pubkey(&out.script_pubkey)?; + .get_path_from_script_pubkey(&out.script_pubkey)? + { + debug!("Found descriptor {:?}/{}", script_type, child); - debug!("found descriptor path {:?}", option_path); - - let (script_type, child) = match option_path { - None => continue, - Some((script_type, child)) => (script_type, child), - }; - - // merge hd_keypaths - let desc = self.get_descriptor_for(script_type); - let mut hd_keypaths = desc.get_hd_keypaths(child)?; - psbt_input.hd_keypaths.append(&mut hd_keypaths); + // merge hd_keypaths + let (desc, _) = self.get_descriptor_for(script_type); + let mut hd_keypaths = desc.get_hd_keypaths(child)?; + psbt_input.hd_keypaths.append(&mut hd_keypaths); + } } } @@ -669,61 +713,36 @@ where } #[maybe_async] - pub fn sync( - &self, - max_address: Option, - _batch_query_size: Option, - ) -> Result<(), Error> { - debug!("begin sync..."); - // TODO: consider taking an RwLock as writere here to prevent other "read-only" calls to - // break because the db is in an inconsistent state + pub fn sync(&self, max_address_param: Option) -> Result<(), Error> { + debug!("Begin sync..."); - let max_address = if self.descriptor.is_fixed() { - 0 - } else { - max_address.unwrap_or(100) + let max_address = match self.descriptor.is_fixed() { + true => 0, + false => max_address_param.unwrap_or(CACHE_ADDR_BATCH_SIZE), }; - - // TODO: - // let batch_query_size = batch_query_size.unwrap_or(20); - - let last_addr = self + if self .database .borrow() - .get_script_pubkey_from_path(ScriptType::External, max_address)?; + .get_script_pubkey_from_path(ScriptType::External, max_address)? + .is_none() + { + self.cache_addresses(ScriptType::External, 0, max_address)?; + } - // cache a few of our addresses - if last_addr.is_none() { - let mut address_batch = self.database.borrow().begin_batch(); - let start = time::Instant::new(); + if let Some(change_descriptor) = &self.change_descriptor { + let max_address = match change_descriptor.is_fixed() { + true => 0, + false => max_address_param.unwrap_or(CACHE_ADDR_BATCH_SIZE), + }; - for i in 0..=max_address { - let derived = self.descriptor.derive(i).unwrap(); - - address_batch.set_script_pubkey( - &derived.script_pubkey(), - ScriptType::External, - i, - )?; + if self + .database + .borrow() + .get_script_pubkey_from_path(ScriptType::Internal, max_address)? + .is_none() + { + self.cache_addresses(ScriptType::Internal, 0, max_address)?; } - if self.change_descriptor.is_some() { - for i in 0..=max_address { - let derived = self.change_descriptor.as_ref().unwrap().derive(i).unwrap(); - - address_batch.set_script_pubkey( - &derived.script_pubkey(), - ScriptType::Internal, - i, - )?; - } - } - - info!( - "derivation of {} addresses, took {} ms", - max_address, - start.elapsed().as_millis() - ); - self.database.borrow_mut().commit_batch(address_batch)?; } maybe_await!(self.client.sync( @@ -740,3 +759,104 @@ where Ok(tx.txid()) } } + +#[cfg(test)] +mod test { + use bitcoin::Network; + + use crate::database::memory::MemoryDatabase; + use crate::database::Database; + use crate::types::ScriptType; + + use super::*; + + #[test] + fn test_cache_addresses_fixed() { + let db = MemoryDatabase::new(); + let wallet: OfflineWallet<_> = Wallet::new_offline( + "wpkh(L5EZftvrYaSudiozVRzTqLcHLNDoVn7H5HSfM9BAN6tMJX8oTWz6)", + None, + Network::Testnet, + db, + ) + .unwrap(); + + assert_eq!( + wallet.get_new_address().unwrap().to_string(), + "tb1qj08ys4ct2hzzc2hcz6h2hgrvlmsjynaw43s835" + ); + assert_eq!( + wallet.get_new_address().unwrap().to_string(), + "tb1qj08ys4ct2hzzc2hcz6h2hgrvlmsjynaw43s835" + ); + + assert!(wallet + .database + .borrow_mut() + .get_script_pubkey_from_path(ScriptType::External, 0) + .unwrap() + .is_some()); + assert!(wallet + .database + .borrow_mut() + .get_script_pubkey_from_path(ScriptType::Internal, 0) + .unwrap() + .is_none()); + } + + #[test] + fn test_cache_addresses() { + let db = MemoryDatabase::new(); + let wallet: OfflineWallet<_> = Wallet::new_offline("wpkh(tpubEBr4i6yk5nf5DAaJpsi9N2pPYBeJ7fZ5Z9rmN4977iYLCGco1VyjB9tvvuvYtfZzjD5A8igzgw3HeWeeKFmanHYqksqZXYXGsw5zjnj7KM9/*)", None, Network::Testnet, db).unwrap(); + + assert_eq!( + wallet.get_new_address().unwrap().to_string(), + "tb1q6yn66vajcctph75pvylgkksgpp6nq04ppwct9a" + ); + assert_eq!( + wallet.get_new_address().unwrap().to_string(), + "tb1q4er7kxx6sssz3q7qp7zsqsdx4erceahhax77d7" + ); + + assert!(wallet + .database + .borrow_mut() + .get_script_pubkey_from_path(ScriptType::External, CACHE_ADDR_BATCH_SIZE - 1) + .unwrap() + .is_some()); + assert!(wallet + .database + .borrow_mut() + .get_script_pubkey_from_path(ScriptType::External, CACHE_ADDR_BATCH_SIZE) + .unwrap() + .is_none()); + } + + #[test] + fn test_cache_addresses_refill() { + let db = MemoryDatabase::new(); + let wallet: OfflineWallet<_> = Wallet::new_offline("wpkh(tpubEBr4i6yk5nf5DAaJpsi9N2pPYBeJ7fZ5Z9rmN4977iYLCGco1VyjB9tvvuvYtfZzjD5A8igzgw3HeWeeKFmanHYqksqZXYXGsw5zjnj7KM9/*)", None, Network::Testnet, db).unwrap(); + + assert_eq!( + wallet.get_new_address().unwrap().to_string(), + "tb1q6yn66vajcctph75pvylgkksgpp6nq04ppwct9a" + ); + assert!(wallet + .database + .borrow_mut() + .get_script_pubkey_from_path(ScriptType::External, CACHE_ADDR_BATCH_SIZE - 1) + .unwrap() + .is_some()); + + for _ in 0..CACHE_ADDR_BATCH_SIZE { + wallet.get_new_address().unwrap(); + } + + assert!(wallet + .database + .borrow_mut() + .get_script_pubkey_from_path(ScriptType::External, CACHE_ADDR_BATCH_SIZE * 2 - 1) + .unwrap() + .is_some()); + } +}