import inspect
import json
from collections.abc import Awaitable, Callable, Sequence
from itertools import chain
from types import GenericAlias
from typing import Annotated, Any, ForwardRef, cast, get_args, get_origin, get_type_hints

import pydantic_core
from pydantic import (
    BaseModel,
    ConfigDict,
    Field,
    RootModel,
    WithJsonSchema,
    create_model,
)
from pydantic._internal._typing_extra import eval_type_backport
from pydantic.fields import FieldInfo
from pydantic.json_schema import GenerateJsonSchema, JsonSchemaWarningKind
from pydantic_core import PydanticUndefined

from mcp.server.fastmcp.exceptions import InvalidSignature
from mcp.server.fastmcp.utilities.logging import get_logger
from mcp.server.fastmcp.utilities.types import Image
from mcp.types import ContentBlock, TextContent

logger = get_logger(__name__)


class StrictJsonSchema(GenerateJsonSchema):
    """A JSON schema generator that raises exceptions instead of emitting warnings.

    This is used to detect non-serializable types during schema generation.
    """

    def emit_warning(self, kind: JsonSchemaWarningKind, detail: str) -> None:
        # Raise an exception instead of emitting a warning
        raise ValueError(f"JSON schema warning: {kind} - {detail}")


class ArgModelBase(BaseModel):
    """A model representing the arguments to a function."""

    def model_dump_one_level(self) -> dict[str, Any]:
        """Return a dict of the model's fields, one level deep.

        That is, sub-models etc are not dumped - they are kept as pydantic models.
        """
        kwargs: dict[str, Any] = {}
        for field_name, field_info in self.__class__.model_fields.items():
            value = getattr(self, field_name)
            # Use the alias if it exists, otherwise use the field name
            output_name = field_info.alias if field_info.alias else field_name
            kwargs[output_name] = value
        return kwargs

    model_config = ConfigDict(
        arbitrary_types_allowed=True,
    )


class FuncMetadata(BaseModel):
    arg_model: Annotated[type[ArgModelBase], WithJsonSchema(None)]
    output_schema: dict[str, Any] | None = None
    output_model: Annotated[type[BaseModel], WithJsonSchema(None)] | None = None
    wrap_output: bool = False

    async def call_fn_with_arg_validation(
        self,
        fn: Callable[..., Any | Awaitable[Any]],
        fn_is_async: bool,
        arguments_to_validate: dict[str, Any],
        arguments_to_pass_directly: dict[str, Any] | None,
    ) -> Any:
        """Call the given function with arguments validated and injected.

        Arguments are first attempted to be parsed from JSON, then validated against
        the argument model, before being passed to the function.
        """
        arguments_pre_parsed = self.pre_parse_json(arguments_to_validate)
        arguments_parsed_model = self.arg_model.model_validate(arguments_pre_parsed)
        arguments_parsed_dict = arguments_parsed_model.model_dump_one_level()

        arguments_parsed_dict |= arguments_to_pass_directly or {}

        if fn_is_async:
            return await fn(**arguments_parsed_dict)
        else:
            return fn(**arguments_parsed_dict)

    def convert_result(self, result: Any) -> Any:
        """
        Convert the result of a function call to the appropriate format for
         the lowlevel server tool call handler:

        - If output_model is None, return the unstructured content directly.
        - If output_model is not None, convert the result to structured output format
            (dict[str, Any]) and return both unstructured and structured content.

        Note: we return unstructured content here **even though the lowlevel server
        tool call handler provides generic backwards compatibility serialization of
        structured content**. This is for FastMCP backwards compatibility: we need to
        retain FastMCP's ad hoc conversion logic for constructing unstructured output
        from function return values, whereas the lowlevel server simply serializes
        the structured output.
        """
        unstructured_content = _convert_to_content(result)

        if self.output_schema is None:
            return unstructured_content
        else:
            if self.wrap_output:
                result = {"result": result}

            assert self.output_model is not None, "Output model must be set if output schema is defined"
            validated = self.output_model.model_validate(result)
            structured_content = validated.model_dump(mode="json", by_alias=True)

            return (unstructured_content, structured_content)

    def pre_parse_json(self, data: dict[str, Any]) -> dict[str, Any]:
        """Pre-parse data from JSON.

        Return a dict with same keys as input but with values parsed from JSON
        if appropriate.

        This is to handle cases like `["a", "b", "c"]` being passed in as JSON inside
        a string rather than an actual list. Claude desktop is prone to this - in fact
        it seems incapable of NOT doing this. For sub-models, it tends to pass
        dicts (JSON objects) as JSON strings, which can be pre-parsed here.
        """
        new_data = data.copy()  # Shallow copy

        # Build a mapping from input keys (including aliases) to field info
        key_to_field_info: dict[str, FieldInfo] = {}
        for field_name, field_info in self.arg_model.model_fields.items():
            # Map both the field name and its alias (if any) to the field info
            key_to_field_info[field_name] = field_info
            if field_info.alias:
                key_to_field_info[field_info.alias] = field_info

        for data_key in data.keys():
            if data_key not in key_to_field_info:
                continue

            field_info = key_to_field_info[data_key]
            if isinstance(data[data_key], str) and field_info.annotation is not str:
                try:
                    pre_parsed = json.loads(data[data_key])
                except json.JSONDecodeError:
                    continue  # Not JSON - skip
                if isinstance(pre_parsed, str | int | float):
                    # This is likely that the raw value is e.g. `"hello"` which we
                    # Should really be parsed as '"hello"' in Python - but if we parse
                    # it as JSON it'll turn into just 'hello'. So we skip it.
                    continue
                new_data[data_key] = pre_parsed
        assert new_data.keys() == data.keys()
        return new_data

    model_config = ConfigDict(
        arbitrary_types_allowed=True,
    )


