Add PeerError structure in peer module

This adds a new PeerError structure in the peer module. To handle all
the peer related errors. PeerErrors contains all the mempool errors too,
for now. Later if we have a more complex mempool, we might decide to
have its own dedicated error.

PeerError is to be included in the global CompactFiltersError type.
This commit is contained in:
codeShark149 2021-07-01 13:12:37 +05:30 committed by rajarshimaitra
parent 474620e6a5
commit 9480faa5d3
No known key found for this signature in database
GPG Key ID: 558ACE7DBB4377C8

View File

@ -10,11 +10,15 @@
// licenses.
use std::collections::HashMap;
use std::net::{TcpStream, ToSocketAddrs};
use std::fmt;
use std::net::{SocketAddr, TcpStream, ToSocketAddrs};
use std::sync::{Arc, Condvar, Mutex, RwLock};
use std::thread;
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use std::sync::PoisonError;
use std::sync::{MutexGuard, RwLockReadGuard, RwLockWriteGuard, WaitTimeoutResult};
use socks::{Socks5Stream, ToTargetAddr};
use rand::{thread_rng, Rng};
@ -30,8 +34,6 @@ use bitcoin::network::stream_reader::StreamReader;
use bitcoin::network::Address;
use bitcoin::{Block, Network, Transaction, Txid, Wtxid};
use super::CompactFiltersError;
type ResponsesMap = HashMap<&'static str, Arc<(Mutex<Vec<NetworkMessage>>, Condvar)>>;
pub(crate) const TIMEOUT_SECS: u64 = 30;
@ -65,17 +67,18 @@ impl Mempool {
///
/// Note that this doesn't propagate the transaction to other
/// peers. To do that, [`broadcast`](crate::blockchain::Blockchain::broadcast) should be used.
pub fn add_tx(&self, tx: Transaction) {
let mut guard = self.0.write().unwrap();
pub fn add_tx(&self, tx: Transaction) -> Result<(), PeerError> {
let mut guard = self.0.write()?;
guard.wtxids.insert(tx.wtxid(), tx.txid());
guard.txs.insert(tx.txid(), tx);
Ok(())
}
/// Look-up a transaction in the mempool given an [`Inventory`] request
pub fn get_tx(&self, inventory: &Inventory) -> Option<Transaction> {
pub fn get_tx(&self, inventory: &Inventory) -> Result<Option<Transaction>, PeerError> {
let identifer = match inventory {
Inventory::Error | Inventory::Block(_) | Inventory::WitnessBlock(_) => return None,
Inventory::Error | Inventory::Block(_) | Inventory::WitnessBlock(_) => return Ok(None),
Inventory::Transaction(txid) => TxIdentifier::Txid(*txid),
Inventory::WitnessTransaction(txid) => TxIdentifier::Txid(*txid),
Inventory::WTx(wtxid) => TxIdentifier::Wtxid(*wtxid),
@ -85,27 +88,34 @@ impl Mempool {
inv_type,
hash
);
return None;
return Ok(None);
}
};
let txid = match identifer {
TxIdentifier::Txid(txid) => Some(txid),
TxIdentifier::Wtxid(wtxid) => self.0.read().unwrap().wtxids.get(&wtxid).cloned(),
TxIdentifier::Wtxid(wtxid) => self.0.read()?.wtxids.get(&wtxid).cloned(),
};
txid.map(|txid| self.0.read().unwrap().txs.get(&txid).cloned())
.flatten()
let result = match txid {
Some(txid) => {
let read_lock = self.0.read()?;
read_lock.txs.get(&txid).cloned()
}
None => None,
};
Ok(result)
}
/// Return whether or not the mempool contains a transaction with a given txid
pub fn has_tx(&self, txid: &Txid) -> bool {
self.0.read().unwrap().txs.contains_key(txid)
pub fn has_tx(&self, txid: &Txid) -> Result<bool, PeerError> {
Ok(self.0.read()?.txs.contains_key(txid))
}
/// Return the list of transactions contained in the mempool
pub fn iter_txs(&self) -> Vec<Transaction> {
self.0.read().unwrap().txs.values().cloned().collect()
pub fn iter_txs(&self) -> Result<Vec<Transaction>, PeerError> {
Ok(self.0.read()?.txs.values().cloned().collect())
}
}
@ -133,12 +143,31 @@ impl Peer {
address: A,
mempool: Arc<Mempool>,
network: Network,
) -> Result<Self, CompactFiltersError> {
) -> Result<Self, PeerError> {
let stream = TcpStream::connect(address)?;
Peer::from_stream(stream, mempool, network)
}
/// Connect to a peer over a plaintext TCP connection with a timeout
///
/// This function behaves exactly the same as `connect` except for two differences
/// 1) It assumes your ToSocketAddrs will resolve to a single address
/// 2) It lets you specify a connection timeout
pub fn connect_with_timeout<A: ToSocketAddrs>(
address: A,
timeout: Duration,
mempool: Arc<Mempool>,
network: Network,
) -> Result<Self, PeerError> {
let socket_addr = address
.to_socket_addrs()?
.next()
.ok_or(PeerError::AddresseResolution)?;
let stream = TcpStream::connect_timeout(&socket_addr, timeout)?;
Peer::from_stream(stream, mempool, network)
}
/// Connect to a peer through a SOCKS5 proxy, optionally by using some credentials, specified
/// as a tuple of `(username, password)`
///
@ -150,7 +179,7 @@ impl Peer {
credentials: Option<(&str, &str)>,
mempool: Arc<Mempool>,
network: Network,
) -> Result<Self, CompactFiltersError> {
) -> Result<Self, PeerError> {
let socks_stream = if let Some((username, password)) = credentials {
Socks5Stream::connect_with_password(proxy, target, username, password)?
} else {
@ -165,12 +194,12 @@ impl Peer {
stream: TcpStream,
mempool: Arc<Mempool>,
network: Network,
) -> Result<Self, CompactFiltersError> {
) -> Result<Self, PeerError> {
let writer = Arc::new(Mutex::new(stream.try_clone()?));
let responses: Arc<RwLock<ResponsesMap>> = Arc::new(RwLock::new(HashMap::new()));
let connected = Arc::new(RwLock::new(true));
let mut locked_writer = writer.lock().unwrap();
let mut locked_writer = writer.lock()?;
let reader_thread_responses = Arc::clone(&responses);
let reader_thread_writer = Arc::clone(&writer);
@ -185,6 +214,7 @@ impl Peer {
reader_thread_mempool,
reader_thread_connected,
)
.unwrap()
});
let timestamp = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs() as i64;
@ -209,18 +239,20 @@ impl Peer {
0,
)),
)?;
let version = if let NetworkMessage::Version(version) =
Self::_recv(&responses, "version", None).unwrap()
{
version
} else {
return Err(CompactFiltersError::InvalidResponse);
let version = match Self::_recv(&responses, "version", Some(Duration::from_secs(1)))? {
Some(NetworkMessage::Version(version)) => version,
_ => {
return Err(PeerError::InvalidResponse(locked_writer.peer_addr()?));
}
};
if let NetworkMessage::Verack = Self::_recv(&responses, "verack", None).unwrap() {
if let Some(NetworkMessage::Verack) =
Self::_recv(&responses, "verack", Some(Duration::from_secs(1)))?
{
Self::_send(&mut locked_writer, network.magic(), NetworkMessage::Verack)?;
} else {
return Err(CompactFiltersError::InvalidResponse);
return Err(PeerError::InvalidResponse(locked_writer.peer_addr()?));
}
std::mem::drop(locked_writer);
@ -236,19 +268,26 @@ impl Peer {
})
}
/// Close the peer connection
// Consume Self
pub fn close(self) -> Result<(), PeerError> {
let locked_writer = self.writer.lock()?;
Ok((*locked_writer).shutdown(std::net::Shutdown::Both)?)
}
/// Get the socket address of the remote peer
pub fn get_address(&self) -> Result<SocketAddr, PeerError> {
let locked_writer = self.writer.lock()?;
Ok(locked_writer.peer_addr()?)
}
/// Send a Bitcoin network message
fn _send(
writer: &mut TcpStream,
magic: u32,
payload: NetworkMessage,
) -> Result<(), CompactFiltersError> {
fn _send(writer: &mut TcpStream, magic: u32, payload: NetworkMessage) -> Result<(), PeerError> {
log::trace!("==> {:?}", payload);
let raw_message = RawNetworkMessage { magic, payload };
raw_message
.consensus_encode(writer)
.map_err(|_| CompactFiltersError::DataCorruption)?;
raw_message.consensus_encode(writer)?;
Ok(())
}
@ -258,30 +297,30 @@ impl Peer {
responses: &Arc<RwLock<ResponsesMap>>,
wait_for: &'static str,
timeout: Option<Duration>,
) -> Option<NetworkMessage> {
) -> Result<Option<NetworkMessage>, PeerError> {
let message_resp = {
let mut lock = responses.write().unwrap();
let mut lock = responses.write()?;
let message_resp = lock.entry(wait_for).or_default();
Arc::clone(&message_resp)
};
let (lock, cvar) = &*message_resp;
let mut messages = lock.lock().unwrap();
let mut messages = lock.lock()?;
while messages.is_empty() {
match timeout {
None => messages = cvar.wait(messages).unwrap(),
None => messages = cvar.wait(messages)?,
Some(t) => {
let result = cvar.wait_timeout(messages, t).unwrap();
let result = cvar.wait_timeout(messages, t)?;
if result.1.timed_out() {
return None;
return Ok(None);
}
messages = result.0;
}
}
}
messages.pop()
Ok(messages.pop())
}
/// Return the [`VersionMessage`] sent by the peer
@ -300,8 +339,8 @@ impl Peer {
}
/// Return whether or not the peer is still connected
pub fn is_connected(&self) -> bool {
*self.connected.read().unwrap()
pub fn is_connected(&self) -> Result<bool, PeerError> {
Ok(*self.connected.read()?)
}
/// Internal function called once the `reader_thread` is spawned
@ -312,14 +351,14 @@ impl Peer {
reader_thread_writer: Arc<Mutex<TcpStream>>,
reader_thread_mempool: Arc<Mempool>,
reader_thread_connected: Arc<RwLock<bool>>,
) {
) -> Result<(), PeerError> {
macro_rules! check_disconnect {
($call:expr) => {
match $call {
Ok(good) => good,
Err(e) => {
log::debug!("Error {:?}", e);
*reader_thread_connected.write().unwrap() = false;
*reader_thread_connected.write()? = false;
break;
}
@ -328,7 +367,7 @@ impl Peer {
}
let mut reader = StreamReader::new(connection, None);
loop {
while *reader_thread_connected.read()? {
let raw_message: RawNetworkMessage = check_disconnect!(reader.read_next());
let in_message = if raw_message.magic != network.magic() {
@ -342,7 +381,7 @@ impl Peer {
match in_message {
NetworkMessage::Ping(nonce) => {
check_disconnect!(Self::_send(
&mut reader_thread_writer.lock().unwrap(),
&mut *reader_thread_writer.lock()?,
network.magic(),
NetworkMessage::Pong(nonce),
));
@ -353,19 +392,21 @@ impl Peer {
NetworkMessage::GetData(ref inv) => {
let (found, not_found): (Vec<_>, Vec<_>) = inv
.iter()
.map(|item| (*item, reader_thread_mempool.get_tx(item)))
.map(|item| (*item, reader_thread_mempool.get_tx(item).unwrap()))
.partition(|(_, d)| d.is_some());
for (_, found_tx) in found {
check_disconnect!(Self::_send(
&mut reader_thread_writer.lock().unwrap(),
&mut *reader_thread_writer.lock()?,
network.magic(),
NetworkMessage::Tx(found_tx.unwrap()),
NetworkMessage::Tx(found_tx.ok_or_else(|| PeerError::Generic(
"Got None while expecting Transaction".to_string()
))?),
));
}
if !not_found.is_empty() {
check_disconnect!(Self::_send(
&mut reader_thread_writer.lock().unwrap(),
&mut *reader_thread_writer.lock()?,
network.magic(),
NetworkMessage::NotFound(
not_found.into_iter().map(|(i, _)| i).collect(),
@ -377,21 +418,23 @@ impl Peer {
}
let message_resp = {
let mut lock = reader_thread_responses.write().unwrap();
let mut lock = reader_thread_responses.write()?;
let message_resp = lock.entry(in_message.cmd()).or_default();
Arc::clone(&message_resp)
};
let (lock, cvar) = &*message_resp;
let mut messages = lock.lock().unwrap();
let mut messages = lock.lock()?;
messages.push(in_message);
cvar.notify_all();
}
Ok(())
}
/// Send a raw Bitcoin message to the peer
pub fn send(&self, payload: NetworkMessage) -> Result<(), CompactFiltersError> {
let mut writer = self.writer.lock().unwrap();
pub fn send(&self, payload: NetworkMessage) -> Result<(), PeerError> {
let mut writer = self.writer.lock()?;
Self::_send(&mut writer, self.network.magic(), payload)
}
@ -400,30 +443,27 @@ impl Peer {
&self,
wait_for: &'static str,
timeout: Option<Duration>,
) -> Result<Option<NetworkMessage>, CompactFiltersError> {
Ok(Self::_recv(&self.responses, wait_for, timeout))
) -> Result<Option<NetworkMessage>, PeerError> {
Self::_recv(&self.responses, wait_for, timeout)
}
}
pub trait CompactFiltersPeer {
fn get_cf_checkpt(
&self,
filter_type: u8,
stop_hash: BlockHash,
) -> Result<CFCheckpt, CompactFiltersError>;
fn get_cf_checkpt(&self, filter_type: u8, stop_hash: BlockHash)
-> Result<CFCheckpt, PeerError>;
fn get_cf_headers(
&self,
filter_type: u8,
start_height: u32,
stop_hash: BlockHash,
) -> Result<CFHeaders, CompactFiltersError>;
) -> Result<CFHeaders, PeerError>;
fn get_cf_filters(
&self,
filter_type: u8,
start_height: u32,
stop_hash: BlockHash,
) -> Result<(), CompactFiltersError>;
fn pop_cf_filter_resp(&self) -> Result<CFilter, CompactFiltersError>;
) -> Result<(), PeerError>;
fn pop_cf_filter_resp(&self) -> Result<CFilter, PeerError>;
}
impl CompactFiltersPeer for Peer {
@ -431,22 +471,20 @@ impl CompactFiltersPeer for Peer {
&self,
filter_type: u8,
stop_hash: BlockHash,
) -> Result<CFCheckpt, CompactFiltersError> {
) -> Result<CFCheckpt, PeerError> {
self.send(NetworkMessage::GetCFCheckpt(GetCFCheckpt {
filter_type,
stop_hash,
}))?;
let response = self
.recv("cfcheckpt", Some(Duration::from_secs(TIMEOUT_SECS)))?
.ok_or(CompactFiltersError::Timeout)?;
let response = self.recv("cfcheckpt", Some(Duration::from_secs(TIMEOUT_SECS)))?;
let response = match response {
NetworkMessage::CFCheckpt(response) => response,
_ => return Err(CompactFiltersError::InvalidResponse),
Some(NetworkMessage::CFCheckpt(response)) => response,
_ => return Err(PeerError::InvalidResponse(self.get_address()?)),
};
if response.filter_type != filter_type {
return Err(CompactFiltersError::InvalidResponse);
return Err(PeerError::InvalidResponse(self.get_address()?));
}
Ok(response)
@ -457,35 +495,31 @@ impl CompactFiltersPeer for Peer {
filter_type: u8,
start_height: u32,
stop_hash: BlockHash,
) -> Result<CFHeaders, CompactFiltersError> {
) -> Result<CFHeaders, PeerError> {
self.send(NetworkMessage::GetCFHeaders(GetCFHeaders {
filter_type,
start_height,
stop_hash,
}))?;
let response = self
.recv("cfheaders", Some(Duration::from_secs(TIMEOUT_SECS)))?
.ok_or(CompactFiltersError::Timeout)?;
let response = self.recv("cfheaders", Some(Duration::from_secs(TIMEOUT_SECS)))?;
let response = match response {
NetworkMessage::CFHeaders(response) => response,
_ => return Err(CompactFiltersError::InvalidResponse),
Some(NetworkMessage::CFHeaders(response)) => response,
_ => return Err(PeerError::InvalidResponse(self.get_address()?)),
};
if response.filter_type != filter_type {
return Err(CompactFiltersError::InvalidResponse);
return Err(PeerError::InvalidResponse(self.get_address()?));
}
Ok(response)
}
fn pop_cf_filter_resp(&self) -> Result<CFilter, CompactFiltersError> {
let response = self
.recv("cfilter", Some(Duration::from_secs(TIMEOUT_SECS)))?
.ok_or(CompactFiltersError::Timeout)?;
fn pop_cf_filter_resp(&self) -> Result<CFilter, PeerError> {
let response = self.recv("cfilter", Some(Duration::from_secs(TIMEOUT_SECS)))?;
let response = match response {
NetworkMessage::CFilter(response) => response,
_ => return Err(CompactFiltersError::InvalidResponse),
Some(NetworkMessage::CFilter(response)) => response,
_ => return Err(PeerError::InvalidResponse(self.get_address()?)),
};
Ok(response)
@ -496,7 +530,7 @@ impl CompactFiltersPeer for Peer {
filter_type: u8,
start_height: u32,
stop_hash: BlockHash,
) -> Result<(), CompactFiltersError> {
) -> Result<(), PeerError> {
self.send(NetworkMessage::GetCFilters(GetCFilters {
filter_type,
start_height,
@ -508,13 +542,13 @@ impl CompactFiltersPeer for Peer {
}
pub trait InvPeer {
fn get_block(&self, block_hash: BlockHash) -> Result<Option<Block>, CompactFiltersError>;
fn ask_for_mempool(&self) -> Result<(), CompactFiltersError>;
fn broadcast_tx(&self, tx: Transaction) -> Result<(), CompactFiltersError>;
fn get_block(&self, block_hash: BlockHash) -> Result<Option<Block>, PeerError>;
fn ask_for_mempool(&self) -> Result<(), PeerError>;
fn broadcast_tx(&self, tx: Transaction) -> Result<(), PeerError>;
}
impl InvPeer for Peer {
fn get_block(&self, block_hash: BlockHash) -> Result<Option<Block>, CompactFiltersError> {
fn get_block(&self, block_hash: BlockHash) -> Result<Option<Block>, PeerError> {
self.send(NetworkMessage::GetData(vec![Inventory::WitnessBlock(
block_hash,
)]))?;
@ -522,51 +556,126 @@ impl InvPeer for Peer {
match self.recv("block", Some(Duration::from_secs(TIMEOUT_SECS)))? {
None => Ok(None),
Some(NetworkMessage::Block(response)) => Ok(Some(response)),
_ => Err(CompactFiltersError::InvalidResponse),
_ => Err(PeerError::InvalidResponse(self.get_address()?)),
}
}
fn ask_for_mempool(&self) -> Result<(), CompactFiltersError> {
fn ask_for_mempool(&self) -> Result<(), PeerError> {
if !self.version.services.has(ServiceFlags::BLOOM) {
return Err(CompactFiltersError::PeerBloomDisabled);
return Err(PeerError::PeerBloomDisabled(self.get_address()?));
}
self.send(NetworkMessage::MemPool)?;
let inv = match self.recv("inv", Some(Duration::from_secs(5)))? {
None => return Ok(()), // empty mempool
Some(NetworkMessage::Inv(inv)) => inv,
_ => return Err(CompactFiltersError::InvalidResponse),
_ => return Err(PeerError::InvalidResponse(self.get_address()?)),
};
let getdata = inv
.iter()
.cloned()
.filter(
|item| matches!(item, Inventory::Transaction(txid) if !self.mempool.has_tx(txid)),
|item| matches!(item, Inventory::Transaction(txid) if !self.mempool.has_tx(txid).unwrap()),
)
.collect::<Vec<_>>();
let num_txs = getdata.len();
self.send(NetworkMessage::GetData(getdata))?;
for _ in 0..num_txs {
let tx = self
.recv("tx", Some(Duration::from_secs(TIMEOUT_SECS)))?
.ok_or(CompactFiltersError::Timeout)?;
let tx = self.recv("tx", Some(Duration::from_secs(TIMEOUT_SECS)))?;
let tx = match tx {
NetworkMessage::Tx(tx) => tx,
_ => return Err(CompactFiltersError::InvalidResponse),
Some(NetworkMessage::Tx(tx)) => tx,
_ => return Err(PeerError::InvalidResponse(self.get_address()?)),
};
self.mempool.add_tx(tx);
self.mempool.add_tx(tx)?;
}
Ok(())
}
fn broadcast_tx(&self, tx: Transaction) -> Result<(), CompactFiltersError> {
self.mempool.add_tx(tx.clone());
fn broadcast_tx(&self, tx: Transaction) -> Result<(), PeerError> {
self.mempool.add_tx(tx.clone())?;
self.send(NetworkMessage::Tx(tx))?;
Ok(())
}
}
/// Peer Errors
#[derive(Debug)]
pub enum PeerError {
/// Internal I/O error
Io(std::io::Error),
/// Internal system time error
Time(std::time::SystemTimeError),
/// A peer sent an invalid or unexpected response
InvalidResponse(SocketAddr),
/// Peer had bloom filter disabled
PeerBloomDisabled(SocketAddr),
/// Internal Mutex poisoning error
MutexPoisoned,
/// Internal Mutex wait timed out
MutexTimedout,
/// Internal RW read lock poisoned
RwReadLockPoisined,
/// Internal RW write lock poisoned
RwWriteLockPoisoned,
/// Mempool Mutex poisoned
MempoolPoisoned,
/// Network address resolution Error
AddresseResolution,
/// Generic Errors
Generic(String),
}
impl std::fmt::Display for PeerError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{:?}", self)
}
}
impl std::error::Error for PeerError {}
impl_error!(std::io::Error, Io, PeerError);
impl_error!(std::time::SystemTimeError, Time, PeerError);
impl<T> From<PoisonError<MutexGuard<'_, T>>> for PeerError {
fn from(_: PoisonError<MutexGuard<'_, T>>) -> Self {
PeerError::MutexPoisoned
}
}
impl<T> From<PoisonError<RwLockWriteGuard<'_, T>>> for PeerError {
fn from(_: PoisonError<RwLockWriteGuard<'_, T>>) -> Self {
PeerError::RwWriteLockPoisoned
}
}
impl<T> From<PoisonError<RwLockReadGuard<'_, T>>> for PeerError {
fn from(_: PoisonError<RwLockReadGuard<'_, T>>) -> Self {
PeerError::RwReadLockPoisined
}
}
impl<T> From<PoisonError<(MutexGuard<'_, T>, WaitTimeoutResult)>> for PeerError {
fn from(err: PoisonError<(MutexGuard<'_, T>, WaitTimeoutResult)>) -> Self {
let (_, wait_result) = err.into_inner();
if wait_result.timed_out() {
PeerError::MutexTimedout
} else {
PeerError::MutexPoisoned
}
}
}