diff --git a/multiaddr/codecs/certhash.py b/multiaddr/codecs/certhash.py new file mode 100644 index 0000000..e6b673c --- /dev/null +++ b/multiaddr/codecs/certhash.py @@ -0,0 +1,85 @@ +from typing import Any + +import multibase +import multihash + +from ..codecs import CodecBase + +SIZE = -1 +IS_PATH = False + + +class Codec(CodecBase): + """ + Codec for certificate hashes (certhash). + + A certhash is a multihash of a certificate, encoded as a multibase string + using the 'base64url' encoding. + """ + + SIZE = SIZE + IS_PATH = IS_PATH + + def validate(self, b: bytes) -> None: + """ + Validates that the byte representation is a valid multihash. + + Args: + b: The bytes to validate. + + Raises: + ValueError: If the bytes cannot be decoded as a multihash. + """ + try: + multihash.decode(b) + except Exception as e: + raise ValueError("Invalid certhash: not a valid multihash") from e + + def to_bytes(self, proto: Any, string: str) -> bytes: + """ + Converts the multibase string representation of a certhash to bytes. + + This involves decoding the multibase string and then validating that + the resulting bytes are a valid multihash. + + Args: + proto: The multiaddr protocol code (unused). + string: The string representation of the certhash. + + Returns: + The raw multihash bytes. + + Raises: + ValueError: If the string is not valid multibase or not a multihash. + """ + try: + # Decode the multibase string to get the raw multihash bytes. + decoded_bytes = multibase.decode(string) + except Exception as e: + raise ValueError(f"Failed to decode multibase string: {string}") from e + + # Validate that the decoded bytes are a valid multihash. + self.validate(decoded_bytes) + return decoded_bytes + + def to_string(self, proto: Any, buf: bytes) -> str: + """ + Converts the raw multihash bytes of a certhash to its string form. + + This involves validating the bytes first and then encoding them as a + 'base64url' multibase string. + + Args: + proto: The multiaddr protocol code (unused). + buf: The raw multihash bytes. + + Returns: + The multibase string representation of the certhash. + """ + # Validate the bytes before encoding. + self.validate(buf) + + # Encode the bytes using base64url, which is standard for certhash. + # The result from `multibase.encode` is bytes, so we decode to a string. + encoded_string = multibase.encode("base64url", buf) + return encoded_string.decode("utf-8") diff --git a/multiaddr/multiaddr.py b/multiaddr/multiaddr.py index 7453c4f..52307aa 100644 --- a/multiaddr/multiaddr.py +++ b/multiaddr/multiaddr.py @@ -20,17 +20,17 @@ def __init__(self, mapping: "Multiaddr") -> None: self._mapping = mapping super().__init__(mapping) - def __contains__(self, proto: object) -> bool: - proto = self._mapping.registry.find(proto) - return collections.abc.Sequence.__contains__(self, proto) - - def __getitem__(self, idx: int | slice) -> Any | Sequence[Any]: - if isinstance(idx, slice): - return list(self)[idx] - if idx < 0: - idx = len(self) + idx - for idx2, proto in enumerate(self): - if idx2 == idx: + def __contains__(self, value: object) -> bool: # type: ignore[override] + value = self._mapping.registry.find(value) + return collections.abc.Sequence.__contains__(self, value) + + def __getitem__(self, index: int | slice) -> Any | Sequence[Any]: + if isinstance(index, slice): + return list(self)[index] + if index < 0: + index = len(self) + index + for index2, proto in enumerate(self): + if index2 == index: return proto raise IndexError("Protocol list index out of range") @@ -49,26 +49,26 @@ def __init__(self, mapping: "Multiaddr") -> None: self._mapping = mapping super().__init__(mapping) - def __contains__(self, item: object) -> bool: + def __contains__(self, item: object) -> bool: # type: ignore[override] if not isinstance(item, tuple) or len(item) != 2: return False - proto, value = item + proto, item = item proto = self._mapping.registry.find(proto) - return collections.abc.Sequence.__contains__(self, (proto, value)) + return collections.abc.Sequence.__contains__(self, (proto, item)) @overload - def __getitem__(self, idx: int) -> tuple[Any, Any]: ... + def __getitem__(self, index: int) -> tuple[Any, Any]: ... @overload - def __getitem__(self, idx: slice) -> Sequence[tuple[Any, Any]]: ... + def __getitem__(self, index: slice) -> Sequence[tuple[Any, Any]]: ... - def __getitem__(self, idx: int | slice) -> tuple[Any, Any] | Sequence[tuple[Any, Any]]: - if isinstance(idx, slice): - return list(self)[idx] - if idx < 0: - idx = len(self) + idx + def __getitem__(self, index: int | slice) -> tuple[Any, Any] | Sequence[tuple[Any, Any]]: + if isinstance(index, slice): + return list(self)[index] + if index < 0: + index = len(self) + index for idx2, item in enumerate(self): - if idx2 == idx: + if idx2 == index: return item raise IndexError("Protocol item list index out of range") @@ -99,13 +99,13 @@ def __init__(self, mapping: "Multiaddr") -> None: def __contains__(self, value: object) -> bool: return collections.abc.Sequence.__contains__(self, value) - def __getitem__(self, idx: int | slice) -> Any | Sequence[Any]: - if isinstance(idx, slice): - return list(self)[idx] - if idx < 0: - idx = len(self) + idx + def __getitem__(self, index: int | slice) -> Any | Sequence[Any]: + if isinstance(index, slice): + return list(self)[index] + if index < 0: + index = len(self) + index for idx2, value in enumerate(self): - if idx2 == idx: + if idx2 == index: return value raise IndexError("Protocol value list index out of range") diff --git a/multiaddr/protocols.py b/multiaddr/protocols.py index c834754..b30090d 100644 --- a/multiaddr/protocols.py +++ b/multiaddr/protocols.py @@ -79,7 +79,10 @@ P_SNI = 0x01C1 P_NOISE = 0x01C6 P_WEBTRANSPORT = 0x01D1 +P_WEBRTC_DIRECT = 0x118 +P_WEBRTC = 0x119 P_MEMORY = 0x309 +P_CERTHASH = 0x1D2 class Protocol: @@ -150,6 +153,8 @@ def __repr__(self) -> str: Protocol(P_DNS4, "dns4", "domain"), Protocol(P_DNS6, "dns6", "domain"), Protocol(P_DNSADDR, "dnsaddr", "domain"), + Protocol(P_SNI, "sni", "domain"), + Protocol(P_NOISE, "noise", None), Protocol(P_SCTP, "sctp", "uint16be"), Protocol(P_UDT, "udt", None), Protocol(P_UTP, "utp", None), @@ -170,7 +175,10 @@ def __repr__(self) -> str: Protocol(P_P2P_CIRCUIT, "p2p-circuit", None), Protocol(P_WEBTRANSPORT, "webtransport", None), Protocol(P_UNIX, "unix", "fspath"), + Protocol(P_WEBRTC_DIRECT, "webrtc-direct", None), + Protocol(P_WEBRTC, "webrtc", None), Protocol(P_MEMORY, "memory", "memory"), + Protocol(P_CERTHASH, "certhash", "certhash"), ] diff --git a/newsfragments/181.feature.rst b/newsfragments/181.feature.rst new file mode 100644 index 0000000..2d728a5 --- /dev/null +++ b/newsfragments/181.feature.rst @@ -0,0 +1,7 @@ +Added the following protocols in reference with go-multiaddr + +- SNI: 0x01C1 +- NOISE: 0x01C6 +- CERTHASH: +- WEBRTC: +- WEBRTC-DIRECT: diff --git a/pyproject.toml b/pyproject.toml index 24d29ea..09e45da 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,6 +35,8 @@ dependencies = [ "psutil", "py-cid >= 0.3.1", "py-multicodec >= 0.2.0", + "py-multibase", + "py-multihash", "trio-typing>=0.0.4", "trio>=0.26.0", "varint", diff --git a/tests/test_multiaddr.py b/tests/test_multiaddr.py index c3454f8..5062606 100644 --- a/tests/test_multiaddr.py +++ b/tests/test_multiaddr.py @@ -52,6 +52,8 @@ "/onion3/vww6ybal4bd7szmgncyruucpgfkqahzddi37ktceo3ah7ngmcopnpyyd:-1", "/onion3/vww6ybal4bd7szmgncyruucpgfkqahzddi37ktceo3ah7ngmcopnpyyd", "/onion3/vww6ybal4bd7szmgncyruucpgfkqahzddi37ktceo3ah7ngmcopnpyy@:666", + "/ip4/127.0.0.1/udp/1234/quic-v1/webtransport/certhash/b2uaraocy6yrdblb4sfptaddgimjmmpy", + "/ip4/127.0.0.1/udp/1234/quic-v1/webtransport/certhash/b2uaraocy6yrdblb4sfptaddgimjmmpy/certhash/zQmbWTwYGcmdyK9CYfNBcfs9nhZs17a6FQ4Y8oea278xx41", "/udp/1234/sctp", "/udp/1234/udt/1234", "/udp/1234/utp/1234", @@ -101,10 +103,19 @@ def test_invalid(addr_str): "/ip4/127.0.0.1/p2p/QmcgpsyWgH8Y8ajJz1Cu72KnS5uo2Aa2LpzU7kinSupNKC/tcp/1234", "/unix/a/b/c/d/e", "/unix/stdio", + "/ip4/127.0.0.1/tcp/127/noise", "/ip4/1.2.3.4/tcp/80/unix/a/b/c/d/e/f", "/ip4/127.0.0.1/p2p/QmcgpsyWgH8Y8ajJz1Cu72KnS5uo2Aa2LpzU7kinSupNKC/tcp/1234/unix/stdio", "/dns/example.com", "/dns4/موقع.وزارة-الاتصالات.مصر", + "/ip4/127.0.0.1/tcp/443/tls/sni/example.com/http/http-path/foo", + "/memory/4", + "/http-path/tmp%2Fbar", + "/http-path/tmp%2Fbar%2Fbaz", + "/http-path/foo", + "/ip4/127.0.0.1/tcp/9090/http/p2p-webrtc-direct", + "/ip4/127.0.0.1/tcp/127/webrtc-direct", + "/ip4/127.0.0.1/tcp/127/webrtc", ], ) # nopep8 def test_valid(addr_str): diff --git a/tests/test_protocols.py b/tests/test_protocols.py index 95d3995..6490179 100644 --- a/tests/test_protocols.py +++ b/tests/test_protocols.py @@ -1,8 +1,10 @@ +import multibase +import multihash import pytest import varint from multiaddr import Multiaddr, exceptions, protocols -from multiaddr.codecs import http_path, ipcidr, memory +from multiaddr.codecs import certhash, http_path, ipcidr, memory from multiaddr.exceptions import BinaryParseError, StringParseError @@ -367,3 +369,65 @@ def test_ipcidr_invalid_bytes_inputs(): with pytest.raises(ValueError): codec.validate(b"\x01\x02") + + +# --------CERT-HASH--------- + +VALID_MULTIHASH_BYTES = multihash.encode(b"hello world", "sha2-256") +VALID_CERTHASH_STRING = multibase.encode("base64url", VALID_MULTIHASH_BYTES).decode("utf-8") + +INVALID_BYTES = b"this is not a multihash" +INVALID_CONTENT_STRING = multibase.encode("base64url", INVALID_BYTES).decode("utf-8") + + +def test_certhash_valid_roundtrip(): + codec = certhash.Codec() + b = codec.to_bytes(None, VALID_CERTHASH_STRING) + assert isinstance(b, bytes) + assert b == VALID_MULTIHASH_BYTES + + +def test_certhash_invalid_multihash_bytes_raises(): + """ + Tests that calling to_string() with bytes that are not a valid + multihash raises a ValueError. + """ + codec = certhash.Codec() + with pytest.raises(ValueError): + codec.to_string(None, INVALID_BYTES) + + +def test_certhash_valid_multibase_but_invalid_content_raises(): + """ + Tests that to_bytes() raises an error if the string is valid multibase + but its decoded content is not a valid multihash. + """ + codec = certhash.Codec() + with pytest.raises(ValueError): + codec.to_bytes(None, INVALID_CONTENT_STRING) + + +def test_certhash_invalid_multibase_string_raises(): + """ + Tests that passing a string with an invalid multibase prefix or + encoding raises an error. + """ + codec = certhash.Codec() + # 'z' is a valid multibase prefix, but the content is not valid base58. + invalid_string = "z-this-is-not-valid" + with pytest.raises(Exception): # Catches errors from the multibase library + codec.to_bytes(None, invalid_string) + + +def test_certhash_memory_validate_function(): + """ + Directly tests the validate method. + """ + codec = certhash.Codec() + + # A valid multihash should not raise an error + codec.validate(VALID_MULTIHASH_BYTES) + + # Invalid bytes should raise a ValueError + with pytest.raises(ValueError): + codec.validate(INVALID_BYTES) diff --git a/tests/test_thin_waist_addresses.py b/tests/test_thin_waist_addresses.py index 7f8dca3..7bd2c33 100644 --- a/tests/test_thin_waist_addresses.py +++ b/tests/test_thin_waist_addresses.py @@ -1,7 +1,4 @@ -import pytest - from multiaddr import Multiaddr -from multiaddr.exceptions import StringParseError from multiaddr.utils import get_thin_waist_addresses @@ -22,12 +19,6 @@ def test_specific_address_override_port(): assert addrs == [Multiaddr("/ip4/123.123.123.123/tcp/100")] -def test_ignore_non_thin_waist(): - # Should raise StringParseError for unknown protocol (e.g. /webrtc) - with pytest.raises(StringParseError): - Multiaddr("/ip4/123.123.123.123/udp/1234/webrtc") - - def test_ipv4_wildcard(): input_addr = Multiaddr("/ip4/0.0.0.0/tcp/1234") addrs = get_thin_waist_addresses(input_addr)