from dataclasses import dataclass
from functools import partial
from typing import Any, Literal

from pydantic import BaseModel, ValidationError
from starlette.requests import Request
from starlette.responses import Response

from mcp.server.auth.errors import (
    stringify_pydantic_error,
)
from mcp.server.auth.json_response import PydanticJSONResponse
from mcp.server.auth.middleware.client_auth import AuthenticationError, ClientAuthenticator
from mcp.server.auth.provider import AccessToken, OAuthAuthorizationServerProvider, RefreshToken


class RevocationRequest(BaseModel):
    """
    # See https://datatracker.ietf.org/doc/html/rfc7009#section-2.1
    """

    token: str
    token_type_hint: Literal["access_token", "refresh_token"] | None = None
    client_id: str
    client_secret: str | None


class RevocationErrorResponse(BaseModel):
    error: Literal["invalid_request", "unauthorized_client"]
    error_description: str | None = None


@dataclass
class RevocationHandler:
    provider: OAuthAuthorizationServerProvider[Any, Any, Any]
    client_authenticator: ClientAuthenticator

    async def handle(self, request: Request) -> Response:
        """
        Handler for the OAuth 2.0 Token Revocation endpoint.
        """
        try:
            form_data = await request.form()
            revocation_request = RevocationRequest.model_validate(dict(form_data))
        except ValidationError as e:
            return PydanticJSONResponse(
                status_code=400,
                content=RevocationErrorResponse(
                    error="invalid_request",
                    error_description=stringify_pydantic_error(e),
                ),
            )

        # Authenticate client
        try:
            client = await self.client_authenticator.authenticate(
                revocation_request.client_id, revocation_request.client_secret
            )
        except AuthenticationError as e:
            return PydanticJSONResponse(
                status_code=401,
                content=RevocationErrorResponse(
                    error="unauthorized_client",
                    error_description=e.message,
                ),
            )

        loaders = [
            self.provider.load_access_token,
            partial(self.provider.load_refresh_token, client),
        ]
        if revocation_request.token_type_hint == "refresh_token":
            loaders = reversed(loaders)

        token: None | AccessToken | RefreshToken = None
        for loader in loaders:
            token = await loader(revocation_request.token)
            if token is not None:
                break

        # if token is not found, just return HTTP 200 per the RFC
        if token and token.client_id == client.client_id:
            # Revoke token; provider is not meant to be able to do validation
            # at this point that would result in an error
            await self.provider.revoke_token(token)

        # Return successful empty response
        return Response(
            status_code=200,
            headers={
                "Cache-Control": "no-store",
                "Pragma": "no-cache",
            },
        )
