"""Provides a tree widget."""

from __future__ import annotations

from dataclasses import dataclass
from typing import TYPE_CHECKING, ClassVar, Generic, Iterable, NewType, TypeVar, cast

import rich.repr
from rich.style import NULL_STYLE, Style
from rich.text import Text, TextType

from textual import events, on
from textual._immutable_sequence_view import ImmutableSequenceView
from textual._loop import loop_last
from textual._segment_tools import line_pad
from textual.binding import Binding, BindingType
from textual.cache import LRUCache
from textual.geometry import Region, Size, clamp
from textual.message import Message
from textual.reactive import reactive, var
from textual.scroll_view import ScrollView
from textual.strip import Strip

if TYPE_CHECKING:
    from typing_extensions import Self, TypeAlias

NodeID = NewType("NodeID", int)
"""The type of an ID applied to a [TreeNode][textual.widgets._tree.TreeNode]."""

TreeDataType = TypeVar("TreeDataType")
"""The type of the data for a given instance of a [Tree][textual.widgets.Tree]."""

EventTreeDataType = TypeVar("EventTreeDataType")
"""The type of the data for a given instance of a [Tree][textual.widgets.Tree].

Similar to [TreeDataType][textual.widgets._tree.TreeDataType] but used for
``Tree`` messages.
"""

LineCacheKey: TypeAlias = "tuple[int | tuple, ...]"

TOGGLE_STYLE = Style.from_meta({"toggle": True})


class RemoveRootError(Exception):
    """Exception raised when trying to remove the root of a [`TreeNode`][textual.widgets.tree.TreeNode]."""


class UnknownNodeID(Exception):
    """Exception raised when referring to an unknown [`TreeNode`][textual.widgets.tree.TreeNode] ID."""


class AddNodeError(Exception):
    """Exception raised when there is an error with a request to add a node."""


@dataclass
class _TreeLine(Generic[TreeDataType]):
    path: list[TreeNode[TreeDataType]]
    last: bool

    @property
    def node(self) -> TreeNode[TreeDataType]:
        """The node associated with this line."""
        return self.path[-1]

    def _get_guide_width(self, guide_depth: int, show_root: bool) -> int:
        """Get the cell width of the line as rendered.

        Args:
            guide_depth: The guide depth (cells in the indentation).

        Returns:
            Width in cells.
        """
        if show_root:
            width = (max(0, len(self.path) - 1)) * guide_depth
        else:
            width = 0
            if len(self.path) > 1:
                width += (len(self.path) - 1) * guide_depth

        return width


class TreeNodes(ImmutableSequenceView["TreeNode[TreeDataType]"]):
    """An immutable collection of `TreeNode`."""


