"""
This module contains the `Worker` class and related objects.

See the guide for how to use [workers](/guide/workers).

"""

from __future__ import annotations

import asyncio
import enum
import inspect
from contextvars import ContextVar
from threading import Event
from time import monotonic
from typing import (
    TYPE_CHECKING,
    Awaitable,
    Callable,
    Coroutine,
    Generic,
    TypeVar,
    Union,
    cast,
)

import rich.repr
from typing_extensions import TypeAlias

from textual.message import Message

if TYPE_CHECKING:
    from textual.app import App
    from textual.dom import DOMNode


active_worker: ContextVar[Worker] = ContextVar("active_worker")
"""Currently active worker context var."""


class NoActiveWorker(Exception):
    """There is no active worker."""


class WorkerError(Exception):
    """A worker related error."""


class WorkerFailed(WorkerError):
    """The worker raised an exception and did not complete."""

    def __init__(self, error: BaseException) -> None:
        self.error = error
        super().__init__(f"Worker raised exception: {error!r}")


class DeadlockError(WorkerError):
    """The operation would result in a deadlock."""


class WorkerCancelled(WorkerError):
    """The worker was cancelled and did not complete."""


def get_current_worker() -> Worker:
    """Get the currently active worker.

    Raises:
        NoActiveWorker: If there is no active worker.

    Returns:
        A Worker instance.
    """
    try:
        return active_worker.get()
    except LookupError:
        raise NoActiveWorker(
            "There is no active worker in this task or thread."
        ) from None


class WorkerState(enum.Enum):
    """A description of the worker's current state."""

    PENDING = 1
    """Worker is initialized, but not running."""
    RUNNING = 2
    """Worker is running."""
    CANCELLED = 3
    """Worker is not running, and was cancelled."""
    ERROR = 4
    """Worker is not running, and exited with an error."""
    SUCCESS = 5
    """Worker is not running, and completed successfully."""


ResultType = TypeVar("ResultType")


WorkType: TypeAlias = Union[
    Callable[[], Coroutine[None, None, ResultType]],
    Callable[[], ResultType],
    Awaitable[ResultType],
]
"""Type used for [workers](/guide/workers/)."""


class _ReprText:
    """Shim to insert a word into the Worker's repr."""

    def __init__(self, text: str) -> None:
        self.text = text

    def __repr__(self) -> str:
        return self.text


