import instructor
from pydantic import BaseModel, Field
from typing import Optional, Type, Generator, AsyncGenerator, get_args
from atomic_agents.context.chat_history import ChatHistory
from atomic_agents.context.system_prompt_generator import (
    BaseDynamicContextProvider,
    SystemPromptGenerator,
)
from atomic_agents.base.base_io_schema import BaseIOSchema

from instructor.dsl.partial import PartialBase
from jiter import from_json


def model_from_chunks_patched(cls, json_chunks, **kwargs):
    potential_object = ""
    partial_model = cls.get_partial_model()
    for chunk in json_chunks:
        potential_object += chunk
        obj = from_json((potential_object or "{}").encode(), partial_mode="trailing-strings")
        obj = partial_model.model_validate(obj, strict=None, **kwargs)
        yield obj


async def model_from_chunks_async_patched(cls, json_chunks, **kwargs):
    potential_object = ""
    partial_model = cls.get_partial_model()
    async for chunk in json_chunks:
        potential_object += chunk
        obj = from_json((potential_object or "{}").encode(), partial_mode="trailing-strings")
        obj = partial_model.model_validate(obj, strict=None, **kwargs)
        yield obj


PartialBase.model_from_chunks = classmethod(model_from_chunks_patched)
PartialBase.model_from_chunks_async = classmethod(model_from_chunks_async_patched)


class BasicChatInputSchema(BaseIOSchema):
    """This schema represents the input from the user to the AI agent."""

    chat_message: str = Field(
        ...,
        description="The chat message sent by the user to the assistant.",
    )


class BasicChatOutputSchema(BaseIOSchema):
    """This schema represents the response generated by the chat agent."""

    chat_message: str = Field(
        ...,
        description=(
            "The chat message exchanged between the user and the chat agent. "
            "This contains the markdown-enabled response generated by the chat agent."
        ),
    )


class AgentConfig(BaseModel):
    client: instructor.client.Instructor = Field(..., description="Client for interacting with the language model.")
    model: str = Field(default="gpt-4o-mini", description="The model to use for generating responses.")
    history: Optional[ChatHistory] = Field(default=None, description="History component for storing chat history.")
    system_prompt_generator: Optional[SystemPromptGenerator] = Field(
        default=None, description="Component for generating system prompts."
    )
    system_role: Optional[str] = Field(
        default="system", description="The role of the system in the conversation. None means no system prompt."
    )
    model_config = {"arbitrary_types_allowed": True}
    model_api_parameters: Optional[dict] = Field(None, description="Additional parameters passed to the API provider.")