@rich.repr.auto
class TreeNode(Generic[TreeDataType]):
    """An object that represents a "node" in a tree control."""

    def __init__(
        self,
        tree: Tree[TreeDataType],
        parent: TreeNode[TreeDataType] | None,
        id: NodeID,
        label: Text,
        data: TreeDataType | None = None,
        *,
        expanded: bool = True,
        allow_expand: bool = True,
    ) -> None:
        """Initialise the node.

        Args:
            tree: The tree that the node is being attached to.
            parent: The parent node that this node is being attached to.
            id: The ID of the node.
            label: The label for the node.
            data: Optional data to associate with the node.
            expanded: Should the node be attached in an expanded state?
            allow_expand: Should the node allow being expanded by the user?
        """
        self._tree = tree
        self._parent = parent
        self._id = id
        self._label = tree.process_label(label)
        self.data = data
        """Optional data associated with the tree node."""
        self._expanded = expanded
        self._children: list[TreeNode[TreeDataType]] = []

        self._hover_ = False
        self._selected_ = False
        self._allow_expand = allow_expand
        self._updates: int = 0
        self._line: int = -1

    def __rich_repr__(self) -> rich.repr.Result:
        yield self._label.plain
        yield self.data

    def _reset(self) -> None:
        self._hover_ = False
        self._selected_ = False
        self._updates += 1

    @property
    def tree(self) -> Tree[TreeDataType]:
        """The tree that this node is attached to."""
        return self._tree

    @property
    def children(self) -> TreeNodes[TreeDataType]:
        """The child nodes of a TreeNode."""
        return TreeNodes(self._children)

    @property
    def siblings(self) -> TreeNodes[TreeDataType]:
        """The siblings of this node (includes self)."""
        if self.parent is None:
            return TreeNodes([self])
        else:
            return self.parent.children

    @property
    def line(self) -> int:
        """The line number for this node, or -1 if it is not displayed."""
        return self._line

    @property
    def _hover(self) -> bool:
        """Check if the mouse is over the node."""
        return self._hover_

    @_hover.setter
    def _hover(self, hover: bool) -> None:
        self._updates += 1
        self._hover_ = hover

    @property
    def _selected(self) -> bool:
        """Check if the node is selected."""
        return self._selected_

    @_selected.setter
    def _selected(self, selected: bool) -> None:
        self._updates += 1
        self._selected_ = selected

    @property
    def id(self) -> NodeID:
        """The ID of the  node."""
        return self._id

    @property
    def parent(self) -> TreeNode[TreeDataType] | None:
        """The parent of the node."""
        return self._parent

    @property
    def next_sibling(self) -> TreeNode[TreeDataType] | None:
        """The next sibling below the node."""
        siblings = self.siblings
        index = siblings.index(self) + 1
        try:
            return siblings[index]
        except IndexError:
            return None

    @property
    def previous_sibling(self) -> TreeNode[TreeDataType] | None:
        """The previous sibling below the node."""
        siblings = self.siblings
        index = siblings.index(self) - 1
        if index < 0:
            return None
        try:
            return siblings[index]
        except IndexError:
            return None

    @property
    def is_expanded(self) -> bool:
        """Is the node expanded?"""
        return self._expanded

    @property
    def is_collapsed(self) -> bool:
        """Is the node collapsed?"""
        return not self._expanded

    @property
    def is_last(self) -> bool:
        """Is this the last child node of its parent?"""
        if self._parent is None:
            return True
        return bool(
            self._parent._children and self._parent._children[-1] == self,
        )

    @property
    def is_root(self) -> bool:
        """Is this node the root of the tree?"""
        return self == self._tree.root

    @property
    def allow_expand(self) -> bool:
        """Is this node allowed to expand?"""
        return self._allow_expand

    @allow_expand.setter
    def allow_expand(self, allow_expand: bool) -> None:
        self._allow_expand = allow_expand
        self._updates += 1

    def _expand(self, expand_all: bool) -> None:
        """Mark the node as expanded (its children are shown).

        Args:
            expand_all: If `True` expand all offspring at all depths.
        """
        self._expanded = True
        self._updates += 1
        self._tree.post_message(Tree.NodeExpanded(self).set_sender(self._tree))
        if expand_all:
            for child in self.children:
                child._expand(expand_all)

    def expand(self) -> Self:
        """Expand the node (show its children).

        Returns:
            The `TreeNode` instance.
        """
        self._expand(False)
        self._tree._invalidate()
        return self

    def expand_all(self) -> Self:
        """Expand the node (show its children) and all those below it.

        Returns:
            The `TreeNode` instance.
        """
        self._expand(True)
        self._tree._invalidate()
        return self

    def _collapse(self, collapse_all: bool) -> None:
        """Mark the node as collapsed (its children are hidden).

        Args:
            collapse_all: If `True` collapse all offspring at all depths.
        """
        self._expanded = False
        self._updates += 1
        self._tree.post_message(Tree.NodeCollapsed(self).set_sender(self._tree))
        if collapse_all:
            for child in self.children:
                child._collapse(collapse_all)

    def collapse(self) -> Self:
        """Collapse the node (hide its children).

        Returns:
            The `TreeNode` instance.
        """
        self._collapse(False)
        self._tree._invalidate()
        return self

    def collapse_all(self) -> Self:
        """Collapse the node (hide its children) and all those below it.

        Returns:
            The `TreeNode` instance.
        """
        self._collapse(True)
        self._tree._invalidate()
        return self

    def toggle(self) -> Self:
        """Toggle the node's expanded state.

        Returns:
            The `TreeNode` instance.
        """
        if self._expanded:
            self.collapse()
        else:
            self.expand()
        return self

    def toggle_all(self) -> Self:
        """Toggle the node's expanded state and make all those below it match.

        Returns:
            The `TreeNode` instance.
        """
        if self._expanded:
            self.collapse_all()
        else:
            self.expand_all()
        return self

    @property
    def label(self) -> TextType:
        """The label for the node."""
        return self._label

    @label.setter
    def label(self, new_label: TextType) -> None:
        self.set_label(new_label)

    def set_label(self, label: TextType) -> None:
        """Set a new label for the node.

        Args:
            label: A ``str`` or ``Text`` object with the new label.
        """
        self._updates += 1
        text_label = self._tree.process_label(label)
        self._label = text_label
        self._tree.call_later(self._tree._refresh_node, self)

    def add(
        self,
        label: TextType,
        data: TreeDataType | None = None,
        *,
        before: int | TreeNode[TreeDataType] | None = None,
        after: int | TreeNode[TreeDataType] | None = None,
        expand: bool = False,
        allow_expand: bool = True,
    ) -> TreeNode[TreeDataType]:
        """Add a node to the sub-tree.

        Args:
            label: The new node's label.
            data: Data associated with the new node.
            before: Optional index or `TreeNode` to add the node before.
            after: Optional index or `TreeNode` to add the node after.
            expand: Node should be expanded.
            allow_expand: Allow user to expand the node via keyboard or mouse.

        Returns:
            A new Tree node

        Raises:
            AddNodeError: If there is a problem with the addition request.

        Note:
            Only one of `before` or `after` can be provided. If both are
            provided a `AddNodeError` will be raised.
        """
        if before is not None and after is not None:
            raise AddNodeError("Unable to add a node both before and after a node")

        insert_index: int = len(self.children)

        if before is not None:
            if isinstance(before, int):
                insert_index = before
            elif isinstance(before, TreeNode):
                try:
                    insert_index = self.children.index(before)
                except ValueError:
                    raise AddNodeError(
                        "The node specified for `before` is not a child of this node"
                    )
            else:
                raise TypeError(
                    "`before` argument must be an index or a TreeNode object to add before"
                )

        if after is not None:
            if isinstance(after, int):
                insert_index = after + 1
                if after < 0:
                    insert_index += len(self.children)
            elif isinstance(after, TreeNode):
                try:
                    insert_index = self.children.index(after) + 1
                except ValueError:
                    raise AddNodeError(
                        "The node specified for `after` is not a child of this node"
                    )
            else:
                raise TypeError(
                    "`after` argument must be an index or a TreeNode object to add after"
                )

        text_label = self._tree.process_label(label)
        node = self._tree._add_node(self, text_label, data)
        node._expanded = expand
        node._allow_expand = allow_expand
        self._updates += 1
        self._children.insert(insert_index, node)
        self._tree._invalidate()

        return node

    def add_leaf(
        self,
        label: TextType,
        data: TreeDataType | None = None,
        *,
        before: int | TreeNode[TreeDataType] | None = None,
        after: int | TreeNode[TreeDataType] | None = None,
    ) -> TreeNode[TreeDataType]:
        """Add a 'leaf' node (a node that can not expand).

        Args:
            label: Label for the node.
            data: Optional data.
            before: Optional index or `TreeNode` to add the node before.
            after: Optional index or `TreeNode` to add the node after.

        Returns:
            New node.

        Raises:
            AddNodeError: If there is a problem with the addition request.

        Note:
            Only one of `before` or `after` can be provided. If both are
            provided a `AddNodeError` will be raised.
        """
        node = self.add(
            label,
            data,
            before=before,
            after=after,
            expand=False,
            allow_expand=False,
        )
        return node

    def _remove_children(self) -> None:
        """Remove child nodes of this node.

        Note:
            This is the internal support method for `remove_children`. Call
            `remove_children` to ensure the tree gets refreshed.
        """
        for child in reversed(self._children):
            child._remove()

    def _remove(self) -> None:
        """Remove the current node and all its children.

        Note:
            This is the internal support method for `remove`. Call `remove`
            to ensure the tree gets refreshed.
        """
        self._remove_children()
        assert self._parent is not None
        del self._parent._children[self._parent._children.index(self)]
        del self._tree._tree_nodes[self.id]

    def remove(self) -> None:
        """Remove this node from the tree.

        Raises:
            RemoveRootError: If there is an attempt to remove the root.
        """
        if self.is_root:
            raise RemoveRootError("Attempt to remove the root node of a Tree.")
        self._remove()
        self._tree._invalidate()

    def remove_children(self) -> None:
        """Remove any child nodes of this node."""
        self._remove_children()
        self._tree._invalidate()

    def refresh(self) -> None:
        """Initiate a refresh (repaint) of this node."""
        self._updates += 1
        self._tree._refresh_line(self._line)


