diff --git a/bip-0375/validator/psbt_bip375.py b/bip-0375/validator/psbt_bip375.py new file mode 100644 index 00000000..daf36e99 --- /dev/null +++ b/bip-0375/validator/psbt_bip375.py @@ -0,0 +1,95 @@ +#!/usr/bin/env python3 +""" +BIP-375 PSBT map extensions + +BIP375PSBTMap (a PSBTMap subclass with BIP-375 field access helpers) +BIP375PSBT (a PSBT subclass that deserializes into BIP375PSBTMap instances) +""" + +from io import BytesIO +import struct +from typing import List, Optional, Tuple + +from deps.bitcoin_test.messages import CTransaction, deser_compact_size, from_binary +from deps.bitcoin_test.psbt import ( + PSBT, + PSBTMap, + PSBT_GLOBAL_VERSION, + PSBT_GLOBAL_INPUT_COUNT, + PSBT_GLOBAL_OUTPUT_COUNT, + PSBT_GLOBAL_UNSIGNED_TX, +) + +PSBT_GLOBAL_SP_ECDH_SHARE = 0x07 +PSBT_GLOBAL_SP_DLEQ = 0x08 + +PSBT_IN_SP_ECDH_SHARE = 0x1D +PSBT_IN_SP_DLEQ = 0x1E + +PSBT_OUT_SP_V0_INFO = 0x09 +PSBT_OUT_SP_V0_LABEL = 0x0A + + +class BIP375PSBTMap(PSBTMap): + """PSBTMap with BIP-375 field access helpers""" + + def __getitem__(self, key): + return self.map[key] + + def __contains__(self, key): + return key in self.map + + def get(self, key, default=None): + return self.map.get(key, default) + + def get_all_by_type(self, key_type: int) -> List[Tuple[bytes, bytes]]: + """ + Get all entries with the given key_type + + Returns list of (key_data, value_data) tuples. For single-byte keys (no + key_data), key_data is b''. + """ + result = [] + for key, value_data in self.map.items(): + if isinstance(key, int) and key == key_type: + result.append((b"", value_data)) + elif isinstance(key, bytes) and len(key) > 0 and key[0] == key_type: + result.append((key[1:], value_data)) + return result + + def get_by_key(self, key_type: int, key_data: bytes) -> Optional[bytes]: + """Get value_data for a specific key_type + key_data combination""" + if key_data == b"": + return self.map.get(key_type) + return self.map.get(bytes([key_type]) + key_data) + + +class BIP375PSBT(PSBT): + """PSBT that deserializes maps as BIP375PSBTMap instances""" + + def deserialize(self, f): + assert f.read(5) == b"psbt\xff" + self.g = from_binary(BIP375PSBTMap, f) + + self.version = 0 + if PSBT_GLOBAL_VERSION in self.g.map: + assert PSBT_GLOBAL_INPUT_COUNT in self.g.map + assert PSBT_GLOBAL_OUTPUT_COUNT in self.g.map + self.version = struct.unpack("