class AtomicAgent[InputSchema: BaseIOSchema, OutputSchema: BaseIOSchema]:
    """
    Base class for chat agents.

    This class provides the core functionality for handling chat interactions, including managing history,
    generating system prompts, and obtaining responses from a language model.

    Type Parameters:
        InputSchema: Schema for the user input, must be a subclass of BaseIOSchema.
        OutputSchema: Schema for the agent's output, must be a subclass of BaseIOSchema.

    Attributes:
        client: Client for interacting with the language model.
        model (str): The model to use for generating responses.
        history (ChatHistory): History component for storing chat history.
        system_prompt_generator (SystemPromptGenerator): Component for generating system prompts.
        system_role (Optional[str]): The role of the system in the conversation. None means no system prompt.
        initial_history (ChatHistory): Initial state of the history.
        current_user_input (Optional[InputSchema]): The current user input being processed.
        model_api_parameters (dict): Additional parameters passed to the API provider.
            - Use this for parameters like 'temperature', 'max_tokens', etc.
    """

    def __init__(self, config: AgentConfig):
        """
        Initializes the AtomicAgent.

        Args:
            config (AgentConfig): Configuration for the chat agent.
        """
        self.client = config.client
        self.model = config.model
        self.history = config.history or ChatHistory()
        self.system_prompt_generator = config.system_prompt_generator or SystemPromptGenerator()
        self.system_role = config.system_role
        self.initial_history = self.history.copy()
        self.current_user_input = None
        self.model_api_parameters = config.model_api_parameters or {}

    def reset_history(self):
        """
        Resets the history to its initial state.
        """
        self.history = self.initial_history.copy()

    @property
    def input_schema(self) -> Type[BaseIOSchema]:
        if hasattr(self, "__orig_class__"):
            TI, _ = get_args(self.__orig_class__)
        else:
            TI = BasicChatInputSchema

        return TI

    @property
    def output_schema(self) -> Type[BaseIOSchema]:
        if hasattr(self, "__orig_class__"):
            _, TO = get_args(self.__orig_class__)
        else:
            TO = BasicChatOutputSchema

        return TO

    def _prepare_messages(self):
        if self.system_role is None:
            self.messages = []
        else:
            self.messages = [
                {
                    "role": self.system_role,
                    "content": self.system_prompt_generator.generate_prompt(),
                }
            ]

        self.messages += self.history.get_history()

    def run(self, user_input: Optional[InputSchema] = None) -> OutputSchema:
        """
        Runs the chat agent with the given user input synchronously.

        Args:
            user_input (Optional[InputSchema]): The input from the user. If not provided, skips adding to history.

        Returns:
            OutputSchema: The response from the chat agent.
        """
        assert not isinstance(
            self.client, instructor.client.AsyncInstructor
        ), "The run method is not supported for async clients. Use run_async instead."
        if user_input:
            self.history.initialize_turn()
            self.current_user_input = user_input
            self.history.add_message("user", user_input)

        self._prepare_messages()
        response = self.client.chat.completions.create(
            messages=self.messages,
            model=self.model,
            response_model=self.output_schema,
            **self.model_api_parameters,
        )
        self.history.add_message("assistant", response)

        return response

    def run_stream(self, user_input: Optional[InputSchema] = None) -> Generator[OutputSchema, None, OutputSchema]:
        """
        Runs the chat agent with the given user input, supporting streaming output.

        Args:
            user_input (Optional[InputSchema]): The input from the user. If not provided, skips adding to history.

        Yields:
            OutputSchema: Partial responses from the chat agent.

        Returns:
            OutputSchema: The final response from the chat agent.
        """
        assert not isinstance(
            self.client, instructor.client.AsyncInstructor
        ), "The run_stream method is not supported for async clients. Use run_async instead."
        if user_input:
            self.history.initialize_turn()
            self.current_user_input = user_input
            self.history.add_message("user", user_input)

        self._prepare_messages()

        response_stream = self.client.chat.completions.create_partial(
            model=self.model,
            messages=self.messages,
            response_model=self.output_schema,
            **self.model_api_parameters,
            stream=True,
        )

        for partial_response in response_stream:
            yield partial_response

        full_response_content = self.output_schema(**partial_response.model_dump())
        self.history.add_message("assistant", full_response_content)

        return full_response_content

    async def run_async(self, user_input: Optional[InputSchema] = None) -> OutputSchema:
        """
        Runs the chat agent asynchronously with the given user input.

        Args:
            user_input (Optional[InputSchema]): The input from the user. If not provided, skips adding to history.

        Returns:
            OutputSchema: The response from the chat agent.

        Raises:
            NotAsyncIterableError: If used as an async generator (in an async for loop).
                                   Use run_async_stream() method instead for streaming responses.
        """
        assert isinstance(self.client, instructor.client.AsyncInstructor), "The run_async method is for async clients."
        if user_input:
            self.history.initialize_turn()
            self.current_user_input = user_input
            self.history.add_message("user", user_input)

        self._prepare_messages()

        response = await self.client.chat.completions.create(
            model=self.model, messages=self.messages, response_model=self.output_schema, **self.model_api_parameters
        )

        self.history.add_message("assistant", response)
        return response

    async def run_async_stream(self, user_input: Optional[InputSchema] = None) -> AsyncGenerator[OutputSchema, None]:
        """
        Runs the chat agent asynchronously with the given user input, supporting streaming output.

        Args:
            user_input (Optional[InputSchema]): The input from the user. If not provided, skips adding to history.

        Yields:
            OutputSchema: Partial responses from the chat agent.
        """
        assert isinstance(self.client, instructor.client.AsyncInstructor), "The run_async method is for async clients."
        if user_input:
            self.history.initialize_turn()
            self.current_user_input = user_input
            self.history.add_message("user", user_input)

        self._prepare_messages()

        response_stream = self.client.chat.completions.create_partial(
            model=self.model,
            messages=self.messages,
            response_model=self.output_schema,
            **self.model_api_parameters,
            stream=True,
        )

        last_response = None
        async for partial_response in response_stream:
            last_response = partial_response
            yield partial_response

        if last_response:
            full_response_content = self.output_schema(**last_response.model_dump())
            self.history.add_message("assistant", full_response_content)

    def get_context_provider(self, provider_name: str) -> Type[BaseDynamicContextProvider]:
        """
        Retrieves a context provider by name.

        Args:
            provider_name (str): The name of the context provider.

        Returns:
            BaseDynamicContextProvider: The context provider if found.

        Raises:
            KeyError: If the context provider is not found.
        """
        if provider_name not in self.system_prompt_generator.context_providers:
            raise KeyError(f"Context provider '{provider_name}' not found.")
        return self.system_prompt_generator.context_providers[provider_name]

    def register_context_provider(self, provider_name: str, provider: BaseDynamicContextProvider):
        """
        Registers a new context provider.

        Args:
            provider_name (str): The name of the context provider.
            provider (BaseDynamicContextProvider): The context provider instance.
        """
        self.system_prompt_generator.context_providers[provider_name] = provider

    def unregister_context_provider(self, provider_name: str):
        """
        Unregisters an existing context provider.

        Args:
            provider_name (str): The name of the context provider to remove.
        """
        if provider_name in self.system_prompt_generator.context_providers:
            del self.system_prompt_generator.context_providers[provider_name]
        else:
            raise KeyError(f"Context provider '{provider_name}' not found.")