def func_metadata(
    func: Callable[..., Any],
    skip_names: Sequence[str] = (),
    structured_output: bool | None = None,
) -> FuncMetadata:
    """Given a function, return metadata including a pydantic model representing its
    signature.

    The use case for this is
    ```
    meta = func_metadata(func)
    validated_args = meta.arg_model.model_validate(some_raw_data_dict)
    return func(**validated_args.model_dump_one_level())
    ```

    **critically** it also provides pre-parse helper to attempt to parse things from
    JSON.

    Args:
        func: The function to convert to a pydantic model
        skip_names: A list of parameter names to skip. These will not be included in
            the model.
        structured_output: Controls whether the tool's output is structured or unstructured
            - If None, auto-detects based on the function's return type annotation
            - If True, unconditionally creates a structured tool (return type annotation permitting)
            - If False, unconditionally creates an unstructured tool

        If structured, creates a Pydantic model for the function's result based on its annotation.
        Supports various return types:
            - BaseModel subclasses (used directly)
            - Primitive types (str, int, float, bool, bytes, None) - wrapped in a
                model with a 'result' field
            - TypedDict - converted to a Pydantic model with same fields
            - Dataclasses and other annotated classes - converted to Pydantic models
            - Generic types (list, dict, Union, etc.) - wrapped in a model with a 'result' field

    Returns:
        A FuncMetadata object containing:
        - arg_model: A pydantic model representing the function's arguments
        - output_model: A pydantic model for the return type if output is structured
        - output_conversion: Records how function output should be converted before returning.
    """
    sig = _get_typed_signature(func)
    params = sig.parameters
    dynamic_pydantic_model_params: dict[str, Any] = {}
    globalns = getattr(func, "__globals__", {})
    for param in params.values():
        if param.name.startswith("_"):
            raise InvalidSignature(f"Parameter {param.name} of {func.__name__} cannot start with '_'")
        if param.name in skip_names:
            continue
        annotation = param.annotation

        # `x: None` / `x: None = None`
        if annotation is None:
            annotation = Annotated[
                None,
                Field(default=param.default if param.default is not inspect.Parameter.empty else PydanticUndefined),
            ]

        # Untyped field
        if annotation is inspect.Parameter.empty:
            annotation = Annotated[
                Any,
                Field(),
                # 🤷
                WithJsonSchema({"title": param.name, "type": "string"}),
            ]

        field_info = FieldInfo.from_annotated_attribute(
            _get_typed_annotation(annotation, globalns),
            param.default if param.default is not inspect.Parameter.empty else PydanticUndefined,
        )

        # Check if the parameter name conflicts with BaseModel attributes
        # This is necessary because Pydantic warns about shadowing parent attributes
        if hasattr(BaseModel, param.name) and callable(getattr(BaseModel, param.name)):
            # Use an alias to avoid the shadowing warning
            field_info.alias = param.name
            field_info.validation_alias = param.name
            field_info.serialization_alias = param.name
            # Use a prefixed internal name
            internal_name = f"field_{param.name}"
            dynamic_pydantic_model_params[internal_name] = (field_info.annotation, field_info)
        else:
            dynamic_pydantic_model_params[param.name] = (field_info.annotation, field_info)
        continue

    arguments_model = create_model(
        f"{func.__name__}Arguments",
        **dynamic_pydantic_model_params,
        __base__=ArgModelBase,
    )

    if structured_output is False:
        return FuncMetadata(arg_model=arguments_model)

    # set up structured output support based on return type annotation

    if sig.return_annotation is inspect.Parameter.empty and structured_output is True:
        raise InvalidSignature(f"Function {func.__name__}: return annotation required for structured output")

    output_info = FieldInfo.from_annotation(_get_typed_annotation(sig.return_annotation, globalns))
    annotation = output_info.annotation

    output_model, output_schema, wrap_output = _try_create_model_and_schema(annotation, func.__name__, output_info)

    if output_model is None and structured_output is True:
        # Model creation failed or produced warnings - no structured output
        raise InvalidSignature(
            f"Function {func.__name__}: return type {annotation} is not serializable for structured output"
        )

    return FuncMetadata(
        arg_model=arguments_model,
        output_schema=output_schema,
        output_model=output_model,
        wrap_output=wrap_output,
    )


