from fastapi import APIRouter, HTTPException, Depends
from pydantic import BaseModel
from sqlalchemy.orm import Session
from src.utils.db import SessionLocal
from src.core.chatbot_SQLAgent import SQLQueryAgent
from src.apps.auth.controller import get_current_user  # Assuming this is how we fetch the user
from src.apps.stores.models import Branch
import instructor
from pydantic import Field
import openai
import os
from src.utils.db import engine
from dotenv import load_dotenv

# Load environment variables
load_dotenv()

# Fetch environment variables

API_KEY = os.getenv("AZURE_OPENAI_API_KEY")
ENDPOINT = os.getenv("AZURE_OPENAI_ENDPOINT")
DEPLOYMENT = os.getenv("AZURE_OPENAI_DEPLOYMENT")
API_VERSION = os.getenv("AZURE_OPENAI_API_VERSION")

# Initialize FastAPI router
router = APIRouter()

# Initialize the SQLQueryAgent
client = instructor.from_openai(openai.AzureOpenAI(api_key=API_KEY, api_version= API_VERSION, azure_endpoint=ENDPOINT,azure_deployment=DEPLOYMENT))

sql_agent = SQLQueryAgent(client=client,db_engine=engine)

# Request schema
class QueryRequest(BaseModel):
    question: str

# Response schema
class QueryResponse(BaseModel):
    result: dict



@router.post("/query", response_model=QueryResponse)
def query_database(request: QueryRequest, current_user = Depends(get_current_user)):
    """
    API endpoint to process user queries and return SQL results.
    Fetches branch_id by comparing Branch.user_id with current_user.user_id.
    """
    try:
        # Open a database session
        with SessionLocal() as session:
            branch = session.query(Branch).filter(Branch.user_id == current_user.user_id).first()
        
        # Ensure branch exists
        if not branch:
            raise HTTPException(status_code=400, detail="No branch found for the current user.")

        # Get branch_id from Branch model
        branch_id = branch.branch_id

        # Process the query with the branch_id
        response = sql_agent.process_query(request.question, {"branch_id":branch_id})

        # Convert response to dictionary format
        return QueryResponse(result=response.model_dump())

    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