@rich.repr.auto(angular=True)
class Worker(Generic[ResultType]):
    """A class to manage concurrent work (either a task or a thread)."""

    @rich.repr.auto
    class StateChanged(Message, bubble=False, namespace="worker"):
        """The worker state changed."""

        def __init__(self, worker: Worker, state: WorkerState) -> None:
            """Initialize the StateChanged message.

            Args:
                worker: The worker object.
                state: New state.
            """
            self.worker = worker
            self.state = state
            super().__init__()

        def __rich_repr__(self) -> rich.repr.Result:
            yield self.worker
            yield self.state

    def __init__(
        self,
        node: DOMNode,
        work: WorkType,
        *,
        name: str = "",
        group: str = "default",
        description: str = "",
        exit_on_error: bool = True,
        thread: bool = False,
    ) -> None:
        """Initialize a Worker.

        Args:
            node: The widget, screen, or App that initiated the work.
            work: A callable, coroutine, or other awaitable object to run in the worker.
            name: Name of the worker (short string to help identify when debugging).
            group: The worker group.
            description: Description of the worker (longer string with more details).
            exit_on_error: Exit the app if the worker raises an error. Set to `False` to suppress exceptions.
            thread: Mark the worker as a thread worker.
        """
        self._node = node
        self._work = work
        self.name = name
        self.group = group
        self.description = description
        self.exit_on_error = exit_on_error
        self.cancelled_event: Event = Event()
        """A threading event set when the worker is cancelled."""
        self._thread_worker = thread
        self._state = WorkerState.PENDING
        self.state = self._state
        self._error: BaseException | None = None
        self._completed_steps: int = 0
        self._total_steps: int | None = None
        self._cancelled: bool = False
        self._created_time = monotonic()
        self._result: ResultType | None = None
        self._task: asyncio.Task | None = None
        self._node.post_message(self.StateChanged(self, self._state))

    def __rich_repr__(self) -> rich.repr.Result:
        yield _ReprText(self.state.name)
        yield "name", self.name, ""
        yield "group", self.group, "default"
        yield "description", self.description, ""
        yield "progress", round(self.progress, 1), 0.0

    @property
    def node(self) -> DOMNode:
        """The node where this worker was run from."""
        return self._node

    @property
    def state(self) -> WorkerState:
        """The current state of the worker."""
        return self._state

    @state.setter
    def state(self, state: WorkerState) -> None:
        """Set the state, and send a message."""
        changed = state != self._state
        self._state = state
        if changed:
            self._node.post_message(self.StateChanged(self, state))

    @property
    def is_cancelled(self) -> bool:
        """Has the work been cancelled?

        Note that cancelled work may still be running.
        """
        return self._cancelled

    @property
    def is_running(self) -> bool:
        """Is the task running?"""
        return self.state == WorkerState.RUNNING

    @property
    def is_finished(self) -> bool:
        """Has the task finished (cancelled, error, or success)?"""
        return self.state in (
            WorkerState.CANCELLED,
            WorkerState.ERROR,
            WorkerState.SUCCESS,
        )

    @property
    def completed_steps(self) -> int:
        """The number of completed steps."""
        return self._completed_steps

    @property
    def total_steps(self) -> int | None:
        """The number of total steps, or None if indeterminate."""
        return self._total_steps

    @property
    def progress(self) -> float:
        """Progress as a percentage.

        If the total steps is None, then this will return 0. The percentage will be clamped between 0 and 100.
        """
        if not self._total_steps:
            return 0.0
        return max(0, min(100, (self._completed_steps / self._total_steps) * 100.0))

    @property
    def result(self) -> ResultType | None:
        """The result of the worker, or `None` if there is no result."""
        return self._result

    @property
    def error(self) -> BaseException | None:
        """The exception raised by the worker, or `None` if there was no error."""
        return self._error

    def update(
        self, completed_steps: int | None = None, total_steps: int | None = -1
    ) -> None:
        """Update the number of completed steps.

        Args:
            completed_steps: The number of completed seps, or `None` to not change.
            total_steps: The total number of steps, `None` for indeterminate, or -1 to leave unchanged.
        """
        if completed_steps is not None:
            self._completed_steps += completed_steps
        if total_steps != -1:
            self._total_steps = None if total_steps is None else max(0, total_steps)

    def advance(self, steps: int = 1) -> None:
        """Advance the number of completed steps.

        Args:
            steps: Number of steps to advance.
        """
        self._completed_steps += steps

    async def _run_threaded(self) -> ResultType:
        """Run a threaded worker.

        Returns:
            Return value of the work.
        """

        def run_awaitable(work: Awaitable[ResultType]) -> ResultType:
            """Set the active worker and await the awaitable."""

            async def do_work() -> ResultType:
                active_worker.set(self)
                return await work

            return asyncio.run(do_work())

        def run_coroutine(
            work: Callable[[], Coroutine[None, None, ResultType]],
        ) -> ResultType:
            """Set the active worker and await coroutine."""
            return run_awaitable(work())

        def run_callable(work: Callable[[], ResultType]) -> ResultType:
            """Set the active worker, and call the callable."""
            active_worker.set(self)
            return work()

        if (
            inspect.iscoroutinefunction(self._work)
            or hasattr(self._work, "func")
            and inspect.iscoroutinefunction(self._work.func)
        ):
            runner = run_coroutine
        elif inspect.isawaitable(self._work):
            runner = run_awaitable
        elif callable(self._work):
            runner = run_callable
        else:
            raise WorkerError("Unsupported attempt to run a thread worker")

        loop = asyncio.get_running_loop()
        assert loop is not None
        return await loop.run_in_executor(None, runner, self._work)

    async def _run_async(self) -> ResultType:
        """Run an async worker.

        Returns:
            Return value of the work.
        """
        if (
            inspect.iscoroutinefunction(self._work)
            or hasattr(self._work, "func")
            and inspect.iscoroutinefunction(self._work.func)
        ):
            return await self._work()
        elif inspect.isawaitable(self._work):
            return await self._work
        elif callable(self._work):
            raise WorkerError("Request to run a non-async function as an async worker")
        raise WorkerError("Unsupported attempt to run an async worker")

    async def run(self) -> ResultType:
        """Run the work.

        Implement this method in a subclass, or pass a callable to the constructor.

        Returns:
            Return value of the work.
        """
        return await (
            self._run_threaded() if self._thread_worker else self._run_async()
        )

    async def _run(self, app: App) -> None:
        """Run the worker.

        Args:
            app: App instance.
        """
        with app._context():
            active_worker.set(self)

            self.state = WorkerState.RUNNING
            app.log.worker(self)
            try:
                self._result = await self.run()
            except asyncio.CancelledError as error:
                self.state = WorkerState.CANCELLED
                self._error = error
                app.log.worker(self)
            except Exception as error:
                self.state = WorkerState.ERROR
                self._error = error
                app.log.worker(self, "failed", repr(error))
                from rich.traceback import Traceback

                app.log.worker(Traceback())
                if self.exit_on_error:
                    worker_failed = WorkerFailed(self._error)
                    app._handle_exception(worker_failed)
            else:
                self.state = WorkerState.SUCCESS
                app.log.worker(self)

    def _start(
        self, app: App, done_callback: Callable[[Worker], None] | None = None
    ) -> None:
        """Start the worker.

        Args:
            app: An app instance.
            done_callback: A callback to call when the task is done.
        """
        if self._task is not None:
            return
        self.state = WorkerState.RUNNING
        self._task = asyncio.create_task(self._run(app))

        def task_done_callback(_task: asyncio.Task) -> None:
            """Run the callback.

            Called by `Task.add_done_callback`.

            Args:
                The worker's task.
            """
            if done_callback is not None:
                done_callback(self)

        self._task.add_done_callback(task_done_callback)

    def cancel(self) -> None:
        """Cancel the task."""
        self._cancelled = True
        if self._task is not None:
            self._task.cancel()
        self.cancelled_event.set()

    async def wait(self) -> ResultType:
        """Wait for the work to complete.

        Raises:
            WorkerFailed: If the Worker raised an exception.
            WorkerCancelled: If the Worker was cancelled before it completed.

        Returns:
            The return value of the work.
        """
        try:
            if active_worker.get() is self:
                raise DeadlockError(
                    "Can't call worker.wait from within the worker function!"
                )
        except LookupError:
            # Not in a worker
            pass

        if self.state == WorkerState.PENDING:
            raise WorkerError("Worker must be started before calling this method.")
        if self._task is not None:
            try:
                await self._task
            except asyncio.CancelledError as error:
                self.state = WorkerState.CANCELLED
                self._error = error
        if self.state == WorkerState.ERROR:
            assert self._error is not None
            raise WorkerFailed(self._error)
        elif self.state == WorkerState.CANCELLED:
            raise WorkerCancelled("Worker was cancelled, and did not complete.")
        return cast("ResultType", self._result)
