diff --git a/src/embit/psbt.py b/src/embit/psbt.py index d20c88b..832d56e 100644 --- a/src/embit/psbt.py +++ b/src/embit/psbt.py @@ -504,6 +504,8 @@ def __init__(self, unknown: dict = {}, vout=None, compress=CompressMode.KEEP_ALL self.bip32_derivations = OrderedDict() self.taproot_bip32_derivations = OrderedDict() self.taproot_internal_key = None + self.dnssec_name = None + self.dnssec_proof = None self.parse_unknowns() def clear_metadata(self, compress=CompressMode.CLEAR_ALL): @@ -516,6 +518,8 @@ def clear_metadata(self, compress=CompressMode.CLEAR_ALL): self.bip32_derivations = OrderedDict() self.taproot_bip32_derivations = OrderedDict() self.taproot_internal_key = None + self.dnssec_name = None + self.dnssec_proof = None def update(self, other): self.value = other.value if other.value is not None else self.value @@ -526,6 +530,8 @@ def update(self, other): self.bip32_derivations.update(other.bip32_derivations) self.taproot_bip32_derivations.update(other.taproot_bip32_derivations) self.taproot_internal_key = other.taproot_internal_key + self.dnssec_name = other.dnssec_name or self.dnssec_name + self.dnssec_proof = other.dnssec_proof or self.dnssec_proof @property def vout(self): @@ -583,6 +589,18 @@ def read_value(self, stream, k): der = DerivationPath.read_from(b) self.taproot_bip32_derivations[pub] = (leaf_hashes, der) + # PSBT_OUT_DNSSEC_PROOF (0x35) + elif k == b"\x35": + if self.dnssec_proof is not None: + raise PSBTError("Duplicated DNSSEC proof") + if len(v) < 1: + raise PSBTError("Invalid DNSSEC proof (missing length)") + name_len = v[0] + if len(v) < 1 + name_len: + raise PSBTError("Invalid DNSSEC proof (truncated name)") + self.dnssec_name = v[1 : 1 + name_len] + self.dnssec_proof = v[1 + name_len :] + else: if k in self.unknown: raise PSBTError("Duplicated key") @@ -624,6 +642,18 @@ def write_to(self, stream, skip_separator=False, version=None, **kwargs) -> int: + derivation.serialize(), ) + # PSBT_OUT_DNSSEC_PROOF (BIP-353) + if self.dnssec_proof is not None and self.dnssec_name is not None: + if len(self.dnssec_name) > 255: + raise PSBTError("DNSSEC name too long") + value = ( + len(self.dnssec_name).to_bytes(1, "big") + + self.dnssec_name + + self.dnssec_proof + ) + r += ser_string(stream, b"\x35") + r += ser_string(stream, value) + # unknown for key in self.unknown: r += ser_string(stream, key)