from collections.abc import Generator
from contextlib import contextmanager
from dataclasses import dataclass, field
from typing import Generic

from pydantic import BaseModel

from mcp.shared.context import LifespanContextT, RequestContext
from mcp.shared.session import (
    BaseSession,
    ReceiveNotificationT,
    ReceiveRequestT,
    SendNotificationT,
    SendRequestT,
    SendResultT,
)
from mcp.types import ProgressToken


class Progress(BaseModel):
    progress: float
    total: float | None


@dataclass
class ProgressContext(Generic[SendRequestT, SendNotificationT, SendResultT, ReceiveRequestT, ReceiveNotificationT]):
    session: BaseSession[SendRequestT, SendNotificationT, SendResultT, ReceiveRequestT, ReceiveNotificationT]
    progress_token: ProgressToken
    total: float | None
    current: float = field(default=0.0, init=False)

    async def progress(self, amount: float, message: str | None = None) -> None:
        self.current += amount

        await self.session.send_progress_notification(
            self.progress_token, self.current, total=self.total, message=message
        )


@contextmanager
def progress(
    ctx: RequestContext[
        BaseSession[SendRequestT, SendNotificationT, SendResultT, ReceiveRequestT, ReceiveNotificationT],
        LifespanContextT,
    ],
    total: float | None = None,
) -> Generator[
    ProgressContext[SendRequestT, SendNotificationT, SendResultT, ReceiveRequestT, ReceiveNotificationT],
    None,
]:
    if ctx.meta is None or ctx.meta.progressToken is None:
        raise ValueError("No progress token provided")

    progress_ctx = ProgressContext(ctx.session, ctx.meta.progressToken, total)
    try:
        yield progress_ctx
    finally:
        pass
