Refactor rust-gbt

This commit is contained in:
junderw
2024-03-10 13:27:09 +09:00
parent 7bedb9488b
commit 92a5fc8159
22 changed files with 187 additions and 540 deletions

View File

@@ -0,0 +1,225 @@
use crate::{
u32_hasher_types::{u32hashset_new, U32HasherState},
ThreadTransaction, thread_acceleration::ThreadAcceleration,
};
use std::{
cmp::Ordering,
collections::HashSet,
hash::{Hash, Hasher},
};
#[allow(clippy::struct_excessive_bools)]
#[derive(Clone, Debug)]
pub struct AuditTransaction {
pub uid: u32,
order: u32,
pub fee: u64,
pub weight: u32,
// exact sigop-adjusted weight
pub sigop_adjusted_weight: u32,
// sigop-adjusted vsize rounded up the the next integer
pub sigop_adjusted_vsize: u32,
pub sigops: u32,
adjusted_fee_per_vsize: f64,
pub effective_fee_per_vsize: f64,
pub dependency_rate: f64,
pub inputs: Vec<u32>,
pub relatives_set_flag: bool,
pub ancestors: HashSet<u32, U32HasherState>,
pub children: HashSet<u32, U32HasherState>,
ancestor_fee: u64,
ancestor_sigop_adjusted_weight: u32,
ancestor_sigop_adjusted_vsize: u32,
ancestor_sigops: u32,
// Safety: Must be private to prevent NaN breaking Ord impl.
score: f64,
pub used: bool,
/// whether this transaction has been moved to the "modified" priority queue
pub modified: bool,
pub dirty: bool,
}
impl Hash for AuditTransaction {
fn hash<H: Hasher>(&self, state: &mut H) {
self.uid.hash(state);
}
}
impl PartialEq for AuditTransaction {
fn eq(&self, other: &Self) -> bool {
self.uid == other.uid
}
}
impl Eq for AuditTransaction {}
#[inline]
pub fn partial_cmp_uid_score(a: (u32, u32, f64), b: (u32, u32, f64)) -> Option<Ordering> {
// If either score is NaN, this is false,
// and partial_cmp will return None
if a.2 != b.2 {
// compare by score (sorts by ascending score)
a.2.partial_cmp(&b.2)
} else if a.1 != b.1 {
// tie-break by comparing partial txids (sorts by descending txid)
Some(b.1.cmp(&a.1))
} else {
// tie-break partial txid collisions by comparing uids (sorts by descending uid)
Some(b.0.cmp(&a.0))
}
}
impl PartialOrd for AuditTransaction {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
partial_cmp_uid_score(
(self.uid, self.order, self.score),
(other.uid, other.order, other.score),
)
}
}
impl Ord for AuditTransaction {
fn cmp(&self, other: &Self) -> Ordering {
// Safety: The only possible values for score are f64
// that are not NaN. This is because outside code can not
// freely assign score. Also, calc_new_score guarantees no NaN.
self.partial_cmp(other).expect("score will never be NaN")
}
}
#[inline]
fn calc_fee_rate(fee: u64, vsize: f64) -> f64 {
(fee as f64) / (if vsize == 0.0 { 1.0 } else { vsize })
}
impl AuditTransaction {
pub fn from_thread_transaction(tx: &ThreadTransaction, maybe_acceleration: Option<Option<&ThreadAcceleration>>) -> Self {
let fee_delta = match maybe_acceleration {
Some(Some(acceleration)) => acceleration.delta,
_ => 0.0
};
let fee = (tx.fee as u64) + (fee_delta as u64);
// rounded up to the nearest integer
let is_adjusted = tx.weight < (tx.sigops * 20);
let sigop_adjusted_vsize = ((tx.weight + 3) / 4).max(tx.sigops * 5);
let sigop_adjusted_weight = tx.weight.max(tx.sigops * 20);
let effective_fee_per_vsize = if is_adjusted || fee_delta > 0.0 {
calc_fee_rate(fee, f64::from(sigop_adjusted_weight) / 4.0)
} else {
tx.effective_fee_per_vsize
};
Self {
uid: tx.uid,
order: tx.order,
fee,
weight: tx.weight,
sigop_adjusted_weight,
sigop_adjusted_vsize,
sigops: tx.sigops,
adjusted_fee_per_vsize: calc_fee_rate(fee, f64::from(sigop_adjusted_vsize)),
effective_fee_per_vsize,
dependency_rate: f64::INFINITY,
inputs: tx.inputs.clone(),
relatives_set_flag: false,
ancestors: u32hashset_new(),
children: u32hashset_new(),
ancestor_fee: fee,
ancestor_sigop_adjusted_weight: sigop_adjusted_weight,
ancestor_sigop_adjusted_vsize: sigop_adjusted_vsize,
ancestor_sigops: tx.sigops,
score: 0.0,
used: false,
modified: false,
dirty: effective_fee_per_vsize != tx.effective_fee_per_vsize || fee_delta > 0.0,
}
}
#[inline]
pub const fn score(&self) -> f64 {
self.score
}
#[inline]
pub const fn order(&self) -> u32 {
self.order
}
#[inline]
pub const fn ancestor_sigop_adjusted_vsize(&self) -> u32 {
self.ancestor_sigop_adjusted_vsize
}
#[inline]
pub const fn ancestor_sigops(&self) -> u32 {
self.ancestor_sigops
}
#[inline]
pub fn cluster_rate(&self) -> f64 {
// Safety: self.ancestor_weight can never be 0.
// Even if it could, as it approaches 0, the value inside the min() call
// grows, so if we think of 0 as "grew infinitely" then dependency_rate would be
// the smaller of the two. If either side is NaN, the other side is returned.
self.dependency_rate.min(calc_fee_rate(
self.ancestor_fee,
f64::from(self.ancestor_sigop_adjusted_weight) / 4.0,
))
}
pub fn set_dirty_if_different(&mut self, cluster_rate: f64) {
if self.effective_fee_per_vsize != cluster_rate {
self.effective_fee_per_vsize = cluster_rate;
self.dirty = true;
}
}
/// Safety: This function must NEVER set score to NaN.
#[inline]
fn calc_new_score(&mut self) {
self.score = self.adjusted_fee_per_vsize.min(calc_fee_rate(
self.ancestor_fee,
f64::from(self.ancestor_sigop_adjusted_vsize),
));
}
#[inline]
pub fn set_ancestors(
&mut self,
ancestors: HashSet<u32, U32HasherState>,
total_fee: u64,
total_sigop_adjusted_weight: u32,
total_sigop_adjusted_vsize: u32,
total_sigops: u32,
) {
self.ancestors = ancestors;
self.ancestor_fee = self.fee + total_fee;
self.ancestor_sigop_adjusted_weight =
self.sigop_adjusted_weight + total_sigop_adjusted_weight;
self.ancestor_sigop_adjusted_vsize = self.sigop_adjusted_vsize + total_sigop_adjusted_vsize;
self.ancestor_sigops = self.sigops + total_sigops;
self.calc_new_score();
self.relatives_set_flag = true;
}
#[inline]
pub fn remove_root(
&mut self,
root_txid: u32,
root_fee: u64,
root_sigop_adjusted_weight: u32,
root_sigop_adjusted_vsize: u32,
root_sigops: u32,
cluster_rate: f64,
) -> f64 {
let old_score = self.score();
self.dependency_rate = self.dependency_rate.min(cluster_rate);
if self.ancestors.remove(&root_txid) {
self.ancestor_fee -= root_fee;
self.ancestor_sigop_adjusted_weight -= root_sigop_adjusted_weight;
self.ancestor_sigop_adjusted_vsize -= root_sigop_adjusted_vsize;
self.ancestor_sigops -= root_sigops;
self.calc_new_score();
}
old_score
}
}

