import json
import logging
from typing import Dict, Set, Optional, Any
from fastapi import WebSocket, WebSocketDisconnect
from sqlalchemy.orm import Session

from src.marketing.apps.hwGpt.llm_service import ChatbotLLMService
from src.marketing.apps.hwGpt.thread_service import ThreadService
from src.marketing.apps.hwGpt.thread_history_service import ThreadHistoryService
from src.utils.db import get_db_session

logger = logging.getLogger(__name__)


class ConnectionManager:
    """Manages WebSocket connections for the chatbot."""
    
    def __init__(self):
        self.active_connections: Dict[str, WebSocket] = {}
        self.user_threads: Dict[str, int] = {}  # user_id -> thread_id
        self.llm_service = ChatbotLLMService()
    
    async def connect(self, websocket: WebSocket, user_id: str, store_id: str, branch_id: Optional[str] = None):
        """Connect a new WebSocket client."""
        await websocket.accept()
        
        connection_key = f"{user_id}_{store_id}_{branch_id or 'main'}"
        self.active_connections[connection_key] = websocket
        self.user_threads[connection_key] = None
        
        logger.info(f"WebSocket connected: {connection_key}")
        
        # Send welcome message
        welcome_message = {
            "type": "connection_established",
            "message": "Connected to hwGpt chatbot",
            "user_id": user_id,
            "store_id": store_id,
            "branch_id": branch_id
        }
        
        await websocket.send_text(json.dumps(welcome_message))
    
    def disconnect(self, user_id: str, store_id: str, branch_id: Optional[str] = None):
        """Disconnect a WebSocket client."""
        connection_key = f"{user_id}_{store_id}_{branch_id or 'main'}"
        
        if connection_key in self.active_connections:
            del self.active_connections[connection_key]
        
        if connection_key in self.user_threads:
            del self.user_threads[connection_key]
        
        logger.info(f"WebSocket disconnected: {connection_key}")
    
    async def send_personal_message(self, message: str, user_id: str, store_id: str, branch_id: Optional[str] = None):
        """Send a message to a specific user."""
        connection_key = f"{user_id}_{store_id}_{branch_id or 'main'}"
        
        if connection_key in self.active_connections:
            try:
                await self.active_connections[connection_key].send_text(message)
            except Exception as e:
                logger.error(f"Error sending message to {connection_key}: {str(e)}")
                # Remove broken connection
                self.disconnect(user_id, store_id, branch_id)
    
    async def broadcast(self, message: str):
        """Broadcast a message to all connected clients."""
        disconnected_keys = []
        
        for connection_key, connection in self.active_connections.items():
            try:
                await connection.send_text(message)
            except Exception as e:
                logger.error(f"Error broadcasting to {connection_key}: {str(e)}")
                disconnected_keys.append(connection_key)
        
        # Clean up disconnected connections
        for key in disconnected_keys:
            user_id, store_id, branch_id = key.split('_', 2)
            self.disconnect(user_id, store_id, branch_id if branch_id != 'main' else None)


