[wallet] Refill the address pool whenever necessary
This commit is contained in:
		
							parent
							
								
									7a23b2b558
								
							
						
					
					
						commit
						b67bbeb202
					
				| @ -291,7 +291,7 @@ where | ||||
|     if let Some(_sub_matches) = matches.subcommand_matches("get_new_address") { | ||||
|         Ok(Some(format!("{}", wallet.get_new_address()?))) | ||||
|     } else if let Some(_sub_matches) = matches.subcommand_matches("sync") { | ||||
|         maybe_await!(wallet.sync(None, None))?; | ||||
|         maybe_await!(wallet.sync(None))?; | ||||
|         Ok(None) | ||||
|     } else if let Some(_sub_matches) = matches.subcommand_matches("list_unspent") { | ||||
|         let mut res = String::new(); | ||||
|  | ||||
| @ -22,8 +22,7 @@ pub mod tx_builder; | ||||
| pub mod utils; | ||||
| 
 | ||||
| pub use tx_builder::TxBuilder; | ||||
| 
 | ||||
| use self::utils::IsDust; | ||||
| use utils::IsDust; | ||||
| 
 | ||||
| use crate::blockchain::{noop_progress, Blockchain, OfflineBlockchain, OnlineBlockchain}; | ||||
| use crate::database::{BatchDatabase, BatchOperations, DatabaseUtils}; | ||||
| @ -33,6 +32,8 @@ use crate::psbt::{utils::PSBTUtils, PSBTSatisfier, PSBTSigner}; | ||||
| use crate::signer::Signer; | ||||
| use crate::types::*; | ||||
| 
 | ||||
| const CACHE_ADDR_BATCH_SIZE: u32 = 100; | ||||
| 
 | ||||
| pub type OfflineWallet<D> = Wallet<OfflineBlockchain, D>; | ||||
| 
 | ||||
| pub struct Wallet<B: Blockchain, D: BatchDatabase> { | ||||
| @ -93,11 +94,7 @@ where | ||||
|     } | ||||
| 
 | ||||
|     pub fn get_new_address(&self) -> Result<Address, Error> { | ||||
|         let index = self | ||||
|             .database | ||||
|             .borrow_mut() | ||||
|             .increment_last_index(ScriptType::External)?; | ||||
|         // TODO: refill the address pool if index is close to the last cached addr
 | ||||
|         let index = self.fetch_and_increment_index(ScriptType::External)?; | ||||
| 
 | ||||
|         self.descriptor | ||||
|             .derive(index)? | ||||
| @ -185,8 +182,10 @@ where | ||||
|         // script is unknown in the database
 | ||||
|         let input_witness_weight = std::cmp::max( | ||||
|             self.get_descriptor_for(ScriptType::Internal) | ||||
|                 .0 | ||||
|                 .max_satisfaction_weight(), | ||||
|             self.get_descriptor_for(ScriptType::External) | ||||
|                 .0 | ||||
|                 .max_satisfaction_weight(), | ||||
|         ); | ||||
| 
 | ||||
| @ -283,7 +282,7 @@ where | ||||
|                 None => continue, | ||||
|             }; | ||||
| 
 | ||||
|             let desc = self.get_descriptor_for(script_type); | ||||
|             let (desc, _) = self.get_descriptor_for(script_type); | ||||
|             psbt_input.hd_keypaths = desc.get_hd_keypaths(child)?; | ||||
|             let derived_descriptor = desc.derive(child)?; | ||||
| 
 | ||||
| @ -537,10 +536,13 @@ where | ||||
| 
 | ||||
|     // Internals
 | ||||
| 
 | ||||
|     fn get_descriptor_for(&self, script_type: ScriptType) -> &ExtendedDescriptor { | ||||
|     fn get_descriptor_for(&self, script_type: ScriptType) -> (&ExtendedDescriptor, ScriptType) { | ||||
|         let desc = match script_type { | ||||
|             ScriptType::External => &self.descriptor, | ||||
|             ScriptType::Internal => &self.change_descriptor.as_ref().unwrap_or(&self.descriptor), | ||||
|             ScriptType::Internal if self.change_descriptor.is_some() => ( | ||||
|                 self.change_descriptor.as_ref().unwrap(), | ||||
|                 ScriptType::Internal, | ||||
|             ), | ||||
|             _ => (&self.descriptor, ScriptType::External), | ||||
|         }; | ||||
| 
 | ||||
|         desc | ||||
| @ -557,24 +559,72 @@ where | ||||
|     } | ||||
| 
 | ||||
|     fn get_change_address(&self) -> Result<Script, Error> { | ||||
|         let (desc, script_type) = if self.change_descriptor.is_none() { | ||||
|             (&self.descriptor, ScriptType::External) | ||||
|         } else { | ||||
|             ( | ||||
|                 self.change_descriptor.as_ref().unwrap(), | ||||
|                 ScriptType::Internal, | ||||
|             ) | ||||
|         }; | ||||
| 
 | ||||
|         // TODO: refill the address pool if index is close to the last cached addr
 | ||||
|         let index = self | ||||
|             .database | ||||
|             .borrow_mut() | ||||
|             .increment_last_index(script_type)?; | ||||
|         let (desc, script_type) = self.get_descriptor_for(ScriptType::Internal); | ||||
|         let index = self.fetch_and_increment_index(script_type)?; | ||||
| 
 | ||||
|         Ok(desc.derive(index)?.script_pubkey()) | ||||
|     } | ||||
| 
 | ||||
|     fn fetch_and_increment_index(&self, script_type: ScriptType) -> Result<u32, Error> { | ||||
|         let (descriptor, script_type) = self.get_descriptor_for(script_type); | ||||
|         let index = match descriptor.is_fixed() { | ||||
|             true => 0, | ||||
|             false => self | ||||
|                 .database | ||||
|                 .borrow_mut() | ||||
|                 .increment_last_index(script_type)?, | ||||
|         }; | ||||
| 
 | ||||
|         if self | ||||
|             .database | ||||
|             .borrow() | ||||
|             .get_script_pubkey_from_path(script_type, index)? | ||||
|             .is_none() | ||||
|         { | ||||
|             self.cache_addresses(script_type, index, CACHE_ADDR_BATCH_SIZE)?; | ||||
|         } | ||||
| 
 | ||||
|         Ok(index) | ||||
|     } | ||||
| 
 | ||||
|     fn cache_addresses( | ||||
|         &self, | ||||
|         script_type: ScriptType, | ||||
|         from: u32, | ||||
|         mut count: u32, | ||||
|     ) -> Result<(), Error> { | ||||
|         let (descriptor, script_type) = self.get_descriptor_for(script_type); | ||||
|         if descriptor.is_fixed() { | ||||
|             if from > 0 { | ||||
|                 return Ok(()); | ||||
|             } | ||||
| 
 | ||||
|             count = 1; | ||||
|         } | ||||
| 
 | ||||
|         let mut address_batch = self.database.borrow().begin_batch(); | ||||
| 
 | ||||
|         let start_time = time::Instant::new(); | ||||
|         for i in from..(from + count) { | ||||
|             address_batch.set_script_pubkey( | ||||
|                 &descriptor.derive(i)?.script_pubkey(), | ||||
|                 script_type, | ||||
|                 i, | ||||
|             )?; | ||||
|         } | ||||
| 
 | ||||
|         info!( | ||||
|             "Derivation of {} addresses from {} took {} ms", | ||||
|             count, | ||||
|             from, | ||||
|             start_time.elapsed().as_millis() | ||||
|         ); | ||||
| 
 | ||||
|         self.database.borrow_mut().commit_batch(address_batch)?; | ||||
| 
 | ||||
|         Ok(()) | ||||
|     } | ||||
| 
 | ||||
|     fn get_available_utxos( | ||||
|         &self, | ||||
|         utxo: &Option<Vec<OutPoint>>, | ||||
| @ -621,25 +671,19 @@ where | ||||
| 
 | ||||
|         // try to add hd_keypaths if we've already seen the output
 | ||||
|         for (psbt_input, out) in psbt.inputs.iter_mut().zip(input_utxos.iter()) { | ||||
|             debug!("searching hd_keypaths for out: {:?}", out); | ||||
| 
 | ||||
|             if let Some(out) = out { | ||||
|                 let option_path = self | ||||
|                 if let Some((script_type, child)) = self | ||||
|                     .database | ||||
|                     .borrow() | ||||
|                     .get_path_from_script_pubkey(&out.script_pubkey)?; | ||||
|                     .get_path_from_script_pubkey(&out.script_pubkey)? | ||||
|                 { | ||||
|                     debug!("Found descriptor {:?}/{}", script_type, child); | ||||
| 
 | ||||
|                 debug!("found descriptor path {:?}", option_path); | ||||
| 
 | ||||
|                 let (script_type, child) = match option_path { | ||||
|                     None => continue, | ||||
|                     Some((script_type, child)) => (script_type, child), | ||||
|                 }; | ||||
| 
 | ||||
|                 // merge hd_keypaths
 | ||||
|                 let desc = self.get_descriptor_for(script_type); | ||||
|                 let mut hd_keypaths = desc.get_hd_keypaths(child)?; | ||||
|                 psbt_input.hd_keypaths.append(&mut hd_keypaths); | ||||
|                     // merge hd_keypaths
 | ||||
|                     let (desc, _) = self.get_descriptor_for(script_type); | ||||
|                     let mut hd_keypaths = desc.get_hd_keypaths(child)?; | ||||
|                     psbt_input.hd_keypaths.append(&mut hd_keypaths); | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
| 
 | ||||
| @ -669,61 +713,36 @@ where | ||||
|     } | ||||
| 
 | ||||
|     #[maybe_async] | ||||
|     pub fn sync( | ||||
|         &self, | ||||
|         max_address: Option<u32>, | ||||
|         _batch_query_size: Option<usize>, | ||||
|     ) -> Result<(), Error> { | ||||
|         debug!("begin sync..."); | ||||
|         // TODO: consider taking an RwLock as writere here to prevent other "read-only" calls to
 | ||||
|         // break because the db is in an inconsistent state
 | ||||
|     pub fn sync(&self, max_address_param: Option<u32>) -> Result<(), Error> { | ||||
|         debug!("Begin sync..."); | ||||
| 
 | ||||
|         let max_address = if self.descriptor.is_fixed() { | ||||
|             0 | ||||
|         } else { | ||||
|             max_address.unwrap_or(100) | ||||
|         let max_address = match self.descriptor.is_fixed() { | ||||
|             true => 0, | ||||
|             false => max_address_param.unwrap_or(CACHE_ADDR_BATCH_SIZE), | ||||
|         }; | ||||
| 
 | ||||
|         // TODO:
 | ||||
|         // let batch_query_size = batch_query_size.unwrap_or(20);
 | ||||
| 
 | ||||
|         let last_addr = self | ||||
|         if self | ||||
|             .database | ||||
|             .borrow() | ||||
|             .get_script_pubkey_from_path(ScriptType::External, max_address)?; | ||||
|             .get_script_pubkey_from_path(ScriptType::External, max_address)? | ||||
|             .is_none() | ||||
|         { | ||||
|             self.cache_addresses(ScriptType::External, 0, max_address)?; | ||||
|         } | ||||
| 
 | ||||
|         // cache a few of our addresses
 | ||||
|         if last_addr.is_none() { | ||||
|             let mut address_batch = self.database.borrow().begin_batch(); | ||||
|             let start = time::Instant::new(); | ||||
|         if let Some(change_descriptor) = &self.change_descriptor { | ||||
|             let max_address = match change_descriptor.is_fixed() { | ||||
|                 true => 0, | ||||
|                 false => max_address_param.unwrap_or(CACHE_ADDR_BATCH_SIZE), | ||||
|             }; | ||||
| 
 | ||||
|             for i in 0..=max_address { | ||||
|                 let derived = self.descriptor.derive(i).unwrap(); | ||||
| 
 | ||||
|                 address_batch.set_script_pubkey( | ||||
|                     &derived.script_pubkey(), | ||||
|                     ScriptType::External, | ||||
|                     i, | ||||
|                 )?; | ||||
|             if self | ||||
|                 .database | ||||
|                 .borrow() | ||||
|                 .get_script_pubkey_from_path(ScriptType::Internal, max_address)? | ||||
|                 .is_none() | ||||
|             { | ||||
|                 self.cache_addresses(ScriptType::Internal, 0, max_address)?; | ||||
|             } | ||||
|             if self.change_descriptor.is_some() { | ||||
|                 for i in 0..=max_address { | ||||
|                     let derived = self.change_descriptor.as_ref().unwrap().derive(i).unwrap(); | ||||
| 
 | ||||
|                     address_batch.set_script_pubkey( | ||||
|                         &derived.script_pubkey(), | ||||
|                         ScriptType::Internal, | ||||
|                         i, | ||||
|                     )?; | ||||
|                 } | ||||
|             } | ||||
| 
 | ||||
|             info!( | ||||
|                 "derivation of {} addresses, took {} ms", | ||||
|                 max_address, | ||||
|                 start.elapsed().as_millis() | ||||
|             ); | ||||
|             self.database.borrow_mut().commit_batch(address_batch)?; | ||||
|         } | ||||
| 
 | ||||
|         maybe_await!(self.client.sync( | ||||
| @ -740,3 +759,104 @@ where | ||||
|         Ok(tx.txid()) | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| #[cfg(test)] | ||||
| mod test { | ||||
|     use bitcoin::Network; | ||||
| 
 | ||||
|     use crate::database::memory::MemoryDatabase; | ||||
|     use crate::database::Database; | ||||
|     use crate::types::ScriptType; | ||||
| 
 | ||||
|     use super::*; | ||||
| 
 | ||||
|     #[test] | ||||
|     fn test_cache_addresses_fixed() { | ||||
|         let db = MemoryDatabase::new(); | ||||
|         let wallet: OfflineWallet<_> = Wallet::new_offline( | ||||
|             "wpkh(L5EZftvrYaSudiozVRzTqLcHLNDoVn7H5HSfM9BAN6tMJX8oTWz6)", | ||||
|             None, | ||||
|             Network::Testnet, | ||||
|             db, | ||||
|         ) | ||||
|         .unwrap(); | ||||
| 
 | ||||
|         assert_eq!( | ||||
|             wallet.get_new_address().unwrap().to_string(), | ||||
|             "tb1qj08ys4ct2hzzc2hcz6h2hgrvlmsjynaw43s835" | ||||
|         ); | ||||
|         assert_eq!( | ||||
|             wallet.get_new_address().unwrap().to_string(), | ||||
|             "tb1qj08ys4ct2hzzc2hcz6h2hgrvlmsjynaw43s835" | ||||
|         ); | ||||
| 
 | ||||
|         assert!(wallet | ||||
|             .database | ||||
|             .borrow_mut() | ||||
|             .get_script_pubkey_from_path(ScriptType::External, 0) | ||||
|             .unwrap() | ||||
|             .is_some()); | ||||
|         assert!(wallet | ||||
|             .database | ||||
|             .borrow_mut() | ||||
|             .get_script_pubkey_from_path(ScriptType::Internal, 0) | ||||
|             .unwrap() | ||||
|             .is_none()); | ||||
|     } | ||||
| 
 | ||||
|     #[test] | ||||
|     fn test_cache_addresses() { | ||||
|         let db = MemoryDatabase::new(); | ||||
|         let wallet: OfflineWallet<_> = Wallet::new_offline("wpkh(tpubEBr4i6yk5nf5DAaJpsi9N2pPYBeJ7fZ5Z9rmN4977iYLCGco1VyjB9tvvuvYtfZzjD5A8igzgw3HeWeeKFmanHYqksqZXYXGsw5zjnj7KM9/*)", None, Network::Testnet, db).unwrap(); | ||||
| 
 | ||||
|         assert_eq!( | ||||
|             wallet.get_new_address().unwrap().to_string(), | ||||
|             "tb1q6yn66vajcctph75pvylgkksgpp6nq04ppwct9a" | ||||
|         ); | ||||
|         assert_eq!( | ||||
|             wallet.get_new_address().unwrap().to_string(), | ||||
|             "tb1q4er7kxx6sssz3q7qp7zsqsdx4erceahhax77d7" | ||||
|         ); | ||||
| 
 | ||||
|         assert!(wallet | ||||
|             .database | ||||
|             .borrow_mut() | ||||
|             .get_script_pubkey_from_path(ScriptType::External, CACHE_ADDR_BATCH_SIZE - 1) | ||||
|             .unwrap() | ||||
|             .is_some()); | ||||
|         assert!(wallet | ||||
|             .database | ||||
|             .borrow_mut() | ||||
|             .get_script_pubkey_from_path(ScriptType::External, CACHE_ADDR_BATCH_SIZE) | ||||
|             .unwrap() | ||||
|             .is_none()); | ||||
|     } | ||||
| 
 | ||||
|     #[test] | ||||
|     fn test_cache_addresses_refill() { | ||||
|         let db = MemoryDatabase::new(); | ||||
|         let wallet: OfflineWallet<_> = Wallet::new_offline("wpkh(tpubEBr4i6yk5nf5DAaJpsi9N2pPYBeJ7fZ5Z9rmN4977iYLCGco1VyjB9tvvuvYtfZzjD5A8igzgw3HeWeeKFmanHYqksqZXYXGsw5zjnj7KM9/*)", None, Network::Testnet, db).unwrap(); | ||||
| 
 | ||||
|         assert_eq!( | ||||
|             wallet.get_new_address().unwrap().to_string(), | ||||
|             "tb1q6yn66vajcctph75pvylgkksgpp6nq04ppwct9a" | ||||
|         ); | ||||
|         assert!(wallet | ||||
|             .database | ||||
|             .borrow_mut() | ||||
|             .get_script_pubkey_from_path(ScriptType::External, CACHE_ADDR_BATCH_SIZE - 1) | ||||
|             .unwrap() | ||||
|             .is_some()); | ||||
| 
 | ||||
|         for _ in 0..CACHE_ADDR_BATCH_SIZE { | ||||
|             wallet.get_new_address().unwrap(); | ||||
|         } | ||||
| 
 | ||||
|         assert!(wallet | ||||
|             .database | ||||
|             .borrow_mut() | ||||
|             .get_script_pubkey_from_path(ScriptType::External, CACHE_ADDR_BATCH_SIZE * 2 - 1) | ||||
|             .unwrap() | ||||
|             .is_some()); | ||||
|     } | ||||
| } | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user