437
rust/gbt/src/gbt.rs Normal file
View File

@@ -0,0 +1,437 @@
use priority_queue::PriorityQueue;
use std::{cmp::Ordering, collections::HashSet, mem::ManuallyDrop};
use tracing::{info, trace};
use crate::{
audit_transaction::{partial_cmp_uid_score, AuditTransaction},
u32_hasher_types::{u32hashset_new, u32priority_queue_with_capacity, U32HasherState},
GbtResult, ThreadTransactionsMap, thread_acceleration::ThreadAcceleration,
};
const MAX_BLOCK_WEIGHT_UNITS: u32 = 4_000_000 - 4_000;
const BLOCK_SIGOPS: u32 = 80_000;
const BLOCK_RESERVED_WEIGHT: u32 = 4_000;
const BLOCK_RESERVED_SIGOPS: u32 = 400;
const MAX_BLOCKS: usize = 8;
type AuditPool = Vec<Option<ManuallyDrop<AuditTransaction>>>;
type ModifiedQueue = PriorityQueue<u32, TxPriority, U32HasherState>;
#[derive(Debug)]
struct TxPriority {
uid: u32,
order: u32,
score: f64,
}
impl PartialEq for TxPriority {
fn eq(&self, other: &Self) -> bool {
self.uid == other.uid
}
}
impl Eq for TxPriority {}
impl PartialOrd for TxPriority {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
partial_cmp_uid_score(
(self.uid, self.order, self.score),
(other.uid, other.order, other.score),
)
}
}
impl Ord for TxPriority {
fn cmp(&self, other: &Self) -> Ordering {
self.partial_cmp(other).expect("score will never be NaN")
}
}
/// Build projected mempool blocks using an approximation of the transaction selection algorithm from Bitcoin Core.
///
/// See `BlockAssembler` in Bitcoin Core's
/// [miner.cpp](https://github.com/bitcoin/bitcoin/blob/master/src/node/miner.cpp).
/// Ported from mempool backend's
/// [tx-selection-worker.ts](https://github.com/mempool/mempool/blob/master/backend/src/api/tx-selection-worker.ts).
//
// TODO: Make gbt smaller to fix these lints.
#[allow(clippy::too_many_lines)]
#[allow(clippy::cognitive_complexity)]
pub fn gbt(mempool: &mut ThreadTransactionsMap, accelerations: &[ThreadAcceleration], max_uid: usize) -> GbtResult {
let mut indexed_accelerations = Vec::with_capacity(max_uid + 1);
indexed_accelerations.resize(max_uid + 1, None);
for acceleration in accelerations {
indexed_accelerations[acceleration.uid as usize] = Some(acceleration);
}
info!("Initializing working vecs with uid capacity for {}", max_uid + 1);
let mempool_len = mempool.len();
let mut audit_pool: AuditPool = Vec::with_capacity(max_uid + 1);
audit_pool.resize(max_uid + 1, None);
let mut mempool_stack: Vec<u32> = Vec::with_capacity(mempool_len);
let mut clusters: Vec<Vec<u32>> = Vec::new();
let mut block_weights: Vec<u32> = Vec::new();
info!("Initializing working structs");
for (uid, tx) in &mut *mempool {
let acceleration = indexed_accelerations.get(*uid as usize);
let audit_tx = AuditTransaction::from_thread_transaction(tx, acceleration.copied());
// Safety: audit_pool and mempool_stack must always contain the same transactions
audit_pool[*uid as usize] = Some(ManuallyDrop::new(audit_tx));
mempool_stack.push(*uid);
}
info!("Building relatives graph & calculate ancestor scores");
for txid in &mempool_stack {
set_relatives(*txid, &mut audit_pool);
}
trace!("Post relative graph Audit Pool: {:#?}", audit_pool);
info!("Sorting by descending ancestor score");
let mut mempool_stack: Vec<(u32, u32, f64)> = mempool_stack
.into_iter()
.map(|txid| {
let atx = audit_pool
.get(txid as usize)
.and_then(Option::as_ref)
.expect("All txids are from audit_pool");
(txid, atx.order(), atx.score())
})
.collect();
mempool_stack.sort_unstable_by(|a, b| partial_cmp_uid_score(*a, *b).expect("Not NaN"));
let mut mempool_stack: Vec<u32> = mempool_stack.into_iter().map(|(txid, _, _)| txid).collect();
info!("Building blocks by greedily choosing the highest feerate package");
info!("(i.e. the package rooted in the transaction with the best ancestor score)");
let mut blocks: Vec<Vec<u32>> = Vec::new();
let mut block_weight: u32 = BLOCK_RESERVED_WEIGHT;
let mut block_sigops: u32 = BLOCK_RESERVED_SIGOPS;
// No need to be bigger than 4096 transactions for the per-block transaction Vec.
let initial_txes_per_block: usize = 4096.min(mempool_len);
let mut transactions: Vec<u32> = Vec::with_capacity(initial_txes_per_block);
let mut modified: ModifiedQueue = u32priority_queue_with_capacity(mempool_len);
let mut overflow: Vec<u32> = Vec::new();
let mut failures = 0;
while !mempool_stack.is_empty() || !modified.is_empty() {
// This trace log storm is big, so to make scrolling through
// Each iteration easier, leaving a bunch of empty rows
// And a header of ======
trace!("\n\n\n\n\n\n\n\n\n\n==================================");
trace!("mempool_array: {:#?}", mempool_stack);
trace!("clusters: {:#?}", clusters);
trace!("modified: {:#?}", modified);
trace!("audit_pool: {:#?}", audit_pool);
trace!("blocks: {:#?}", blocks);
trace!("block_weight: {:#?}", block_weight);
trace!("block_sigops: {:#?}", block_sigops);
trace!("transactions: {:#?}", transactions);
trace!("overflow: {:#?}", overflow);
trace!("failures: {:#?}", failures);
trace!("\n==================================");
let next_from_stack = next_valid_from_stack(&mut mempool_stack, &audit_pool);
let next_from_queue = next_valid_from_queue(&mut modified, &audit_pool);
if next_from_stack.is_none() && next_from_queue.is_none() {
info!("No transactions left! {:#?} in overflow", overflow.len());
} else {
let (next_tx, from_stack) = match (next_from_stack, next_from_queue) {
(Some(stack_tx), Some(queue_tx)) => match queue_tx.cmp(stack_tx) {
std::cmp::Ordering::Less => (stack_tx, true),
_ => (queue_tx, false),
},
(Some(stack_tx), None) => (stack_tx, true),
(None, Some(queue_tx)) => (queue_tx, false),
(None, None) => unreachable!(),
};
if from_stack {
mempool_stack.pop();
} else {
modified.pop();
}
if blocks.len() < (MAX_BLOCKS - 1)
&& ((block_weight + (4 * next_tx.ancestor_sigop_adjusted_vsize())
>= MAX_BLOCK_WEIGHT_UNITS)
|| (block_sigops + next_tx.ancestor_sigops() > BLOCK_SIGOPS))
{
// hold this package in an overflow list while we check for smaller options
overflow.push(next_tx.uid);
failures += 1;
} else {
let mut package: Vec<(u32, u32, usize)> = Vec::new();
let mut cluster: Vec<u32> = Vec::new();
let is_cluster: bool = !next_tx.ancestors.is_empty();
for ancestor_id in &next_tx.ancestors {
if let Some(Some(ancestor)) = audit_pool.get(*ancestor_id as usize) {
package.push((*ancestor_id, ancestor.order(), ancestor.ancestors.len()));
}
}
package.sort_unstable_by(|a, b| -> Ordering {
if a.2 != b.2 {
// order by ascending ancestor count
a.2.cmp(&b.2)
} else if a.1 != b.1 {
// tie-break by ascending partial txid
a.1.cmp(&b.1)
} else {
// tie-break partial txid collisions by ascending uid
a.0.cmp(&b.0)
}
});
package.push((next_tx.uid, next_tx.order(), next_tx.ancestors.len()));
let cluster_rate = next_tx.cluster_rate();
for (txid, _, _) in &package {
cluster.push(*txid);
if let Some(Some(tx)) = audit_pool.get_mut(*txid as usize) {
tx.used = true;
tx.set_dirty_if_different(cluster_rate);
transactions.push(tx.uid);
block_weight += tx.weight;
block_sigops += tx.sigops;
}
update_descendants(*txid, &mut audit_pool, &mut modified, cluster_rate);
}
if is_cluster {
clusters.push(cluster);
}
failures = 0;
}
}
// this block is full
let exceeded_package_tries =
failures > 1000 && block_weight > (MAX_BLOCK_WEIGHT_UNITS - BLOCK_RESERVED_WEIGHT);
let queue_is_empty = mempool_stack.is_empty() && modified.is_empty();
if (exceeded_package_tries || queue_is_empty) && blocks.len() < (MAX_BLOCKS - 1) {
// finalize this block
if transactions.is_empty() {
info!("trying to push an empty block! breaking loop! mempool {:#?} | modified {:#?} | overflow {:#?}", mempool_stack.len(), modified.len(), overflow.len());
break;
}
blocks.push(transactions);
block_weights.push(block_weight);
// reset for the next block
transactions = Vec::with_capacity(initial_txes_per_block);
block_weight = BLOCK_RESERVED_WEIGHT;
block_sigops = BLOCK_RESERVED_SIGOPS;
failures = 0;
// 'overflow' packages didn't fit in this block, but are valid candidates for the next
overflow.reverse();
for overflowed in &overflow {
if let Some(Some(overflowed_tx)) = audit_pool.get(*overflowed as usize) {
if overflowed_tx.modified {
modified.push(
*overflowed,
TxPriority {
uid: *overflowed,
order: overflowed_tx.order(),
score: overflowed_tx.score(),
},
);
} else {
mempool_stack.push(*overflowed);
}
}
}
overflow = Vec::new();
}
}
info!("add the final unbounded block if it contains any transactions");
if !transactions.is_empty() {
blocks.push(transactions);
block_weights.push(block_weight);
}
info!("make a list of dirty transactions and their new rates");
let mut rates: Vec<Vec<f64>> = Vec::new();
for (uid, thread_tx) in mempool {
// Takes ownership of the audit_tx and replaces with None
if let Some(Some(audit_tx)) = audit_pool.get_mut(*uid as usize).map(Option::take) {
trace!("txid: {}, is_dirty: {}", uid, audit_tx.dirty);
if audit_tx.dirty {
rates.push(vec![f64::from(*uid), audit_tx.effective_fee_per_vsize]);
thread_tx.effective_fee_per_vsize = audit_tx.effective_fee_per_vsize;
}
// Drops the AuditTransaction manually
// There are no audit_txs that are not in the mempool HashMap
// So there is guaranteed to be no memory leaks.
ManuallyDrop::into_inner(audit_tx);
}
}
trace!("\n\n\n\n\n====================");
trace!("blocks: {:#?}", blocks);
trace!("clusters: {:#?}", clusters);
trace!("rates: {:#?}\n====================\n\n\n\n\n", rates);
GbtResult {
blocks,
block_weights,
clusters,
rates,
overflow,
}
}
fn next_valid_from_stack<'a>(
mempool_stack: &mut Vec<u32>,
audit_pool: &'a AuditPool,
) -> Option<&'a AuditTransaction> {
while let Some(next_txid) = mempool_stack.last() {
match audit_pool.get(*next_txid as usize) {
Some(Some(tx)) if !tx.used && !tx.modified => {
return Some(tx);
}
_ => {
mempool_stack.pop();
}
}
}
None
}
fn next_valid_from_queue<'a>(
queue: &mut ModifiedQueue,
audit_pool: &'a AuditPool,
) -> Option<&'a AuditTransaction> {
while let Some((next_txid, _)) = queue.peek() {
match audit_pool.get(*next_txid as usize) {
Some(Some(tx)) if !tx.used => {
return Some(tx);
}
_ => {
queue.pop();
}
}
}
None
}
fn set_relatives(txid: u32, audit_pool: &mut AuditPool) {
let mut parents: HashSet<u32, U32HasherState> = u32hashset_new();
if let Some(Some(tx)) = audit_pool.get(txid as usize) {
if tx.relatives_set_flag {
return;
}
for input in &tx.inputs {
parents.insert(*input);
}
} else {
return;
}
let mut ancestors: HashSet<u32, U32HasherState> = u32hashset_new();
for parent_id in &parents {
set_relatives(*parent_id, audit_pool);
if let Some(Some(parent)) = audit_pool.get_mut(*parent_id as usize) {
// Safety: ancestors must always contain only txes in audit_pool
ancestors.insert(*parent_id);
parent.children.insert(txid);
for ancestor in &parent.ancestors {
ancestors.insert(*ancestor);
}
}
}
let mut total_fee: u64 = 0;
let mut total_sigop_adjusted_weight: u32 = 0;
let mut total_sigop_adjusted_vsize: u32 = 0;
let mut total_sigops: u32 = 0;
for ancestor_id in &ancestors {
if let Some(ancestor) = audit_pool
.get(*ancestor_id as usize)
.expect("audit_pool contains all ancestors")
{
total_fee += ancestor.fee;
total_sigop_adjusted_weight += ancestor.sigop_adjusted_weight;
total_sigop_adjusted_vsize += ancestor.sigop_adjusted_vsize;
total_sigops += ancestor.sigops;
} else { todo!() };
}
if let Some(Some(tx)) = audit_pool.get_mut(txid as usize) {
tx.set_ancestors(
ancestors,
total_fee,
total_sigop_adjusted_weight,
total_sigop_adjusted_vsize,
total_sigops,
);
}
}
// iterate over remaining descendants, removing the root as a valid ancestor & updating the ancestor score
fn update_descendants(
root_txid: u32,
audit_pool: &mut AuditPool,
modified: &mut ModifiedQueue,
cluster_rate: f64,
) {
let mut visited: HashSet<u32, U32HasherState> = u32hashset_new();
let mut descendant_stack: Vec<u32> = Vec::new();
let root_fee: u64;
let root_sigop_adjusted_weight: u32;
let root_sigop_adjusted_vsize: u32;
let root_sigops: u32;
if let Some(Some(root_tx)) = audit_pool.get(root_txid as usize) {
for descendant_id in &root_tx.children {
if !visited.contains(descendant_id) {
descendant_stack.push(*descendant_id);
visited.insert(*descendant_id);
}
}
root_fee = root_tx.fee;
root_sigop_adjusted_weight = root_tx.sigop_adjusted_weight;
root_sigop_adjusted_vsize = root_tx.sigop_adjusted_vsize;
root_sigops = root_tx.sigops;
} else {
return;
}
while let Some(next_txid) = descendant_stack.pop() {
if let Some(Some(descendant)) = audit_pool.get_mut(next_txid as usize) {
// remove root tx as ancestor
let old_score = descendant.remove_root(
root_txid,
root_fee,
root_sigop_adjusted_weight,
root_sigop_adjusted_vsize,
root_sigops,
cluster_rate,
);
// add to priority queue or update priority if score has changed
if descendant.score() < old_score {
descendant.modified = true;
modified.push_decrease(
descendant.uid,
TxPriority {
uid: descendant.uid,
order: descendant.order(),
score: descendant.score(),
},
);
} else if descendant.score() > old_score {
descendant.modified = true;
modified.push_increase(
descendant.uid,
TxPriority {
uid: descendant.uid,
order: descendant.order(),
score: descendant.score(),
},
);
}
// add this node's children to the stack
for child_id in &descendant.children {
if !visited.contains(child_id) {
descendant_stack.push(*child_id);
visited.insert(*child_id);
}
}
}
}
}

