From 9480faa5d3cefae40e8f9b54d9b31af93ee40ba0 Mon Sep 17 00:00:00 2001 From: codeShark149 Date: Thu, 1 Jul 2021 13:12:37 +0530 Subject: [PATCH] Add PeerError structure in peer module This adds a new PeerError structure in the peer module. To handle all the peer related errors. PeerErrors contains all the mempool errors too, for now. Later if we have a more complex mempool, we might decide to have its own dedicated error. PeerError is to be included in the global CompactFiltersError type. --- src/blockchain/compact_filters/peer.rs | 317 +++++++++++++++++-------- 1 file changed, 213 insertions(+), 104 deletions(-) diff --git a/src/blockchain/compact_filters/peer.rs b/src/blockchain/compact_filters/peer.rs index 683e25db..a4b7bd0f 100644 --- a/src/blockchain/compact_filters/peer.rs +++ b/src/blockchain/compact_filters/peer.rs @@ -10,11 +10,15 @@ // licenses. use std::collections::HashMap; -use std::net::{TcpStream, ToSocketAddrs}; +use std::fmt; +use std::net::{SocketAddr, TcpStream, ToSocketAddrs}; use std::sync::{Arc, Condvar, Mutex, RwLock}; use std::thread; use std::time::{Duration, SystemTime, UNIX_EPOCH}; +use std::sync::PoisonError; +use std::sync::{MutexGuard, RwLockReadGuard, RwLockWriteGuard, WaitTimeoutResult}; + use socks::{Socks5Stream, ToTargetAddr}; use rand::{thread_rng, Rng}; @@ -30,8 +34,6 @@ use bitcoin::network::stream_reader::StreamReader; use bitcoin::network::Address; use bitcoin::{Block, Network, Transaction, Txid, Wtxid}; -use super::CompactFiltersError; - type ResponsesMap = HashMap<&'static str, Arc<(Mutex>, Condvar)>>; pub(crate) const TIMEOUT_SECS: u64 = 30; @@ -65,17 +67,18 @@ impl Mempool { /// /// Note that this doesn't propagate the transaction to other /// peers. To do that, [`broadcast`](crate::blockchain::Blockchain::broadcast) should be used. - pub fn add_tx(&self, tx: Transaction) { - let mut guard = self.0.write().unwrap(); + pub fn add_tx(&self, tx: Transaction) -> Result<(), PeerError> { + let mut guard = self.0.write()?; guard.wtxids.insert(tx.wtxid(), tx.txid()); guard.txs.insert(tx.txid(), tx); + Ok(()) } /// Look-up a transaction in the mempool given an [`Inventory`] request - pub fn get_tx(&self, inventory: &Inventory) -> Option { + pub fn get_tx(&self, inventory: &Inventory) -> Result, PeerError> { let identifer = match inventory { - Inventory::Error | Inventory::Block(_) | Inventory::WitnessBlock(_) => return None, + Inventory::Error | Inventory::Block(_) | Inventory::WitnessBlock(_) => return Ok(None), Inventory::Transaction(txid) => TxIdentifier::Txid(*txid), Inventory::WitnessTransaction(txid) => TxIdentifier::Txid(*txid), Inventory::WTx(wtxid) => TxIdentifier::Wtxid(*wtxid), @@ -85,27 +88,34 @@ impl Mempool { inv_type, hash ); - return None; + return Ok(None); } }; let txid = match identifer { TxIdentifier::Txid(txid) => Some(txid), - TxIdentifier::Wtxid(wtxid) => self.0.read().unwrap().wtxids.get(&wtxid).cloned(), + TxIdentifier::Wtxid(wtxid) => self.0.read()?.wtxids.get(&wtxid).cloned(), }; - txid.map(|txid| self.0.read().unwrap().txs.get(&txid).cloned()) - .flatten() + let result = match txid { + Some(txid) => { + let read_lock = self.0.read()?; + read_lock.txs.get(&txid).cloned() + } + None => None, + }; + + Ok(result) } /// Return whether or not the mempool contains a transaction with a given txid - pub fn has_tx(&self, txid: &Txid) -> bool { - self.0.read().unwrap().txs.contains_key(txid) + pub fn has_tx(&self, txid: &Txid) -> Result { + Ok(self.0.read()?.txs.contains_key(txid)) } /// Return the list of transactions contained in the mempool - pub fn iter_txs(&self) -> Vec { - self.0.read().unwrap().txs.values().cloned().collect() + pub fn iter_txs(&self) -> Result, PeerError> { + Ok(self.0.read()?.txs.values().cloned().collect()) } } @@ -133,12 +143,31 @@ impl Peer { address: A, mempool: Arc, network: Network, - ) -> Result { + ) -> Result { let stream = TcpStream::connect(address)?; Peer::from_stream(stream, mempool, network) } + /// Connect to a peer over a plaintext TCP connection with a timeout + /// + /// This function behaves exactly the same as `connect` except for two differences + /// 1) It assumes your ToSocketAddrs will resolve to a single address + /// 2) It lets you specify a connection timeout + pub fn connect_with_timeout( + address: A, + timeout: Duration, + mempool: Arc, + network: Network, + ) -> Result { + let socket_addr = address + .to_socket_addrs()? + .next() + .ok_or(PeerError::AddresseResolution)?; + let stream = TcpStream::connect_timeout(&socket_addr, timeout)?; + Peer::from_stream(stream, mempool, network) + } + /// Connect to a peer through a SOCKS5 proxy, optionally by using some credentials, specified /// as a tuple of `(username, password)` /// @@ -150,7 +179,7 @@ impl Peer { credentials: Option<(&str, &str)>, mempool: Arc, network: Network, - ) -> Result { + ) -> Result { let socks_stream = if let Some((username, password)) = credentials { Socks5Stream::connect_with_password(proxy, target, username, password)? } else { @@ -165,12 +194,12 @@ impl Peer { stream: TcpStream, mempool: Arc, network: Network, - ) -> Result { + ) -> Result { let writer = Arc::new(Mutex::new(stream.try_clone()?)); let responses: Arc> = Arc::new(RwLock::new(HashMap::new())); let connected = Arc::new(RwLock::new(true)); - let mut locked_writer = writer.lock().unwrap(); + let mut locked_writer = writer.lock()?; let reader_thread_responses = Arc::clone(&responses); let reader_thread_writer = Arc::clone(&writer); @@ -185,6 +214,7 @@ impl Peer { reader_thread_mempool, reader_thread_connected, ) + .unwrap() }); let timestamp = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs() as i64; @@ -209,18 +239,20 @@ impl Peer { 0, )), )?; - let version = if let NetworkMessage::Version(version) = - Self::_recv(&responses, "version", None).unwrap() - { - version - } else { - return Err(CompactFiltersError::InvalidResponse); + + let version = match Self::_recv(&responses, "version", Some(Duration::from_secs(1)))? { + Some(NetworkMessage::Version(version)) => version, + _ => { + return Err(PeerError::InvalidResponse(locked_writer.peer_addr()?)); + } }; - if let NetworkMessage::Verack = Self::_recv(&responses, "verack", None).unwrap() { + if let Some(NetworkMessage::Verack) = + Self::_recv(&responses, "verack", Some(Duration::from_secs(1)))? + { Self::_send(&mut locked_writer, network.magic(), NetworkMessage::Verack)?; } else { - return Err(CompactFiltersError::InvalidResponse); + return Err(PeerError::InvalidResponse(locked_writer.peer_addr()?)); } std::mem::drop(locked_writer); @@ -236,19 +268,26 @@ impl Peer { }) } + /// Close the peer connection + // Consume Self + pub fn close(self) -> Result<(), PeerError> { + let locked_writer = self.writer.lock()?; + Ok((*locked_writer).shutdown(std::net::Shutdown::Both)?) + } + + /// Get the socket address of the remote peer + pub fn get_address(&self) -> Result { + let locked_writer = self.writer.lock()?; + Ok(locked_writer.peer_addr()?) + } + /// Send a Bitcoin network message - fn _send( - writer: &mut TcpStream, - magic: u32, - payload: NetworkMessage, - ) -> Result<(), CompactFiltersError> { + fn _send(writer: &mut TcpStream, magic: u32, payload: NetworkMessage) -> Result<(), PeerError> { log::trace!("==> {:?}", payload); let raw_message = RawNetworkMessage { magic, payload }; - raw_message - .consensus_encode(writer) - .map_err(|_| CompactFiltersError::DataCorruption)?; + raw_message.consensus_encode(writer)?; Ok(()) } @@ -258,30 +297,30 @@ impl Peer { responses: &Arc>, wait_for: &'static str, timeout: Option, - ) -> Option { + ) -> Result, PeerError> { let message_resp = { - let mut lock = responses.write().unwrap(); + let mut lock = responses.write()?; let message_resp = lock.entry(wait_for).or_default(); Arc::clone(&message_resp) }; let (lock, cvar) = &*message_resp; - let mut messages = lock.lock().unwrap(); + let mut messages = lock.lock()?; while messages.is_empty() { match timeout { - None => messages = cvar.wait(messages).unwrap(), + None => messages = cvar.wait(messages)?, Some(t) => { - let result = cvar.wait_timeout(messages, t).unwrap(); + let result = cvar.wait_timeout(messages, t)?; if result.1.timed_out() { - return None; + return Ok(None); } messages = result.0; } } } - messages.pop() + Ok(messages.pop()) } /// Return the [`VersionMessage`] sent by the peer @@ -300,8 +339,8 @@ impl Peer { } /// Return whether or not the peer is still connected - pub fn is_connected(&self) -> bool { - *self.connected.read().unwrap() + pub fn is_connected(&self) -> Result { + Ok(*self.connected.read()?) } /// Internal function called once the `reader_thread` is spawned @@ -312,14 +351,14 @@ impl Peer { reader_thread_writer: Arc>, reader_thread_mempool: Arc, reader_thread_connected: Arc>, - ) { + ) -> Result<(), PeerError> { macro_rules! check_disconnect { ($call:expr) => { match $call { Ok(good) => good, Err(e) => { log::debug!("Error {:?}", e); - *reader_thread_connected.write().unwrap() = false; + *reader_thread_connected.write()? = false; break; } @@ -328,7 +367,7 @@ impl Peer { } let mut reader = StreamReader::new(connection, None); - loop { + while *reader_thread_connected.read()? { let raw_message: RawNetworkMessage = check_disconnect!(reader.read_next()); let in_message = if raw_message.magic != network.magic() { @@ -342,7 +381,7 @@ impl Peer { match in_message { NetworkMessage::Ping(nonce) => { check_disconnect!(Self::_send( - &mut reader_thread_writer.lock().unwrap(), + &mut *reader_thread_writer.lock()?, network.magic(), NetworkMessage::Pong(nonce), )); @@ -353,19 +392,21 @@ impl Peer { NetworkMessage::GetData(ref inv) => { let (found, not_found): (Vec<_>, Vec<_>) = inv .iter() - .map(|item| (*item, reader_thread_mempool.get_tx(item))) + .map(|item| (*item, reader_thread_mempool.get_tx(item).unwrap())) .partition(|(_, d)| d.is_some()); for (_, found_tx) in found { check_disconnect!(Self::_send( - &mut reader_thread_writer.lock().unwrap(), + &mut *reader_thread_writer.lock()?, network.magic(), - NetworkMessage::Tx(found_tx.unwrap()), + NetworkMessage::Tx(found_tx.ok_or_else(|| PeerError::Generic( + "Got None while expecting Transaction".to_string() + ))?), )); } if !not_found.is_empty() { check_disconnect!(Self::_send( - &mut reader_thread_writer.lock().unwrap(), + &mut *reader_thread_writer.lock()?, network.magic(), NetworkMessage::NotFound( not_found.into_iter().map(|(i, _)| i).collect(), @@ -377,21 +418,23 @@ impl Peer { } let message_resp = { - let mut lock = reader_thread_responses.write().unwrap(); + let mut lock = reader_thread_responses.write()?; let message_resp = lock.entry(in_message.cmd()).or_default(); Arc::clone(&message_resp) }; let (lock, cvar) = &*message_resp; - let mut messages = lock.lock().unwrap(); + let mut messages = lock.lock()?; messages.push(in_message); cvar.notify_all(); } + + Ok(()) } /// Send a raw Bitcoin message to the peer - pub fn send(&self, payload: NetworkMessage) -> Result<(), CompactFiltersError> { - let mut writer = self.writer.lock().unwrap(); + pub fn send(&self, payload: NetworkMessage) -> Result<(), PeerError> { + let mut writer = self.writer.lock()?; Self::_send(&mut writer, self.network.magic(), payload) } @@ -400,30 +443,27 @@ impl Peer { &self, wait_for: &'static str, timeout: Option, - ) -> Result, CompactFiltersError> { - Ok(Self::_recv(&self.responses, wait_for, timeout)) + ) -> Result, PeerError> { + Self::_recv(&self.responses, wait_for, timeout) } } pub trait CompactFiltersPeer { - fn get_cf_checkpt( - &self, - filter_type: u8, - stop_hash: BlockHash, - ) -> Result; + fn get_cf_checkpt(&self, filter_type: u8, stop_hash: BlockHash) + -> Result; fn get_cf_headers( &self, filter_type: u8, start_height: u32, stop_hash: BlockHash, - ) -> Result; + ) -> Result; fn get_cf_filters( &self, filter_type: u8, start_height: u32, stop_hash: BlockHash, - ) -> Result<(), CompactFiltersError>; - fn pop_cf_filter_resp(&self) -> Result; + ) -> Result<(), PeerError>; + fn pop_cf_filter_resp(&self) -> Result; } impl CompactFiltersPeer for Peer { @@ -431,22 +471,20 @@ impl CompactFiltersPeer for Peer { &self, filter_type: u8, stop_hash: BlockHash, - ) -> Result { + ) -> Result { self.send(NetworkMessage::GetCFCheckpt(GetCFCheckpt { filter_type, stop_hash, }))?; - let response = self - .recv("cfcheckpt", Some(Duration::from_secs(TIMEOUT_SECS)))? - .ok_or(CompactFiltersError::Timeout)?; + let response = self.recv("cfcheckpt", Some(Duration::from_secs(TIMEOUT_SECS)))?; let response = match response { - NetworkMessage::CFCheckpt(response) => response, - _ => return Err(CompactFiltersError::InvalidResponse), + Some(NetworkMessage::CFCheckpt(response)) => response, + _ => return Err(PeerError::InvalidResponse(self.get_address()?)), }; if response.filter_type != filter_type { - return Err(CompactFiltersError::InvalidResponse); + return Err(PeerError::InvalidResponse(self.get_address()?)); } Ok(response) @@ -457,35 +495,31 @@ impl CompactFiltersPeer for Peer { filter_type: u8, start_height: u32, stop_hash: BlockHash, - ) -> Result { + ) -> Result { self.send(NetworkMessage::GetCFHeaders(GetCFHeaders { filter_type, start_height, stop_hash, }))?; - let response = self - .recv("cfheaders", Some(Duration::from_secs(TIMEOUT_SECS)))? - .ok_or(CompactFiltersError::Timeout)?; + let response = self.recv("cfheaders", Some(Duration::from_secs(TIMEOUT_SECS)))?; let response = match response { - NetworkMessage::CFHeaders(response) => response, - _ => return Err(CompactFiltersError::InvalidResponse), + Some(NetworkMessage::CFHeaders(response)) => response, + _ => return Err(PeerError::InvalidResponse(self.get_address()?)), }; if response.filter_type != filter_type { - return Err(CompactFiltersError::InvalidResponse); + return Err(PeerError::InvalidResponse(self.get_address()?)); } Ok(response) } - fn pop_cf_filter_resp(&self) -> Result { - let response = self - .recv("cfilter", Some(Duration::from_secs(TIMEOUT_SECS)))? - .ok_or(CompactFiltersError::Timeout)?; + fn pop_cf_filter_resp(&self) -> Result { + let response = self.recv("cfilter", Some(Duration::from_secs(TIMEOUT_SECS)))?; let response = match response { - NetworkMessage::CFilter(response) => response, - _ => return Err(CompactFiltersError::InvalidResponse), + Some(NetworkMessage::CFilter(response)) => response, + _ => return Err(PeerError::InvalidResponse(self.get_address()?)), }; Ok(response) @@ -496,7 +530,7 @@ impl CompactFiltersPeer for Peer { filter_type: u8, start_height: u32, stop_hash: BlockHash, - ) -> Result<(), CompactFiltersError> { + ) -> Result<(), PeerError> { self.send(NetworkMessage::GetCFilters(GetCFilters { filter_type, start_height, @@ -508,13 +542,13 @@ impl CompactFiltersPeer for Peer { } pub trait InvPeer { - fn get_block(&self, block_hash: BlockHash) -> Result, CompactFiltersError>; - fn ask_for_mempool(&self) -> Result<(), CompactFiltersError>; - fn broadcast_tx(&self, tx: Transaction) -> Result<(), CompactFiltersError>; + fn get_block(&self, block_hash: BlockHash) -> Result, PeerError>; + fn ask_for_mempool(&self) -> Result<(), PeerError>; + fn broadcast_tx(&self, tx: Transaction) -> Result<(), PeerError>; } impl InvPeer for Peer { - fn get_block(&self, block_hash: BlockHash) -> Result, CompactFiltersError> { + fn get_block(&self, block_hash: BlockHash) -> Result, PeerError> { self.send(NetworkMessage::GetData(vec![Inventory::WitnessBlock( block_hash, )]))?; @@ -522,51 +556,126 @@ impl InvPeer for Peer { match self.recv("block", Some(Duration::from_secs(TIMEOUT_SECS)))? { None => Ok(None), Some(NetworkMessage::Block(response)) => Ok(Some(response)), - _ => Err(CompactFiltersError::InvalidResponse), + _ => Err(PeerError::InvalidResponse(self.get_address()?)), } } - fn ask_for_mempool(&self) -> Result<(), CompactFiltersError> { + fn ask_for_mempool(&self) -> Result<(), PeerError> { if !self.version.services.has(ServiceFlags::BLOOM) { - return Err(CompactFiltersError::PeerBloomDisabled); + return Err(PeerError::PeerBloomDisabled(self.get_address()?)); } self.send(NetworkMessage::MemPool)?; let inv = match self.recv("inv", Some(Duration::from_secs(5)))? { None => return Ok(()), // empty mempool Some(NetworkMessage::Inv(inv)) => inv, - _ => return Err(CompactFiltersError::InvalidResponse), + _ => return Err(PeerError::InvalidResponse(self.get_address()?)), }; let getdata = inv .iter() .cloned() .filter( - |item| matches!(item, Inventory::Transaction(txid) if !self.mempool.has_tx(txid)), + |item| matches!(item, Inventory::Transaction(txid) if !self.mempool.has_tx(txid).unwrap()), ) .collect::>(); let num_txs = getdata.len(); self.send(NetworkMessage::GetData(getdata))?; for _ in 0..num_txs { - let tx = self - .recv("tx", Some(Duration::from_secs(TIMEOUT_SECS)))? - .ok_or(CompactFiltersError::Timeout)?; + let tx = self.recv("tx", Some(Duration::from_secs(TIMEOUT_SECS)))?; let tx = match tx { - NetworkMessage::Tx(tx) => tx, - _ => return Err(CompactFiltersError::InvalidResponse), + Some(NetworkMessage::Tx(tx)) => tx, + _ => return Err(PeerError::InvalidResponse(self.get_address()?)), }; - self.mempool.add_tx(tx); + self.mempool.add_tx(tx)?; } Ok(()) } - fn broadcast_tx(&self, tx: Transaction) -> Result<(), CompactFiltersError> { - self.mempool.add_tx(tx.clone()); + fn broadcast_tx(&self, tx: Transaction) -> Result<(), PeerError> { + self.mempool.add_tx(tx.clone())?; self.send(NetworkMessage::Tx(tx))?; Ok(()) } } + +/// Peer Errors +#[derive(Debug)] +pub enum PeerError { + /// Internal I/O error + Io(std::io::Error), + + /// Internal system time error + Time(std::time::SystemTimeError), + + /// A peer sent an invalid or unexpected response + InvalidResponse(SocketAddr), + + /// Peer had bloom filter disabled + PeerBloomDisabled(SocketAddr), + + /// Internal Mutex poisoning error + MutexPoisoned, + + /// Internal Mutex wait timed out + MutexTimedout, + + /// Internal RW read lock poisoned + RwReadLockPoisined, + + /// Internal RW write lock poisoned + RwWriteLockPoisoned, + + /// Mempool Mutex poisoned + MempoolPoisoned, + + /// Network address resolution Error + AddresseResolution, + + /// Generic Errors + Generic(String), +} + +impl std::fmt::Display for PeerError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{:?}", self) + } +} + +impl std::error::Error for PeerError {} + +impl_error!(std::io::Error, Io, PeerError); +impl_error!(std::time::SystemTimeError, Time, PeerError); + +impl From>> for PeerError { + fn from(_: PoisonError>) -> Self { + PeerError::MutexPoisoned + } +} + +impl From>> for PeerError { + fn from(_: PoisonError>) -> Self { + PeerError::RwWriteLockPoisoned + } +} + +impl From>> for PeerError { + fn from(_: PoisonError>) -> Self { + PeerError::RwReadLockPoisined + } +} + +impl From, WaitTimeoutResult)>> for PeerError { + fn from(err: PoisonError<(MutexGuard<'_, T>, WaitTimeoutResult)>) -> Self { + let (_, wait_result) = err.into_inner(); + if wait_result.timed_out() { + PeerError::MutexTimedout + } else { + PeerError::MutexPoisoned + } + } +}