class ChatbotWebSocketManager:
    """Manages chatbot-specific WebSocket operations."""
    
    def __init__(self):
        self.connection_manager = ConnectionManager()
        self.llm_service = ChatbotLLMService()
    
    async def handle_chat_message(
        self, 
        websocket: WebSocket, 
        user_id: str, 
        store_id: str, 
        branch_id: Optional[str] = None,
        message_data: Dict[str, Any] = None
    ):
        """Handle incoming chat messages and generate responses."""
        try:
            # Get database session
            db = get_db_session()
            
            try:
                # Extract message content
                user_message = message_data.get("message", "")
                thread_id = message_data.get("thread_id")
                
                if not user_message.strip():
                    await websocket.send_text(json.dumps({
                        "type": "error",
                        "message": "Message cannot be empty"
                    }))
                    return
                
                # If no thread_id provided, create a new thread
                if not thread_id:
                    thread_service = ThreadService(db)
                    thread_data = {
                        "store_id": int(store_id),
                        "branch_id": int(branch_id) if branch_id else None,
                        "user_id": int(user_id),
                        "title": f"Marketing Strategy - {user_message[:30]}...",
                        "description": "Marketing-focused chat thread",
                        "thread_type": "marketing"
                    }
                    
                    # Create thread using the service
                    from src.marketing.apps.hwGpt import schema
                    thread_request = schema.ThreadCreateRequest(**thread_data)
                    thread = thread_service.create_thread(thread_request)
                    thread_id = thread.id
                    
                    # Send thread created notification
                    await websocket.send_text(json.dumps({
                        "type": "thread_created",
                        "thread_id": thread_id,
                        "title": thread.title
                    }))
                
                # Process the chat message
                result = self.llm_service.process_chat_message(
                    db=db,
                    thread_id=thread_id,
                    user_id=int(user_id),
                    user_message=user_message,
                    store_id=int(store_id),
                    branch_id=int(branch_id) if branch_id else None
                )
                
                if result["success"]:
                    # Send user message confirmation
                    await websocket.send_text(json.dumps({
                        "type": "user_message_received",
                        "message_id": result["user_message"]["id"],
                        "content": result["user_message"]["content"],
                        "timestamp": result["user_message"]["created_at"]
                    }))
                    
                    # Send AI response
                    await websocket.send_text(json.dumps({
                        "type": "ai_response",
                        "message_id": result["ai_response"]["id"],
                        "content": result["ai_response"]["content"],
                        "timestamp": result["ai_response"]["created_at"],
                        "tokens_used": result["ai_response"]["tokens_used"],
                        "model_used": result["ai_response"]["model_used"],
                        "thread_id": thread_id
                    }))
                    
                    # Update connection manager with thread info
                    connection_key = f"{user_id}_{store_id}_{branch_id or 'main'}"
                    self.connection_manager.user_threads[connection_key] = thread_id
                    
                    logger.info(f"Chat message processed successfully: Thread {thread_id}, User {user_id}")
                    
                else:
                    # Send error response
                    await websocket.send_text(json.dumps({
                        "type": "error",
                        "message": result.get("response", "Failed to process message"),
                        "details": result.get("error", "Unknown error")
                    }))
            
            finally:
                db.close()
                
        except Exception as e:
            logger.error(f"Error handling chat message: {str(e)}")
            await websocket.send_text(json.dumps({
                "type": "error",
                "message": "Internal server error",
                "details": str(e)
            }))
    
    async def handle_thread_operations(
        self, 
        websocket: WebSocket, 
        user_id: str, 
        store_id: str, 
        operation: str,
        data: Dict[str, Any]
    ):
        """Handle thread-related operations (create, list, switch, etc.)."""
        try:
            db = get_db_session()
            
            try:
                if operation == "create_thread":
                    # Create new thread
                    thread_service = ThreadService(db)
                    thread_request = {
                        "store_id": int(store_id),
                        "branch_id": int(data.get("branch_id")) if data.get("branch_id") else None,
                        "user_id": int(user_id),
                        "title": data.get("title", "New Chat Thread"),
                        "description": data.get("description", ""),
                        "thread_type": data.get("thread_type", "chat")
                    }
                    
                    from src.marketing.apps.hwGpt import schema
                    thread_data = schema.ThreadCreateRequest(**thread_request)
                    thread = thread_service.create_thread(thread_data)
                    
                    await websocket.send_text(json.dumps({
                        "type": "thread_created",
                        "thread": {
                            "id": thread.id,
                            "title": thread.title,
                            "description": thread.description,
                            "thread_type": thread.thread_type
                        }
                    }))
                
                elif operation == "list_threads":
                    # List user threads
                    thread_service = ThreadService(db)
                    threads = thread_service.list_user_threads(
                        user_id=int(user_id),
                        store_id=int(store_id),
                        page=data.get("page", 1),
                        page_size=data.get("page_size", 20)
                    )
                    
                    await websocket.send_text(json.dumps({
                        "type": "threads_listed",
                        "threads": [thread.dict() for thread in threads.threads],
                        "total": threads.total_threads,
                        "page": threads.page,
                        "page_size": threads.page_size
                    }))
                
                elif operation == "switch_thread":
                    # Switch to a different thread
                    thread_id = data.get("thread_id")
                    if thread_id:
                        # Get thread messages
                        history_service = ThreadHistoryService(db)
                        messages = history_service.get_thread_messages(
                            thread_id=int(thread_id),
                            user_id=int(user_id),
                            limit=50
                        )
                        
                        # Update connection manager
                        connection_key = f"{user_id}_{store_id}_{data.get('branch_id', 'main')}"
                        self.connection_manager.user_threads[connection_key] = int(thread_id)
                        
                        await websocket.send_text(json.dumps({
                            "type": "thread_switched",
                            "thread_id": thread_id,
                            "messages": [msg.dict() for msg in messages]
                        }))
                
                elif operation == "get_suggestions":
                    # Get chat suggestions based on store persona
                    store_persona = self.llm_service.get_store_persona_context(
                        db, int(store_id), int(data.get("branch_id")) if data.get("branch_id") else None
                    )
                    
                    suggestions = self.llm_service.get_chat_suggestions(
                        store_persona, data.get("context", "general")
                    )
                    
                    await websocket.send_text(json.dumps({
                        "type": "suggestions",
                        "suggestions": suggestions
                    }))
            
            finally:
                db.close()
                
        except Exception as e:
            logger.error(f"Error handling thread operation {operation}: {str(e)}")
            await websocket.send_text(json.dumps({
                "type": "error",
                "message": f"Failed to {operation}",
                "details": str(e)
            }))
    
    async def handle_websocket_connection(self, websocket: WebSocket, user_id: str, store_id: str, branch_id: Optional[str] = None):
        """Handle the main WebSocket connection lifecycle."""
        await self.connection_manager.connect(websocket, user_id, store_id, branch_id)
        
        try:
            while True:
                # Receive message from client
                data = await websocket.receive_text()
                message_data = json.loads(data)
                
                message_type = message_data.get("type", "chat")
                
                if message_type == "chat":
                    # Handle chat message
                    await self.handle_chat_message(
                        websocket, user_id, store_id, branch_id, message_data
                    )
                
                elif message_type == "thread_operation":
                    # Handle thread operations
                    operation = message_data.get("operation")
                    operation_data = message_data.get("data", {})
                    await self.handle_thread_operations(
                        websocket, user_id, store_id, operation, operation_data
                    )
                
                elif message_type == "ping":
                    # Handle ping for connection health
                    await websocket.send_text(json.dumps({"type": "pong"}))
                
                else:
                    # Unknown message type
                    await websocket.send_text(json.dumps({
                        "type": "error",
                        "message": f"Unknown message type: {message_type}"
                    }))
                    
        except WebSocketDisconnect:
            logger.info(f"WebSocket disconnected: {user_id}_{store_id}_{branch_id or 'main'}")
        except Exception as e:
            logger.error(f"WebSocket error: {str(e)}")
        finally:
            self.connection_manager.disconnect(user_id, store_id, branch_id)
    
    def get_connection_stats(self) -> Dict[str, Any]:
        """Get connection statistics."""
        return {
            "active_connections": len(self.connection_manager.active_connections),
            "user_threads": len(self.connection_manager.user_threads),
            "connection_keys": list(self.connection_manager.active_connections.keys())
        }