def _try_create_model_and_schema(
    annotation: Any, func_name: str, field_info: FieldInfo
) -> tuple[type[BaseModel] | None, dict[str, Any] | None, bool]:
    """Try to create a model and schema for the given annotation without warnings.

    Returns:
        tuple of (model or None, schema or None, wrap_output)
        Model and schema are None if warnings occur or creation fails.
        wrap_output is True if the result needs to be wrapped in {"result": ...}
    """
    model = None
    wrap_output = False

    # First handle special case: None
    if annotation is None:
        model = _create_wrapped_model(func_name, annotation, field_info)
        wrap_output = True

    # Handle GenericAlias types (list[str], dict[str, int], Union[str, int], etc.)
    elif isinstance(annotation, GenericAlias):
        origin = get_origin(annotation)

        # Special case: dict with string keys can use RootModel
        if origin is dict:
            args = get_args(annotation)
            if len(args) == 2 and args[0] is str:
                model = _create_dict_model(func_name, annotation)
            else:
                # dict with non-str keys needs wrapping
                model = _create_wrapped_model(func_name, annotation, field_info)
                wrap_output = True
        else:
            # All other generic types need wrapping (list, tuple, Union, Optional, etc.)
            model = _create_wrapped_model(func_name, annotation, field_info)
            wrap_output = True

    # Handle regular type objects
    elif isinstance(annotation, type):
        type_annotation: type[Any] = cast(type[Any], annotation)

        # Case 1: BaseModel subclasses (can be used directly)
        if issubclass(annotation, BaseModel):
            model = annotation

        # Case 2: TypedDict (special dict subclass with __annotations__)
        elif hasattr(type_annotation, "__annotations__") and issubclass(annotation, dict):
            model = _create_model_from_typeddict(type_annotation)

        # Case 3: Primitive types that need wrapping
        elif annotation in (str, int, float, bool, bytes, type(None)):
            model = _create_wrapped_model(func_name, annotation, field_info)
            wrap_output = True

        # Case 4: Other class types (dataclasses, regular classes with annotations)
        else:
            type_hints = get_type_hints(type_annotation)
            if type_hints:
                # Classes with type hints can be converted to Pydantic models
                model = _create_model_from_class(type_annotation)
            # Classes without type hints are not serializable - model remains None

    # Handle any other types not covered above
    else:
        # This includes typing constructs that aren't GenericAlias in Python 3.10
        # (e.g., Union, Optional in some Python versions)
        model = _create_wrapped_model(func_name, annotation, field_info)
        wrap_output = True

    if model:
        # If we successfully created a model, try to get its schema
        # Use StrictJsonSchema to raise exceptions instead of warnings
        try:
            schema = model.model_json_schema(schema_generator=StrictJsonSchema)
        except (TypeError, ValueError, pydantic_core.SchemaError, pydantic_core.ValidationError) as e:
            # These are expected errors when a type can't be converted to a Pydantic schema
            # TypeError: When Pydantic can't handle the type
            # ValueError: When there are issues with the type definition (including our custom warnings)
            # SchemaError: When Pydantic can't build a schema
            # ValidationError: When validation fails
            logger.info(f"Cannot create schema for type {annotation} in {func_name}: {type(e).__name__}: {e}")
            return None, None, False

        return model, schema, wrap_output

    return None, None, False


