Refactor rust-gbt

This commit is contained in:
junderw
2024-03-10 13:27:09 +09:00
parent 7bedb9488b
commit 92a5fc8159
22 changed files with 187 additions and 540 deletions

3
backend/.gitignore vendored
View File

@@ -54,3 +54,6 @@ Thumbs.db
# package folder (npm run package output)
/package
# Rust GBT folder (We build externally first)
/rust-gbt

View File

@@ -6,7 +6,4 @@ cd package/node_modules
rm -r \
typescript \
@typescript-eslint \
@napi-rs \
./rust-gbt/src \
./rust-gbt/Cargo.toml \
./rust-gbt/build.rs
@napi-rs

View File

@@ -7,6 +7,7 @@
"": {
"name": "mempool-backend",
"version": "3.0.0-dev",
"hasInstallScript": true,
"license": "GNU Affero General Public License v3.0",
"dependencies": {
"@babel/core": "^7.24.0",
@@ -42,6 +43,13 @@
"ts-node": "^10.9.1"
}
},
"../rust/gbt": {
"version": "3.0.1",
"extraneous": true,
"engines": {
"node": ">= 12"
}
},
"node_modules/@aashutoshrathi/word-wrap": {
"version": "1.2.6",
"resolved": "https://registry.npmjs.org/@aashutoshrathi/word-wrap/-/word-wrap-1.2.6.tgz",
@@ -1499,21 +1507,6 @@
"node": ">=6"
}
},
"node_modules/@napi-rs/cli": {
"version": "2.18.0",
"resolved": "https://registry.npmjs.org/@napi-rs/cli/-/cli-2.18.0.tgz",
"integrity": "sha512-lfSRT7cs3iC4L+kv9suGYQEezn5Nii7Kpu+THsYVI0tA1Vh59LH45p4QADaD7hvIkmOz79eEGtoKQ9nAkAPkzA==",
"bin": {
"napi": "scripts/index.js"
},
"engines": {
"node": ">= 10"
},
"funding": {
"type": "github",
"url": "https://github.com/sponsors/Brooooooklyn"
}
},
"node_modules/@noble/hashes": {
"version": "1.3.0",
"resolved": "https://registry.npmjs.org/@noble/hashes/-/hashes-1.3.0.tgz",
@@ -7668,10 +7661,6 @@
"rust-gbt": {
"name": "gbt",
"version": "3.0.1",
"hasInstallScript": true,
"dependencies": {
"@napi-rs/cli": "2.18.0"
},
"engines": {
"node": ">= 12"
}
@@ -8774,11 +8763,6 @@
"resolved": "https://registry.npmjs.org/@mempool/electrum-client/-/electrum-client-1.1.9.tgz",
"integrity": "sha512-mlvPiCzUlaETpYW3i6V87A24jjMYgsebaXtUo3WQyyLnYUuxs0KiXQ2mnKh3h15j8Xg/hfxeGIi+5OC9u0nftQ=="
},
"@napi-rs/cli": {
"version": "2.18.0",
"resolved": "https://registry.npmjs.org/@napi-rs/cli/-/cli-2.18.0.tgz",
"integrity": "sha512-lfSRT7cs3iC4L+kv9suGYQEezn5Nii7Kpu+THsYVI0tA1Vh59LH45p4QADaD7hvIkmOz79eEGtoKQ9nAkAPkzA=="
},
"@noble/hashes": {
"version": "1.3.0",
"resolved": "https://registry.npmjs.org/@noble/hashes/-/hashes-1.3.0.tgz",
@@ -12701,10 +12685,7 @@
}
},
"rust-gbt": {
"version": "file:rust-gbt",
"requires": {
"@napi-rs/cli": "2.18.0"
}
"version": "file:rust-gbt"
},
"safe-buffer": {
"version": "5.2.1",

View File

@@ -22,10 +22,12 @@
"main": "index.ts",
"scripts": {
"tsc": "./node_modules/typescript/bin/tsc -p tsconfig.build.json",
"build": "npm run rust-build && npm run tsc && npm run create-resources",
"build": "npm run tsc && npm run create-resources",
"clean": "rm -rf ./dist/ ./node_modules/ ./package/ ./rust-gbt/",
"create-resources": "cp ./src/tasks/price-feeds/mtgox-weekly.json ./dist/tasks && node dist/api/fetch-version.js",
"package": "./npm_package.sh",
"package-rm-build-deps": "./npm_package_rm_build_deps.sh",
"preinstall": "cd ../rust/gbt && npm run build-release && npm run to-backend",
"start": "node --max-old-space-size=2048 dist/index.js",
"start-production": "node --max-old-space-size=16384 dist/index.js",
"reindex-updated-pools": "npm run start-production --update-pools",
@@ -34,9 +36,7 @@
"test:ci": "CI=true ./node_modules/.bin/jest --coverage",
"lint": "./node_modules/.bin/eslint . --ext .ts",
"lint:fix": "./node_modules/.bin/eslint . --ext .ts --fix",
"prettier": "./node_modules/.bin/prettier --write \"src/**/*.{js,ts}\"",
"rust-clean": "cd rust-gbt && rm -f *.node index.d.ts index.js && rm -rf target && cd ../",
"rust-build": "npm run rust-clean && cd rust-gbt && npm run build-release"
"prettier": "./node_modules/.bin/prettier --write \"src/**/*.{js,ts}\""
},
"dependencies": {
"@babel/core": "^7.24.0",

View File

@@ -1,4 +0,0 @@
*.node
**/node_modules
**/.DS_Store
npm-debug.log*

View File

@@ -1,25 +0,0 @@
[package]
name = "gbt"
version = "1.0.0"
description = "An efficient re-implementation of the getBlockTemplate algorithm in Rust"
authors = ["mononaut"]
edition = "2021"
publish = false
[lib]
crate-type = ["cdylib"]
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
priority-queue = "2.0.2"
bytes = "1.4.0"
napi = { version = "2.16.0", features = ["napi8", "tokio_rt"] }
napi-derive = "2.16.0"
bytemuck = "1.13.1"
tracing = "0.1.36"
tracing-log = "0.2.0"
tracing-subscriber = { version = "0.3.15", features = ["env-filter"]}
[build-dependencies]
napi-build = "2.1.2"

View File

@@ -1,123 +0,0 @@
# gbt
**gbt:** rust implementation of the getBlockTemplate algorithm
This project was bootstrapped by [napi](https://www.npmjs.com/package/@napi-rs/cli).
## Installing gbt
Installing gbt requires a [supported version of Node and Rust](https://github.com/napi-rs/napi-rs#platform-support).
The build process also requires [Rust](https://www.rust-lang.org/tools/install) to be installed.
You can install the project with npm. In the project directory, run:
```sh
$ npm install
```
This fully installs the project, including installing any dependencies and running the build.
## Building gbt
If you have already installed the project and only want to run the build, run:
```sh
$ npm run build
```
This command uses the [napi build](https://www.npmjs.com/package/@napi-rs/cli) utility to run the Rust build and copy the built library into `./gbt.[TARGET_TRIPLE].node`.
## Exploring gbt
After building gbt, you can explore its exports at the Node REPL:
```sh
$ npm install
$ node
> require('.').hello()
"hello node"
```
## Available Scripts
In the project directory, you can run:
### `npm install`
Installs the project, including running `npm run build-release`.
### `npm build`
Builds the Node addon (`gbt.[TARGET_TRIPLE].node`) from source.
Additional [`cargo build`](https://doc.rust-lang.org/cargo/commands/cargo-build.html) arguments may be passed to `npm build` and `npm build-*` commands. For example, to enable a [cargo feature](https://doc.rust-lang.org/cargo/reference/features.html):
```
npm run build -- --feature=beetle
```
#### `npm build-debug`
Alias for `npm build`.
#### `npm build-release`
Same as [`npm build`](#npm-build) but, builds the module with the [`release`](https://doc.rust-lang.org/cargo/reference/profiles.html#release) profile. Release builds will compile slower, but run faster.
### `npm test`
Runs the unit tests by calling `cargo test`. You can learn more about [adding tests to your Rust code](https://doc.rust-lang.org/book/ch11-01-writing-tests.html) from the [Rust book](https://doc.rust-lang.org/book/).
## Project Layout
The directory structure of this project is:
```
gbt/
├── Cargo.toml
├── README.md
├── gbt.[TARGET_TRIPLE].node
├── package.json
├── src/
| └── lib.rs
└── target/
```
### Cargo.toml
The Cargo [manifest file](https://doc.rust-lang.org/cargo/reference/manifest.html), which informs the `cargo` command.
### README.md
This file.
### gbt.\[TARGET_TRIPLE\].node
The Node addon—i.e., a binary Node module—generated by building the project. This is the main module for this package, as dictated by the `"main"` key in `package.json`.
Under the hood, a [Node addon](https://nodejs.org/api/addons.html) is a [dynamically-linked shared object](https://en.wikipedia.org/wiki/Library_(computing)#Shared_libraries). The `"build"` script produces this file by copying it from within the `target/` directory, which is where the Rust build produces the shared object.
### package.json
The npm [manifest file](https://docs.npmjs.com/cli/v7/configuring-npm/package-json), which informs the `npm` command.
### src/
The directory tree containing the Rust source code for the project.
### src/lib.rs
The Rust library's main module.
### target/
Binary artifacts generated by the Rust build.
## Learn More
To learn more about Neon, see the [Napi-RS documentation](https://napi.rs/docs/introduction/getting-started).
To learn more about Rust, see the [Rust documentation](https://www.rust-lang.org).
To learn more about Node, see the [Node documentation](https://nodejs.org).

View File

@@ -1,3 +0,0 @@
fn main() {
napi_build::setup();
}

View File

@@ -1,50 +0,0 @@
/* tslint:disable */
/* eslint-disable */
/* auto-generated by NAPI-RS */
export interface ThreadTransaction {
uid: number
order: number
fee: number
weight: number
sigops: number
effectiveFeePerVsize: number
inputs: Array<number>
}
export interface ThreadAcceleration {
uid: number
delta: number
}
export class GbtGenerator {
constructor()
/**
* # Errors
*
* Rejects if the thread panics or if the Mutex is poisoned.
*/
make(mempool: Array<ThreadTransaction>, accelerations: Array<ThreadAcceleration>, maxUid: number): Promise<GbtResult>
/**
* # Errors
*
* Rejects if the thread panics or if the Mutex is poisoned.
*/
update(newTxs: Array<ThreadTransaction>, removeTxs: Array<number>, accelerations: Array<ThreadAcceleration>, maxUid: number): Promise<GbtResult>
}
/**
* The result from calling the gbt function.
*
* This tuple contains the following:
* blocks: A 2D Vector of transaction IDs (u32), the inner Vecs each represent a block.
* block_weights: A Vector of total weights per block.
* clusters: A 2D Vector of transaction IDs representing clusters of dependent mempool transactions
* rates: A Vector of tuples containing transaction IDs (u32) and effective fee per vsize (f64)
*/
export class GbtResult {
blocks: Array<Array<number>>
blockWeights: Array<number>
clusters: Array<Array<number>>
rates: Array<Array<number>>
overflow: Array<number>
constructor(blocks: Array<Array<number>>, blockWeights: Array<number>, clusters: Array<Array<number>>, rates: Array<Array<number>>, overflow: Array<number>)
}

View File

@@ -1,301 +0,0 @@
/* tslint:disable */
/* eslint-disable */
/* prettier-ignore */
/* auto-generated by NAPI-RS */
const { existsSync, readFileSync } = require('fs')
const { join } = require('path')
const { platform, arch } = process
let nativeBinding = null
let localFileExisted = false
let loadError = null
function isMusl() {
// For Node 10
if (!process.report || typeof process.report.getReport !== 'function') {
try {
const lddPath = require('child_process').execSync('which ldd').toString().trim()
return readFileSync(lddPath, 'utf8').includes('musl')
} catch (e) {
return true
}
} else {
const { glibcVersionRuntime } = process.report.getReport().header
return !glibcVersionRuntime
}
}
switch (platform) {
case 'android':
switch (arch) {
case 'arm64':
localFileExisted = existsSync(join(__dirname, 'gbt.android-arm64.node'))
try {
if (localFileExisted) {
nativeBinding = require('./gbt.android-arm64.node')
} else {
nativeBinding = require('gbt-android-arm64')
}
} catch (e) {
loadError = e
}
break
case 'arm':
localFileExisted = existsSync(join(__dirname, 'gbt.android-arm-eabi.node'))
try {
if (localFileExisted) {
nativeBinding = require('./gbt.android-arm-eabi.node')
} else {
nativeBinding = require('gbt-android-arm-eabi')
}
} catch (e) {
loadError = e
}
break
default:
throw new Error(`Unsupported architecture on Android ${arch}`)
}
break
case 'win32':
switch (arch) {
case 'x64':
localFileExisted = existsSync(
join(__dirname, 'gbt.win32-x64-msvc.node')
)
try {
if (localFileExisted) {
nativeBinding = require('./gbt.win32-x64-msvc.node')
} else {
nativeBinding = require('gbt-win32-x64-msvc')
}
} catch (e) {
loadError = e
}
break
case 'ia32':
localFileExisted = existsSync(
join(__dirname, 'gbt.win32-ia32-msvc.node')
)
try {
if (localFileExisted) {
nativeBinding = require('./gbt.win32-ia32-msvc.node')
} else {
nativeBinding = require('gbt-win32-ia32-msvc')
}
} catch (e) {
loadError = e
}
break
case 'arm64':
localFileExisted = existsSync(
join(__dirname, 'gbt.win32-arm64-msvc.node')
)
try {
if (localFileExisted) {
nativeBinding = require('./gbt.win32-arm64-msvc.node')
} else {
nativeBinding = require('gbt-win32-arm64-msvc')
}
} catch (e) {
loadError = e
}
break
default:
throw new Error(`Unsupported architecture on Windows: ${arch}`)
}
break
case 'darwin':
localFileExisted = existsSync(join(__dirname, 'gbt.darwin-universal.node'))
try {
if (localFileExisted) {
nativeBinding = require('./gbt.darwin-universal.node')
} else {
nativeBinding = require('gbt-darwin-universal')
}
break
} catch {}
switch (arch) {
case 'x64':
localFileExisted = existsSync(join(__dirname, 'gbt.darwin-x64.node'))
try {
if (localFileExisted) {
nativeBinding = require('./gbt.darwin-x64.node')
} else {
nativeBinding = require('gbt-darwin-x64')
}
} catch (e) {
loadError = e
}
break
case 'arm64':
localFileExisted = existsSync(
join(__dirname, 'gbt.darwin-arm64.node')
)
try {
if (localFileExisted) {
nativeBinding = require('./gbt.darwin-arm64.node')
} else {
nativeBinding = require('gbt-darwin-arm64')
}
} catch (e) {
loadError = e
}
break
default:
throw new Error(`Unsupported architecture on macOS: ${arch}`)
}
break
case 'freebsd':
if (arch !== 'x64') {
throw new Error(`Unsupported architecture on FreeBSD: ${arch}`)
}
localFileExisted = existsSync(join(__dirname, 'gbt.freebsd-x64.node'))
try {
if (localFileExisted) {
nativeBinding = require('./gbt.freebsd-x64.node')
} else {
nativeBinding = require('gbt-freebsd-x64')
}
} catch (e) {
loadError = e
}
break
case 'linux':
switch (arch) {
case 'x64':
if (isMusl()) {
localFileExisted = existsSync(
join(__dirname, 'gbt.linux-x64-musl.node')
)
try {
if (localFileExisted) {
nativeBinding = require('./gbt.linux-x64-musl.node')
} else {
nativeBinding = require('gbt-linux-x64-musl')
}
} catch (e) {
loadError = e
}
} else {
localFileExisted = existsSync(
join(__dirname, 'gbt.linux-x64-gnu.node')
)
try {
if (localFileExisted) {
nativeBinding = require('./gbt.linux-x64-gnu.node')
} else {
nativeBinding = require('gbt-linux-x64-gnu')
}
} catch (e) {
loadError = e
}
}
break
case 'arm64':
if (isMusl()) {
localFileExisted = existsSync(
join(__dirname, 'gbt.linux-arm64-musl.node')
)
try {
if (localFileExisted) {
nativeBinding = require('./gbt.linux-arm64-musl.node')
} else {
nativeBinding = require('gbt-linux-arm64-musl')
}
} catch (e) {
loadError = e
}
} else {
localFileExisted = existsSync(
join(__dirname, 'gbt.linux-arm64-gnu.node')
)
try {
if (localFileExisted) {
nativeBinding = require('./gbt.linux-arm64-gnu.node')
} else {
nativeBinding = require('gbt-linux-arm64-gnu')
}
} catch (e) {
loadError = e
}
}
break
case 'arm':
localFileExisted = existsSync(
join(__dirname, 'gbt.linux-arm-gnueabihf.node')
)
try {
if (localFileExisted) {
nativeBinding = require('./gbt.linux-arm-gnueabihf.node')
} else {
nativeBinding = require('gbt-linux-arm-gnueabihf')
}
} catch (e) {
loadError = e
}
break
case 'riscv64':
if (isMusl()) {
localFileExisted = existsSync(
join(__dirname, 'gbt.linux-riscv64-musl.node')
)
try {
if (localFileExisted) {
nativeBinding = require('./gbt.linux-riscv64-musl.node')
} else {
nativeBinding = require('gbt-linux-riscv64-musl')
}
} catch (e) {
loadError = e
}
} else {
localFileExisted = existsSync(
join(__dirname, 'gbt.linux-riscv64-gnu.node')
)
try {
if (localFileExisted) {
nativeBinding = require('./gbt.linux-riscv64-gnu.node')
} else {
nativeBinding = require('gbt-linux-riscv64-gnu')
}
} catch (e) {
loadError = e
}
}
break
case 's390x':
localFileExisted = existsSync(
join(__dirname, 'gbt.linux-s390x-gnu.node')
)
try {
if (localFileExisted) {
nativeBinding = require('./gbt.linux-s390x-gnu.node')
} else {
nativeBinding = require('gbt-linux-s390x-gnu')
}
} catch (e) {
loadError = e
}
break
default:
throw new Error(`Unsupported architecture on Linux: ${arch}`)
}
break
default:
throw new Error(`Unsupported OS: ${platform}, architecture: ${arch}`)
}
if (!nativeBinding) {
if (loadError) {
throw loadError
}
throw new Error(`Failed to load native binding`)
}
const { GbtGenerator, GbtResult } = nativeBinding
module.exports.GbtGenerator = GbtGenerator
module.exports.GbtResult = GbtResult

View File

@@ -1,34 +0,0 @@
{
"name": "gbt",
"version": "3.0.1",
"lockfileVersion": 3,
"requires": true,
"packages": {
"": {
"name": "gbt",
"version": "3.0.1",
"hasInstallScript": true,
"dependencies": {
"@napi-rs/cli": "2.18.0"
},
"engines": {
"node": ">= 12"
}
},
"node_modules/@napi-rs/cli": {
"version": "2.18.0",
"resolved": "https://registry.npmjs.org/@napi-rs/cli/-/cli-2.18.0.tgz",
"integrity": "sha512-lfSRT7cs3iC4L+kv9suGYQEezn5Nii7Kpu+THsYVI0tA1Vh59LH45p4QADaD7hvIkmOz79eEGtoKQ9nAkAPkzA==",
"bin": {
"napi": "scripts/index.js"
},
"engines": {
"node": ">= 10"
},
"funding": {
"type": "github",
"url": "https://github.com/sponsors/Brooooooklyn"
}
}
}
}

View File

@@ -1,33 +0,0 @@
{
"name": "gbt",
"version": "3.0.1",
"description": "An efficient re-implementation of the getBlockTemplate algorithm in Rust",
"main": "index.js",
"types": "index.d.ts",
"scripts": {
"artifacts": "napi artifacts",
"build": "napi build --platform",
"build-debug": "npm run build",
"build-release": "npm run build -- --release --strip",
"install": "npm run build-release",
"prepublishOnly": "napi prepublish -t npm",
"test": "cargo test"
},
"author": "mononaut",
"napi": {
"name": "gbt",
"triples": {
"defaults": false,
"additional": [
"x86_64-unknown-linux-gnu",
"x86_64-unknown-freebsd"
]
}
},
"dependencies": {
"@napi-rs/cli": "2.18.0"
},
"engines": {
"node": ">= 12"
}
}

View File

@@ -1,225 +0,0 @@
use crate::{
u32_hasher_types::{u32hashset_new, U32HasherState},
ThreadTransaction, thread_acceleration::ThreadAcceleration,
};
use std::{
cmp::Ordering,
collections::HashSet,
hash::{Hash, Hasher},
};
#[allow(clippy::struct_excessive_bools)]
#[derive(Clone, Debug)]
pub struct AuditTransaction {
pub uid: u32,
order: u32,
pub fee: u64,
pub weight: u32,
// exact sigop-adjusted weight
pub sigop_adjusted_weight: u32,
// sigop-adjusted vsize rounded up the the next integer
pub sigop_adjusted_vsize: u32,
pub sigops: u32,
adjusted_fee_per_vsize: f64,
pub effective_fee_per_vsize: f64,
pub dependency_rate: f64,
pub inputs: Vec<u32>,
pub relatives_set_flag: bool,
pub ancestors: HashSet<u32, U32HasherState>,
pub children: HashSet<u32, U32HasherState>,
ancestor_fee: u64,
ancestor_sigop_adjusted_weight: u32,
ancestor_sigop_adjusted_vsize: u32,
ancestor_sigops: u32,
// Safety: Must be private to prevent NaN breaking Ord impl.
score: f64,
pub used: bool,
/// whether this transaction has been moved to the "modified" priority queue
pub modified: bool,
pub dirty: bool,
}
impl Hash for AuditTransaction {
fn hash<H: Hasher>(&self, state: &mut H) {
self.uid.hash(state);
}
}
impl PartialEq for AuditTransaction {
fn eq(&self, other: &Self) -> bool {
self.uid == other.uid
}
}
impl Eq for AuditTransaction {}
#[inline]
pub fn partial_cmp_uid_score(a: (u32, u32, f64), b: (u32, u32, f64)) -> Option<Ordering> {
// If either score is NaN, this is false,
// and partial_cmp will return None
if a.2 != b.2 {
// compare by score (sorts by ascending score)
a.2.partial_cmp(&b.2)
} else if a.1 != b.1 {
// tie-break by comparing partial txids (sorts by descending txid)
Some(b.1.cmp(&a.1))
} else {
// tie-break partial txid collisions by comparing uids (sorts by descending uid)
Some(b.0.cmp(&a.0))
}
}
impl PartialOrd for AuditTransaction {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
partial_cmp_uid_score(
(self.uid, self.order, self.score),
(other.uid, other.order, other.score),
)
}
}
impl Ord for AuditTransaction {
fn cmp(&self, other: &Self) -> Ordering {
// Safety: The only possible values for score are f64
// that are not NaN. This is because outside code can not
// freely assign score. Also, calc_new_score guarantees no NaN.
self.partial_cmp(other).expect("score will never be NaN")
}
}
#[inline]
fn calc_fee_rate(fee: u64, vsize: f64) -> f64 {
(fee as f64) / (if vsize == 0.0 { 1.0 } else { vsize })
}
impl AuditTransaction {
pub fn from_thread_transaction(tx: &ThreadTransaction, maybe_acceleration: Option<Option<&ThreadAcceleration>>) -> Self {
let fee_delta = match maybe_acceleration {
Some(Some(acceleration)) => acceleration.delta,
_ => 0.0
};
let fee = (tx.fee as u64) + (fee_delta as u64);
// rounded up to the nearest integer
let is_adjusted = tx.weight < (tx.sigops * 20);
let sigop_adjusted_vsize = ((tx.weight + 3) / 4).max(tx.sigops * 5);
let sigop_adjusted_weight = tx.weight.max(tx.sigops * 20);
let effective_fee_per_vsize = if is_adjusted || fee_delta > 0.0 {
calc_fee_rate(fee, f64::from(sigop_adjusted_weight) / 4.0)
} else {
tx.effective_fee_per_vsize
};
Self {
uid: tx.uid,
order: tx.order,
fee,
weight: tx.weight,
sigop_adjusted_weight,
sigop_adjusted_vsize,
sigops: tx.sigops,
adjusted_fee_per_vsize: calc_fee_rate(fee, f64::from(sigop_adjusted_vsize)),
effective_fee_per_vsize,
dependency_rate: f64::INFINITY,
inputs: tx.inputs.clone(),
relatives_set_flag: false,
ancestors: u32hashset_new(),
children: u32hashset_new(),
ancestor_fee: fee,
ancestor_sigop_adjusted_weight: sigop_adjusted_weight,
ancestor_sigop_adjusted_vsize: sigop_adjusted_vsize,
ancestor_sigops: tx.sigops,
score: 0.0,
used: false,
modified: false,
dirty: effective_fee_per_vsize != tx.effective_fee_per_vsize || fee_delta > 0.0,
}
}
#[inline]
pub const fn score(&self) -> f64 {
self.score
}
#[inline]
pub const fn order(&self) -> u32 {
self.order
}
#[inline]
pub const fn ancestor_sigop_adjusted_vsize(&self) -> u32 {
self.ancestor_sigop_adjusted_vsize
}
#[inline]
pub const fn ancestor_sigops(&self) -> u32 {
self.ancestor_sigops
}
#[inline]
pub fn cluster_rate(&self) -> f64 {
// Safety: self.ancestor_weight can never be 0.
// Even if it could, as it approaches 0, the value inside the min() call
// grows, so if we think of 0 as "grew infinitely" then dependency_rate would be
// the smaller of the two. If either side is NaN, the other side is returned.
self.dependency_rate.min(calc_fee_rate(
self.ancestor_fee,
f64::from(self.ancestor_sigop_adjusted_weight) / 4.0,
))
}
pub fn set_dirty_if_different(&mut self, cluster_rate: f64) {
if self.effective_fee_per_vsize != cluster_rate {
self.effective_fee_per_vsize = cluster_rate;
self.dirty = true;
}
}
/// Safety: This function must NEVER set score to NaN.
#[inline]
fn calc_new_score(&mut self) {
self.score = self.adjusted_fee_per_vsize.min(calc_fee_rate(
self.ancestor_fee,
f64::from(self.ancestor_sigop_adjusted_vsize),
));
}
#[inline]
pub fn set_ancestors(
&mut self,
ancestors: HashSet<u32, U32HasherState>,
total_fee: u64,
total_sigop_adjusted_weight: u32,
total_sigop_adjusted_vsize: u32,
total_sigops: u32,
) {
self.ancestors = ancestors;
self.ancestor_fee = self.fee + total_fee;
self.ancestor_sigop_adjusted_weight =
self.sigop_adjusted_weight + total_sigop_adjusted_weight;
self.ancestor_sigop_adjusted_vsize = self.sigop_adjusted_vsize + total_sigop_adjusted_vsize;
self.ancestor_sigops = self.sigops + total_sigops;
self.calc_new_score();
self.relatives_set_flag = true;
}
#[inline]
pub fn remove_root(
&mut self,
root_txid: u32,
root_fee: u64,
root_sigop_adjusted_weight: u32,
root_sigop_adjusted_vsize: u32,
root_sigops: u32,
cluster_rate: f64,
) -> f64 {
let old_score = self.score();
self.dependency_rate = self.dependency_rate.min(cluster_rate);
if self.ancestors.remove(&root_txid) {
self.ancestor_fee -= root_fee;
self.ancestor_sigop_adjusted_weight -= root_sigop_adjusted_weight;
self.ancestor_sigop_adjusted_vsize -= root_sigop_adjusted_vsize;
self.ancestor_sigops -= root_sigops;
self.calc_new_score();
}
old_score
}
}

View File

@@ -1,437 +0,0 @@
use priority_queue::PriorityQueue;
use std::{cmp::Ordering, collections::HashSet, mem::ManuallyDrop};
use tracing::{info, trace};
use crate::{
audit_transaction::{partial_cmp_uid_score, AuditTransaction},
u32_hasher_types::{u32hashset_new, u32priority_queue_with_capacity, U32HasherState},
GbtResult, ThreadTransactionsMap, thread_acceleration::ThreadAcceleration,
};
const MAX_BLOCK_WEIGHT_UNITS: u32 = 4_000_000 - 4_000;
const BLOCK_SIGOPS: u32 = 80_000;
const BLOCK_RESERVED_WEIGHT: u32 = 4_000;
const BLOCK_RESERVED_SIGOPS: u32 = 400;
const MAX_BLOCKS: usize = 8;
type AuditPool = Vec<Option<ManuallyDrop<AuditTransaction>>>;
type ModifiedQueue = PriorityQueue<u32, TxPriority, U32HasherState>;
#[derive(Debug)]
struct TxPriority {
uid: u32,
order: u32,
score: f64,
}
impl PartialEq for TxPriority {
fn eq(&self, other: &Self) -> bool {
self.uid == other.uid
}
}
impl Eq for TxPriority {}
impl PartialOrd for TxPriority {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
partial_cmp_uid_score(
(self.uid, self.order, self.score),
(other.uid, other.order, other.score),
)
}
}
impl Ord for TxPriority {
fn cmp(&self, other: &Self) -> Ordering {
self.partial_cmp(other).expect("score will never be NaN")
}
}
/// Build projected mempool blocks using an approximation of the transaction selection algorithm from Bitcoin Core.
///
/// See `BlockAssembler` in Bitcoin Core's
/// [miner.cpp](https://github.com/bitcoin/bitcoin/blob/master/src/node/miner.cpp).
/// Ported from mempool backend's
/// [tx-selection-worker.ts](https://github.com/mempool/mempool/blob/master/backend/src/api/tx-selection-worker.ts).
//
// TODO: Make gbt smaller to fix these lints.
#[allow(clippy::too_many_lines)]
#[allow(clippy::cognitive_complexity)]
pub fn gbt(mempool: &mut ThreadTransactionsMap, accelerations: &[ThreadAcceleration], max_uid: usize) -> GbtResult {
let mut indexed_accelerations = Vec::with_capacity(max_uid + 1);
indexed_accelerations.resize(max_uid + 1, None);
for acceleration in accelerations {
indexed_accelerations[acceleration.uid as usize] = Some(acceleration);
}
info!("Initializing working vecs with uid capacity for {}", max_uid + 1);
let mempool_len = mempool.len();
let mut audit_pool: AuditPool = Vec::with_capacity(max_uid + 1);
audit_pool.resize(max_uid + 1, None);
let mut mempool_stack: Vec<u32> = Vec::with_capacity(mempool_len);
let mut clusters: Vec<Vec<u32>> = Vec::new();
let mut block_weights: Vec<u32> = Vec::new();
info!("Initializing working structs");
for (uid, tx) in &mut *mempool {
let acceleration = indexed_accelerations.get(*uid as usize);
let audit_tx = AuditTransaction::from_thread_transaction(tx, acceleration.copied());
// Safety: audit_pool and mempool_stack must always contain the same transactions
audit_pool[*uid as usize] = Some(ManuallyDrop::new(audit_tx));
mempool_stack.push(*uid);
}
info!("Building relatives graph & calculate ancestor scores");
for txid in &mempool_stack {
set_relatives(*txid, &mut audit_pool);
}
trace!("Post relative graph Audit Pool: {:#?}", audit_pool);
info!("Sorting by descending ancestor score");
let mut mempool_stack: Vec<(u32, u32, f64)> = mempool_stack
.into_iter()
.map(|txid| {
let atx = audit_pool
.get(txid as usize)
.and_then(Option::as_ref)
.expect("All txids are from audit_pool");
(txid, atx.order(), atx.score())
})
.collect();
mempool_stack.sort_unstable_by(|a, b| partial_cmp_uid_score(*a, *b).expect("Not NaN"));
let mut mempool_stack: Vec<u32> = mempool_stack.into_iter().map(|(txid, _, _)| txid).collect();
info!("Building blocks by greedily choosing the highest feerate package");
info!("(i.e. the package rooted in the transaction with the best ancestor score)");
let mut blocks: Vec<Vec<u32>> = Vec::new();
let mut block_weight: u32 = BLOCK_RESERVED_WEIGHT;
let mut block_sigops: u32 = BLOCK_RESERVED_SIGOPS;
// No need to be bigger than 4096 transactions for the per-block transaction Vec.
let initial_txes_per_block: usize = 4096.min(mempool_len);
let mut transactions: Vec<u32> = Vec::with_capacity(initial_txes_per_block);
let mut modified: ModifiedQueue = u32priority_queue_with_capacity(mempool_len);
let mut overflow: Vec<u32> = Vec::new();
let mut failures = 0;
while !mempool_stack.is_empty() || !modified.is_empty() {
// This trace log storm is big, so to make scrolling through
// Each iteration easier, leaving a bunch of empty rows
// And a header of ======
trace!("\n\n\n\n\n\n\n\n\n\n==================================");
trace!("mempool_array: {:#?}", mempool_stack);
trace!("clusters: {:#?}", clusters);
trace!("modified: {:#?}", modified);
trace!("audit_pool: {:#?}", audit_pool);
trace!("blocks: {:#?}", blocks);
trace!("block_weight: {:#?}", block_weight);
trace!("block_sigops: {:#?}", block_sigops);
trace!("transactions: {:#?}", transactions);
trace!("overflow: {:#?}", overflow);
trace!("failures: {:#?}", failures);
trace!("\n==================================");
let next_from_stack = next_valid_from_stack(&mut mempool_stack, &audit_pool);
let next_from_queue = next_valid_from_queue(&mut modified, &audit_pool);
if next_from_stack.is_none() && next_from_queue.is_none() {
info!("No transactions left! {:#?} in overflow", overflow.len());
} else {
let (next_tx, from_stack) = match (next_from_stack, next_from_queue) {
(Some(stack_tx), Some(queue_tx)) => match queue_tx.cmp(stack_tx) {
std::cmp::Ordering::Less => (stack_tx, true),
_ => (queue_tx, false),
},
(Some(stack_tx), None) => (stack_tx, true),
(None, Some(queue_tx)) => (queue_tx, false),
(None, None) => unreachable!(),
};
if from_stack {
mempool_stack.pop();
} else {
modified.pop();
}
if blocks.len() < (MAX_BLOCKS - 1)
&& ((block_weight + (4 * next_tx.ancestor_sigop_adjusted_vsize())
>= MAX_BLOCK_WEIGHT_UNITS)
|| (block_sigops + next_tx.ancestor_sigops() > BLOCK_SIGOPS))
{
// hold this package in an overflow list while we check for smaller options
overflow.push(next_tx.uid);
failures += 1;
} else {
let mut package: Vec<(u32, u32, usize)> = Vec::new();
let mut cluster: Vec<u32> = Vec::new();
let is_cluster: bool = !next_tx.ancestors.is_empty();
for ancestor_id in &next_tx.ancestors {
if let Some(Some(ancestor)) = audit_pool.get(*ancestor_id as usize) {
package.push((*ancestor_id, ancestor.order(), ancestor.ancestors.len()));
}
}
package.sort_unstable_by(|a, b| -> Ordering {
if a.2 != b.2 {
// order by ascending ancestor count
a.2.cmp(&b.2)
} else if a.1 != b.1 {
// tie-break by ascending partial txid
a.1.cmp(&b.1)
} else {
// tie-break partial txid collisions by ascending uid
a.0.cmp(&b.0)
}
});
package.push((next_tx.uid, next_tx.order(), next_tx.ancestors.len()));
let cluster_rate = next_tx.cluster_rate();
for (txid, _, _) in &package {
cluster.push(*txid);
if let Some(Some(tx)) = audit_pool.get_mut(*txid as usize) {
tx.used = true;
tx.set_dirty_if_different(cluster_rate);
transactions.push(tx.uid);
block_weight += tx.weight;
block_sigops += tx.sigops;
}
update_descendants(*txid, &mut audit_pool, &mut modified, cluster_rate);
}
if is_cluster {
clusters.push(cluster);
}
failures = 0;
}
}
// this block is full
let exceeded_package_tries =
failures > 1000 && block_weight > (MAX_BLOCK_WEIGHT_UNITS - BLOCK_RESERVED_WEIGHT);
let queue_is_empty = mempool_stack.is_empty() && modified.is_empty();
if (exceeded_package_tries || queue_is_empty) && blocks.len() < (MAX_BLOCKS - 1) {
// finalize this block
if transactions.is_empty() {
info!("trying to push an empty block! breaking loop! mempool {:#?} | modified {:#?} | overflow {:#?}", mempool_stack.len(), modified.len(), overflow.len());
break;
}
blocks.push(transactions);
block_weights.push(block_weight);
// reset for the next block
transactions = Vec::with_capacity(initial_txes_per_block);
block_weight = BLOCK_RESERVED_WEIGHT;
block_sigops = BLOCK_RESERVED_SIGOPS;
failures = 0;
// 'overflow' packages didn't fit in this block, but are valid candidates for the next
overflow.reverse();
for overflowed in &overflow {
if let Some(Some(overflowed_tx)) = audit_pool.get(*overflowed as usize) {
if overflowed_tx.modified {
modified.push(
*overflowed,
TxPriority {
uid: *overflowed,
order: overflowed_tx.order(),
score: overflowed_tx.score(),
},
);
} else {
mempool_stack.push(*overflowed);
}
}
}
overflow = Vec::new();
}
}
info!("add the final unbounded block if it contains any transactions");
if !transactions.is_empty() {
blocks.push(transactions);
block_weights.push(block_weight);
}
info!("make a list of dirty transactions and their new rates");
let mut rates: Vec<Vec<f64>> = Vec::new();
for (uid, thread_tx) in mempool {
// Takes ownership of the audit_tx and replaces with None
if let Some(Some(audit_tx)) = audit_pool.get_mut(*uid as usize).map(Option::take) {
trace!("txid: {}, is_dirty: {}", uid, audit_tx.dirty);
if audit_tx.dirty {
rates.push(vec![f64::from(*uid), audit_tx.effective_fee_per_vsize]);
thread_tx.effective_fee_per_vsize = audit_tx.effective_fee_per_vsize;
}
// Drops the AuditTransaction manually
// There are no audit_txs that are not in the mempool HashMap
// So there is guaranteed to be no memory leaks.
ManuallyDrop::into_inner(audit_tx);
}
}
trace!("\n\n\n\n\n====================");
trace!("blocks: {:#?}", blocks);
trace!("clusters: {:#?}", clusters);
trace!("rates: {:#?}\n====================\n\n\n\n\n", rates);
GbtResult {
blocks,
block_weights,
clusters,
rates,
overflow,
}
}
fn next_valid_from_stack<'a>(
mempool_stack: &mut Vec<u32>,
audit_pool: &'a AuditPool,
) -> Option<&'a AuditTransaction> {
while let Some(next_txid) = mempool_stack.last() {
match audit_pool.get(*next_txid as usize) {
Some(Some(tx)) if !tx.used && !tx.modified => {
return Some(tx);
}
_ => {
mempool_stack.pop();
}
}
}
None
}
fn next_valid_from_queue<'a>(
queue: &mut ModifiedQueue,
audit_pool: &'a AuditPool,
) -> Option<&'a AuditTransaction> {
while let Some((next_txid, _)) = queue.peek() {
match audit_pool.get(*next_txid as usize) {
Some(Some(tx)) if !tx.used => {
return Some(tx);
}
_ => {
queue.pop();
}
}
}
None
}
fn set_relatives(txid: u32, audit_pool: &mut AuditPool) {
let mut parents: HashSet<u32, U32HasherState> = u32hashset_new();
if let Some(Some(tx)) = audit_pool.get(txid as usize) {
if tx.relatives_set_flag {
return;
}
for input in &tx.inputs {
parents.insert(*input);
}
} else {
return;
}
let mut ancestors: HashSet<u32, U32HasherState> = u32hashset_new();
for parent_id in &parents {
set_relatives(*parent_id, audit_pool);
if let Some(Some(parent)) = audit_pool.get_mut(*parent_id as usize) {
// Safety: ancestors must always contain only txes in audit_pool
ancestors.insert(*parent_id);
parent.children.insert(txid);
for ancestor in &parent.ancestors {
ancestors.insert(*ancestor);
}
}
}
let mut total_fee: u64 = 0;
let mut total_sigop_adjusted_weight: u32 = 0;
let mut total_sigop_adjusted_vsize: u32 = 0;
let mut total_sigops: u32 = 0;
for ancestor_id in &ancestors {
if let Some(ancestor) = audit_pool
.get(*ancestor_id as usize)
.expect("audit_pool contains all ancestors")
{
total_fee += ancestor.fee;
total_sigop_adjusted_weight += ancestor.sigop_adjusted_weight;
total_sigop_adjusted_vsize += ancestor.sigop_adjusted_vsize;
total_sigops += ancestor.sigops;
} else { todo!() };
}
if let Some(Some(tx)) = audit_pool.get_mut(txid as usize) {
tx.set_ancestors(
ancestors,
total_fee,
total_sigop_adjusted_weight,
total_sigop_adjusted_vsize,
total_sigops,
);
}
}
// iterate over remaining descendants, removing the root as a valid ancestor & updating the ancestor score
fn update_descendants(
root_txid: u32,
audit_pool: &mut AuditPool,
modified: &mut ModifiedQueue,
cluster_rate: f64,
) {
let mut visited: HashSet<u32, U32HasherState> = u32hashset_new();
let mut descendant_stack: Vec<u32> = Vec::new();
let root_fee: u64;
let root_sigop_adjusted_weight: u32;
let root_sigop_adjusted_vsize: u32;
let root_sigops: u32;
if let Some(Some(root_tx)) = audit_pool.get(root_txid as usize) {
for descendant_id in &root_tx.children {
if !visited.contains(descendant_id) {
descendant_stack.push(*descendant_id);
visited.insert(*descendant_id);
}
}
root_fee = root_tx.fee;
root_sigop_adjusted_weight = root_tx.sigop_adjusted_weight;
root_sigop_adjusted_vsize = root_tx.sigop_adjusted_vsize;
root_sigops = root_tx.sigops;
} else {
return;
}
while let Some(next_txid) = descendant_stack.pop() {
if let Some(Some(descendant)) = audit_pool.get_mut(next_txid as usize) {
// remove root tx as ancestor
let old_score = descendant.remove_root(
root_txid,
root_fee,
root_sigop_adjusted_weight,
root_sigop_adjusted_vsize,
root_sigops,
cluster_rate,
);
// add to priority queue or update priority if score has changed
if descendant.score() < old_score {
descendant.modified = true;
modified.push_decrease(
descendant.uid,
TxPriority {
uid: descendant.uid,
order: descendant.order(),
score: descendant.score(),
},
);
} else if descendant.score() > old_score {
descendant.modified = true;
modified.push_increase(
descendant.uid,
TxPriority {
uid: descendant.uid,
order: descendant.order(),
score: descendant.score(),
},
);
}
// add this node's children to the stack
for child_id in &descendant.children {
if !visited.contains(child_id) {
descendant_stack.push(*child_id);
visited.insert(*child_id);
}
}
}
}
}

View File

@@ -1,184 +0,0 @@
#![warn(clippy::all)]
#![warn(clippy::pedantic)]
#![warn(clippy::nursery)]
#![allow(clippy::cast_precision_loss)]
#![allow(clippy::cast_possible_truncation)]
#![allow(clippy::cast_sign_loss)]
#![allow(clippy::float_cmp)]
use napi::bindgen_prelude::Result;
use napi_derive::napi;
use thread_transaction::ThreadTransaction;
use thread_acceleration::ThreadAcceleration;
use tracing::{debug, info, trace};
use tracing_log::LogTracer;
use tracing_subscriber::{EnvFilter, FmtSubscriber};
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
mod audit_transaction;
mod gbt;
mod thread_transaction;
mod thread_acceleration;
mod u32_hasher_types;
use u32_hasher_types::{u32hashmap_with_capacity, U32HasherState};
/// This is the initial capacity of the `GbtGenerator` struct's inner `HashMap`.
///
/// Note: This doesn't *have* to be a power of 2. (uwu)
const STARTING_CAPACITY: usize = 1_048_576;
type ThreadTransactionsMap = HashMap<u32, ThreadTransaction, U32HasherState>;
#[napi]
pub struct GbtGenerator {
thread_transactions: Arc<Mutex<ThreadTransactionsMap>>,
}
#[napi::module_init]
fn init() {
// Set all `tracing` logs to print to STDOUT
// Note: Passing RUST_LOG env variable to the node process
// will change the log level for the rust module.
tracing::subscriber::set_global_default(
FmtSubscriber::builder()
.with_env_filter(EnvFilter::from_default_env())
.with_ansi(
// Default to no-color logs.
// Setting RUST_LOG_COLOR to 1 or true|TRUE|True etc.
// will enable color
std::env::var("RUST_LOG_COLOR")
.map(|s| ["1", "true"].contains(&&*s.to_lowercase()))
.unwrap_or(false),
)
.finish(),
)
.expect("Logging subscriber failed");
// Convert all `log` logs into `tracing` events
LogTracer::init().expect("Legacy log subscriber failed");
}
#[napi]
impl GbtGenerator {
#[napi(constructor)]
#[allow(clippy::new_without_default)]
#[must_use]
pub fn new() -> Self {
debug!("Created new GbtGenerator");
Self {
thread_transactions: Arc::new(Mutex::new(u32hashmap_with_capacity(STARTING_CAPACITY))),
}
}
/// # Errors
///
/// Rejects if the thread panics or if the Mutex is poisoned.
#[napi]
pub async fn make(&self, mempool: Vec<ThreadTransaction>, accelerations: Vec<ThreadAcceleration>, max_uid: u32) -> Result<GbtResult> {
trace!("make: Current State {:#?}", self.thread_transactions);
run_task(
Arc::clone(&self.thread_transactions),
accelerations,
max_uid as usize,
move |map| {
for tx in mempool {
map.insert(tx.uid, tx);
}
},
)
.await
}
/// # Errors
///
/// Rejects if the thread panics or if the Mutex is poisoned.
#[napi]
pub async fn update(
&self,
new_txs: Vec<ThreadTransaction>,
remove_txs: Vec<u32>,
accelerations: Vec<ThreadAcceleration>,
max_uid: u32,
) -> Result<GbtResult> {
trace!("update: Current State {:#?}", self.thread_transactions);
run_task(
Arc::clone(&self.thread_transactions),
accelerations,
max_uid as usize,
move |map| {
for tx in new_txs {
map.insert(tx.uid, tx);
}
for txid in &remove_txs {
map.remove(txid);
}
},
)
.await
}
}
/// The result from calling the gbt function.
///
/// This tuple contains the following:
/// blocks: A 2D Vector of transaction IDs (u32), the inner Vecs each represent a block.
/// block_weights: A Vector of total weights per block.
/// clusters: A 2D Vector of transaction IDs representing clusters of dependent mempool transactions
/// rates: A Vector of tuples containing transaction IDs (u32) and effective fee per vsize (f64)
#[napi(constructor)]
pub struct GbtResult {
pub blocks: Vec<Vec<u32>>,
pub block_weights: Vec<u32>,
pub clusters: Vec<Vec<u32>>,
pub rates: Vec<Vec<f64>>, // Tuples not supported. u32 fits inside f64
pub overflow: Vec<u32>,
}
/// All on another thread, this runs an arbitrary task in between
/// taking the lock and running gbt.
///
/// Rather than filling / updating the `HashMap` on the main thread,
/// this allows for `HashMap` modifying tasks to be run before running and returning gbt results.
///
/// `thread_transactions` is a cloned `Arc` of the `Mutex` for the `HashMap` state.
/// `callback` is a `'static + Send` `FnOnce` closure/function that takes a mutable reference
/// to the `HashMap` as the only argument. (A move closure is recommended to meet the bounds)
async fn run_task<F>(
thread_transactions: Arc<Mutex<ThreadTransactionsMap>>,
accelerations: Vec<ThreadAcceleration>,
max_uid: usize,
callback: F,
) -> Result<GbtResult>
where
F: FnOnce(&mut ThreadTransactionsMap) + Send + 'static,
{
debug!("Spawning thread...");
let handle = napi::tokio::task::spawn_blocking(move || {
debug!(
"Getting lock for thread_transactions from thread {:?}...",
std::thread::current().id()
);
let mut map = thread_transactions
.lock()
.map_err(|_| napi::Error::from_reason("THREAD_TRANSACTIONS Mutex poisoned"))?;
callback(&mut map);
info!("Starting gbt algorithm for {} elements...", map.len());
let result = gbt::gbt(&mut map, &accelerations, max_uid);
info!("Finished gbt algorithm for {} elements...", map.len());
debug!(
"Releasing lock for thread_transactions from thread {:?}...",
std::thread::current().id()
);
drop(map);
Ok(result)
});
handle
.await
.map_err(|_| napi::Error::from_reason("thread panicked"))?
}

View File

@@ -1,8 +0,0 @@
use napi_derive::napi;
#[derive(Debug)]
#[napi(object)]
pub struct ThreadAcceleration {
pub uid: u32,
pub delta: f64, // fee delta
}

View File

@@ -1,13 +0,0 @@
use napi_derive::napi;
#[derive(Debug)]
#[napi(object)]
pub struct ThreadTransaction {
pub uid: u32,
pub order: u32,
pub fee: f64,
pub weight: u32,
pub sigops: u32,
pub effective_fee_per_vsize: f64,
pub inputs: Vec<u32>,
}

View File

@@ -1,132 +0,0 @@
use priority_queue::PriorityQueue;
use std::{
collections::{HashMap, HashSet},
fmt::Debug,
hash::{BuildHasher, Hasher},
};
/// This is the only way to create a `HashMap` with the `U32HasherState` and capacity
pub fn u32hashmap_with_capacity<V>(capacity: usize) -> HashMap<u32, V, U32HasherState> {
HashMap::with_capacity_and_hasher(capacity, U32HasherState(()))
}
/// This is the only way to create a `PriorityQueue` with the `U32HasherState` and capacity
pub fn u32priority_queue_with_capacity<V: Ord>(
capacity: usize,
) -> PriorityQueue<u32, V, U32HasherState> {
PriorityQueue::with_capacity_and_hasher(capacity, U32HasherState(()))
}
/// This is the only way to create a `HashSet` with the `U32HasherState`
pub fn u32hashset_new() -> HashSet<u32, U32HasherState> {
HashSet::with_hasher(U32HasherState(()))
}
/// A private unit type is contained so no one can make an instance of it.
#[derive(Clone)]
pub struct U32HasherState(());
impl Debug for U32HasherState {
fn fmt(&self, _: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
Ok(())
}
}
impl BuildHasher for U32HasherState {
type Hasher = U32Hasher;
fn build_hasher(&self) -> Self::Hasher {
U32Hasher(0)
}
}
/// This also can't be created outside this module due to private field.
pub struct U32Hasher(u32);
impl Hasher for U32Hasher {
fn finish(&self) -> u64 {
// Safety: Two u32s next to each other will make a u64
bytemuck::cast([self.0, 0])
}
fn write(&mut self, bytes: &[u8]) {
// Assert in debug builds (testing too) that only 4 byte keys (u32, i32, f32, etc.) run
debug_assert!(bytes.len() == 4);
// Safety: We know that the size of the key is 4 bytes
// We also know that the only way to get an instance of HashMap using this "hasher"
// is through the public functions in this module which set the key type to u32.
self.0 = *bytemuck::from_bytes(bytes);
}
}
#[cfg(test)]
mod tests {
use super::U32HasherState;
use priority_queue::PriorityQueue;
use std::collections::HashMap;
#[test]
fn test_hashmap() {
let mut hm: HashMap<u32, String, U32HasherState> = HashMap::with_hasher(U32HasherState(()));
// Testing basic operations with the custom hasher
hm.insert(0, String::from("0"));
hm.insert(42, String::from("42"));
hm.insert(256, String::from("256"));
hm.insert(u32::MAX, String::from("MAX"));
hm.insert(u32::MAX >> 2, String::from("MAX >> 2"));
assert_eq!(hm.get(&0), Some(&String::from("0")));
assert_eq!(hm.get(&42), Some(&String::from("42")));
assert_eq!(hm.get(&256), Some(&String::from("256")));
assert_eq!(hm.get(&u32::MAX), Some(&String::from("MAX")));
assert_eq!(hm.get(&(u32::MAX >> 2)), Some(&String::from("MAX >> 2")));
assert_eq!(hm.get(&(u32::MAX >> 4)), None);
assert_eq!(hm.get(&3), None);
assert_eq!(hm.get(&43), None);
}
#[test]
fn test_priority_queue() {
let mut pq: PriorityQueue<u32, i32, U32HasherState> =
PriorityQueue::with_hasher(U32HasherState(()));
// Testing basic operations with the custom hasher
assert_eq!(pq.push(1, 5), None);
assert_eq!(pq.push(2, -10), None);
assert_eq!(pq.push(3, 7), None);
assert_eq!(pq.push(4, 20), None);
assert_eq!(pq.push(u32::MAX, -42), None);
assert_eq!(pq.push_increase(1, 4), Some(4));
assert_eq!(pq.push_increase(2, -8), Some(-10));
assert_eq!(pq.push_increase(3, 5), Some(5));
assert_eq!(pq.push_increase(4, 21), Some(20));
assert_eq!(pq.push_increase(u32::MAX, -99), Some(-99));
assert_eq!(pq.push_increase(42, 1337), None);
assert_eq!(pq.push_decrease(1, 4), Some(5));
assert_eq!(pq.push_decrease(2, -10), Some(-8));
assert_eq!(pq.push_decrease(3, 5), Some(7));
assert_eq!(pq.push_decrease(4, 20), Some(21));
assert_eq!(pq.push_decrease(u32::MAX, 100), Some(100));
assert_eq!(pq.push_decrease(69, 420), None);
assert_eq!(pq.peek(), Some((&42, &1337)));
assert_eq!(pq.pop(), Some((42, 1337)));
assert_eq!(pq.peek(), Some((&69, &420)));
assert_eq!(pq.pop(), Some((69, 420)));
assert_eq!(pq.peek(), Some((&4, &20)));
assert_eq!(pq.pop(), Some((4, 20)));
assert_eq!(pq.peek(), Some((&3, &5)));
assert_eq!(pq.pop(), Some((3, 5)));
assert_eq!(pq.peek(), Some((&1, &4)));
assert_eq!(pq.pop(), Some((1, 4)));
assert_eq!(pq.peek(), Some((&2, &-10)));
assert_eq!(pq.pop(), Some((2, -10)));
assert_eq!(pq.peek(), Some((&u32::MAX, &-42)));
assert_eq!(pq.pop(), Some((u32::MAX, -42)));
assert_eq!(pq.peek(), None);
assert_eq!(pq.pop(), None);
}
}