diff --git a/src/psbt/mod.rs b/src/psbt/mod.rs index f291bb1b..d83f18a3 100644 --- a/src/psbt/mod.rs +++ b/src/psbt/mod.rs @@ -37,3 +37,66 @@ impl PsbtUtils for PSBT { } } } + +#[cfg(test)] +mod test { + use crate::bitcoin::consensus::deserialize; + use crate::bitcoin::TxIn; + use crate::psbt::PSBT; + use crate::wallet::test::{get_funded_wallet, get_test_wpkh}; + use crate::wallet::AddressIndex; + use crate::SignOptions; + + // from bip 174 + const PSBT_STR: &str = "cHNidP8BAKACAAAAAqsJSaCMWvfEm4IS9Bfi8Vqz9cM9zxU4IagTn4d6W3vkAAAAAAD+////qwlJoIxa98SbghL0F+LxWrP1wz3PFTghqBOfh3pbe+QBAAAAAP7///8CYDvqCwAAAAAZdqkUdopAu9dAy+gdmI5x3ipNXHE5ax2IrI4kAAAAAAAAGXapFG9GILVT+glechue4O/p+gOcykWXiKwAAAAAAAEHakcwRAIgR1lmF5fAGwNrJZKJSGhiGDR9iYZLcZ4ff89X0eURZYcCIFMJ6r9Wqk2Ikf/REf3xM286KdqGbX+EhtdVRs7tr5MZASEDXNxh/HupccC1AaZGoqg7ECy0OIEhfKaC3Ibi1z+ogpIAAQEgAOH1BQAAAAAXqRQ1RebjO4MsRwUPJNPuuTycA5SLx4cBBBYAFIXRNTfy4mVAWjTbr6nj3aAfuCMIAAAA"; + + #[test] + #[should_panic(expected = "InputIndexOutOfRange")] + fn test_psbt_malformed_psbt_input_legacy() { + let psbt_bip: PSBT = deserialize(&base64::decode(PSBT_STR).unwrap()).unwrap(); + let (wallet, _, _) = get_funded_wallet(get_test_wpkh()); + let send_to = wallet.get_address(AddressIndex::New).unwrap(); + let mut builder = wallet.build_tx(); + builder.add_recipient(send_to.script_pubkey(), 10_000); + let (mut psbt, _) = builder.finish().unwrap(); + psbt.inputs.push(psbt_bip.inputs[0].clone()); + let options = SignOptions { + trust_witness_utxo: true, + assume_height: None, + }; + let _ = wallet.sign(&mut psbt, options).unwrap(); + } + + #[test] + #[should_panic(expected = "InputIndexOutOfRange")] + fn test_psbt_malformed_psbt_input_segwit() { + let psbt_bip: PSBT = deserialize(&base64::decode(PSBT_STR).unwrap()).unwrap(); + let (wallet, _, _) = get_funded_wallet(get_test_wpkh()); + let send_to = wallet.get_address(AddressIndex::New).unwrap(); + let mut builder = wallet.build_tx(); + builder.add_recipient(send_to.script_pubkey(), 10_000); + let (mut psbt, _) = builder.finish().unwrap(); + psbt.inputs.push(psbt_bip.inputs[1].clone()); + let options = SignOptions { + trust_witness_utxo: true, + assume_height: None, + }; + let _ = wallet.sign(&mut psbt, options).unwrap(); + } + + #[test] + #[should_panic(expected = "InputIndexOutOfRange")] + fn test_psbt_malformed_tx_input() { + let (wallet, _, _) = get_funded_wallet(get_test_wpkh()); + let send_to = wallet.get_address(AddressIndex::New).unwrap(); + let mut builder = wallet.build_tx(); + builder.add_recipient(send_to.script_pubkey(), 10_000); + let (mut psbt, _) = builder.finish().unwrap(); + psbt.global.unsigned_tx.input.push(TxIn::default()); + let options = SignOptions { + trust_witness_utxo: true, + assume_height: None, + }; + let _ = wallet.sign(&mut psbt, options).unwrap(); + } +} diff --git a/src/wallet/mod.rs b/src/wallet/mod.rs index 4ddd7d36..ac32cb63 100644 --- a/src/wallet/mod.rs +++ b/src/wallet/mod.rs @@ -61,6 +61,7 @@ use crate::descriptor::{ }; use crate::error::Error; use crate::psbt::PsbtUtils; +use crate::signer::SignerError; use crate::types::*; const CACHE_ADDR_BATCH_SIZE: u32 = 100; @@ -927,7 +928,10 @@ where let mut finished = true; for (n, input) in tx.input.iter().enumerate() { - let psbt_input = &psbt.inputs[n]; + let psbt_input = &psbt + .inputs + .get(n) + .ok_or(Error::Signer(SignerError::InputIndexOutOfRange))?; if psbt_input.final_script_sig.is_some() || psbt_input.final_script_witness.is_some() { continue; } @@ -1497,7 +1501,7 @@ where } #[cfg(test)] -mod test { +pub(crate) mod test { use std::str::FromStr; use bitcoin::{util::psbt, Network}; diff --git a/src/wallet/signer.rs b/src/wallet/signer.rs index 3198a61f..04f4d2ad 100644 --- a/src/wallet/signer.rs +++ b/src/wallet/signer.rs @@ -476,7 +476,7 @@ impl ComputeSighash for Legacy { psbt: &psbt::PartiallySignedTransaction, input_index: usize, ) -> Result<(SigHash, SigHashType), SignerError> { - if input_index >= psbt.inputs.len() { + if input_index >= psbt.inputs.len() || input_index >= psbt.global.unsigned_tx.input.len() { return Err(SignerError::InputIndexOutOfRange); } @@ -524,7 +524,7 @@ impl ComputeSighash for Segwitv0 { psbt: &psbt::PartiallySignedTransaction, input_index: usize, ) -> Result<(SigHash, SigHashType), SignerError> { - if input_index >= psbt.inputs.len() { + if input_index >= psbt.inputs.len() || input_index >= psbt.global.unsigned_tx.input.len() { return Err(SignerError::InputIndexOutOfRange); }