class Tree(Generic[TreeDataType], ScrollView, can_focus=True):
    """A widget for displaying and navigating data in a tree."""

    ICON_NODE = "▶ "
    """Unicode 'icon' to use for an expandable node."""
    ICON_NODE_EXPANDED = "▼ "
    """Unicode 'icon' to use for an expanded node."""

    BINDINGS: ClassVar[list[BindingType]] = [
        Binding("shift+left", "cursor_parent", "Cursor to parent", show=False),
        Binding(
            "shift+right",
            "cursor_parent_next_sibling",
            "Cursor to next ancestor",
            show=False,
        ),
        Binding(
            "shift+up",
            "cursor_previous_sibling",
            "Cursor to previous sibling",
            show=False,
        ),
        Binding(
            "shift+down",
            "cursor_next_sibling",
            "Cursor to next sibling",
            show=False,
        ),
        Binding("enter", "select_cursor", "Select", show=False),
        Binding("space", "toggle_node", "Toggle", show=False),
        Binding(
            "shift+space", "toggle_expand_all", "Expand or collapse all", show=False
        ),
        Binding("up", "cursor_up", "Cursor Up", show=False),
        Binding("down", "cursor_down", "Cursor Down", show=False),
    ]
    """
    | Key(s) | Description |
    | :- | :- |
    | enter | Select the current item. |
    | space | Toggle the expand/collapsed space of the current item. |
    | up | Move the cursor up. |
    | down | Move the cursor down. |
    """

    COMPONENT_CLASSES: ClassVar[set[str]] = {
        "tree--cursor",
        "tree--guides",
        "tree--guides-hover",
        "tree--guides-selected",
        "tree--highlight",
        "tree--highlight-line",
        "tree--label",
    }
    """
    | Class | Description |
    | :- | :- |
    | `tree--cursor` | Targets the cursor. |
    | `tree--guides` | Targets the indentation guides. |
    | `tree--guides-hover` | Targets the indentation guides under the cursor. |
    | `tree--guides-selected` | Targets the indentation guides that are selected. |
    | `tree--highlight` | Targets the highlighted items. |
    | `tree--highlight-line` | Targets the lines under the cursor. |
    | `tree--label` | Targets the (text) labels of the items. |
    """

    DEFAULT_CSS = """
    Tree {
        background: $surface;
        color: $foreground;

        & > .tree--label {}
        & > .tree--guides {
            color: $surface-lighten-2;
        }
        & > .tree--guides-hover {
            color: $surface-lighten-2;
        }
        & > .tree--guides-selected {
            color: $block-cursor-blurred-background;
        }
        & > .tree--cursor {
            text-style: $block-cursor-blurred-text-style;
            background: $block-cursor-blurred-background;
        }
        & > .tree--highlight {}
        & > .tree--highlight-line {
            background: $block-hover-background;
        }

        &:focus {
            background-tint: $foreground 5%;
            & > .tree--cursor {
                color: $block-cursor-foreground;
                background: $block-cursor-background;
                text-style: $block-cursor-text-style;
            }
            & > .tree--guides {
                color: $surface-lighten-3;
            }
            & > .tree--guides-hover {
                color: $surface-lighten-3;
            }
            & > .tree--guides-selected {
                color: $block-cursor-background;
            }
        }

        &:light {
            /* In light mode the guides are darker*/
            & > .tree--guides {
                color: $surface-darken-1;
            }
            & > .tree--guides-hover {
                color: $block-cursor-background;
            }
            & > .tree--guides-selected {
                color: $block-cursor-background;
            }
        }

        &:ansi {
            color: ansi_default;
            & > .tree--guides {
                color: ansi_green;
            }
            &:nocolor > .tree--cursor{
                text-style: reverse;
            }
        }
    }

    """

    show_root = reactive(True)
    """Show the root of the tree."""
    hover_line = var(-1)
    """The line number under the mouse pointer, or -1 if not under the mouse pointer."""
    cursor_line = var(-1, always_update=True)
    """The line with the cursor, or -1 if no cursor."""
    show_guides = reactive(True)
    """Enable display of tree guide lines."""
    guide_depth = reactive(4, init=False)
    """The indent depth of tree nodes."""
    auto_expand = var(True)
    """Auto expand tree nodes when they are selected."""
    center_scroll = var(False)
    """Keep selected node in the center of the control, where possible."""

    LINES: dict[str, tuple[str, str, str, str]] = {
        "default": (
            "  ",
            "│ ",
            "└─",
            "├─",
        ),
        "bold": (
            "  ",
            "┃ ",
            "┗━",
            "┣━",
        ),
        "double": (
            "  ",
            "║ ",
            "╚═",
            "╠═",
        ),
    }

    class NodeCollapsed(Generic[EventTreeDataType], Message):
        """Event sent when a node is collapsed.

        Can be handled using `on_tree_node_collapsed` in a subclass of `Tree` or in a
        parent node in the DOM.
        """

        def __init__(self, node: TreeNode[EventTreeDataType]) -> None:
            self.node: TreeNode[EventTreeDataType] = node
            """The node that was collapsed."""
            super().__init__()

        @property
        def control(self) -> Tree[EventTreeDataType]:
            """The tree that sent the message."""
            return self.node.tree

    class NodeExpanded(Generic[EventTreeDataType], Message):
        """Event sent when a node is expanded.

        Can be handled using `on_tree_node_expanded` in a subclass of `Tree` or in a
        parent node in the DOM.
        """

        def __init__(self, node: TreeNode[EventTreeDataType]) -> None:
            self.node: TreeNode[EventTreeDataType] = node
            """The node that was expanded."""
            super().__init__()

        @property
        def control(self) -> Tree[EventTreeDataType]:
            """The tree that sent the message."""
            return self.node.tree

    class NodeHighlighted(Generic[EventTreeDataType], Message):
        """Event sent when a node is highlighted.

        Can be handled using `on_tree_node_highlighted` in a subclass of `Tree` or in a
        parent node in the DOM.
        """

        def __init__(self, node: TreeNode[EventTreeDataType]) -> None:
            self.node: TreeNode[EventTreeDataType] = node
            """The node that was highlighted."""
            super().__init__()

        @property
        def control(self) -> Tree[EventTreeDataType]:
            """The tree that sent the message."""
            return self.node.tree

    class NodeSelected(Generic[EventTreeDataType], Message):
        """Event sent when a node is selected.

        Can be handled using `on_tree_node_selected` in a subclass of `Tree` or in a
        parent node in the DOM.
        """

        def __init__(self, node: TreeNode[EventTreeDataType]) -> None:
            self.node: TreeNode[EventTreeDataType] = node
            """The node that was selected."""
            super().__init__()

        @property
        def control(self) -> Tree[EventTreeDataType]:
            """The tree that sent the message."""
            return self.node.tree

    def __init__(
        self,
        label: TextType,
        data: TreeDataType | None = None,
        *,
        name: str | None = None,
        id: str | None = None,
        classes: str | None = None,
        disabled: bool = False,
    ) -> None:
        """Initialise a Tree.

        Args:
            label: The label of the root node of the tree.
            data: The optional data to associate with the root node of the tree.
            name: The name of the Tree.
            id: The ID of the tree in the DOM.
            classes: The CSS classes of the tree.
            disabled: Whether the tree is disabled or not.
        """

        text_label = self.process_label(label)

        self._updates = 0
        self._tree_nodes: dict[NodeID, TreeNode[TreeDataType]] = {}
        self._current_id = 0
        self.root = self._add_node(None, text_label, data)
        """The root node of the tree."""
        self._line_cache: LRUCache[LineCacheKey, Strip] = LRUCache(1024)
        self._tree_lines_cached: list[_TreeLine[TreeDataType]] | None = None
        self._cursor_node: TreeNode[TreeDataType] | None = None

        super().__init__(name=name, id=id, classes=classes, disabled=disabled)

    def add_json(self, json_data: object, node: TreeNode | None = None) -> None:
        """Adds JSON data to a node.

        Args:
            json_data: An object decoded from JSON.
            node: Node to add data to.

        """

        if node is None:
            node = self.root

        from rich.highlighter import ReprHighlighter

        highlighter = ReprHighlighter()

        def add_node(name: str, node: TreeNode, data: object) -> None:
            """Adds a node to the tree.

            Args:
                name: Name of the node.
                node: Parent node.
                data: Data associated with the node.
            """
            if isinstance(data, dict):
                node.set_label(Text(f"{{}} {name}"))
                for key, value in data.items():
                    new_node = node.add("")
                    add_node(key, new_node, value)
            elif isinstance(data, list):
                node.set_label(Text(f"[] {name}"))
                for index, value in enumerate(data):
                    new_node = node.add("")
                    add_node(str(index), new_node, value)
            else:
                node.allow_expand = False
                if name:
                    label = Text.assemble(
                        Text.from_markup(f"[b]{name}[/b]="), highlighter(repr(data))
                    )
                else:
                    label = Text(repr(data))
                node.set_label(label)

        add_node("", node, json_data)

    @property
    def cursor_node(self) -> TreeNode[TreeDataType] | None:
        """The currently selected node, or ``None`` if no selection."""
        return self._cursor_node

    @property
    def last_line(self) -> int:
        """The index of the last line."""
        return len(self._tree_lines) - 1

    def process_label(self, label: TextType) -> Text:
        """Process a `str` or `Text` value into a label.

        May be overridden in a subclass to change how labels are rendered.

        Args:
            label: Label.

        Returns:
            A Rich Text object.
        """
        if isinstance(label, str):
            text_label = Text.from_markup(label)
        else:
            text_label = label
        first_line = text_label.split()[0]
        return first_line

    def _add_node(
        self,
        parent: TreeNode[TreeDataType] | None,
        label: Text,
        data: TreeDataType | None,
        expand: bool = False,
    ) -> TreeNode[TreeDataType]:
        node = TreeNode(self, parent, self._new_id(), label, data, expanded=expand)
        self._tree_nodes[node._id] = node
        self._updates += 1
        return node

    def render_label(
        self, node: TreeNode[TreeDataType], base_style: Style, style: Style
    ) -> Text:
        """Render a label for the given node. Override this to modify how labels are rendered.

        Args:
            node: A tree node.
            base_style: The base style of the widget.
            style: The additional style for the label.

        Returns:
            A Rich Text object containing the label.
        """
        node_label = node._label.copy()
        node_label.stylize(style)

        if node._allow_expand:
            prefix = (
                self.ICON_NODE_EXPANDED if node.is_expanded else self.ICON_NODE,
                base_style + TOGGLE_STYLE,
            )
        else:
            prefix = ("", base_style)

        text = Text.assemble(prefix, node_label)
        return text

    def get_label_width(self, node: TreeNode[TreeDataType]) -> int:
        """Get the width of the nodes label.

        The default behavior is to call `render_label` and return the cell length. This method may be
        overridden in a sub-class if it can be done more efficiently.

        Args:
            node: A node.

        Returns:
            Width in cells.
        """
        label = self.render_label(node, NULL_STYLE, NULL_STYLE)
        return label.cell_len

    def _clear_line_cache(self) -> None:
        """Clear line cache."""
        self._line_cache.clear()
        self._tree_lines_cached = None

    def clear(self) -> Self:
        """Clear all nodes under root.

        Returns:
            The `Tree` instance.
        """
        self._clear_line_cache()
        self._current_id = 0
        root_label = self.root._label
        root_data = self.root.data
        root_expanded = self.root.is_expanded
        self.root = TreeNode(
            self,
            None,
            self._new_id(),
            root_label,
            root_data,
            expanded=root_expanded,
        )
        self._updates += 1
        self.refresh()
        return self

    def reset(self, label: TextType, data: TreeDataType | None = None) -> Self:
        """Clear the tree and reset the root node.

        Args:
            label: The label for the root node.
            data: Optional data for the root node.

        Returns:
            The `Tree` instance.
        """
        self.clear()
        self.root.label = label
        self.root.data = data
        return self

    def move_cursor(
        self, node: TreeNode[TreeDataType] | None, animate: bool = False
    ) -> None:
        """Move the cursor to the given node, or reset cursor.

        Args:
            node: A tree node, or None to reset cursor.
            animate: Enable animation
        """
        previous_cursor_line = self.cursor_line
        self.cursor_line = -1 if node is None else node._line
        if node is not None and self.cursor_node is not None:
            self.scroll_to_node(
                self.cursor_node,
                animate=animate and abs(self.cursor_line - previous_cursor_line) > 1,
            )

    def move_cursor_to_line(self, line: int, animate=False) -> None:
        """Move the cursor to the given line.

        Args:
            line: The line number (negative indexes are offsets from the last line).
            animate: Enable scrolling animation.

        Raises:
            IndexError: If the line doesn't exist.
        """
        if self.cursor_line == line:
            return
        try:
            node = self._tree_lines[line].node
        except IndexError:
            raise IndexError(f"No line no. {line} in the tree")
        self.move_cursor(node, animate=animate)

    def select_node(self, node: TreeNode[TreeDataType] | None) -> None:
        """Move the cursor to the given node and select it, or reset cursor.

        Args:
            node: A tree node to move the cursor to and select, or None to reset cursor.
        """
        self.move_cursor(node)
        if node is not None:
            self.post_message(Tree.NodeSelected(node))

    def unselect(self) -> None:
        """Hide and reset the cursor."""
        self.set_reactive(Tree.cursor_line, -1)
        self._invalidate()

    @on(NodeSelected)
    def _expand_node_on_select(self, event: NodeSelected[TreeDataType]) -> None:
        """When the node is selected, expand the node if `auto_expand` is True."""
        node = event.node
        if self.auto_expand:
            self._toggle_node(node)

    def get_node_at_line(self, line_no: int) -> TreeNode[TreeDataType] | None:
        """Get the node for a given line.

        Args:
            line_no: A line number.

        Returns:
            A tree node, or ``None`` if there is no node at that line.
        """
        try:
            line = self._tree_lines[line_no]
        except IndexError:
            return None
        else:
            return line.node

    def get_node_by_id(self, node_id: NodeID) -> TreeNode[TreeDataType]:
        """Get a tree node by its ID.

        Args:
            node_id: The ID of the node to get.

        Returns:
            The node associated with that ID.

        Raises:
            UnknownNodeID: Raised if the `TreeNode` ID is unknown.
        """
        try:
            return self._tree_nodes[node_id]
        except KeyError:
            raise UnknownNodeID(f"Unknown NodeID ({node_id}) in tree") from None

    def validate_cursor_line(self, value: int) -> int:
        """Prevent cursor line from going outside of range.

        Args:
            value: The value to test.

        Return:
            A valid version of the given value.
        """
        return clamp(value, 0, len(self._tree_lines) - 1)

    def validate_guide_depth(self, value: int) -> int:
        """Restrict guide depth to reasonable range.

        Args:
            value: The value to test.

        Return:
            A valid version of the given value.
        """
        return clamp(value, 2, 10)

    def _invalidate(self) -> None:
        """Invalidate caches."""
        self._clear_line_cache()
        self._updates += 1
        self.root._reset()
        self.refresh(layout=True)

    def _on_mouse_move(self, event: events.MouseMove) -> None:
        meta = event.style.meta
        if meta and "line" in meta:
            self.hover_line = meta["line"]
        else:
            self.hover_line = -1

    def _on_leave(self, _: events.Leave) -> None:
        # Ensure the hover effect doesn't linger after the mouse leaves.
        self.hover_line = -1

    def _new_id(self) -> NodeID:
        """Create a new node ID.

        Returns:
            A unique node ID.
        """
        id = self._current_id
        self._current_id += 1
        return NodeID(id)

    def _get_node(self, line: int) -> TreeNode[TreeDataType] | None:
        if line == -1:
            return None
        try:
            tree_line = self._tree_lines[line]
        except IndexError:
            return None
        else:
            return tree_line.node

    def _get_label_region(self, line: int) -> Region | None:
        """Returns the region occupied by the label of the node at line `line`.

        This can be used, e.g., when scrolling to that line such that the label
        is visible after the scroll.

        Args:
            line: A line number.

        Returns:
            The region occupied by the label, or `None` if the
            line is not in the tree.
        """
        try:
            tree_line = self._tree_lines[line]
        except IndexError:
            return None
        region_x = tree_line._get_guide_width(self.guide_depth, self.show_root)
        region_width = self.get_label_width(tree_line.node)
        return Region(region_x, line, region_width, 1)

    def watch_hover_line(self, previous_hover_line: int, hover_line: int) -> None:
        previous_node = self._get_node(previous_hover_line)
        if previous_node is not None:
            self._refresh_node(previous_node)
            previous_node._hover = False

        node = self._get_node(hover_line)
        if node is not None:
            self._refresh_node(node)
            node._hover = True

    def watch_cursor_line(self, previous_line: int, line: int) -> None:
        previous_node = self._get_node(previous_line)
        node = self._get_node(line)

        if self.cursor_node is not None:
            self.cursor_node._selected = False

        if previous_node is not None:
            previous_node._selected = False

        if node is not None:
            node._selected = True
            self._cursor_node = node
        else:
            self._cursor_node = None

        if previous_line == line:
            # No change, so no need for refresh
            return

        # Refresh previous cursor node
        if previous_node is not None:
            self._refresh_node(previous_node)

        # Refresh new node
        if node is not None:
            self._refresh_node(node)
            if previous_node != node:
                self.post_message(self.NodeHighlighted(node))

    def watch_guide_depth(self, guide_depth: int) -> None:
        self._invalidate()

    def watch_show_root(self, show_root: bool) -> None:
        self.cursor_line = -1
        self._invalidate()

    def scroll_to_line(self, line: int, animate: bool = True) -> None:
        """Scroll to the given line.

        Args:
            line: A line number.
            animate: Enable animation.
        """
        region = self._get_label_region(line)
        if region is not None:
            self.scroll_to_region(
                region,
                animate=animate,
                force=True,
                center=self.center_scroll,
                origin_visible=False,
                x_axis=False,  # Scrolling the X axis is quite jarring, and rarely necessary
            )

    def scroll_to_node(
        self, node: TreeNode[TreeDataType], animate: bool = True
    ) -> None:
        """Scroll to the given node.

        Args:
            node: Node to scroll in to view.
            animate: Animate scrolling.
        """
        line = node._line
        if line != -1:
            self.scroll_to_line(line, animate=animate)

    def _refresh_line(self, line: int) -> None:
        """Refresh (repaint) a given line in the tree.

        Args:
            line: Line number.
        """
        region = Region(0, line - self.scroll_offset.y, self.size.width, 1)
        self.refresh(region)

    def _refresh_node_line(self, line: int) -> None:
        node = self._get_node(line)
        if node is not None:
            self._refresh_node(node)

    def _refresh_node(self, node: TreeNode[TreeDataType]) -> None:
        """Refresh a node and all its children.

        Args:
            node: A tree node.
        """
        scroll_y = self.scroll_offset.y
        height = self.size.height
        visible_lines = self._tree_lines[scroll_y : scroll_y + height]
        for line_no, line in enumerate(visible_lines, scroll_y):
            if node in line.path:
                self._refresh_line(line_no)

    @property
    def _tree_lines(self) -> list[_TreeLine[TreeDataType]]:
        if self._tree_lines_cached is None:
            self._build()
        assert self._tree_lines_cached is not None
        return self._tree_lines_cached

    async def _on_idle(self, event: events.Idle) -> None:
        """Check tree needs a rebuild on idle."""
        # Property calls build if required
        async with self.lock:
            self._tree_lines

    def _build(self) -> None:
        """Builds the tree by traversing nodes, and creating tree lines."""

        TreeLine = _TreeLine
        lines: list[_TreeLine[TreeDataType]] = []
        add_line = lines.append

        root = self.root

        def add_node(
            path: list[TreeNode[TreeDataType]], node: TreeNode[TreeDataType], last: bool
        ) -> None:
            child_path = [*path, node]
            node._line = len(lines)
            add_line(TreeLine(child_path, last))
            if node._expanded:
                for last, child in loop_last(node._children):
                    add_node(child_path, child, last)

        if self.show_root:
            add_node([], root, True)
        else:
            for node in self.root._children:
                add_node([], node, True)
        self._tree_lines_cached = lines

        guide_depth = self.guide_depth
        show_root = self.show_root
        get_label_width = self.get_label_width

        def get_line_width(line: _TreeLine[TreeDataType]) -> int:
            return get_label_width(line.node) + line._get_guide_width(
                guide_depth, show_root
            )

        if lines:
            width = max([get_line_width(line) for line in lines])
        else:
            width = self.size.width

        self.virtual_size = Size(width, len(lines))
        if self.cursor_line != -1:
            if self.cursor_node is not None:
                self.cursor_line = self.cursor_node._line
            if self.cursor_line >= len(lines):
                self.cursor_line = -1

    def render_lines(self, crop: Region) -> list[Strip]:
        self._pseudo_class_state = self.get_pseudo_class_state()
        return super().render_lines(crop)

    def render_line(self, y: int) -> Strip:
        width = self.size.width
        scroll_x, scroll_y = self.scroll_offset
        style = self.rich_style
        return self._render_line(
            y + scroll_y,
            scroll_x,
            scroll_x + width,
            style,
        )

    def _render_line(self, y: int, x1: int, x2: int, base_style: Style) -> Strip:
        tree_lines = self._tree_lines
        width = self.size.width

        if y >= len(tree_lines):
            return Strip.blank(width, base_style)

        line = tree_lines[y]

        is_hover = self.hover_line >= 0 and any(node._hover for node in line.path)

        cache_key = (
            y,
            is_hover,
            width,
            self._updates,
            self._pseudo_class_state,
            tuple(node._updates for node in line.path),
        )
        if cache_key in self._line_cache:
            strip = self._line_cache[cache_key]
        else:
            # Allow tree guides to be explicitly disabled by setting color to transparent
            base_hidden = self.get_component_styles("tree--guides").color.a == 0
            hover_hidden = self.get_component_styles("tree--guides-hover").color.a == 0
            selected_hidden = (
                self.get_component_styles("tree--guides-selected").color.a == 0
            )

            base_guide_style = self.get_component_rich_style(
                "tree--guides", partial=True
            )
            guide_hover_style = base_guide_style + self.get_component_rich_style(
                "tree--guides-hover", partial=True
            )
            guide_selected_style = base_guide_style + self.get_component_rich_style(
                "tree--guides-selected", partial=True
            )

            hover = line.path[0]._hover
            selected = line.path[0]._selected and self.has_focus

            def get_guides(style: Style, hidden: bool) -> tuple[str, str, str, str]:
                """Get the guide strings for a given style.

                Args:
                    style: A Style object.
                    hidden: Switch to hide guides (make them invisible).

                Returns:
                    Strings for space, vertical, terminator and cross.
                """
                lines: tuple[Iterable[str], Iterable[str], Iterable[str], Iterable[str]]
                if self.show_guides and not hidden:
                    lines = self.LINES["default"]
                    if style.bold:
                        lines = self.LINES["bold"]
                    elif style.underline2:
                        lines = self.LINES["double"]
                else:
                    lines = ("  ", "  ", "  ", "  ")

                guide_depth = max(0, self.guide_depth - 2)
                guide_lines = tuple(
                    f"{characters[0]}{characters[1] * guide_depth} "
                    for characters in lines
                )
                return cast("tuple[str, str, str, str]", guide_lines)

            if is_hover:
                line_style = self.get_component_rich_style("tree--highlight-line")
            else:
                line_style = base_style

            line_style += Style(meta={"line": y})

            guides = Text(style=line_style)
            guides_append = guides.append

            guide_style = base_guide_style

            hidden = True
            for node in line.path[1:]:
                hidden = base_hidden
                if hover:
                    guide_style = guide_hover_style
                    hidden = hover_hidden
                if selected:
                    guide_style = guide_selected_style
                    hidden = selected_hidden

                space, vertical, _, _ = get_guides(guide_style, hidden)
                guide = space if node.is_last else vertical
                if node != line.path[-1]:
                    guides_append(guide, style=guide_style)
                hover = hover or node._hover
                selected = (selected or node._selected) and self.has_focus

            if len(line.path) > 1:
                _, _, terminator, cross = get_guides(guide_style, hidden)
                if line.last:
                    guides.append(terminator, style=guide_style)
                else:
                    guides.append(cross, style=guide_style)

            label_style = self.get_component_rich_style("tree--label", partial=True)
            if self.hover_line == y:
                label_style += self.get_component_rich_style(
                    "tree--highlight", partial=True
                )
            if self.cursor_line == y:
                label_style += self.get_component_rich_style(
                    "tree--cursor", partial=False
                )

            label = self.render_label(line.path[-1], line_style, label_style).copy()
            label.stylize(Style(meta={"node": line.node._id}))
            guides.append(label)

            segments = list(guides.render(self.app.console))
            pad_width = max(self.virtual_size.width, width)
            segments = line_pad(segments, 0, pad_width - guides.cell_len, line_style)
            strip = self._line_cache[cache_key] = Strip(segments)

        strip = strip.crop(x1, x2)
        return strip

    def _on_resize(self, event: events.Resize) -> None:
        self._line_cache.grow(event.size.height)
        self._invalidate()

    def _toggle_node(self, node: TreeNode[TreeDataType]) -> None:
        if not node.allow_expand:
            return
        if node.is_expanded:
            node.collapse()
        else:
            node.expand()

    async def _on_click(self, event: events.Click) -> None:
        async with self.lock:
            meta = event.style.meta
            if "line" in meta:
                cursor_line = meta["line"]
                if meta.get("toggle", False):
                    node = self.get_node_at_line(cursor_line)
                    if node is not None:
                        self._toggle_node(node)

                else:
                    self.cursor_line = cursor_line
                    await self.run_action("select_cursor")

    def notify_style_update(self) -> None:
        self._invalidate()

    def action_cursor_up(self) -> None:
        """Move the cursor up one node."""
        if self.cursor_line == -1:
            self.cursor_line = self.last_line
        else:
            self.cursor_line -= 1
        self.scroll_to_line(self.cursor_line, animate=False)

    def action_cursor_down(self) -> None:
        """Move the cursor down one node."""
        if self.cursor_line == -1:
            self.cursor_line = 0
        else:
            self.cursor_line += 1
        self.scroll_to_line(self.cursor_line, animate=False)

    def action_page_down(self) -> None:
        """Move the cursor down a page's-worth of nodes."""
        if self.cursor_line == -1:
            self.cursor_line = 0
        self.cursor_line += self.scrollable_content_region.height - 1
        self.scroll_to_line(self.cursor_line)

    def action_page_up(self) -> None:
        """Move the cursor up a page's-worth of nodes."""
        if self.cursor_line == -1:
            self.cursor_line = self.last_line
        self.cursor_line -= self.scrollable_content_region.height - 1
        self.scroll_to_line(self.cursor_line)

    def action_scroll_home(self) -> None:
        """Move the cursor to the top of the tree."""
        self.cursor_line = 0
        self.scroll_to_line(self.cursor_line)

    def action_scroll_end(self) -> None:
        """Move the cursor to the bottom of the tree.

        Note:
            Here bottom means vertically, not branch depth.
        """
        self.cursor_line = self.last_line
        self.scroll_to_line(self.cursor_line)

    def action_toggle_node(self) -> None:
        """Toggle the expanded state of the target node."""
        try:
            line = self._tree_lines[self.cursor_line]
        except IndexError:
            pass
        else:
            self._toggle_node(line.path[-1])

    def action_select_cursor(self) -> None:
        """Cause a select event for the target node.

        Note:
            If `auto_expand` is `True` use of this action on a non-leaf node
            will cause both an expand/collapse event to occur, as well as a
            selected event.
        """
        if self.cursor_line < 0:
            return
        try:
            line = self._tree_lines[self.cursor_line]
        except IndexError:
            pass
        else:
            node = line.path[-1]
            self.post_message(Tree.NodeSelected(node))

    def action_cursor_parent(self) -> None:
        """Move the cursor to the parent node."""
        cursor_node = self.cursor_node
        if cursor_node is not None and cursor_node.parent is not None:
            self.move_cursor(cursor_node.parent, animate=True)

    def action_cursor_parent_next_sibling(self) -> None:
        """Move the cursor to the parent's next sibling."""
        cursor_node = self.cursor_node
        if cursor_node is not None and cursor_node.parent is not None:
            self.move_cursor(cursor_node.parent.next_sibling, animate=True)

    def action_cursor_previous_sibling(self) -> None:
        """Move the cursor to previous sibling, or to the parent if there are no more siblings."""
        cursor_node = self.cursor_node
        if cursor_node is not None:
            previous_sibling = cursor_node.previous_sibling
            if previous_sibling is None:
                self.move_cursor(cursor_node.parent, animate=True)
            else:
                self.move_cursor(previous_sibling, animate=True)

    def action_cursor_next_sibling(self) -> None:
        """Move the cursor to the next sibling, or to the paren't sibling if there are no more siblings."""
        cursor_node = self.cursor_node
        if cursor_node is not None:
            next_sibling = cursor_node.next_sibling
            if next_sibling is None:
                if cursor_node.parent is not None:
                    parent_sibling = cursor_node.parent.next_sibling
                    self.move_cursor(parent_sibling, animate=True)
            else:
                self.move_cursor(next_sibling, animate=True)

    def action_toggle_expand_all(self) -> None:
        """Expand or collapse all siblings.

        If all the siblings are collapsed then they will be expanded.
        Otherwise they will all be collapsed.

        """

        if self.cursor_node is None or self.cursor_node.parent is None:
            return

        siblings = self.cursor_node.siblings
        cursor_node = self.cursor_node

        # If all siblings are collapsed we want to expand them all
        if all(child.is_collapsed for child in siblings):
            for child in siblings:
                if child.allow_expand:
                    child.expand()
        # Otherwise we want to collapse them all
        else:
            for child in siblings:
                if child.allow_expand:
                    child.collapse()

        self.call_after_refresh(self.move_cursor, cursor_node, animate=False)
