diff --git a/src/database/memory.rs b/src/database/memory.rs index 495d9393..29bc0df1 100644 --- a/src/database/memory.rs +++ b/src/database/memory.rs @@ -86,12 +86,14 @@ fn after(key: &Vec) -> Vec { #[derive(Debug)] pub struct MemoryDatabase { map: BTreeMap, Box>, + deleted_keys: Vec>, } impl MemoryDatabase { pub fn new() -> Self { MemoryDatabase { map: BTreeMap::new(), + deleted_keys: Vec::new(), } } } @@ -160,6 +162,7 @@ impl BatchOperations for MemoryDatabase { let deriv_path = DerivationPath::from(path.as_ref()); let key = MapKey::Path((Some(script_type), Some(&deriv_path))).as_map_key(); let res = self.map.remove(&key); + self.deleted_keys.push(key); Ok(res.map(|x| x.downcast_ref().cloned().unwrap())) } @@ -169,6 +172,7 @@ impl BatchOperations for MemoryDatabase { ) -> Result, Error> { let key = MapKey::Script(Some(script)).as_map_key(); let res = self.map.remove(&key); + self.deleted_keys.push(key); match res { None => Ok(None), @@ -184,6 +188,7 @@ impl BatchOperations for MemoryDatabase { fn del_utxo(&mut self, outpoint: &OutPoint) -> Result, Error> { let key = MapKey::UTXO(Some(outpoint)).as_map_key(); let res = self.map.remove(&key); + self.deleted_keys.push(key); match res { None => Ok(None), @@ -199,6 +204,7 @@ impl BatchOperations for MemoryDatabase { fn del_raw_tx(&mut self, txid: &Txid) -> Result, Error> { let key = MapKey::RawTx(Some(txid)).as_map_key(); let res = self.map.remove(&key); + self.deleted_keys.push(key); Ok(res.map(|x| x.downcast_ref().cloned().unwrap())) } @@ -215,6 +221,7 @@ impl BatchOperations for MemoryDatabase { let key = MapKey::Transaction(Some(txid)).as_map_key(); let res = self.map.remove(&key); + self.deleted_keys.push(key); match res { None => Ok(None), @@ -229,6 +236,7 @@ impl BatchOperations for MemoryDatabase { fn del_last_index(&mut self, script_type: ScriptType) -> Result, Error> { let key = MapKey::LastIndex(script_type).as_map_key(); let res = self.map.remove(&key); + self.deleted_keys.push(key); match res { None => Ok(None), @@ -391,6 +399,10 @@ impl BatchDatabase for MemoryDatabase { } fn commit_batch(&mut self, mut batch: Self::Batch) -> Result<(), Error> { + for key in batch.deleted_keys { + self.map.remove(&key); + } + Ok(self.map.append(&mut batch.map)) } } @@ -503,6 +515,30 @@ mod test { assert_eq!(tree.iter_script_pubkeys(None).unwrap().len(), 0); } + #[test] + fn test_del_script_pubkey_batch() { + let mut tree = get_tree(); + + let script = Script::from( + Vec::::from_hex("76a91402306a7c23f3e8010de41e9e591348bb83f11daa88ac").unwrap(), + ); + let path = DerivationPath::from_str("m/0/1/2/3").unwrap(); + let script_type = ScriptType::External; + + tree.set_script_pubkey(&script, script_type, &path).unwrap(); + assert_eq!(tree.iter_script_pubkeys(None).unwrap().len(), 1); + + let mut batch = tree.begin_batch(); + batch + .del_script_pubkey_from_path(script_type, &path) + .unwrap(); + + assert_eq!(tree.iter_script_pubkeys(None).unwrap().len(), 1); + + tree.commit_batch(batch); + assert_eq!(tree.iter_script_pubkeys(None).unwrap().len(), 0); + } + #[test] fn test_utxo() { let mut tree = get_tree();