diff --git a/crates/bdk/src/wallet/mod.rs b/crates/bdk/src/wallet/mod.rs index ea76ad65..c9a1c28c 100644 --- a/crates/bdk/src/wallet/mod.rs +++ b/crates/bdk/src/wallet/mod.rs @@ -293,6 +293,8 @@ pub enum LoadError { Descriptor(crate::descriptor::DescriptorError), /// Loading data from the persistence backend failed. Load(L), + /// Wallet not initialized, persistence backend is empty. + NotInitialized, /// Data loaded from persistence is missing network type. MissingNetwork, /// Data loaded from persistence is missing genesis hash. @@ -307,6 +309,9 @@ where match self { LoadError::Descriptor(e) => e.fmt(f), LoadError::Load(e) => e.fmt(f), + LoadError::NotInitialized => { + write!(f, "wallet is not initialized, persistence backend is empty") + } LoadError::MissingNetwork => write!(f, "loaded data is missing network type"), LoadError::MissingGenesis => write!(f, "loaded data is missing genesis hash"), } @@ -330,6 +335,8 @@ pub enum NewOrLoadError { Write(W), /// Loading from the persistence backend failed. Load(L), + /// Wallet is not initialized, persistence backend is empty. + NotInitialized, /// The loaded genesis hash does not match what was provided. LoadedGenesisDoesNotMatch { /// The expected genesis block hash. @@ -356,6 +363,9 @@ where NewOrLoadError::Descriptor(e) => e.fmt(f), NewOrLoadError::Write(e) => write!(f, "failed to write to persistence: {}", e), NewOrLoadError::Load(e) => write!(f, "failed to load from persistence: {}", e), + NewOrLoadError::NotInitialized => { + write!(f, "wallet is not initialized, persistence backend is empty") + } NewOrLoadError::LoadedGenesisDoesNotMatch { expected, got } => { write!(f, "loaded genesis hash is not {}, got {:?}", expected, got) } @@ -451,11 +461,26 @@ impl Wallet { change_descriptor: Option, mut db: D, ) -> Result> + where + D: PersistBackend, + { + let changeset = db + .load_from_persistence() + .map_err(LoadError::Load)? + .ok_or(LoadError::NotInitialized)?; + Self::load_from_changeset(descriptor, change_descriptor, db, changeset) + } + + fn load_from_changeset( + descriptor: E, + change_descriptor: Option, + db: D, + changeset: ChangeSet, + ) -> Result> where D: PersistBackend, { let secp = Secp256k1::new(); - let changeset = db.load_from_persistence().map_err(LoadError::Load)?; let network = changeset.network.ok_or(LoadError::MissingNetwork)?; let chain = LocalChain::from_changeset(changeset.chain).map_err(|_| LoadError::MissingGenesis)?; @@ -517,8 +542,43 @@ impl Wallet { where D: PersistBackend, { - if db.is_empty().map_err(NewOrLoadError::Load)? { - return Self::new_with_genesis_hash( + let changeset = db.load_from_persistence().map_err(NewOrLoadError::Load)?; + match changeset { + Some(changeset) => { + let wallet = + Self::load_from_changeset(descriptor, change_descriptor, db, changeset) + .map_err(|e| match e { + LoadError::Descriptor(e) => NewOrLoadError::Descriptor(e), + LoadError::Load(e) => NewOrLoadError::Load(e), + LoadError::NotInitialized => NewOrLoadError::NotInitialized, + LoadError::MissingNetwork => { + NewOrLoadError::LoadedNetworkDoesNotMatch { + expected: network, + got: None, + } + } + LoadError::MissingGenesis => { + NewOrLoadError::LoadedGenesisDoesNotMatch { + expected: genesis_hash, + got: None, + } + } + })?; + if wallet.network != network { + return Err(NewOrLoadError::LoadedNetworkDoesNotMatch { + expected: network, + got: Some(wallet.network), + }); + } + if wallet.chain.genesis_hash() != genesis_hash { + return Err(NewOrLoadError::LoadedGenesisDoesNotMatch { + expected: genesis_hash, + got: Some(wallet.chain.genesis_hash()), + }); + } + Ok(wallet) + } + None => Self::new_with_genesis_hash( descriptor, change_descriptor, db, @@ -528,34 +588,8 @@ impl Wallet { .map_err(|e| match e { NewError::Descriptor(e) => NewOrLoadError::Descriptor(e), NewError::Write(e) => NewOrLoadError::Write(e), - }); + }), } - - let wallet = Self::load(descriptor, change_descriptor, db).map_err(|e| match e { - LoadError::Descriptor(e) => NewOrLoadError::Descriptor(e), - LoadError::Load(e) => NewOrLoadError::Load(e), - LoadError::MissingNetwork => NewOrLoadError::LoadedNetworkDoesNotMatch { - expected: network, - got: None, - }, - LoadError::MissingGenesis => NewOrLoadError::LoadedGenesisDoesNotMatch { - expected: genesis_hash, - got: None, - }, - })?; - if wallet.network != network { - return Err(NewOrLoadError::LoadedNetworkDoesNotMatch { - expected: network, - got: Some(wallet.network), - }); - } - if wallet.chain.genesis_hash() != genesis_hash { - return Err(NewOrLoadError::LoadedGenesisDoesNotMatch { - expected: genesis_hash, - got: Some(wallet.chain.genesis_hash()), - }); - } - Ok(wallet) } /// Get the Bitcoin network the wallet is using. diff --git a/crates/chain/src/persist.rs b/crates/chain/src/persist.rs index 634e369e..3c8c8b9e 100644 --- a/crates/chain/src/persist.rs +++ b/crates/chain/src/persist.rs @@ -79,19 +79,10 @@ pub trait PersistBackend { fn write_changes(&mut self, changeset: &C) -> Result<(), Self::WriteError>; /// Return the aggregate changeset `C` from persistence. - fn load_from_persistence(&mut self) -> Result; - - /// Returns whether the persistence backend contains no data. - fn is_empty(&mut self) -> Result - where - C: Append, - { - self.load_from_persistence() - .map(|changeset| changeset.is_empty()) - } + fn load_from_persistence(&mut self) -> Result, Self::LoadError>; } -impl PersistBackend for () { +impl PersistBackend for () { type WriteError = Infallible; type LoadError = Infallible; @@ -100,11 +91,7 @@ impl PersistBackend for () { Ok(()) } - fn load_from_persistence(&mut self) -> Result { - Ok(C::default()) - } - - fn is_empty(&mut self) -> Result { - Ok(true) + fn load_from_persistence(&mut self) -> Result, Self::LoadError> { + Ok(None) } } diff --git a/crates/file_store/src/store.rs b/crates/file_store/src/store.rs index 8af10cbd..bf88b8d3 100644 --- a/crates/file_store/src/store.rs +++ b/crates/file_store/src/store.rs @@ -23,7 +23,7 @@ pub struct Store<'a, C> { impl<'a, C> PersistBackend for Store<'a, C> where - C: Default + Append + serde::Serialize + serde::de::DeserializeOwned, + C: Append + serde::Serialize + serde::de::DeserializeOwned, { type WriteError = std::io::Error; @@ -33,23 +33,14 @@ where self.append_changeset(changeset) } - fn load_from_persistence(&mut self) -> Result { - let (changeset, result) = self.aggregate_changesets(); - result.map(|_| changeset) - } - - fn is_empty(&mut self) -> Result { - let init_pos = self.db_file.stream_position()?; - let stream_len = self.db_file.seek(io::SeekFrom::End(0))?; - let magic_len = self.magic.len() as u64; - self.db_file.seek(io::SeekFrom::Start(init_pos))?; - Ok(stream_len == magic_len) + fn load_from_persistence(&mut self) -> Result, Self::LoadError> { + self.aggregate_changesets().map_err(|e| e.iter_error) } } impl<'a, C> Store<'a, C> where - C: Default + Append + serde::Serialize + serde::de::DeserializeOwned, + C: Append + serde::Serialize + serde::de::DeserializeOwned, { /// Create a new [`Store`] file in write-only mode; error if the file exists. /// @@ -156,16 +147,24 @@ where /// /// **WARNING**: This method changes the write position of the underlying file. The next /// changeset will be written over the erroring entry (or the end of the file if none existed). - pub fn aggregate_changesets(&mut self) -> (C, Result<(), IterError>) { - let mut changeset = C::default(); - let result = (|| { - for next_changeset in self.iter_changesets() { - changeset.append(next_changeset?); + pub fn aggregate_changesets(&mut self) -> Result, AggregateChangesetsError> { + let mut changeset = Option::::None; + for next_changeset in self.iter_changesets() { + let next_changeset = match next_changeset { + Ok(next_changeset) => next_changeset, + Err(iter_error) => { + return Err(AggregateChangesetsError { + changeset, + iter_error, + }) + } + }; + match &mut changeset { + Some(changeset) => changeset.append(next_changeset), + changeset => *changeset = Some(next_changeset), } - Ok(()) - })(); - - (changeset, result) + } + Ok(changeset) } /// Append a new changeset to the file and truncate the file to the end of the appended @@ -196,6 +195,24 @@ where } } +/// Error type for [`Store::aggregate_changesets`]. +#[derive(Debug)] +pub struct AggregateChangesetsError { + /// The partially-aggregated changeset. + pub changeset: Option, + + /// The error returned by [`EntryIter`]. + pub iter_error: IterError, +} + +impl std::fmt::Display for AggregateChangesetsError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + std::fmt::Display::fmt(&self.iter_error, f) + } +} + +impl std::error::Error for AggregateChangesetsError {} + #[cfg(test)] mod test { use super::*; @@ -248,25 +265,11 @@ mod test { { let mut db = Store::::open_or_create_new(&TEST_MAGIC_BYTES, &file_path) .expect("must recover"); - let (recovered_changeset, r) = db.aggregate_changesets(); - r.expect("must succeed"); - assert_eq!(recovered_changeset, changeset); + let recovered_changeset = db.aggregate_changesets().expect("must succeed"); + assert_eq!(recovered_changeset, Some(changeset)); } } - #[test] - fn is_empty() { - let mut file = NamedTempFile::new().unwrap(); - file.write_all(&TEST_MAGIC_BYTES).expect("should write"); - - let mut db = - Store::::open(&TEST_MAGIC_BYTES, file.path()).expect("must open"); - assert!(db.is_empty().expect("must read")); - db.write_changes(&vec!["hello".to_string(), "world".to_string()]) - .expect("must write"); - assert!(!db.is_empty().expect("must read")); - } - #[test] fn new_fails_if_file_is_too_short() { let mut file = NamedTempFile::new().unwrap(); diff --git a/example-crates/example_cli/src/lib.rs b/example-crates/example_cli/src/lib.rs index 0b5d9cd3..f9574c0e 100644 --- a/example-crates/example_cli/src/lib.rs +++ b/example-crates/example_cli/src/lib.rs @@ -687,7 +687,7 @@ where Err(err) => return Err(anyhow::anyhow!("failed to init db backend: {:?}", err)), }; - let init_changeset = db_backend.load_from_persistence()?; + let init_changeset = db_backend.load_from_persistence()?.unwrap_or_default(); Ok(( args,