from fastapi import APIRouter, Depends, HTTPException, Query, WebSocket, WebSocketDisconnect
from sqlalchemy.orm import Session
from typing import Optional, List
from datetime import date
import json
import logging

logger = logging.getLogger(__name__)

from src.utils.db import get_db
from src.marketing.apps.hwGpt.thread_service import ThreadService
from src.marketing.apps.hwGpt.thread_history_service import ThreadHistoryService
# Temporarily disabled to avoid import issues
# from src.marketing.apps.hwGpt.websocket_manager import ChatbotWebSocketManager
from src.marketing.apps.hwGpt import schema

router = APIRouter()
# Temporarily disabled to avoid import issues
# chatbot_manager = ChatbotWebSocketManager()


# Thread Management Endpoints
@router.post("/threads", response_model=schema.ThreadResponse)
def create_thread(
    thread_data: schema.ThreadCreateRequest,
    db: Session = Depends(get_db)
):
    """Create a new chat thread."""
    try:
        thread_service = ThreadService(db)
        return thread_service.create_thread(thread_data)
    except Exception as e:
        raise HTTPException(status_code=400, detail=str(e))


@router.get("/threads/{thread_id}", response_model=schema.ThreadResponse)
def get_thread(
    thread_id: int,
    user_id: int = Query(..., description="User ID"),
    db: Session = Depends(get_db)
):
    """Get a specific thread by ID."""
    thread_service = ThreadService(db)
    thread = thread_service.get_thread(thread_id, user_id)
    
    if not thread:
        raise HTTPException(status_code=404, detail="Thread not found")
    
    return thread


@router.get("/threads/{thread_id}/messages", response_model=schema.ThreadWithMessagesResponse)
def get_thread_with_messages(
    thread_id: int,
    user_id: int = Query(..., description="User ID"),
    limit: int = Query(50, description="Maximum number of messages to return"),
    offset: int = Query(0, description="Number of messages to skip"),
    db: Session = Depends(get_db)
):
    """Get thread with its messages."""
    thread_service = ThreadService(db)
    thread_with_messages = thread_service.get_thread_with_messages(
        thread_id, user_id, limit, offset
    )
    
    if not thread_with_messages:
        raise HTTPException(status_code=404, detail="Thread not found")
    
    return thread_with_messages


@router.get("/threads", response_model=schema.ThreadListResponse)
def list_user_threads(
    user_id: int = Query(..., description="User ID"),
    store_id: Optional[int] = Query(None, description="Filter by store ID"),
    branch_id: Optional[int] = Query(None, description="Filter by branch ID"),
    thread_type: Optional[str] = Query(None, description="Filter by thread type"),
    is_archived: Optional[bool] = Query(None, description="Filter by archive status"),
    page: int = Query(1, description="Page number"),
    page_size: int = Query(20, description="Number of threads per page"),
    db: Session = Depends(get_db)
):
    """List threads for a user with pagination and filters."""
    thread_service = ThreadService(db)
    return thread_service.list_user_threads(
        user_id, store_id, branch_id, thread_type, is_archived, page, page_size
    )


@router.put("/threads/{thread_id}", response_model=schema.ThreadResponse)
def update_thread(
    thread_id: int,
    user_id: int = Query(..., description="User ID"),
    update_data: schema.ThreadUpdateRequest = None,
    db: Session = Depends(get_db)
):
    """Update a thread."""
    if not update_data:
        update_data = schema.ThreadUpdateRequest()
    
    thread_service = ThreadService(db)
    updated_thread = thread_service.update_thread(thread_id, user_id, update_data)
    
    if not updated_thread:
        raise HTTPException(status_code=404, detail="Thread not found")
    
    return updated_thread


@router.delete("/threads/{thread_id}")
def delete_thread(
    thread_id: int,
    user_id: int = Query(..., description="User ID"),
    db: Session = Depends(get_db)
):
    """Delete a thread (soft delete)."""
    thread_service = ThreadService(db)
    success = thread_service.delete_thread(thread_id, user_id)
    
    if not success:
        raise HTTPException(status_code=404, detail="Thread not found")
    
    return {"success": True, "message": "Thread deleted successfully"}


@router.post("/threads/{thread_id}/archive")
def archive_thread(
    thread_id: int,
    user_id: int = Query(..., description="User ID"),
    db: Session = Depends(get_db)
):
    """Archive a thread."""
    thread_service = ThreadService(db)
    success = thread_service.archive_thread(thread_id, user_id)
    
    if not success:
        raise HTTPException(status_code=404, detail="Thread not found")
    
    return {"success": True, "message": "Thread archived successfully"}


