from __future__ import annotations

from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from typing import TYPE_CHECKING, Iterable

import rich.repr

from textual.css._help_renderables import HelpText
from textual.css.styles import Styles
from textual.css.tokenize import Token
from textual.css.types import Specificity3

if TYPE_CHECKING:
    from typing import Callable

    from typing_extensions import Self

    from textual.dom import DOMNode


class SelectorType(Enum):
    """Type of selector."""

    UNIVERSAL = 1
    """i.e. * operator"""
    TYPE = 2
    """A CSS type, e.g  Label"""
    CLASS = 3
    """CSS class, e.g. .loaded"""
    ID = 4
    """CSS ID, e.g. #main"""
    NESTED = 5
    """Placeholder for nesting operator, i.e &"""


class CombinatorType(Enum):
    """Type of combinator."""

    SAME = 1
    """Selector is combined with previous selector"""
    DESCENDENT = 2
    """Selector is a descendant of the previous selector"""
    CHILD = 3
    """Selector is an immediate child of the previous selector"""


def _check_universal(name: str, node: DOMNode) -> bool:
    """Check node matches universal selector.

    Args:
        name: Selector name.
        node: A DOM node.

    Returns:
        `True` if the selector matches.
    """
    return not node.has_class("-textual-system")


def _check_type(name: str, node: DOMNode) -> bool:
    """Check node matches a type selector.

    Args:
        name: Selector name.
        node: A DOM node.

    Returns:
        `True` if the selector matches.
    """
    return name in node._css_type_names


def _check_class(name: str, node: DOMNode) -> bool:
    """Check node matches a class selector.

    Args:
        name: Selector name.
        node: A DOM node.

    Returns:
        `True` if the selector matches.
    """
    return name in node._classes


def _check_id(name: str, node: DOMNode) -> bool:
    """Check node matches an ID selector.

    Args:
        name: Selector name.
        node: A DOM node.

    Returns:
        `True` if the selector matches.
    """
    return node.id == name


_CHECKS = {
    SelectorType.UNIVERSAL: _check_universal,
    SelectorType.TYPE: _check_type,
    SelectorType.CLASS: _check_class,
    SelectorType.ID: _check_id,
    SelectorType.NESTED: _check_universal,
}


@dataclass
class Selector:
    """Represents a CSS selector.

    Some examples of selectors:

    *
    Header.title
    App > Content
    """

    name: str
    combinator: CombinatorType = CombinatorType.DESCENDENT
    type: SelectorType = SelectorType.TYPE
    pseudo_classes: set[str] = field(default_factory=set)
    specificity: Specificity3 = field(default_factory=lambda: (0, 0, 0))
    advance: int = 1

    def __post_init__(self) -> None:
        self._check: Callable[[DOMNode], bool] = partial(_CHECKS[self.type], self.name)

    @property
    def css(self) -> str:
        """Rebuilds the selector as it would appear in CSS."""
        pseudo_suffix = "".join(f":{name}" for name in sorted(self.pseudo_classes))
        if self.type == SelectorType.UNIVERSAL:
            return "*"
        elif self.type == SelectorType.TYPE:
            return f"{self.name}{pseudo_suffix}"
        elif self.type == SelectorType.CLASS:
            return f".{self.name}{pseudo_suffix}"
        else:
            return f"#{self.name}{pseudo_suffix}"

    def _add_pseudo_class(self, pseudo_class: str) -> None:
        """Adds a pseudo class and updates specificity.

        Args:
            pseudo_class: Name of pseudo class.
        """
        self.pseudo_classes.add(pseudo_class)
        specificity1, specificity2, specificity3 = self.specificity
        self.specificity = (specificity1, specificity2 + 1, specificity3)

    def check(self, node: DOMNode) -> bool:
        """Check if a given node matches the selector.

        Args:
            node: A DOM node.

        Returns:
            True if the selector matches, otherwise False.
        """
        return self._check(node) and (
            node.has_pseudo_classes(self.pseudo_classes)
            if self.pseudo_classes
            else True
        )