184
rust/gbt/src/lib.rs Normal file
View File

@@ -0,0 +1,184 @@
#![warn(clippy::all)]
#![warn(clippy::pedantic)]
#![warn(clippy::nursery)]
#![allow(clippy::cast_precision_loss)]
#![allow(clippy::cast_possible_truncation)]
#![allow(clippy::cast_sign_loss)]
#![allow(clippy::float_cmp)]
use napi::bindgen_prelude::Result;
use napi_derive::napi;
use thread_transaction::ThreadTransaction;
use thread_acceleration::ThreadAcceleration;
use tracing::{debug, info, trace};
use tracing_log::LogTracer;
use tracing_subscriber::{EnvFilter, FmtSubscriber};
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
mod audit_transaction;
mod gbt;
mod thread_transaction;
mod thread_acceleration;
mod u32_hasher_types;
use u32_hasher_types::{u32hashmap_with_capacity, U32HasherState};
/// This is the initial capacity of the `GbtGenerator` struct's inner `HashMap`.
///
/// Note: This doesn't *have* to be a power of 2. (uwu)
const STARTING_CAPACITY: usize = 1_048_576;
type ThreadTransactionsMap = HashMap<u32, ThreadTransaction, U32HasherState>;
#[napi]
pub struct GbtGenerator {
thread_transactions: Arc<Mutex<ThreadTransactionsMap>>,
}
#[napi::module_init]
fn init() {
// Set all `tracing` logs to print to STDOUT
// Note: Passing RUST_LOG env variable to the node process
// will change the log level for the rust module.
tracing::subscriber::set_global_default(
FmtSubscriber::builder()
.with_env_filter(EnvFilter::from_default_env())
.with_ansi(
// Default to no-color logs.
// Setting RUST_LOG_COLOR to 1 or true|TRUE|True etc.
// will enable color
std::env::var("RUST_LOG_COLOR")
.map(|s| ["1", "true"].contains(&&*s.to_lowercase()))
.unwrap_or(false),
)
.finish(),
)
.expect("Logging subscriber failed");
// Convert all `log` logs into `tracing` events
LogTracer::init().expect("Legacy log subscriber failed");
}
#[napi]
impl GbtGenerator {
#[napi(constructor)]
#[allow(clippy::new_without_default)]
#[must_use]
pub fn new() -> Self {
debug!("Created new GbtGenerator");
Self {
thread_transactions: Arc::new(Mutex::new(u32hashmap_with_capacity(STARTING_CAPACITY))),
}
}
/// # Errors
///
/// Rejects if the thread panics or if the Mutex is poisoned.
#[napi]
pub async fn make(&self, mempool: Vec<ThreadTransaction>, accelerations: Vec<ThreadAcceleration>, max_uid: u32) -> Result<GbtResult> {
trace!("make: Current State {:#?}", self.thread_transactions);
run_task(
Arc::clone(&self.thread_transactions),
accelerations,
max_uid as usize,
move |map| {
for tx in mempool {
map.insert(tx.uid, tx);
}
},
)
.await
}
/// # Errors
///
/// Rejects if the thread panics or if the Mutex is poisoned.
#[napi]
pub async fn update(
&self,
new_txs: Vec<ThreadTransaction>,
remove_txs: Vec<u32>,
accelerations: Vec<ThreadAcceleration>,
max_uid: u32,
) -> Result<GbtResult> {
trace!("update: Current State {:#?}", self.thread_transactions);
run_task(
Arc::clone(&self.thread_transactions),
accelerations,
max_uid as usize,
move |map| {
for tx in new_txs {
map.insert(tx.uid, tx);
}
for txid in &remove_txs {
map.remove(txid);
}
},
)
.await
}
}
/// The result from calling the gbt function.
///
/// This tuple contains the following:
/// blocks: A 2D Vector of transaction IDs (u32), the inner Vecs each represent a block.
/// block_weights: A Vector of total weights per block.
/// clusters: A 2D Vector of transaction IDs representing clusters of dependent mempool transactions
/// rates: A Vector of tuples containing transaction IDs (u32) and effective fee per vsize (f64)
#[napi(constructor)]
pub struct GbtResult {
pub blocks: Vec<Vec<u32>>,
pub block_weights: Vec<u32>,
pub clusters: Vec<Vec<u32>>,
pub rates: Vec<Vec<f64>>, // Tuples not supported. u32 fits inside f64
pub overflow: Vec<u32>,
}
/// All on another thread, this runs an arbitrary task in between
/// taking the lock and running gbt.
///
/// Rather than filling / updating the `HashMap` on the main thread,
/// this allows for `HashMap` modifying tasks to be run before running and returning gbt results.
///
/// `thread_transactions` is a cloned `Arc` of the `Mutex` for the `HashMap` state.
/// `callback` is a `'static + Send` `FnOnce` closure/function that takes a mutable reference
/// to the `HashMap` as the only argument. (A move closure is recommended to meet the bounds)
async fn run_task<F>(
thread_transactions: Arc<Mutex<ThreadTransactionsMap>>,
accelerations: Vec<ThreadAcceleration>,
max_uid: usize,
callback: F,
) -> Result<GbtResult>
where
F: FnOnce(&mut ThreadTransactionsMap) + Send + 'static,
{
debug!("Spawning thread...");
let handle = napi::tokio::task::spawn_blocking(move || {
debug!(
"Getting lock for thread_transactions from thread {:?}...",
std::thread::current().id()
);
let mut map = thread_transactions
.lock()
.map_err(|_| napi::Error::from_reason("THREAD_TRANSACTIONS Mutex poisoned"))?;
callback(&mut map);
info!("Starting gbt algorithm for {} elements...", map.len());
let result = gbt::gbt(&mut map, &accelerations, max_uid);
info!("Finished gbt algorithm for {} elements...", map.len());
debug!(
"Releasing lock for thread_transactions from thread {:?}...",
std::thread::current().id()
);
drop(map);
Ok(result)
});
handle
.await
.map_err(|_| napi::Error::from_reason("thread panicked"))?
}