@router.post("/threads/{thread_id}/unarchive")
def unarchive_thread(
    thread_id: int,
    user_id: int = Query(..., description="User ID"),
    db: Session = Depends(get_db)
):
    """Unarchive a thread."""
    thread_service = ThreadService(db)
    success = thread_service.unarchive_thread(thread_id, user_id)
    
    if not success:
        raise HTTPException(status_code=404, detail="Thread not found")
    
    return {"success": True, "message": "Thread unarchived successfully"}


# Message Management Endpoints
@router.post("/messages", response_model=schema.MessageResponse)
def add_message(
    message_data: schema.MessageCreateRequest,
    db: Session = Depends(get_db)
):
    """Add a new message to a thread."""
    try:
        history_service = ThreadHistoryService(db)
        return history_service.add_message(message_data)
    except Exception as e:
        raise HTTPException(status_code=400, detail=str(e))


@router.get("/messages/{message_id}", response_model=schema.MessageResponse)
def get_message(
    message_id: int,
    user_id: int = Query(..., description="User ID"),
    db: Session = Depends(get_db)
):
    """Get a specific message by ID."""
    history_service = ThreadHistoryService(db)
    message = history_service.get_message(message_id, user_id)
    
    if not message:
        raise HTTPException(status_code=404, detail="Message not found")
    
    return message


@router.get("/threads/{thread_id}/messages", response_model=List[schema.MessageResponse])
def get_thread_messages(
    thread_id: int,
    user_id: int = Query(..., description="User ID"),
    limit: int = Query(50, description="Maximum number of messages to return"),
    offset: int = Query(0, description="Number of messages to skip"),
    before_message_id: Optional[int] = Query(None, description="Get messages before this ID"),
    after_message_id: Optional[int] = Query(None, description="Get messages after this ID"),
    db: Session = Depends(get_db)
):
    """Get messages from a thread with pagination options."""
    history_service = ThreadHistoryService(db)
    return history_service.get_thread_messages(
        thread_id, user_id, limit, offset, before_message_id, after_message_id
    )


@router.put("/messages/{message_id}", response_model=schema.MessageResponse)
def update_message(
    message_id: int,
    user_id: int = Query(..., description="User ID"),
    update_data: schema.MessageUpdateRequest = None,
    db: Session = Depends(get_db)
):
    """Update a message."""
    if not update_data:
        update_data = schema.MessageUpdateRequest()
    
    history_service = ThreadHistoryService(db)
    updated_message = history_service.update_message(message_id, user_id, update_data)
    
    if not updated_message:
        raise HTTPException(status_code=404, detail="Message not found")
    
    return updated_message


@router.delete("/messages/{message_id}")
def delete_message(
    message_id: int,
    user_id: int = Query(..., description="User ID"),
    db: Session = Depends(get_db)
):
    """Delete a message (soft delete)."""
    history_service = ThreadHistoryService(db)
    success = history_service.delete_message(message_id, user_id)
    
    if not success:
        raise HTTPException(status_code=404, detail="Message not found")
    
    return {"success": True, "message": "Message deleted successfully"}


# Search and Analytics Endpoints
@router.get("/messages/search", response_model=List[schema.MessageResponse])
def search_messages(
    query: str = Query(..., description="Search query"),
    user_id: int = Query(..., description="User ID"),
    thread_id: Optional[int] = Query(None, description="Filter by thread ID"),
    store_id: Optional[int] = Query(None, description="Filter by store ID"),
    role: Optional[str] = Query(None, description="Filter by message role"),
    limit: int = Query(50, description="Maximum number of results"),
    db: Session = Depends(get_db)
):
    """Search messages across threads."""
    history_service = ThreadHistoryService(db)
    return history_service.search_messages(
        user_id, query, thread_id, store_id, role, limit
    )


@router.get("/threads/{thread_id}/messages/date-range", response_model=List[schema.MessageResponse])
def get_thread_messages_by_date_range(
    thread_id: int,
    user_id: int = Query(..., description="User ID"),
    start_date: date = Query(..., description="Start date"),
    end_date: date = Query(..., description="End date"),
    limit: int = Query(100, description="Maximum number of messages to return"),
    db: Session = Depends(get_db)
):
    """Get messages from a thread within a date range."""
    history_service = ThreadHistoryService(db)
    return history_service.get_thread_messages_by_date_range(
        thread_id, user_id, start_date, end_date, limit
    )


# Statistics Endpoints
@router.get("/threads/stats")
def get_thread_stats(
    user_id: int = Query(..., description="User ID"),
    store_id: Optional[int] = Query(None, description="Filter by store ID"),
    db: Session = Depends(get_db)
):
    """Get thread statistics for a user."""
    thread_service = ThreadService(db)
    return thread_service.get_thread_stats(user_id, store_id)


@router.get("/messages/stats")
def get_message_stats(
    user_id: int = Query(..., description="User ID"),
    thread_id: Optional[int] = Query(None, description="Filter by thread ID"),
    store_id: Optional[int] = Query(None, description="Filter by store ID"),
    days: int = Query(30, description="Number of days to analyze"),
    db: Session = Depends(get_db)
):
    """Get message statistics for a user."""
    history_service = ThreadHistoryService(db)
    return history_service.get_message_stats(user_id, thread_id, store_id, days)


