[wallet] Refill the address pool whenever necessary
This commit is contained in:
		
							parent
							
								
									7a23b2b558
								
							
						
					
					
						commit
						b67bbeb202
					
				| @ -291,7 +291,7 @@ where | |||||||
|     if let Some(_sub_matches) = matches.subcommand_matches("get_new_address") { |     if let Some(_sub_matches) = matches.subcommand_matches("get_new_address") { | ||||||
|         Ok(Some(format!("{}", wallet.get_new_address()?))) |         Ok(Some(format!("{}", wallet.get_new_address()?))) | ||||||
|     } else if let Some(_sub_matches) = matches.subcommand_matches("sync") { |     } else if let Some(_sub_matches) = matches.subcommand_matches("sync") { | ||||||
|         maybe_await!(wallet.sync(None, None))?; |         maybe_await!(wallet.sync(None))?; | ||||||
|         Ok(None) |         Ok(None) | ||||||
|     } else if let Some(_sub_matches) = matches.subcommand_matches("list_unspent") { |     } else if let Some(_sub_matches) = matches.subcommand_matches("list_unspent") { | ||||||
|         let mut res = String::new(); |         let mut res = String::new(); | ||||||
|  | |||||||
| @ -22,8 +22,7 @@ pub mod tx_builder; | |||||||
| pub mod utils; | pub mod utils; | ||||||
| 
 | 
 | ||||||
| pub use tx_builder::TxBuilder; | pub use tx_builder::TxBuilder; | ||||||
| 
 | use utils::IsDust; | ||||||
| use self::utils::IsDust; |  | ||||||
| 
 | 
 | ||||||
| use crate::blockchain::{noop_progress, Blockchain, OfflineBlockchain, OnlineBlockchain}; | use crate::blockchain::{noop_progress, Blockchain, OfflineBlockchain, OnlineBlockchain}; | ||||||
| use crate::database::{BatchDatabase, BatchOperations, DatabaseUtils}; | use crate::database::{BatchDatabase, BatchOperations, DatabaseUtils}; | ||||||
| @ -33,6 +32,8 @@ use crate::psbt::{utils::PSBTUtils, PSBTSatisfier, PSBTSigner}; | |||||||
| use crate::signer::Signer; | use crate::signer::Signer; | ||||||
| use crate::types::*; | use crate::types::*; | ||||||
| 
 | 
 | ||||||
|  | const CACHE_ADDR_BATCH_SIZE: u32 = 100; | ||||||
|  | 
 | ||||||
| pub type OfflineWallet<D> = Wallet<OfflineBlockchain, D>; | pub type OfflineWallet<D> = Wallet<OfflineBlockchain, D>; | ||||||
| 
 | 
 | ||||||
| pub struct Wallet<B: Blockchain, D: BatchDatabase> { | pub struct Wallet<B: Blockchain, D: BatchDatabase> { | ||||||
| @ -93,11 +94,7 @@ where | |||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     pub fn get_new_address(&self) -> Result<Address, Error> { |     pub fn get_new_address(&self) -> Result<Address, Error> { | ||||||
|         let index = self |         let index = self.fetch_and_increment_index(ScriptType::External)?; | ||||||
|             .database |  | ||||||
|             .borrow_mut() |  | ||||||
|             .increment_last_index(ScriptType::External)?; |  | ||||||
|         // TODO: refill the address pool if index is close to the last cached addr
 |  | ||||||
| 
 | 
 | ||||||
|         self.descriptor |         self.descriptor | ||||||
|             .derive(index)? |             .derive(index)? | ||||||
| @ -185,8 +182,10 @@ where | |||||||
|         // script is unknown in the database
 |         // script is unknown in the database
 | ||||||
|         let input_witness_weight = std::cmp::max( |         let input_witness_weight = std::cmp::max( | ||||||
|             self.get_descriptor_for(ScriptType::Internal) |             self.get_descriptor_for(ScriptType::Internal) | ||||||
|  |                 .0 | ||||||
|                 .max_satisfaction_weight(), |                 .max_satisfaction_weight(), | ||||||
|             self.get_descriptor_for(ScriptType::External) |             self.get_descriptor_for(ScriptType::External) | ||||||
|  |                 .0 | ||||||
|                 .max_satisfaction_weight(), |                 .max_satisfaction_weight(), | ||||||
|         ); |         ); | ||||||
| 
 | 
 | ||||||
| @ -283,7 +282,7 @@ where | |||||||
|                 None => continue, |                 None => continue, | ||||||
|             }; |             }; | ||||||
| 
 | 
 | ||||||
|             let desc = self.get_descriptor_for(script_type); |             let (desc, _) = self.get_descriptor_for(script_type); | ||||||
|             psbt_input.hd_keypaths = desc.get_hd_keypaths(child)?; |             psbt_input.hd_keypaths = desc.get_hd_keypaths(child)?; | ||||||
|             let derived_descriptor = desc.derive(child)?; |             let derived_descriptor = desc.derive(child)?; | ||||||
| 
 | 
 | ||||||
| @ -537,10 +536,13 @@ where | |||||||
| 
 | 
 | ||||||
|     // Internals
 |     // Internals
 | ||||||
| 
 | 
 | ||||||
|     fn get_descriptor_for(&self, script_type: ScriptType) -> &ExtendedDescriptor { |     fn get_descriptor_for(&self, script_type: ScriptType) -> (&ExtendedDescriptor, ScriptType) { | ||||||
|         let desc = match script_type { |         let desc = match script_type { | ||||||
|             ScriptType::External => &self.descriptor, |             ScriptType::Internal if self.change_descriptor.is_some() => ( | ||||||
|             ScriptType::Internal => &self.change_descriptor.as_ref().unwrap_or(&self.descriptor), |                 self.change_descriptor.as_ref().unwrap(), | ||||||
|  |                 ScriptType::Internal, | ||||||
|  |             ), | ||||||
|  |             _ => (&self.descriptor, ScriptType::External), | ||||||
|         }; |         }; | ||||||
| 
 | 
 | ||||||
|         desc |         desc | ||||||
| @ -557,24 +559,72 @@ where | |||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     fn get_change_address(&self) -> Result<Script, Error> { |     fn get_change_address(&self) -> Result<Script, Error> { | ||||||
|         let (desc, script_type) = if self.change_descriptor.is_none() { |         let (desc, script_type) = self.get_descriptor_for(ScriptType::Internal); | ||||||
|             (&self.descriptor, ScriptType::External) |         let index = self.fetch_and_increment_index(script_type)?; | ||||||
|         } else { |  | ||||||
|             ( |  | ||||||
|                 self.change_descriptor.as_ref().unwrap(), |  | ||||||
|                 ScriptType::Internal, |  | ||||||
|             ) |  | ||||||
|         }; |  | ||||||
| 
 |  | ||||||
|         // TODO: refill the address pool if index is close to the last cached addr
 |  | ||||||
|         let index = self |  | ||||||
|             .database |  | ||||||
|             .borrow_mut() |  | ||||||
|             .increment_last_index(script_type)?; |  | ||||||
| 
 | 
 | ||||||
|         Ok(desc.derive(index)?.script_pubkey()) |         Ok(desc.derive(index)?.script_pubkey()) | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|  |     fn fetch_and_increment_index(&self, script_type: ScriptType) -> Result<u32, Error> { | ||||||
|  |         let (descriptor, script_type) = self.get_descriptor_for(script_type); | ||||||
|  |         let index = match descriptor.is_fixed() { | ||||||
|  |             true => 0, | ||||||
|  |             false => self | ||||||
|  |                 .database | ||||||
|  |                 .borrow_mut() | ||||||
|  |                 .increment_last_index(script_type)?, | ||||||
|  |         }; | ||||||
|  | 
 | ||||||
|  |         if self | ||||||
|  |             .database | ||||||
|  |             .borrow() | ||||||
|  |             .get_script_pubkey_from_path(script_type, index)? | ||||||
|  |             .is_none() | ||||||
|  |         { | ||||||
|  |             self.cache_addresses(script_type, index, CACHE_ADDR_BATCH_SIZE)?; | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         Ok(index) | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     fn cache_addresses( | ||||||
|  |         &self, | ||||||
|  |         script_type: ScriptType, | ||||||
|  |         from: u32, | ||||||
|  |         mut count: u32, | ||||||
|  |     ) -> Result<(), Error> { | ||||||
|  |         let (descriptor, script_type) = self.get_descriptor_for(script_type); | ||||||
|  |         if descriptor.is_fixed() { | ||||||
|  |             if from > 0 { | ||||||
|  |                 return Ok(()); | ||||||
|  |             } | ||||||
|  | 
 | ||||||
|  |             count = 1; | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         let mut address_batch = self.database.borrow().begin_batch(); | ||||||
|  | 
 | ||||||
|  |         let start_time = time::Instant::new(); | ||||||
|  |         for i in from..(from + count) { | ||||||
|  |             address_batch.set_script_pubkey( | ||||||
|  |                 &descriptor.derive(i)?.script_pubkey(), | ||||||
|  |                 script_type, | ||||||
|  |                 i, | ||||||
|  |             )?; | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         info!( | ||||||
|  |             "Derivation of {} addresses from {} took {} ms", | ||||||
|  |             count, | ||||||
|  |             from, | ||||||
|  |             start_time.elapsed().as_millis() | ||||||
|  |         ); | ||||||
|  | 
 | ||||||
|  |         self.database.borrow_mut().commit_batch(address_batch)?; | ||||||
|  | 
 | ||||||
|  |         Ok(()) | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|     fn get_available_utxos( |     fn get_available_utxos( | ||||||
|         &self, |         &self, | ||||||
|         utxo: &Option<Vec<OutPoint>>, |         utxo: &Option<Vec<OutPoint>>, | ||||||
| @ -621,27 +671,21 @@ where | |||||||
| 
 | 
 | ||||||
|         // try to add hd_keypaths if we've already seen the output
 |         // try to add hd_keypaths if we've already seen the output
 | ||||||
|         for (psbt_input, out) in psbt.inputs.iter_mut().zip(input_utxos.iter()) { |         for (psbt_input, out) in psbt.inputs.iter_mut().zip(input_utxos.iter()) { | ||||||
|             debug!("searching hd_keypaths for out: {:?}", out); |  | ||||||
| 
 |  | ||||||
|             if let Some(out) = out { |             if let Some(out) = out { | ||||||
|                 let option_path = self |                 if let Some((script_type, child)) = self | ||||||
|                     .database |                     .database | ||||||
|                     .borrow() |                     .borrow() | ||||||
|                     .get_path_from_script_pubkey(&out.script_pubkey)?; |                     .get_path_from_script_pubkey(&out.script_pubkey)? | ||||||
| 
 |                 { | ||||||
|                 debug!("found descriptor path {:?}", option_path); |                     debug!("Found descriptor {:?}/{}", script_type, child); | ||||||
| 
 |  | ||||||
|                 let (script_type, child) = match option_path { |  | ||||||
|                     None => continue, |  | ||||||
|                     Some((script_type, child)) => (script_type, child), |  | ||||||
|                 }; |  | ||||||
| 
 | 
 | ||||||
|                     // merge hd_keypaths
 |                     // merge hd_keypaths
 | ||||||
|                 let desc = self.get_descriptor_for(script_type); |                     let (desc, _) = self.get_descriptor_for(script_type); | ||||||
|                     let mut hd_keypaths = desc.get_hd_keypaths(child)?; |                     let mut hd_keypaths = desc.get_hd_keypaths(child)?; | ||||||
|                     psbt_input.hd_keypaths.append(&mut hd_keypaths); |                     psbt_input.hd_keypaths.append(&mut hd_keypaths); | ||||||
|                 } |                 } | ||||||
|             } |             } | ||||||
|  |         } | ||||||
| 
 | 
 | ||||||
|         Ok(()) |         Ok(()) | ||||||
|     } |     } | ||||||
| @ -669,61 +713,36 @@ where | |||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     #[maybe_async] |     #[maybe_async] | ||||||
|     pub fn sync( |     pub fn sync(&self, max_address_param: Option<u32>) -> Result<(), Error> { | ||||||
|         &self, |         debug!("Begin sync..."); | ||||||
|         max_address: Option<u32>, |  | ||||||
|         _batch_query_size: Option<usize>, |  | ||||||
|     ) -> Result<(), Error> { |  | ||||||
|         debug!("begin sync..."); |  | ||||||
|         // TODO: consider taking an RwLock as writere here to prevent other "read-only" calls to
 |  | ||||||
|         // break because the db is in an inconsistent state
 |  | ||||||
| 
 | 
 | ||||||
|         let max_address = if self.descriptor.is_fixed() { |         let max_address = match self.descriptor.is_fixed() { | ||||||
|             0 |             true => 0, | ||||||
|         } else { |             false => max_address_param.unwrap_or(CACHE_ADDR_BATCH_SIZE), | ||||||
|             max_address.unwrap_or(100) |  | ||||||
|         }; |         }; | ||||||
| 
 |         if self | ||||||
|         // TODO:
 |  | ||||||
|         // let batch_query_size = batch_query_size.unwrap_or(20);
 |  | ||||||
| 
 |  | ||||||
|         let last_addr = self |  | ||||||
|             .database |             .database | ||||||
|             .borrow() |             .borrow() | ||||||
|             .get_script_pubkey_from_path(ScriptType::External, max_address)?; |             .get_script_pubkey_from_path(ScriptType::External, max_address)? | ||||||
| 
 |             .is_none() | ||||||
|         // cache a few of our addresses
 |         { | ||||||
|         if last_addr.is_none() { |             self.cache_addresses(ScriptType::External, 0, max_address)?; | ||||||
|             let mut address_batch = self.database.borrow().begin_batch(); |  | ||||||
|             let start = time::Instant::new(); |  | ||||||
| 
 |  | ||||||
|             for i in 0..=max_address { |  | ||||||
|                 let derived = self.descriptor.derive(i).unwrap(); |  | ||||||
| 
 |  | ||||||
|                 address_batch.set_script_pubkey( |  | ||||||
|                     &derived.script_pubkey(), |  | ||||||
|                     ScriptType::External, |  | ||||||
|                     i, |  | ||||||
|                 )?; |  | ||||||
|             } |  | ||||||
|             if self.change_descriptor.is_some() { |  | ||||||
|                 for i in 0..=max_address { |  | ||||||
|                     let derived = self.change_descriptor.as_ref().unwrap().derive(i).unwrap(); |  | ||||||
| 
 |  | ||||||
|                     address_batch.set_script_pubkey( |  | ||||||
|                         &derived.script_pubkey(), |  | ||||||
|                         ScriptType::Internal, |  | ||||||
|                         i, |  | ||||||
|                     )?; |  | ||||||
|                 } |  | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|             info!( |         if let Some(change_descriptor) = &self.change_descriptor { | ||||||
|                 "derivation of {} addresses, took {} ms", |             let max_address = match change_descriptor.is_fixed() { | ||||||
|                 max_address, |                 true => 0, | ||||||
|                 start.elapsed().as_millis() |                 false => max_address_param.unwrap_or(CACHE_ADDR_BATCH_SIZE), | ||||||
|             ); |             }; | ||||||
|             self.database.borrow_mut().commit_batch(address_batch)?; | 
 | ||||||
|  |             if self | ||||||
|  |                 .database | ||||||
|  |                 .borrow() | ||||||
|  |                 .get_script_pubkey_from_path(ScriptType::Internal, max_address)? | ||||||
|  |                 .is_none() | ||||||
|  |             { | ||||||
|  |                 self.cache_addresses(ScriptType::Internal, 0, max_address)?; | ||||||
|  |             } | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|         maybe_await!(self.client.sync( |         maybe_await!(self.client.sync( | ||||||
| @ -740,3 +759,104 @@ where | |||||||
|         Ok(tx.txid()) |         Ok(tx.txid()) | ||||||
|     } |     } | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | #[cfg(test)] | ||||||
|  | mod test { | ||||||
|  |     use bitcoin::Network; | ||||||
|  | 
 | ||||||
|  |     use crate::database::memory::MemoryDatabase; | ||||||
|  |     use crate::database::Database; | ||||||
|  |     use crate::types::ScriptType; | ||||||
|  | 
 | ||||||
|  |     use super::*; | ||||||
|  | 
 | ||||||
|  |     #[test] | ||||||
|  |     fn test_cache_addresses_fixed() { | ||||||
|  |         let db = MemoryDatabase::new(); | ||||||
|  |         let wallet: OfflineWallet<_> = Wallet::new_offline( | ||||||
|  |             "wpkh(L5EZftvrYaSudiozVRzTqLcHLNDoVn7H5HSfM9BAN6tMJX8oTWz6)", | ||||||
|  |             None, | ||||||
|  |             Network::Testnet, | ||||||
|  |             db, | ||||||
|  |         ) | ||||||
|  |         .unwrap(); | ||||||
|  | 
 | ||||||
|  |         assert_eq!( | ||||||
|  |             wallet.get_new_address().unwrap().to_string(), | ||||||
|  |             "tb1qj08ys4ct2hzzc2hcz6h2hgrvlmsjynaw43s835" | ||||||
|  |         ); | ||||||
|  |         assert_eq!( | ||||||
|  |             wallet.get_new_address().unwrap().to_string(), | ||||||
|  |             "tb1qj08ys4ct2hzzc2hcz6h2hgrvlmsjynaw43s835" | ||||||
|  |         ); | ||||||
|  | 
 | ||||||
|  |         assert!(wallet | ||||||
|  |             .database | ||||||
|  |             .borrow_mut() | ||||||
|  |             .get_script_pubkey_from_path(ScriptType::External, 0) | ||||||
|  |             .unwrap() | ||||||
|  |             .is_some()); | ||||||
|  |         assert!(wallet | ||||||
|  |             .database | ||||||
|  |             .borrow_mut() | ||||||
|  |             .get_script_pubkey_from_path(ScriptType::Internal, 0) | ||||||
|  |             .unwrap() | ||||||
|  |             .is_none()); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     #[test] | ||||||
|  |     fn test_cache_addresses() { | ||||||
|  |         let db = MemoryDatabase::new(); | ||||||
|  |         let wallet: OfflineWallet<_> = Wallet::new_offline("wpkh(tpubEBr4i6yk5nf5DAaJpsi9N2pPYBeJ7fZ5Z9rmN4977iYLCGco1VyjB9tvvuvYtfZzjD5A8igzgw3HeWeeKFmanHYqksqZXYXGsw5zjnj7KM9/*)", None, Network::Testnet, db).unwrap(); | ||||||
|  | 
 | ||||||
|  |         assert_eq!( | ||||||
|  |             wallet.get_new_address().unwrap().to_string(), | ||||||
|  |             "tb1q6yn66vajcctph75pvylgkksgpp6nq04ppwct9a" | ||||||
|  |         ); | ||||||
|  |         assert_eq!( | ||||||
|  |             wallet.get_new_address().unwrap().to_string(), | ||||||
|  |             "tb1q4er7kxx6sssz3q7qp7zsqsdx4erceahhax77d7" | ||||||
|  |         ); | ||||||
|  | 
 | ||||||
|  |         assert!(wallet | ||||||
|  |             .database | ||||||
|  |             .borrow_mut() | ||||||
|  |             .get_script_pubkey_from_path(ScriptType::External, CACHE_ADDR_BATCH_SIZE - 1) | ||||||
|  |             .unwrap() | ||||||
|  |             .is_some()); | ||||||
|  |         assert!(wallet | ||||||
|  |             .database | ||||||
|  |             .borrow_mut() | ||||||
|  |             .get_script_pubkey_from_path(ScriptType::External, CACHE_ADDR_BATCH_SIZE) | ||||||
|  |             .unwrap() | ||||||
|  |             .is_none()); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     #[test] | ||||||
|  |     fn test_cache_addresses_refill() { | ||||||
|  |         let db = MemoryDatabase::new(); | ||||||
|  |         let wallet: OfflineWallet<_> = Wallet::new_offline("wpkh(tpubEBr4i6yk5nf5DAaJpsi9N2pPYBeJ7fZ5Z9rmN4977iYLCGco1VyjB9tvvuvYtfZzjD5A8igzgw3HeWeeKFmanHYqksqZXYXGsw5zjnj7KM9/*)", None, Network::Testnet, db).unwrap(); | ||||||
|  | 
 | ||||||
|  |         assert_eq!( | ||||||
|  |             wallet.get_new_address().unwrap().to_string(), | ||||||
|  |             "tb1q6yn66vajcctph75pvylgkksgpp6nq04ppwct9a" | ||||||
|  |         ); | ||||||
|  |         assert!(wallet | ||||||
|  |             .database | ||||||
|  |             .borrow_mut() | ||||||
|  |             .get_script_pubkey_from_path(ScriptType::External, CACHE_ADDR_BATCH_SIZE - 1) | ||||||
|  |             .unwrap() | ||||||
|  |             .is_some()); | ||||||
|  | 
 | ||||||
|  |         for _ in 0..CACHE_ADDR_BATCH_SIZE { | ||||||
|  |             wallet.get_new_address().unwrap(); | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         assert!(wallet | ||||||
|  |             .database | ||||||
|  |             .borrow_mut() | ||||||
|  |             .get_script_pubkey_from_path(ScriptType::External, CACHE_ADDR_BATCH_SIZE * 2 - 1) | ||||||
|  |             .unwrap() | ||||||
|  |             .is_some()); | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user