View File

@@ -0,0 +1,8 @@
use napi_derive::napi;
#[derive(Debug)]
#[napi(object)]
pub struct ThreadAcceleration {
pub uid: u32,
pub delta: f64, // fee delta
}

View File

@@ -0,0 +1,13 @@
use napi_derive::napi;
#[derive(Debug)]
#[napi(object)]
pub struct ThreadTransaction {
pub uid: u32,
pub order: u32,
pub fee: f64,
pub weight: u32,
pub sigops: u32,
pub effective_fee_per_vsize: f64,
pub inputs: Vec<u32>,
}

View File

@@ -0,0 +1,132 @@
use priority_queue::PriorityQueue;
use std::{
collections::{HashMap, HashSet},
fmt::Debug,
hash::{BuildHasher, Hasher},
};
/// This is the only way to create a `HashMap` with the `U32HasherState` and capacity
pub fn u32hashmap_with_capacity<V>(capacity: usize) -> HashMap<u32, V, U32HasherState> {
HashMap::with_capacity_and_hasher(capacity, U32HasherState(()))
}
/// This is the only way to create a `PriorityQueue` with the `U32HasherState` and capacity
pub fn u32priority_queue_with_capacity<V: Ord>(
capacity: usize,
) -> PriorityQueue<u32, V, U32HasherState> {
PriorityQueue::with_capacity_and_hasher(capacity, U32HasherState(()))
}
/// This is the only way to create a `HashSet` with the `U32HasherState`
pub fn u32hashset_new() -> HashSet<u32, U32HasherState> {
HashSet::with_hasher(U32HasherState(()))
}
/// A private unit type is contained so no one can make an instance of it.
#[derive(Clone)]
pub struct U32HasherState(());
impl Debug for U32HasherState {
fn fmt(&self, _: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
Ok(())
}
}
impl BuildHasher for U32HasherState {
type Hasher = U32Hasher;
fn build_hasher(&self) -> Self::Hasher {
U32Hasher(0)
}
}
/// This also can't be created outside this module due to private field.
pub struct U32Hasher(u32);
impl Hasher for U32Hasher {
fn finish(&self) -> u64 {
// Safety: Two u32s next to each other will make a u64
bytemuck::cast([self.0, 0])
}
fn write(&mut self, bytes: &[u8]) {
// Assert in debug builds (testing too) that only 4 byte keys (u32, i32, f32, etc.) run
debug_assert!(bytes.len() == 4);
// Safety: We know that the size of the key is 4 bytes
// We also know that the only way to get an instance of HashMap using this "hasher"
// is through the public functions in this module which set the key type to u32.
self.0 = *bytemuck::from_bytes(bytes);
}
}
#[cfg(test)]
mod tests {
use super::U32HasherState;
use priority_queue::PriorityQueue;
use std::collections::HashMap;
#[test]
fn test_hashmap() {
let mut hm: HashMap<u32, String, U32HasherState> = HashMap::with_hasher(U32HasherState(()));
// Testing basic operations with the custom hasher
hm.insert(0, String::from("0"));
hm.insert(42, String::from("42"));
hm.insert(256, String::from("256"));
hm.insert(u32::MAX, String::from("MAX"));
hm.insert(u32::MAX >> 2, String::from("MAX >> 2"));
assert_eq!(hm.get(&0), Some(&String::from("0")));
assert_eq!(hm.get(&42), Some(&String::from("42")));
assert_eq!(hm.get(&256), Some(&String::from("256")));
assert_eq!(hm.get(&u32::MAX), Some(&String::from("MAX")));
assert_eq!(hm.get(&(u32::MAX >> 2)), Some(&String::from("MAX >> 2")));
assert_eq!(hm.get(&(u32::MAX >> 4)), None);
assert_eq!(hm.get(&3), None);
assert_eq!(hm.get(&43), None);
}
#[test]
fn test_priority_queue() {
let mut pq: PriorityQueue<u32, i32, U32HasherState> =
PriorityQueue::with_hasher(U32HasherState(()));
// Testing basic operations with the custom hasher
assert_eq!(pq.push(1, 5), None);
assert_eq!(pq.push(2, -10), None);
assert_eq!(pq.push(3, 7), None);
assert_eq!(pq.push(4, 20), None);
assert_eq!(pq.push(u32::MAX, -42), None);
assert_eq!(pq.push_increase(1, 4), Some(4));
assert_eq!(pq.push_increase(2, -8), Some(-10));
assert_eq!(pq.push_increase(3, 5), Some(5));
assert_eq!(pq.push_increase(4, 21), Some(20));
assert_eq!(pq.push_increase(u32::MAX, -99), Some(-99));
assert_eq!(pq.push_increase(42, 1337), None);
assert_eq!(pq.push_decrease(1, 4), Some(5));
assert_eq!(pq.push_decrease(2, -10), Some(-8));
assert_eq!(pq.push_decrease(3, 5), Some(7));
assert_eq!(pq.push_decrease(4, 20), Some(21));
assert_eq!(pq.push_decrease(u32::MAX, 100), Some(100));
assert_eq!(pq.push_decrease(69, 420), None);
assert_eq!(pq.peek(), Some((&42, &1337)));
assert_eq!(pq.pop(), Some((42, 1337)));
assert_eq!(pq.peek(), Some((&69, &420)));
assert_eq!(pq.pop(), Some((69, 420)));
assert_eq!(pq.peek(), Some((&4, &20)));
assert_eq!(pq.pop(), Some((4, 20)));
assert_eq!(pq.peek(), Some((&3, &5)));
assert_eq!(pq.pop(), Some((3, 5)));
assert_eq!(pq.peek(), Some((&1, &4)));
assert_eq!(pq.pop(), Some((1, 4)));
assert_eq!(pq.peek(), Some((&2, &-10)));
assert_eq!(pq.pop(), Some((2, -10)));
assert_eq!(pq.peek(), Some((&u32::MAX, &-42)));
assert_eq!(pq.pop(), Some((u32::MAX, -42)));
assert_eq!(pq.peek(), None);
assert_eq!(pq.pop(), None);
}
}