@dataclass
class Declaration:
    """A single CSS declaration (not yet processed)."""

    token: Token
    name: str
    tokens: list[Token] = field(default_factory=list)


@rich.repr.auto(angular=True)
@dataclass
class SelectorSet:
    """A set of selectors associated with a rule set."""

    selectors: list[Selector] = field(default_factory=list)
    specificity: Specificity3 = (0, 0, 0)

    def __post_init__(self) -> None:
        SAME = CombinatorType.SAME
        for selector, next_selector in zip(self.selectors, self.selectors[1:]):
            selector.advance = int(next_selector.combinator != SAME)

    @property
    def css(self) -> str:
        return RuleSet._selector_to_css(self.selectors)

    @property
    def is_simple(self) -> bool:
        """Are all the selectors simple (i.e. only dependent on static DOM state)."""
        simple_types = {SelectorType.ID, SelectorType.TYPE}
        return all(
            (selector.type in simple_types and not selector.pseudo_classes)
            for selector in self.selectors
        )

    def __rich_repr__(self) -> rich.repr.Result:
        selectors = RuleSet._selector_to_css(self.selectors)
        yield selectors
        yield None, self.specificity

    def _total_specificity(self) -> Self:
        """Calculate total specificity of selectors.

        Returns:
            Self.
        """
        id_total = class_total = type_total = 0
        for selector in self.selectors:
            _id, _class, _type = selector.specificity
            id_total += _id
            class_total += _class
            type_total += _type
        self.specificity = (id_total, class_total, type_total)
        return self

    @classmethod
    def from_selectors(cls, selectors: list[list[Selector]]) -> Iterable[SelectorSet]:
        for selector_list in selectors:
            id_total = class_total = type_total = 0
            for selector in selector_list:
                _id, _class, _type = selector.specificity
                id_total += _id
                class_total += _class
                type_total += _type
            yield SelectorSet(selector_list, (id_total, class_total, type_total))


@dataclass
class RuleSet:
    selector_set: list[SelectorSet] = field(default_factory=list)
    styles: Styles = field(default_factory=Styles)
    errors: list[tuple[Token, str | HelpText]] = field(default_factory=list)

    is_default_rules: bool = False
    tie_breaker: int = 0
    selector_names: set[str] = field(default_factory=set)
    pseudo_classes: set[str] = field(default_factory=set)

    def __hash__(self):
        return id(self)

    @classmethod
    def _selector_to_css(cls, selectors: list[Selector]) -> str:
        tokens: list[str] = []
        for selector in selectors:
            if selector.combinator == CombinatorType.DESCENDENT:
                tokens.append(" ")
            elif selector.combinator == CombinatorType.CHILD:
                tokens.append(" > ")
            tokens.append(selector.css)

        return "".join(tokens).strip()

    @property
    def selectors(self):
        return ", ".join(
            self._selector_to_css(selector_set.selectors)
            for selector_set in self.selector_set
        )

    @property
    def css(self) -> str:
        """Generate the CSS this RuleSet

        Returns:
            A string containing CSS code.
        """
        declarations = "\n".join(f"    {line}" for line in self.styles.css_lines)
        css = f"{self.selectors} {{\n{declarations}\n}}"
        return css

    def _post_parse(self) -> None:
        """Called after the RuleSet is parsed."""
        # Build a set of the class names that have been updated

        class_type = SelectorType.CLASS
        id_type = SelectorType.ID
        type_type = SelectorType.TYPE
        universal_type = SelectorType.UNIVERSAL

        add_selector = self.selector_names.add
        add_pseudo_classes = self.pseudo_classes.update

        for selector_set in self.selector_set:
            for selector in selector_set.selectors:
                add_pseudo_classes(selector.pseudo_classes)

            selector = selector_set.selectors[-1]
            selector_type = selector.type
            if selector_type == universal_type:
                add_selector("*")
            elif selector_type == type_type:
                add_selector(selector.name)
            elif selector_type == class_type:
                add_selector(f".{selector.name}")
            elif selector_type == id_type:
                add_selector(f"#{selector.name}")
