Add PeerError structure in peer module

This adds a new PeerError structure in the peer module. To handle all
the peer related errors. PeerErrors contains all the mempool errors too,
for now. Later if we have a more complex mempool, we might decide to
have its own dedicated error.

PeerError is to be included in the global CompactFiltersError type.
This commit is contained in:
codeShark149 2021-07-01 13:12:37 +05:30 committed by rajarshimaitra
parent 474620e6a5
commit 9480faa5d3
No known key found for this signature in database
GPG Key ID: 558ACE7DBB4377C8

View File

@ -10,11 +10,15 @@
// licenses. // licenses.
use std::collections::HashMap; use std::collections::HashMap;
use std::net::{TcpStream, ToSocketAddrs}; use std::fmt;
use std::net::{SocketAddr, TcpStream, ToSocketAddrs};
use std::sync::{Arc, Condvar, Mutex, RwLock}; use std::sync::{Arc, Condvar, Mutex, RwLock};
use std::thread; use std::thread;
use std::time::{Duration, SystemTime, UNIX_EPOCH}; use std::time::{Duration, SystemTime, UNIX_EPOCH};
use std::sync::PoisonError;
use std::sync::{MutexGuard, RwLockReadGuard, RwLockWriteGuard, WaitTimeoutResult};
use socks::{Socks5Stream, ToTargetAddr}; use socks::{Socks5Stream, ToTargetAddr};
use rand::{thread_rng, Rng}; use rand::{thread_rng, Rng};
@ -30,8 +34,6 @@ use bitcoin::network::stream_reader::StreamReader;
use bitcoin::network::Address; use bitcoin::network::Address;
use bitcoin::{Block, Network, Transaction, Txid, Wtxid}; use bitcoin::{Block, Network, Transaction, Txid, Wtxid};
use super::CompactFiltersError;
type ResponsesMap = HashMap<&'static str, Arc<(Mutex<Vec<NetworkMessage>>, Condvar)>>; type ResponsesMap = HashMap<&'static str, Arc<(Mutex<Vec<NetworkMessage>>, Condvar)>>;
pub(crate) const TIMEOUT_SECS: u64 = 30; pub(crate) const TIMEOUT_SECS: u64 = 30;
@ -65,17 +67,18 @@ impl Mempool {
/// ///
/// Note that this doesn't propagate the transaction to other /// Note that this doesn't propagate the transaction to other
/// peers. To do that, [`broadcast`](crate::blockchain::Blockchain::broadcast) should be used. /// peers. To do that, [`broadcast`](crate::blockchain::Blockchain::broadcast) should be used.
pub fn add_tx(&self, tx: Transaction) { pub fn add_tx(&self, tx: Transaction) -> Result<(), PeerError> {
let mut guard = self.0.write().unwrap(); let mut guard = self.0.write()?;
guard.wtxids.insert(tx.wtxid(), tx.txid()); guard.wtxids.insert(tx.wtxid(), tx.txid());
guard.txs.insert(tx.txid(), tx); guard.txs.insert(tx.txid(), tx);
Ok(())
} }
/// Look-up a transaction in the mempool given an [`Inventory`] request /// Look-up a transaction in the mempool given an [`Inventory`] request
pub fn get_tx(&self, inventory: &Inventory) -> Option<Transaction> { pub fn get_tx(&self, inventory: &Inventory) -> Result<Option<Transaction>, PeerError> {
let identifer = match inventory { let identifer = match inventory {
Inventory::Error | Inventory::Block(_) | Inventory::WitnessBlock(_) => return None, Inventory::Error | Inventory::Block(_) | Inventory::WitnessBlock(_) => return Ok(None),
Inventory::Transaction(txid) => TxIdentifier::Txid(*txid), Inventory::Transaction(txid) => TxIdentifier::Txid(*txid),
Inventory::WitnessTransaction(txid) => TxIdentifier::Txid(*txid), Inventory::WitnessTransaction(txid) => TxIdentifier::Txid(*txid),
Inventory::WTx(wtxid) => TxIdentifier::Wtxid(*wtxid), Inventory::WTx(wtxid) => TxIdentifier::Wtxid(*wtxid),
@ -85,27 +88,34 @@ impl Mempool {
inv_type, inv_type,
hash hash
); );
return None; return Ok(None);
} }
}; };
let txid = match identifer { let txid = match identifer {
TxIdentifier::Txid(txid) => Some(txid), TxIdentifier::Txid(txid) => Some(txid),
TxIdentifier::Wtxid(wtxid) => self.0.read().unwrap().wtxids.get(&wtxid).cloned(), TxIdentifier::Wtxid(wtxid) => self.0.read()?.wtxids.get(&wtxid).cloned(),
}; };
txid.map(|txid| self.0.read().unwrap().txs.get(&txid).cloned()) let result = match txid {
.flatten() Some(txid) => {
let read_lock = self.0.read()?;
read_lock.txs.get(&txid).cloned()
}
None => None,
};
Ok(result)
} }
/// Return whether or not the mempool contains a transaction with a given txid /// Return whether or not the mempool contains a transaction with a given txid
pub fn has_tx(&self, txid: &Txid) -> bool { pub fn has_tx(&self, txid: &Txid) -> Result<bool, PeerError> {
self.0.read().unwrap().txs.contains_key(txid) Ok(self.0.read()?.txs.contains_key(txid))
} }
/// Return the list of transactions contained in the mempool /// Return the list of transactions contained in the mempool
pub fn iter_txs(&self) -> Vec<Transaction> { pub fn iter_txs(&self) -> Result<Vec<Transaction>, PeerError> {
self.0.read().unwrap().txs.values().cloned().collect() Ok(self.0.read()?.txs.values().cloned().collect())
} }
} }
@ -133,12 +143,31 @@ impl Peer {
address: A, address: A,
mempool: Arc<Mempool>, mempool: Arc<Mempool>,
network: Network, network: Network,
) -> Result<Self, CompactFiltersError> { ) -> Result<Self, PeerError> {
let stream = TcpStream::connect(address)?; let stream = TcpStream::connect(address)?;
Peer::from_stream(stream, mempool, network) Peer::from_stream(stream, mempool, network)
} }
/// Connect to a peer over a plaintext TCP connection with a timeout
///
/// This function behaves exactly the same as `connect` except for two differences
/// 1) It assumes your ToSocketAddrs will resolve to a single address
/// 2) It lets you specify a connection timeout
pub fn connect_with_timeout<A: ToSocketAddrs>(
address: A,
timeout: Duration,
mempool: Arc<Mempool>,
network: Network,
) -> Result<Self, PeerError> {
let socket_addr = address
.to_socket_addrs()?
.next()
.ok_or(PeerError::AddresseResolution)?;
let stream = TcpStream::connect_timeout(&socket_addr, timeout)?;
Peer::from_stream(stream, mempool, network)
}
/// Connect to a peer through a SOCKS5 proxy, optionally by using some credentials, specified /// Connect to a peer through a SOCKS5 proxy, optionally by using some credentials, specified
/// as a tuple of `(username, password)` /// as a tuple of `(username, password)`
/// ///
@ -150,7 +179,7 @@ impl Peer {
credentials: Option<(&str, &str)>, credentials: Option<(&str, &str)>,
mempool: Arc<Mempool>, mempool: Arc<Mempool>,
network: Network, network: Network,
) -> Result<Self, CompactFiltersError> { ) -> Result<Self, PeerError> {
let socks_stream = if let Some((username, password)) = credentials { let socks_stream = if let Some((username, password)) = credentials {
Socks5Stream::connect_with_password(proxy, target, username, password)? Socks5Stream::connect_with_password(proxy, target, username, password)?
} else { } else {
@ -165,12 +194,12 @@ impl Peer {
stream: TcpStream, stream: TcpStream,
mempool: Arc<Mempool>, mempool: Arc<Mempool>,
network: Network, network: Network,
) -> Result<Self, CompactFiltersError> { ) -> Result<Self, PeerError> {
let writer = Arc::new(Mutex::new(stream.try_clone()?)); let writer = Arc::new(Mutex::new(stream.try_clone()?));
let responses: Arc<RwLock<ResponsesMap>> = Arc::new(RwLock::new(HashMap::new())); let responses: Arc<RwLock<ResponsesMap>> = Arc::new(RwLock::new(HashMap::new()));
let connected = Arc::new(RwLock::new(true)); let connected = Arc::new(RwLock::new(true));
let mut locked_writer = writer.lock().unwrap(); let mut locked_writer = writer.lock()?;
let reader_thread_responses = Arc::clone(&responses); let reader_thread_responses = Arc::clone(&responses);
let reader_thread_writer = Arc::clone(&writer); let reader_thread_writer = Arc::clone(&writer);
@ -185,6 +214,7 @@ impl Peer {
reader_thread_mempool, reader_thread_mempool,
reader_thread_connected, reader_thread_connected,
) )
.unwrap()
}); });
let timestamp = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs() as i64; let timestamp = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs() as i64;
@ -209,18 +239,20 @@ impl Peer {
0, 0,
)), )),
)?; )?;
let version = if let NetworkMessage::Version(version) =
Self::_recv(&responses, "version", None).unwrap() let version = match Self::_recv(&responses, "version", Some(Duration::from_secs(1)))? {
{ Some(NetworkMessage::Version(version)) => version,
version _ => {
} else { return Err(PeerError::InvalidResponse(locked_writer.peer_addr()?));
return Err(CompactFiltersError::InvalidResponse); }
}; };
if let NetworkMessage::Verack = Self::_recv(&responses, "verack", None).unwrap() { if let Some(NetworkMessage::Verack) =
Self::_recv(&responses, "verack", Some(Duration::from_secs(1)))?
{
Self::_send(&mut locked_writer, network.magic(), NetworkMessage::Verack)?; Self::_send(&mut locked_writer, network.magic(), NetworkMessage::Verack)?;
} else { } else {
return Err(CompactFiltersError::InvalidResponse); return Err(PeerError::InvalidResponse(locked_writer.peer_addr()?));
} }
std::mem::drop(locked_writer); std::mem::drop(locked_writer);
@ -236,19 +268,26 @@ impl Peer {
}) })
} }
/// Close the peer connection
// Consume Self
pub fn close(self) -> Result<(), PeerError> {
let locked_writer = self.writer.lock()?;
Ok((*locked_writer).shutdown(std::net::Shutdown::Both)?)
}
/// Get the socket address of the remote peer
pub fn get_address(&self) -> Result<SocketAddr, PeerError> {
let locked_writer = self.writer.lock()?;
Ok(locked_writer.peer_addr()?)
}
/// Send a Bitcoin network message /// Send a Bitcoin network message
fn _send( fn _send(writer: &mut TcpStream, magic: u32, payload: NetworkMessage) -> Result<(), PeerError> {
writer: &mut TcpStream,
magic: u32,
payload: NetworkMessage,
) -> Result<(), CompactFiltersError> {
log::trace!("==> {:?}", payload); log::trace!("==> {:?}", payload);
let raw_message = RawNetworkMessage { magic, payload }; let raw_message = RawNetworkMessage { magic, payload };
raw_message raw_message.consensus_encode(writer)?;
.consensus_encode(writer)
.map_err(|_| CompactFiltersError::DataCorruption)?;
Ok(()) Ok(())
} }
@ -258,30 +297,30 @@ impl Peer {
responses: &Arc<RwLock<ResponsesMap>>, responses: &Arc<RwLock<ResponsesMap>>,
wait_for: &'static str, wait_for: &'static str,
timeout: Option<Duration>, timeout: Option<Duration>,
) -> Option<NetworkMessage> { ) -> Result<Option<NetworkMessage>, PeerError> {
let message_resp = { let message_resp = {
let mut lock = responses.write().unwrap(); let mut lock = responses.write()?;
let message_resp = lock.entry(wait_for).or_default(); let message_resp = lock.entry(wait_for).or_default();
Arc::clone(&message_resp) Arc::clone(&message_resp)
}; };
let (lock, cvar) = &*message_resp; let (lock, cvar) = &*message_resp;
let mut messages = lock.lock().unwrap(); let mut messages = lock.lock()?;
while messages.is_empty() { while messages.is_empty() {
match timeout { match timeout {
None => messages = cvar.wait(messages).unwrap(), None => messages = cvar.wait(messages)?,
Some(t) => { Some(t) => {
let result = cvar.wait_timeout(messages, t).unwrap(); let result = cvar.wait_timeout(messages, t)?;
if result.1.timed_out() { if result.1.timed_out() {
return None; return Ok(None);
} }
messages = result.0; messages = result.0;
} }
} }
} }
messages.pop() Ok(messages.pop())
} }
/// Return the [`VersionMessage`] sent by the peer /// Return the [`VersionMessage`] sent by the peer
@ -300,8 +339,8 @@ impl Peer {
} }
/// Return whether or not the peer is still connected /// Return whether or not the peer is still connected
pub fn is_connected(&self) -> bool { pub fn is_connected(&self) -> Result<bool, PeerError> {
*self.connected.read().unwrap() Ok(*self.connected.read()?)
} }
/// Internal function called once the `reader_thread` is spawned /// Internal function called once the `reader_thread` is spawned
@ -312,14 +351,14 @@ impl Peer {
reader_thread_writer: Arc<Mutex<TcpStream>>, reader_thread_writer: Arc<Mutex<TcpStream>>,
reader_thread_mempool: Arc<Mempool>, reader_thread_mempool: Arc<Mempool>,
reader_thread_connected: Arc<RwLock<bool>>, reader_thread_connected: Arc<RwLock<bool>>,
) { ) -> Result<(), PeerError> {
macro_rules! check_disconnect { macro_rules! check_disconnect {
($call:expr) => { ($call:expr) => {
match $call { match $call {
Ok(good) => good, Ok(good) => good,
Err(e) => { Err(e) => {
log::debug!("Error {:?}", e); log::debug!("Error {:?}", e);
*reader_thread_connected.write().unwrap() = false; *reader_thread_connected.write()? = false;
break; break;
} }
@ -328,7 +367,7 @@ impl Peer {
} }
let mut reader = StreamReader::new(connection, None); let mut reader = StreamReader::new(connection, None);
loop { while *reader_thread_connected.read()? {
let raw_message: RawNetworkMessage = check_disconnect!(reader.read_next()); let raw_message: RawNetworkMessage = check_disconnect!(reader.read_next());
let in_message = if raw_message.magic != network.magic() { let in_message = if raw_message.magic != network.magic() {
@ -342,7 +381,7 @@ impl Peer {
match in_message { match in_message {
NetworkMessage::Ping(nonce) => { NetworkMessage::Ping(nonce) => {
check_disconnect!(Self::_send( check_disconnect!(Self::_send(
&mut reader_thread_writer.lock().unwrap(), &mut *reader_thread_writer.lock()?,
network.magic(), network.magic(),
NetworkMessage::Pong(nonce), NetworkMessage::Pong(nonce),
)); ));
@ -353,19 +392,21 @@ impl Peer {
NetworkMessage::GetData(ref inv) => { NetworkMessage::GetData(ref inv) => {
let (found, not_found): (Vec<_>, Vec<_>) = inv let (found, not_found): (Vec<_>, Vec<_>) = inv
.iter() .iter()
.map(|item| (*item, reader_thread_mempool.get_tx(item))) .map(|item| (*item, reader_thread_mempool.get_tx(item).unwrap()))
.partition(|(_, d)| d.is_some()); .partition(|(_, d)| d.is_some());
for (_, found_tx) in found { for (_, found_tx) in found {
check_disconnect!(Self::_send( check_disconnect!(Self::_send(
&mut reader_thread_writer.lock().unwrap(), &mut *reader_thread_writer.lock()?,
network.magic(), network.magic(),
NetworkMessage::Tx(found_tx.unwrap()), NetworkMessage::Tx(found_tx.ok_or_else(|| PeerError::Generic(
"Got None while expecting Transaction".to_string()
))?),
)); ));
} }
if !not_found.is_empty() { if !not_found.is_empty() {
check_disconnect!(Self::_send( check_disconnect!(Self::_send(
&mut reader_thread_writer.lock().unwrap(), &mut *reader_thread_writer.lock()?,
network.magic(), network.magic(),
NetworkMessage::NotFound( NetworkMessage::NotFound(
not_found.into_iter().map(|(i, _)| i).collect(), not_found.into_iter().map(|(i, _)| i).collect(),
@ -377,21 +418,23 @@ impl Peer {
} }
let message_resp = { let message_resp = {
let mut lock = reader_thread_responses.write().unwrap(); let mut lock = reader_thread_responses.write()?;
let message_resp = lock.entry(in_message.cmd()).or_default(); let message_resp = lock.entry(in_message.cmd()).or_default();
Arc::clone(&message_resp) Arc::clone(&message_resp)
}; };
let (lock, cvar) = &*message_resp; let (lock, cvar) = &*message_resp;
let mut messages = lock.lock().unwrap(); let mut messages = lock.lock()?;
messages.push(in_message); messages.push(in_message);
cvar.notify_all(); cvar.notify_all();
} }
Ok(())
} }
/// Send a raw Bitcoin message to the peer /// Send a raw Bitcoin message to the peer
pub fn send(&self, payload: NetworkMessage) -> Result<(), CompactFiltersError> { pub fn send(&self, payload: NetworkMessage) -> Result<(), PeerError> {
let mut writer = self.writer.lock().unwrap(); let mut writer = self.writer.lock()?;
Self::_send(&mut writer, self.network.magic(), payload) Self::_send(&mut writer, self.network.magic(), payload)
} }
@ -400,30 +443,27 @@ impl Peer {
&self, &self,
wait_for: &'static str, wait_for: &'static str,
timeout: Option<Duration>, timeout: Option<Duration>,
) -> Result<Option<NetworkMessage>, CompactFiltersError> { ) -> Result<Option<NetworkMessage>, PeerError> {
Ok(Self::_recv(&self.responses, wait_for, timeout)) Self::_recv(&self.responses, wait_for, timeout)
} }
} }
pub trait CompactFiltersPeer { pub trait CompactFiltersPeer {
fn get_cf_checkpt( fn get_cf_checkpt(&self, filter_type: u8, stop_hash: BlockHash)
&self, -> Result<CFCheckpt, PeerError>;
filter_type: u8,
stop_hash: BlockHash,
) -> Result<CFCheckpt, CompactFiltersError>;
fn get_cf_headers( fn get_cf_headers(
&self, &self,
filter_type: u8, filter_type: u8,
start_height: u32, start_height: u32,
stop_hash: BlockHash, stop_hash: BlockHash,
) -> Result<CFHeaders, CompactFiltersError>; ) -> Result<CFHeaders, PeerError>;
fn get_cf_filters( fn get_cf_filters(
&self, &self,
filter_type: u8, filter_type: u8,
start_height: u32, start_height: u32,
stop_hash: BlockHash, stop_hash: BlockHash,
) -> Result<(), CompactFiltersError>; ) -> Result<(), PeerError>;
fn pop_cf_filter_resp(&self) -> Result<CFilter, CompactFiltersError>; fn pop_cf_filter_resp(&self) -> Result<CFilter, PeerError>;
} }
impl CompactFiltersPeer for Peer { impl CompactFiltersPeer for Peer {
@ -431,22 +471,20 @@ impl CompactFiltersPeer for Peer {
&self, &self,
filter_type: u8, filter_type: u8,
stop_hash: BlockHash, stop_hash: BlockHash,
) -> Result<CFCheckpt, CompactFiltersError> { ) -> Result<CFCheckpt, PeerError> {
self.send(NetworkMessage::GetCFCheckpt(GetCFCheckpt { self.send(NetworkMessage::GetCFCheckpt(GetCFCheckpt {
filter_type, filter_type,
stop_hash, stop_hash,
}))?; }))?;
let response = self let response = self.recv("cfcheckpt", Some(Duration::from_secs(TIMEOUT_SECS)))?;
.recv("cfcheckpt", Some(Duration::from_secs(TIMEOUT_SECS)))?
.ok_or(CompactFiltersError::Timeout)?;
let response = match response { let response = match response {
NetworkMessage::CFCheckpt(response) => response, Some(NetworkMessage::CFCheckpt(response)) => response,
_ => return Err(CompactFiltersError::InvalidResponse), _ => return Err(PeerError::InvalidResponse(self.get_address()?)),
}; };
if response.filter_type != filter_type { if response.filter_type != filter_type {
return Err(CompactFiltersError::InvalidResponse); return Err(PeerError::InvalidResponse(self.get_address()?));
} }
Ok(response) Ok(response)
@ -457,35 +495,31 @@ impl CompactFiltersPeer for Peer {
filter_type: u8, filter_type: u8,
start_height: u32, start_height: u32,
stop_hash: BlockHash, stop_hash: BlockHash,
) -> Result<CFHeaders, CompactFiltersError> { ) -> Result<CFHeaders, PeerError> {
self.send(NetworkMessage::GetCFHeaders(GetCFHeaders { self.send(NetworkMessage::GetCFHeaders(GetCFHeaders {
filter_type, filter_type,
start_height, start_height,
stop_hash, stop_hash,
}))?; }))?;
let response = self let response = self.recv("cfheaders", Some(Duration::from_secs(TIMEOUT_SECS)))?;
.recv("cfheaders", Some(Duration::from_secs(TIMEOUT_SECS)))?
.ok_or(CompactFiltersError::Timeout)?;
let response = match response { let response = match response {
NetworkMessage::CFHeaders(response) => response, Some(NetworkMessage::CFHeaders(response)) => response,
_ => return Err(CompactFiltersError::InvalidResponse), _ => return Err(PeerError::InvalidResponse(self.get_address()?)),
}; };
if response.filter_type != filter_type { if response.filter_type != filter_type {
return Err(CompactFiltersError::InvalidResponse); return Err(PeerError::InvalidResponse(self.get_address()?));
} }
Ok(response) Ok(response)
} }
fn pop_cf_filter_resp(&self) -> Result<CFilter, CompactFiltersError> { fn pop_cf_filter_resp(&self) -> Result<CFilter, PeerError> {
let response = self let response = self.recv("cfilter", Some(Duration::from_secs(TIMEOUT_SECS)))?;
.recv("cfilter", Some(Duration::from_secs(TIMEOUT_SECS)))?
.ok_or(CompactFiltersError::Timeout)?;
let response = match response { let response = match response {
NetworkMessage::CFilter(response) => response, Some(NetworkMessage::CFilter(response)) => response,
_ => return Err(CompactFiltersError::InvalidResponse), _ => return Err(PeerError::InvalidResponse(self.get_address()?)),
}; };
Ok(response) Ok(response)
@ -496,7 +530,7 @@ impl CompactFiltersPeer for Peer {
filter_type: u8, filter_type: u8,
start_height: u32, start_height: u32,
stop_hash: BlockHash, stop_hash: BlockHash,
) -> Result<(), CompactFiltersError> { ) -> Result<(), PeerError> {
self.send(NetworkMessage::GetCFilters(GetCFilters { self.send(NetworkMessage::GetCFilters(GetCFilters {
filter_type, filter_type,
start_height, start_height,
@ -508,13 +542,13 @@ impl CompactFiltersPeer for Peer {
} }
pub trait InvPeer { pub trait InvPeer {
fn get_block(&self, block_hash: BlockHash) -> Result<Option<Block>, CompactFiltersError>; fn get_block(&self, block_hash: BlockHash) -> Result<Option<Block>, PeerError>;
fn ask_for_mempool(&self) -> Result<(), CompactFiltersError>; fn ask_for_mempool(&self) -> Result<(), PeerError>;
fn broadcast_tx(&self, tx: Transaction) -> Result<(), CompactFiltersError>; fn broadcast_tx(&self, tx: Transaction) -> Result<(), PeerError>;
} }
impl InvPeer for Peer { impl InvPeer for Peer {
fn get_block(&self, block_hash: BlockHash) -> Result<Option<Block>, CompactFiltersError> { fn get_block(&self, block_hash: BlockHash) -> Result<Option<Block>, PeerError> {
self.send(NetworkMessage::GetData(vec![Inventory::WitnessBlock( self.send(NetworkMessage::GetData(vec![Inventory::WitnessBlock(
block_hash, block_hash,
)]))?; )]))?;
@ -522,51 +556,126 @@ impl InvPeer for Peer {
match self.recv("block", Some(Duration::from_secs(TIMEOUT_SECS)))? { match self.recv("block", Some(Duration::from_secs(TIMEOUT_SECS)))? {
None => Ok(None), None => Ok(None),
Some(NetworkMessage::Block(response)) => Ok(Some(response)), Some(NetworkMessage::Block(response)) => Ok(Some(response)),
_ => Err(CompactFiltersError::InvalidResponse), _ => Err(PeerError::InvalidResponse(self.get_address()?)),
} }
} }
fn ask_for_mempool(&self) -> Result<(), CompactFiltersError> { fn ask_for_mempool(&self) -> Result<(), PeerError> {
if !self.version.services.has(ServiceFlags::BLOOM) { if !self.version.services.has(ServiceFlags::BLOOM) {
return Err(CompactFiltersError::PeerBloomDisabled); return Err(PeerError::PeerBloomDisabled(self.get_address()?));
} }
self.send(NetworkMessage::MemPool)?; self.send(NetworkMessage::MemPool)?;
let inv = match self.recv("inv", Some(Duration::from_secs(5)))? { let inv = match self.recv("inv", Some(Duration::from_secs(5)))? {
None => return Ok(()), // empty mempool None => return Ok(()), // empty mempool
Some(NetworkMessage::Inv(inv)) => inv, Some(NetworkMessage::Inv(inv)) => inv,
_ => return Err(CompactFiltersError::InvalidResponse), _ => return Err(PeerError::InvalidResponse(self.get_address()?)),
}; };
let getdata = inv let getdata = inv
.iter() .iter()
.cloned() .cloned()
.filter( .filter(
|item| matches!(item, Inventory::Transaction(txid) if !self.mempool.has_tx(txid)), |item| matches!(item, Inventory::Transaction(txid) if !self.mempool.has_tx(txid).unwrap()),
) )
.collect::<Vec<_>>(); .collect::<Vec<_>>();
let num_txs = getdata.len(); let num_txs = getdata.len();
self.send(NetworkMessage::GetData(getdata))?; self.send(NetworkMessage::GetData(getdata))?;
for _ in 0..num_txs { for _ in 0..num_txs {
let tx = self let tx = self.recv("tx", Some(Duration::from_secs(TIMEOUT_SECS)))?;
.recv("tx", Some(Duration::from_secs(TIMEOUT_SECS)))?
.ok_or(CompactFiltersError::Timeout)?;
let tx = match tx { let tx = match tx {
NetworkMessage::Tx(tx) => tx, Some(NetworkMessage::Tx(tx)) => tx,
_ => return Err(CompactFiltersError::InvalidResponse), _ => return Err(PeerError::InvalidResponse(self.get_address()?)),
}; };
self.mempool.add_tx(tx); self.mempool.add_tx(tx)?;
} }
Ok(()) Ok(())
} }
fn broadcast_tx(&self, tx: Transaction) -> Result<(), CompactFiltersError> { fn broadcast_tx(&self, tx: Transaction) -> Result<(), PeerError> {
self.mempool.add_tx(tx.clone()); self.mempool.add_tx(tx.clone())?;
self.send(NetworkMessage::Tx(tx))?; self.send(NetworkMessage::Tx(tx))?;
Ok(()) Ok(())
} }
} }
/// Peer Errors
#[derive(Debug)]
pub enum PeerError {
/// Internal I/O error
Io(std::io::Error),
/// Internal system time error
Time(std::time::SystemTimeError),
/// A peer sent an invalid or unexpected response
InvalidResponse(SocketAddr),
/// Peer had bloom filter disabled
PeerBloomDisabled(SocketAddr),
/// Internal Mutex poisoning error
MutexPoisoned,
/// Internal Mutex wait timed out
MutexTimedout,
/// Internal RW read lock poisoned
RwReadLockPoisined,
/// Internal RW write lock poisoned
RwWriteLockPoisoned,
/// Mempool Mutex poisoned
MempoolPoisoned,
/// Network address resolution Error
AddresseResolution,
/// Generic Errors
Generic(String),
}
impl std::fmt::Display for PeerError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{:?}", self)
}
}
impl std::error::Error for PeerError {}
impl_error!(std::io::Error, Io, PeerError);
impl_error!(std::time::SystemTimeError, Time, PeerError);
impl<T> From<PoisonError<MutexGuard<'_, T>>> for PeerError {
fn from(_: PoisonError<MutexGuard<'_, T>>) -> Self {
PeerError::MutexPoisoned
}
}
impl<T> From<PoisonError<RwLockWriteGuard<'_, T>>> for PeerError {
fn from(_: PoisonError<RwLockWriteGuard<'_, T>>) -> Self {
PeerError::RwWriteLockPoisoned
}
}
impl<T> From<PoisonError<RwLockReadGuard<'_, T>>> for PeerError {
fn from(_: PoisonError<RwLockReadGuard<'_, T>>) -> Self {
PeerError::RwReadLockPoisined
}
}
impl<T> From<PoisonError<(MutexGuard<'_, T>, WaitTimeoutResult)>> for PeerError {
fn from(err: PoisonError<(MutexGuard<'_, T>, WaitTimeoutResult)>) -> Self {
let (_, wait_result) = err.into_inner();
if wait_result.timed_out() {
PeerError::MutexTimedout
} else {
PeerError::MutexPoisoned
}
}
}