from typing import Dict, Any, Optional
import pandas as pd
from pydantic import Field
from sqlalchemy import create_engine, text
from sqlalchemy.engine import Engine
# from atomic_agents.agents.base_agent import BaseAgent, BaseAgentConfig, BaseIOSchema
# NEW (v2.0)
from atomic_agents import AtomicAgent, AgentConfig, BaseIOSchema, BaseTool, BaseToolConfig
from atomic_agents.context import ChatHistory, SystemPromptGenerator, BaseDynamicContextProvider


# Define the input schema for the SQL Query Agent.
class SQLQueryInputSchema(BaseIOSchema):
    """Schema for user input requiring SQL query generation and summarization."""
    user_input: str = Field(..., description="User's query in natural language.")
    additional_filters: Optional[Dict[str, Any]] = Field(
        None, description="Optional dictionary of filtering conditions, passed as context to the LLM."
    )

# Define the output schema for the SQL Query Agent.
class SQLQueryOutputSchema(BaseIOSchema):
    """
    Schema for summarizing the query results.

    Attributes:
      summary: A concise and structured summary of the records as per user requirements.
    """
    generated_query: str = Field(..., description="The generated SQL query.")
    dataframe_preview: str = Field(..., description="Preview of retrieved records in pandas DataFrame format.")
    summary: str = Field(..., description="A summarized version of the retrieved data.")

class SQLQueryAgent:
    def __init__(self, client, db_engine: Engine, model: str = "gpt-4o-mini", temperature: float = 0.2):
        """
        Initialize the SQLQueryAgent with schema constraints.

        Args:
            client: The LLM client (e.g., from instructor.from_openai).
            db_engine: SQLAlchemy database engine for connecting to PostgreSQL.
            model: The identifier for the model to use.
            temperature: Sampling temperature for the LLM.
        """
        # from atomic_agents.lib.components.system_prompt_generator import SystemPromptGenerator

        # Define the schema details for the AI to understand table structures
        database_schema = """
        Your task is to generate optimized SQL queries based on the following database schema:

        Table: feedbacks
        - feedback_id (Integer, PK)
        - store_id (Integer, FK -> stores.store_id)
        - datasource_id (Integer, FK -> datasource.ds_id)
        - branch_id (Integer, FK -> branches.branch_id)
        - customer_name (String)
        - feedback_posting_date (DateTime)
        - feedback_source (String)  # e.g., Google, Yelp, Hubwallet
        - feedback_rating (Integer)  # Rating between 1-5
        - feedback_type (String)  # e.g., text, voice
        - original_content (Text)  # Raw feedback data
        - transcription (Text)  # If voice feedback is provided
        - sentiment (String)  # positive, neutral, negative, mixed
        - confidence_score (Float)
        - emotion (String)  # e.g., frustration, happiness, disappointment
        - arousal (String)  # passive, active
        - created_at (DateTime)

        Table: emotions
        - emotion_id (Integer, PK)
        - store_id (Integer, FK -> stores.store_id)
        - datasource_id (Integer, FK -> datasource.ds_id)
        - branch_id (Integer, FK -> branches.branch_id)
        - feedback_id (Integer, FK -> feedbacks.feedback_id)
        - emotions (String)  # e.g., happiness, sadness, anger
        - created_at (DateTime)

        Table: words
        - word_id (Integer, PK)
        - store_id (Integer, FK -> stores.store_id)
        - datasource_id (Integer, FK -> datasource.ds_id)
        - branch_id (Integer, FK -> branches.branch_id)
        - feedback_id (Integer, FK -> feedbacks.feedback_id)
        - words (String)
        - sentiment (String)  # positive, neutral, negative, mixed
        - created_at (DateTime)

        Table: review_topics
        - rt_id (Integer, PK)
        - store_id (Integer, FK -> stores.store_id)
        - datasource_id (Integer, FK -> datasource.ds_id)
        - branch_id (Integer, FK -> branches.branch_id)
        - feedback_id (Integer, FK -> feedbacks.feedback_id)
        - topic_id (Integer, FK -> topics.topic_id)
        - topic_name (String)  # e.g., service, food, ambiance
        - topic_sentiment (String)  # positive, neutral, negative, mixed
        - created_at (DateTime)

        -- All queries must include a WHERE clause to filter by branch_id.
        -- branch_id should always be included in every SQL query.
        """

        system_prompt_generator = SystemPromptGenerator(
            background=[
                "You are an expert SQL query generator and database analyst.",
                "You understand database schemas and can write optimized SQL queries.",
                "You help users extract insights from structured data.",
                database_schema
            ],
            steps=[
                "Receive the user query in natural language.",
                "Analyze additional filters provided and include them in the query if relevant.",
                "Generate a structured SQL query using the appropriate schema for the database.",
                "Ensure the query always includes a WHERE clause with branch_id.",
                "Execute the SQL query on the PostgreSQL database and retrieve the records.",
                "Convert the retrieved data into a structured pandas DataFrame.",
                "Summarize the extracted records into a concise, structured response for the user. Do not mention branch_id, any SQL queires or any techinal term in the summary"
            ],
            output_instructions=[
                "Return only a JSON object with the following keys: 'generated_query', 'dataframe_preview', and 'summary'.",
                "Ensure the summary is user-friendly, concise, and includes key insights from the retrieved data. Do not mention branch_id, any SQL queires or any techinal term in the summary"
            ]
        )

        config = AgentConfig(
            client=client,
            model=model,
            temperature=temperature,
            system_prompt_generator=system_prompt_generator
        )
        self.agent = AtomicAgent[SQLQueryInputSchema,SQLQueryOutputSchema](config=config)
        self.db_engine = db_engine  # Store the database engine

    def execute_query(self, query: str) -> pd.DataFrame:
        """
        Execute the generated SQL query on PostgreSQL and return results as a pandas DataFrame.

        Args:
            query: The SQL query to be executed.

        Returns:
            A pandas DataFrame containing the retrieved records.
        """
        with self.db_engine.connect() as connection:
            result = connection.execute(text(query))
            records = result.fetchall()
            columns = result.keys()  # Extract column names
            return pd.DataFrame(records, columns=columns)

    def process_query(self, user_input: str, additional_filters: Optional[Dict[str, Any]] = None) -> SQLQueryOutputSchema:
        """
        Process user input, generate an SQL query, execute it, retrieve results, and summarize.

        Args:
            user_input: The user's query in natural language.
            additional_filters: Optional dictionary of filters, passed to the LLM.

        Returns:
            An instance of SQLQueryOutputSchema containing only the summary.
        """
        input_data = SQLQueryInputSchema(user_input=user_input, additional_filters=additional_filters)
        
        # Step 1: Generate the SQL query
        result = self.agent.run(input_data)
        generated_query = result.generated_query  # LLM decides WHERE clauses

        # Step 2: Execute the query and retrieve data
        df = self.execute_query(generated_query)
        # Step 3: Convert DataFrame to preview format
        dataframe_preview = df.head().to_string(index=False)  # Get first few rows as a string
        

        # Step 3: Summarize the retrieved records
        summary_input = f"Summarize the following dataset:\n{df.to_string(index=False)}  as per user's question: {input_data.user_input}. IMPORTANT: The answer to the user's question is in the provided dataset."
        summary_result = self.agent.run(SQLQueryInputSchema(user_input=summary_input))

        return SQLQueryOutputSchema(
            generated_query=generated_query,
            dataframe_preview=dataframe_preview,
            summary=summary_result.summary
        )
