"""Elicitation utilities for MCP servers."""

from __future__ import annotations

import types
from typing import Generic, Literal, TypeVar, Union, get_args, get_origin

from pydantic import BaseModel
from pydantic.fields import FieldInfo

from mcp.server.session import ServerSession
from mcp.types import RequestId

ElicitSchemaModelT = TypeVar("ElicitSchemaModelT", bound=BaseModel)


class AcceptedElicitation(BaseModel, Generic[ElicitSchemaModelT]):
    """Result when user accepts the elicitation."""

    action: Literal["accept"] = "accept"
    data: ElicitSchemaModelT


class DeclinedElicitation(BaseModel):
    """Result when user declines the elicitation."""

    action: Literal["decline"] = "decline"


class CancelledElicitation(BaseModel):
    """Result when user cancels the elicitation."""

    action: Literal["cancel"] = "cancel"


ElicitationResult = AcceptedElicitation[ElicitSchemaModelT] | DeclinedElicitation | CancelledElicitation


# Primitive types allowed in elicitation schemas
_ELICITATION_PRIMITIVE_TYPES = (str, int, float, bool)


def _validate_elicitation_schema(schema: type[BaseModel]) -> None:
    """Validate that a Pydantic model only contains primitive field types."""
    for field_name, field_info in schema.model_fields.items():
        if not _is_primitive_field(field_info):
            raise TypeError(
                f"Elicitation schema field '{field_name}' must be a primitive type "
                f"{_ELICITATION_PRIMITIVE_TYPES} or Optional of these types. "
                f"Complex types like lists, dicts, or nested models are not allowed."
            )


def _is_primitive_field(field_info: FieldInfo) -> bool:
    """Check if a field is a primitive type allowed in elicitation schemas."""
    annotation = field_info.annotation

    # Handle None type
    if annotation is types.NoneType:
        return True

    # Handle basic primitive types
    if annotation in _ELICITATION_PRIMITIVE_TYPES:
        return True

    # Handle Union types
    origin = get_origin(annotation)
    if origin is Union or origin is types.UnionType:
        args = get_args(annotation)
        # All args must be primitive types or None
        return all(arg is types.NoneType or arg in _ELICITATION_PRIMITIVE_TYPES for arg in args)

    return False


async def elicit_with_validation(
    session: ServerSession,
    message: str,
    schema: type[ElicitSchemaModelT],
    related_request_id: RequestId | None = None,
) -> ElicitationResult[ElicitSchemaModelT]:
    """Elicit information from the client/user with schema validation.

    This method can be used to interactively ask for additional information from the
    client within a tool's execution. The client might display the message to the
    user and collect a response according to the provided schema. Or in case a
    client is an agent, it might decide how to handle the elicitation -- either by asking
    the user or automatically generating a response.
    """
    # Validate that schema only contains primitive types and fail loudly if not
    _validate_elicitation_schema(schema)

    json_schema = schema.model_json_schema()

    result = await session.elicit(
        message=message,
        requestedSchema=json_schema,
        related_request_id=related_request_id,
    )

    if result.action == "accept" and result.content:
        # Validate and parse the content using the schema
        validated_data = schema.model_validate(result.content)
        return AcceptedElicitation(data=validated_data)
    elif result.action == "decline":
        return DeclinedElicitation()
    elif result.action == "cancel":
        return CancelledElicitation()
    else:
        # This should never happen, but handle it just in case
        raise ValueError(f"Unexpected elicitation action: {result.action}")