# Single WebSocket Marketing Chat Endpoint
@router.websocket("/chat/marketing")
async def websocket_marketing_chat(websocket: WebSocket):
    """WebSocket endpoint for marketing-focused chat with persona integration."""
    await websocket.accept()
    
    try:
        # Send welcome message
        welcome_message = {
            "type": "connection_established",
            "message": "Connected to hwGpt marketing chat",
            "status": "ready"
        }
        await websocket.send_text(json.dumps(welcome_message))
        
        # Keep connection alive and handle messages
        while True:
            try:
                # Receive message from frontend
                data = await websocket.receive_text()
                message = json.loads(data)
                
                # Extract required fields
                store_id = message.get("store_id")
                branch_id = message.get("branch_id")
                user_id = message.get("user_id")
                thread_id = message.get("thread_id")
                user_message = message.get("message")
                message_type = message.get("type", "text")  # Default to "text" if not specified
                session_id = message.get("session_id", f"user_{user_id}_store_{store_id}")
                
                logger.info(f"Received message: {message}")
                
                # Check for valid input
                if not store_id or not user_id or not user_message:
                    await websocket.send_text(json.dumps({
                        "type": "error",
                        "message": "Missing required fields: store_id, user_id, or message"
                    }))
                    continue
                
                # Process message based on type
                try:
                    from src.marketing.apps.hwGpt.llm_service import ChatbotLLMService
                    from src.utils.db import get_db_session
                    
                    # Get database session
                    db = get_db_session()
                    
                    # Initialize LLM service
                    llm_service = ChatbotLLMService()
                    
                    # Handle different message types
                    if message_type == "image":
                        # Process image generation request
                        llm_response = await llm_service.process_image_generation_request(
                            db, 
                            int(user_id), 
                            user_message,   
                            int(store_id),
                            int(branch_id) if branch_id else None, 
                            int(thread_id) if thread_id else None,
                            session_id
                        )
                        
                        # Send image generation response
                        response = {
                            "type": "image_generation_response",
                            "message": llm_response,
                            "user_id": user_id,
                            "store_id": store_id,
                            "branch_id": branch_id,
                            "thread_id": thread_id,
                            "message_type": "image"
                        }
                    else:
                        # Process regular text chat
                        llm_response = await llm_service.process_chat_message(
                            db, 
                            int(user_id), 
                            user_message,   
                            int(store_id),
                            int(branch_id) if branch_id else None, 
                            int(thread_id) if thread_id else None
                        )
                        
                        # Send LLM response
                        response = {
                            "type": "llm_response",
                            "message": llm_response,
                            "user_id": user_id,
                            "store_id": store_id,
                            "branch_id": branch_id,
                            "thread_id": thread_id,
                            "message_type": "text"
                        }
                    
                    db.close()
                    
                except Exception as e:
                    logger.error(f"LLM error: {str(e)}")
                    # Fallback to simple response
                    response = {
                        "type": "response",
                        "message": f"Received: {user_message} (LLM temporarily unavailable)",
                        "user_id": user_id,
                        "store_id": store_id,
                        "branch_id": branch_id,
                        "thread_id": thread_id,
                        "error": str(e)
                    }
                
                # Send response back to client
                await websocket.send_text(json.dumps(response))
                
            except json.JSONDecodeError:
                await websocket.send_text(json.dumps({
                    "type": "error",
                    "message": "Invalid JSON format"
                }))
                
    except WebSocketDisconnect:
        logger.info("WebSocket disconnected")
    except Exception as e:
        logger.error(f"WebSocket error: {str(e)}")
        await websocket.close()

@router.get("/chat/connections/stats")
def get_chat_connection_stats():
    """Get WebSocket connection statistics."""
    # Temporarily disabled
    return {"status": "disabled", "message": "Connection stats temporarily unavailable"}


@router.post("/chat/suggestions")
def get_chat_suggestions(
    store_id: int = Query(..., description="Store ID"),
    branch_id: Optional[int] = Query(None, description="Branch ID"),
    context: str = Query("general", description="Context for suggestions"),
    db: Session = Depends(get_db)
):
    """Get chat suggestions based on store persona."""
    try:
        from src.marketing.apps.hwGpt.llm_service import ChatbotLLMService
        llm_service = ChatbotLLMService()
        
        store_persona = llm_service.get_store_persona_context(db, store_id, branch_id)
        suggestions = llm_service.get_chat_suggestions(store_persona, context)
        
        return {
            "success": True,
            "suggestions": suggestions,
            "store_id": store_id,
            "branch_id": branch_id,
            "context": context
        }
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))