if __name__ == "__main__":
    from rich.console import Console
    from rich.panel import Panel
    from rich.table import Table
    from rich.syntax import Syntax
    from rich import box
    from openai import OpenAI, AsyncOpenAI
    import instructor
    import asyncio
    from rich.live import Live
    import json

    def _create_schema_table(title: str, schema: Type[BaseModel]) -> Table:
        """Create a table displaying schema information.

        Args:
            title (str): Title of the table
            schema (Type[BaseModel]): Schema to display

        Returns:
            Table: Rich table containing schema information
        """
        schema_table = Table(title=title, box=box.ROUNDED)
        schema_table.add_column("Field", style="cyan")
        schema_table.add_column("Type", style="magenta")
        schema_table.add_column("Description", style="green")

        for field_name, field in schema.model_fields.items():
            schema_table.add_row(field_name, str(field.annotation), field.description or "")

        return schema_table

    def _create_config_table(agent: AtomicAgent) -> Table:
        """Create a table displaying agent configuration.

        Args:
            agent (AtomicAgent): Agent instance

        Returns:
            Table: Rich table containing configuration information
        """
        info_table = Table(title="Agent Configuration", box=box.ROUNDED)
        info_table.add_column("Property", style="cyan")
        info_table.add_column("Value", style="yellow")

        info_table.add_row("Model", agent.model)
        info_table.add_row("History", str(type(agent.history).__name__))
        info_table.add_row("System Prompt Generator", str(type(agent.system_prompt_generator).__name__))

        return info_table

    def display_agent_info(agent: AtomicAgent):
        """Display information about the agent's configuration and schemas."""
        console = Console()
        console.print(
            Panel.fit(
                "[bold blue]Agent Information[/bold blue]",
                border_style="blue",
                padding=(1, 1),
            )
        )

        # Display input schema
        input_schema_table = _create_schema_table("Input Schema", agent.input_schema)
        console.print(input_schema_table)

        # Display output schema
        output_schema_table = _create_schema_table("Output Schema", agent.output_schema)
        console.print(output_schema_table)

        # Display configuration
        info_table = _create_config_table(agent)
        console.print(info_table)

        # Display system prompt
        system_prompt = agent.system_prompt_generator.generate_prompt()
        console.print(
            Panel(
                Syntax(system_prompt, "markdown", theme="monokai", line_numbers=True),
                title="Sample System Prompt",
                border_style="green",
                expand=False,
            )
        )

    async def chat_loop(streaming: bool = False):
        """Interactive chat loop with the AI agent.

        Args:
            streaming (bool): Whether to use streaming mode for responses
        """
        if streaming:
            client = instructor.from_openai(AsyncOpenAI())
            config = AgentConfig(client=client, model="gpt-4o-mini")
            agent = AtomicAgent[BasicChatInputSchema, BasicChatOutputSchema](config)
        else:
            client = instructor.from_openai(OpenAI())
            config = AgentConfig(client=client, model="gpt-4o-mini")
            agent = AtomicAgent[BasicChatInputSchema, BasicChatOutputSchema](config)

        # Display agent information before starting the chat
        display_agent_info(agent)

        console = Console()
        console.print(
            Panel.fit(
                "[bold blue]Interactive Chat Mode[/bold blue]\n"
                f"[cyan]Streaming: {streaming}[/cyan]\n"
                "Type 'exit' to quit",
                border_style="blue",
                padding=(1, 1),
            )
        )

        while True:
            user_message = console.input("\n[bold green]You:[/bold green] ")

            if user_message.lower() == "exit":
                console.print("[yellow]Goodbye![/yellow]")
                break

            user_input = agent.input_schema(chat_message=user_message)

            console.print("[bold blue]Assistant:[/bold blue]")
            if streaming:
                with Live(console=console, refresh_per_second=4) as live:
                    # Use run_async_stream instead of run_async for streaming responses
                    async for partial_response in agent.run_async_stream(user_input):
                        response_json = partial_response.model_dump()
                        json_str = json.dumps(response_json, indent=2)
                        live.update(json_str)
            else:
                response = agent.run(user_input)
                response_json = response.model_dump()
                json_str = json.dumps(response_json, indent=2)
                console.print(json_str)

    console = Console()
    console.print("\n[bold]Starting chat loop...[/bold]")
    asyncio.run(chat_loop(streaming=True))
