"""
An encoding / decoding format suitable for serializing data structures to binary.

This is based on https://en.wikipedia.org/wiki/Bencode with some extensions.

The following data types may be encoded:

- None
- int
- bool
- bytes
- str
- list
- tuple
- dict

"""

from __future__ import annotations

from typing import Any, Callable


class DecodeError(Exception):
    """A problem decoding data."""


def dump(data: object) -> bytes:
    """Encodes a data structure in to bytes.

    Args:
        data: Data structure

    Returns:
        A byte string encoding the data.
    """

    def encode_none(_datum: None) -> bytes:
        """
        Encodes a None value.

        Args:
            datum: Always None.

        Returns:
            None encoded.
        """
        return b"N"

    def encode_bool(datum: bool) -> bytes:
        """
        Encode a boolean value.

        Args:
            datum: The boolean value to encode.

        Returns:
            The encoded bytes.
        """
        return b"T" if datum else b"F"

    def encode_int(datum: int) -> bytes:
        """
        Encode an integer value.

        Args:
            datum: The integer value to encode.

        Returns:
            The encoded bytes.
        """
        return b"i%ie" % datum

    def encode_bytes(datum: bytes) -> bytes:
        """
        Encode a bytes value.

        Args:
            datum: The bytes value to encode.

        Returns:
            The encoded bytes.
        """
        return b"%i:%s" % (len(datum), datum)

    def encode_string(datum: str) -> bytes:
        """
        Encode a string value.

        Args:
            datum: The string value to encode.

        Returns:
            The encoded bytes.
        """
        encoded_data = datum.encode("utf-8")
        return b"s%i:%s" % (len(encoded_data), encoded_data)

    def encode_list(datum: list) -> bytes:
        """
        Encode a list value.

        Args:
            datum: The list value to encode.

        Returns:
            The encoded bytes.
        """
        return b"l%se" % b"".join(encode(element) for element in datum)

    def encode_tuple(datum: tuple) -> bytes:
        """
        Encode a tuple value.

        Args:
            datum: The tuple value to encode.

        Returns:
            The encoded bytes.
        """
        return b"t%se" % b"".join(encode(element) for element in datum)

    def encode_dict(datum: dict) -> bytes:
        """
        Encode a dictionary value.

        Args:
            datum: The dictionary value to encode.

        Returns:
            The encoded bytes.
        """
        return b"d%se" % b"".join(
            b"%s%s" % (encode(key), encode(value)) for key, value in datum.items()
        )

    ENCODERS: dict[type, Callable[[Any], Any]] = {
        type(None): encode_none,
        bool: encode_bool,
        int: encode_int,
        bytes: encode_bytes,
        str: encode_string,
        list: encode_list,
        tuple: encode_tuple,
        dict: encode_dict,
    }

    def encode(datum: object) -> bytes:
        """Recursively encode data.

        Args:
            datum: Data suitable for encoding.

        Raises:
            TypeError: If `datum` is not one of the supported types.

        Returns:
            Encoded data bytes.
        """
        try:
            decoder = ENCODERS[type(datum)]
        except KeyError:
            raise TypeError("Can't encode {datum!r}") from None
        return decoder(datum)

    return encode(data)


def load(encoded: bytes) -> object:
    """Load an encoded data structure from bytes.

    Args:
        encoded: Encoded data in bytes.

    Raises:
        DecodeError: If an error was encountered decoding the string.

    Returns:
        Decoded data.
    """
    if not isinstance(encoded, bytes):
        raise TypeError("must be bytes")
    max_position = len(encoded)
    position = 0

    def get_byte() -> bytes:
        """Get an encoded byte and advance position.

        Raises:
            DecodeError: If the end of the data was reached

        Returns:
            A bytes object with a single byte.
        """
        nonlocal position
        if position >= max_position:
            raise DecodeError("More data expected")
        character = encoded[position : position + 1]
        position += 1
        return character

    def peek_byte() -> bytes:
        """Get the byte at the current position, but don't advance position.

        Returns:
            A bytes object with a single byte.
        """
        return encoded[position : position + 1]

    def get_bytes(size: int) -> bytes:
        """Get a number of bytes of encode data.

        Args:
            size: Number of bytes to retrieve.

        Raises:
            DecodeError: If there aren't enough bytes.

        Returns:
            A bytes object.
        """
        nonlocal position
        bytes_data = encoded[position : position + size]
        if len(bytes_data) != size:
            raise DecodeError(b"Missing bytes in {bytes_data!r}")
        position += size
        return bytes_data

    def decode_int() -> int:
        """Decode an int from the encoded data.

        Returns:
            An integer.
        """
        int_bytes = b""
        while (byte := get_byte()) != b"e":
            int_bytes += byte
        return int(int_bytes)

    def decode_bytes(size_bytes: bytes) -> bytes:
        """Decode a bytes string from the encoded data.

        Returns:
            A bytes object.
        """
        while (byte := get_byte()) != b":":
            size_bytes += byte
        bytes_string = get_bytes(int(size_bytes))
        return bytes_string

    def decode_string() -> str:
        """Decode a (utf-8 encoded) string from the encoded data.

        Returns:
            A string.
        """
        size_bytes = b""
        while (byte := get_byte()) != b":":
            size_bytes += byte
        bytes_string = get_bytes(int(size_bytes))
        decoded_string = bytes_string.decode("utf-8", errors="replace")
        return decoded_string

    def decode_list() -> list[object]:
        """Decode a list.

        Returns:
            A list of data.
        """
        elements: list[object] = []
        add_element = elements.append
        while peek_byte() != b"e":
            add_element(decode())
        get_byte()
        return elements

    def decode_tuple() -> tuple[object, ...]:
        """Decode a tuple.

        Returns:
            A tuple of decoded data.
        """
        elements: list[object] = []
        add_element = elements.append
        while peek_byte() != b"e":
            add_element(decode())
        get_byte()
        return tuple(elements)

    def decode_dict() -> dict[object, object]:
        """Decode a dict.

        Returns:
            A dict of decoded data.
        """
        elements: dict[object, object] = {}
        add_element = elements.__setitem__
        while peek_byte() != b"e":
            add_element(decode(), decode())
        get_byte()
        return elements

    DECODERS = {
        b"i": decode_int,
        b"s": decode_string,
        b"l": decode_list,
        b"t": decode_tuple,
        b"d": decode_dict,
        b"T": lambda: True,
        b"F": lambda: False,
        b"N": lambda: None,
    }

    def decode() -> object:
        """Recursively decode data.

        Returns:
            Decoded data.
        """
        decoder = DECODERS.get(initial := get_byte(), None)
        if decoder is None:
            return decode_bytes(initial)
        return decoder()

    return decode()