def _create_model_from_class(cls: type[Any]) -> type[BaseModel]:
    """Create a Pydantic model from an ordinary class.

    The created model will:
    - Have the same name as the class
    - Have fields with the same names and types as the class's fields
    - Include all fields whose type does not include None in the set of required fields

    Precondition: cls must have type hints (i.e., get_type_hints(cls) is non-empty)
    """
    type_hints = get_type_hints(cls)

    model_fields: dict[str, Any] = {}
    for field_name, field_type in type_hints.items():
        if field_name.startswith("_"):
            continue

        default = getattr(cls, field_name, PydanticUndefined)
        field_info = FieldInfo.from_annotated_attribute(field_type, default)
        model_fields[field_name] = (field_info.annotation, field_info)

    # Create a base class with the config
    class BaseWithConfig(BaseModel):
        model_config = ConfigDict(from_attributes=True)

    return create_model(cls.__name__, **model_fields, __base__=BaseWithConfig)


def _create_model_from_typeddict(td_type: type[Any]) -> type[BaseModel]:
    """Create a Pydantic model from a TypedDict.

    The created model will have the same name and fields as the TypedDict.
    """
    type_hints = get_type_hints(td_type)
    required_keys = getattr(td_type, "__required_keys__", set(type_hints.keys()))

    model_fields: dict[str, Any] = {}
    for field_name, field_type in type_hints.items():
        field_info = FieldInfo.from_annotation(field_type)

        if field_name not in required_keys:
            # For optional TypedDict fields, set default=None
            # This makes them not required in the Pydantic model
            # The model should use exclude_unset=True when dumping to get TypedDict semantics
            field_info.default = None

        model_fields[field_name] = (field_info.annotation, field_info)

    return create_model(td_type.__name__, **model_fields, __base__=BaseModel)


def _create_wrapped_model(func_name: str, annotation: Any, field_info: FieldInfo) -> type[BaseModel]:
    """Create a model that wraps a type in a 'result' field.

    This is used for primitive types, generic types like list/dict, etc.
    """
    model_name = f"{func_name}Output"

    # Pydantic needs type(None) instead of None for the type annotation
    if annotation is None:
        annotation = type(None)

    return create_model(model_name, result=(annotation, field_info), __base__=BaseModel)


def _create_dict_model(func_name: str, dict_annotation: Any) -> type[BaseModel]:
    """Create a RootModel for dict[str, T] types."""

    class DictModel(RootModel[dict_annotation]):
        pass

    # Give it a meaningful name
    DictModel.__name__ = f"{func_name}DictOutput"
    DictModel.__qualname__ = f"{func_name}DictOutput"

    return DictModel


def _get_typed_annotation(annotation: Any, globalns: dict[str, Any]) -> Any:
    def try_eval_type(value: Any, globalns: dict[str, Any], localns: dict[str, Any]) -> tuple[Any, bool]:
        try:
            return eval_type_backport(value, globalns, localns), True
        except NameError:
            return value, False

    if isinstance(annotation, str):
        annotation = ForwardRef(annotation)
        annotation, status = try_eval_type(annotation, globalns, globalns)

        # This check and raise could perhaps be skipped, and we (FastMCP) just call
        # model_rebuild right before using it 🤷
        if status is False:
            raise InvalidSignature(f"Unable to evaluate type annotation {annotation}")

    return annotation


def _get_typed_signature(call: Callable[..., Any]) -> inspect.Signature:
    """Get function signature while evaluating forward references"""
    signature = inspect.signature(call)
    globalns = getattr(call, "__globals__", {})
    typed_params = [
        inspect.Parameter(
            name=param.name,
            kind=param.kind,
            default=param.default,
            annotation=_get_typed_annotation(param.annotation, globalns),
        )
        for param in signature.parameters.values()
    ]
    typed_return = _get_typed_annotation(signature.return_annotation, globalns)
    typed_signature = inspect.Signature(typed_params, return_annotation=typed_return)
    return typed_signature


def _convert_to_content(
    result: Any,
) -> Sequence[ContentBlock]:
    """
    Convert a result to a sequence of content objects.

    Note: This conversion logic comes from previous versions of FastMCP and is being
    retained for purposes of backwards compatibility. It produces different unstructured
    output than the lowlevel server tool call handler, which just serializes structured
    content verbatim.
    """
    if result is None:
        return []

    if isinstance(result, ContentBlock):
        return [result]

    if isinstance(result, Image):
        return [result.to_image_content()]

    if isinstance(result, list | tuple):
        return list(
            chain.from_iterable(
                _convert_to_content(item)
                for item in result  # type: ignore
            )
        )

    if not isinstance(result, str):
        result = pydantic_core.to_json(result, fallback=str, indent=2).decode()

    return [TextContent(type="text", text=result)]
